Stochastic Variational Inference (SVI)

svi(model, guide, loss, optim_init, optim_update, get_params, **kwargs)[source]

Stochastic Variational Inference given an ELBo loss objective.

Parameters:
  • model – Python callable with Pyro primitives for the model.
  • guide – Python callable with Pyro primitives for the guide (recognition network).
  • loss – ELBo loss, i.e. negative Evidence Lower Bound, to minimize.
  • optim_init – initialization function returned by a JAX optimizer. see: jax.experimental.optimizers.
  • optim_update – update function for the optimizer
  • get_params – function to get current parameters values given the optimizer state.
  • **kwargs – static arguments for the model / guide, i.e. arguments that remain constant during fitting.
Returns:

tuple of (init_fn, update_fn, evaluate).

init_fn(rng, model_args=(), guide_args=(), params=None)
Parameters:
  • rng (jax.random.PRNGKey) – random number generator seed.
  • model_args (tuple) – arguments to the model (these can possibly vary during the course of fitting).
  • guide_args (tuple) – arguments to the guide (these can possibly vary during the course of fitting).
  • params (dict) – initial parameter values to condition on. This can be useful forx
Returns:

initial optimizer state.

update_fn(i, opt_state, rng, model_args=(), guide_args=())

Take a single step of SVI (possibly on a batch / minibatch of data), using the optimizer.

Parameters:
  • i (int) – represents the i’th iteration over the epoch, passed as an argument to the optimizer’s update function.
  • opt_state – current optimizer state.
  • rng (jax.random.PRNGKey) – random number generator seed.
  • model_args (tuple) – dynamic arguments to the model.
  • guide_args (tuple) – dynamic arguments to the guide.
Returns:

tuple of (loss_val, opt_state, rng).

evaluate(opt_state, rng, model_args=(), guide_args=())

Take a single step of SVI (possibly on a batch / minibatch of data).

Parameters:
  • opt_state – current optimizer state.
  • rng (jax.random.PRNGKey) – random number generator seed.
  • model_args (tuple) – arguments to the model (these can possibly vary during the course of fitting).
  • guide_args (tuple) – arguments to the guide (these can possibly vary during the course of fitting).
Returns:

evaluate ELBo loss given the current parameter values (held within opt_state).

ELBo

elbo(param_map, model, guide, model_args, guide_args, kwargs)[source]

This is the most basic implementation of the Evidence Lower Bound, which is the fundamental objective in Variational Inference. This implementation has various limitations (for example it only supports random variablbes with reparameterized samplers) but can be used as a template to build more sophisticated loss objectives.

For more details, refer to http://pyro.ai/examples/svi_part_i.html.

Parameters:
  • param_map (dict) – dictionary of current parameter values keyed by site name.
  • model – Python callable with Pyro primitives for the model.
  • guide – Python callable with Pyro primitives for the guide (recognition network).
  • model_args (tuple) – arguments to the model (these can possibly vary during the course of fitting).
  • guide_args (tuple) – arguments to the guide (these can possibly vary during the course of fitting).
  • kwargs (dict) – static keyword arguments to the model / guide.
Returns:

negative of the Evidence Lower Bound (ELBo) to be minimized.