Stochastic Variational Inference (SVI)

class SVI(model, guide, optim, loss, **static_kwargs)[source]

Bases: object

Stochastic Variational Inference given an ELBO loss objective.

References

  1. SVI Part I: An Introduction to Stochastic Variational Inference in Pyro, (http://pyro.ai/examples/svi_part_i.html)

Example:

>>> from jax import lax, random
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.distributions import constraints
>>> from numpyro.infer import SVI, Trace_ELBO

>>> def model(data):
...     f = numpyro.sample("latent_fairness", dist.Beta(10, 10))
...     with numpyro.plate("N", data.shape[0]):
...         numpyro.sample("obs", dist.Bernoulli(f), obs=data)

>>> def guide(data):
...     alpha_q = numpyro.param("alpha_q", 15., constraint=constraints.positive)
...     beta_q = numpyro.param("beta_q", 15., constraint=constraints.positive)
...     numpyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q))

>>> data = jnp.concatenate([jnp.ones(6), jnp.zeros(4)])
>>> optimizer = numpyro.optim.Adam(step_size=0.0005)
>>> svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
>>> init_state = svi.init(random.PRNGKey(0), data)
>>> state = lax.fori_loop(0, 2000, lambda i, state: svi.update(state, data)[0], init_state)
>>> # or to collect losses during the loop
>>> # state, losses = lax.scan(lambda state, i: svi.update(state, data), init_state, jnp.arange(2000))
>>> params = svi.get_params(state)
>>> inferred_mean = params["alpha_q"] / (params["alpha_q"] + params["beta_q"])
Parameters:
  • model – Python callable with Pyro primitives for the model.
  • guide – Python callable with Pyro primitives for the guide (recognition network).
  • optim – an instance of _NumpyroOptim.
  • loss – ELBO loss, i.e. negative Evidence Lower Bound, to minimize.
  • static_kwargs – static arguments for the model / guide, i.e. arguments that remain constant during fitting.
Returns:

tuple of (init_fn, update_fn, evaluate).

init(rng_key, *args, **kwargs)[source]
Parameters:
  • rng_key (jax.random.PRNGKey) – random number generator seed.
  • args – arguments to the model / guide (these can possibly vary during the course of fitting).
  • kwargs – keyword arguments to the model / guide (these can possibly vary during the course of fitting).
Returns:

tuple containing initial SVIState, and get_params, a callable that transforms unconstrained parameter values from the optimizer to the specified constrained domain

get_params(svi_state)[source]

Gets values at param sites of the model and guide.

Parameters:svi_state – current state of the optimizer.
update(svi_state, *args, **kwargs)[source]

Take a single step of SVI (possibly on a batch / minibatch of data), using the optimizer.

Parameters:
  • svi_state – current state of SVI.
  • args – arguments to the model / guide (these can possibly vary during the course of fitting).
  • kwargs – keyword arguments to the model / guide (these can possibly vary during the course of fitting).
Returns:

tuple of (svi_state, loss).

evaluate(svi_state, *args, **kwargs)[source]

Take a single step of SVI (possibly on a batch / minibatch of data).

Parameters:
  • svi_state – current state of SVI.
  • args – arguments to the model / guide (these can possibly vary during the course of fitting).
  • kwargs – keyword arguments to the model / guide.
Returns:

evaluate ELBO loss given the current parameter values (held within svi_state.optim_state).

ELBO

class ELBO(num_particles=1)[source]

Bases: numpyro.infer.elbo.Trace_ELBO

RenyiELBO

class RenyiELBO(alpha=0, num_particles=2)[source]

Bases: numpyro.infer.elbo.Trace_ELBO

An implementation of Renyi’s \(\alpha\)-divergence variational inference following reference [1]. In order for the objective to be a strict lower bound, we require \(\alpha \ge 0\). Note, however, that according to reference [1], depending on the dataset \(\alpha < 0\) might give better results. In the special case \(\alpha = 0\), the objective function is that of the important weighted autoencoder derived in reference [2].

Note

Setting \(\alpha < 1\) gives a better bound than the usual ELBO.

Parameters:
  • alpha (float) – The order of \(\alpha\)-divergence. Here \(\alpha \neq 1\). Default is 0.
  • num_particles – The number of particles/samples used to form the objective (gradient) estimator. Default is 2.

References:

  1. Renyi Divergence Variational Inference, Yingzhen Li, Richard E. Turner
  2. Importance Weighted Autoencoders, Yuri Burda, Roger Grosse, Ruslan Salakhutdinov
loss(rng_key, param_map, model, guide, *args, **kwargs)[source]

Evaluates the Renyi ELBO with an estimator that uses num_particles many samples/particles.

Parameters:
  • rng_key (jax.random.PRNGKey) – random number generator seed.
  • param_map (dict) – dictionary of current parameter values keyed by site name.
  • model – Python callable with NumPyro primitives for the model.
  • guide – Python callable with NumPyro primitives for the guide.
  • args – arguments to the model / guide (these can possibly vary during the course of fitting).
  • kwargs – keyword arguments to the model / guide (these can possibly vary during the course of fitting).
Returns:

negative of the Renyi Evidence Lower Bound (ELBO) to be minimized.

Trace_ELBO

class Trace_ELBO(num_particles=1)[source]

Bases: object

A trace implementation of ELBO-based SVI. The estimator is constructed along the lines of references [1] and [2]. There are no restrictions on the dependency structure of the model or the guide.

This is the most basic implementation of the Evidence Lower Bound, which is the fundamental objective in Variational Inference. This implementation has various limitations (for example it only supports random variables with reparameterized samplers) but can be used as a template to build more sophisticated loss objectives.

For more details, refer to http://pyro.ai/examples/svi_part_i.html.

References:

  1. Automated Variational Inference in Probabilistic Programming, David Wingate, Theo Weber
  2. Black Box Variational Inference, Rajesh Ranganath, Sean Gerrish, David M. Blei
Parameters:num_particles – The number of particles/samples used to form the ELBO (gradient) estimators.
loss(rng_key, param_map, model, guide, *args, **kwargs)[source]

Evaluates the ELBO with an estimator that uses num_particles many samples/particles.

Parameters:
  • rng_key (jax.random.PRNGKey) – random number generator seed.
  • param_map (dict) – dictionary of current parameter values keyed by site name.
  • model – Python callable with NumPyro primitives for the model.
  • guide – Python callable with NumPyro primitives for the guide.
  • args – arguments to the model / guide (these can possibly vary during the course of fitting).
  • kwargs – keyword arguments to the model / guide (these can possibly vary during the course of fitting).
Returns:

negative of the Evidence Lower Bound (ELBO) to be minimized.