Stochastic Variational Inference (SVI)¶
- class SVI(model, guide, optim, loss, **static_kwargs)[source]¶
Bases:
object
Stochastic Variational Inference given an ELBO loss objective.
References
SVI Part I: An Introduction to Stochastic Variational Inference in Pyro, (http://pyro.ai/examples/svi_part_i.html)
Example:
>>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.distributions import constraints >>> from numpyro.infer import Predictive, SVI, Trace_ELBO >>> def model(data): ... f = numpyro.sample("latent_fairness", dist.Beta(10, 10)) ... with numpyro.plate("N", data.shape[0]): ... numpyro.sample("obs", dist.Bernoulli(f), obs=data) >>> def guide(data): ... alpha_q = numpyro.param("alpha_q", 15., constraint=constraints.positive) ... beta_q = numpyro.param("beta_q", lambda rng_key: random.exponential(rng_key), ... constraint=constraints.positive) ... numpyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q)) >>> data = jnp.concatenate([jnp.ones(6), jnp.zeros(4)]) >>> optimizer = numpyro.optim.Adam(step_size=0.0005) >>> svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) >>> svi_result = svi.run(random.PRNGKey(0), 2000, data) >>> params = svi_result.params >>> inferred_mean = params["alpha_q"] / (params["alpha_q"] + params["beta_q"]) >>> # get posterior samples >>> predictive = Predictive(guide, params=params, num_samples=1000) >>> samples = predictive(random.PRNGKey(1), data)
- Parameters
model – Python callable with Pyro primitives for the model.
guide – Python callable with Pyro primitives for the guide (recognition network).
optim –
An instance of
_NumpyroOptim
, ajax.experimental.optimizers.Optimizer
or an OptaxGradientTransformation
. If you pass an Optax optimizer it will automatically be wrapped usingnumpyro.contrib.optim.optax_to_numpyro()
.>>> from optax import adam, chain, clip >>> svi = SVI(model, guide, chain(clip(10.0), adam(1e-3)), loss=Trace_ELBO())
loss – ELBO loss, i.e. negative Evidence Lower Bound, to minimize.
static_kwargs – static arguments for the model / guide, i.e. arguments that remain constant during fitting.
- Returns
tuple of (init_fn, update_fn, evaluate).
- init(rng_key, *args, **kwargs)[source]¶
Gets the initial SVI state.
- Parameters
rng_key (jax.random.PRNGKey) – random number generator seed.
args – arguments to the model / guide (these can possibly vary during the course of fitting).
kwargs – keyword arguments to the model / guide (these can possibly vary during the course of fitting).
- Returns
the initial
SVIState
- get_params(svi_state)[source]¶
Gets values at param sites of the model and guide.
- Parameters
svi_state – current state of SVI.
- Returns
the corresponding parameters
- update(svi_state, *args, **kwargs)[source]¶
Take a single step of SVI (possibly on a batch / minibatch of data), using the optimizer.
- Parameters
svi_state – current state of SVI.
args – arguments to the model / guide (these can possibly vary during the course of fitting).
kwargs – keyword arguments to the model / guide (these can possibly vary during the course of fitting).
- Returns
tuple of (svi_state, loss).
- stable_update(svi_state, *args, **kwargs)[source]¶
Similar to
update()
but returns the current state if the the loss or the new state contains invalid values.- Parameters
svi_state – current state of SVI.
args – arguments to the model / guide (these can possibly vary during the course of fitting).
kwargs – keyword arguments to the model / guide (these can possibly vary during the course of fitting).
- Returns
tuple of (svi_state, loss).
- run(rng_key, num_steps, *args, progress_bar=True, stable_update=False, **kwargs)[source]¶
(EXPERIMENTAL INTERFACE) Run SVI with num_steps iterations, then return the optimized parameters and the stacked losses at every step. If num_steps is large, setting progress_bar=False can make the run faster.
Note
For a complex training process (e.g. the one requires early stopping, epoch training, varying args/kwargs,…), we recommend to use the more flexible methods
init()
,update()
,evaluate()
to customize your training procedure.- Parameters
rng_key (jax.random.PRNGKey) – random number generator seed.
num_steps (int) – the number of optimization steps.
args – arguments to the model / guide
progress_bar (bool) – Whether to enable progress bar updates. Defaults to
True
.stable_update (bool) – whether to use
stable_update()
to update the state. Defaults to False.kwargs – keyword arguments to the model / guide
- Returns
a namedtuple with fields params and losses where params holds the optimized values at
numpyro.param
sites, and losses is the collected loss during the process.- Return type
SVIRunResult
- evaluate(svi_state, *args, **kwargs)[source]¶
Take a single step of SVI (possibly on a batch / minibatch of data).
- Parameters
svi_state – current state of SVI.
args – arguments to the model / guide (these can possibly vary during the course of fitting).
kwargs – keyword arguments to the model / guide.
- Returns
evaluate ELBO loss given the current parameter values (held within svi_state.optim_state).
ELBO¶
- class ELBO(num_particles=1)[source]¶
Bases:
object
Base class for all ELBO objectives.
Subclasses should implement either
loss()
orloss_with_mutable_state()
.- Parameters
num_particles – The number of particles/samples used to form the ELBO (gradient) estimators.
- can_infer_discrete = False¶
- loss(rng_key, param_map, model, guide, *args, **kwargs)[source]¶
Evaluates the ELBO with an estimator that uses num_particles many samples/particles.
- Parameters
rng_key (jax.random.PRNGKey) – random number generator seed.
param_map (dict) – dictionary of current parameter values keyed by site name.
model – Python callable with NumPyro primitives for the model.
guide – Python callable with NumPyro primitives for the guide.
args – arguments to the model / guide (these can possibly vary during the course of fitting).
kwargs – keyword arguments to the model / guide (these can possibly vary during the course of fitting).
- Returns
negative of the Evidence Lower Bound (ELBO) to be minimized.
- loss_with_mutable_state(rng_key, param_map, model, guide, *args, **kwargs)[source]¶
Likes
loss()
but also update and return the mutable state, which stores the values atmutable()
sites.- Parameters
rng_key (jax.random.PRNGKey) – random number generator seed.
param_map (dict) – dictionary of current parameter values keyed by site name.
model – Python callable with NumPyro primitives for the model.
guide – Python callable with NumPyro primitives for the guide.
args – arguments to the model / guide (these can possibly vary during the course of fitting).
kwargs – keyword arguments to the model / guide (these can possibly vary during the course of fitting).
- Returns
a tuple of ELBO loss and the mutable state
Trace_ELBO¶
- class Trace_ELBO(num_particles=1)[source]¶
Bases:
numpyro.infer.elbo.ELBO
A trace implementation of ELBO-based SVI. The estimator is constructed along the lines of references [1] and [2]. There are no restrictions on the dependency structure of the model or the guide.
This is the most basic implementation of the Evidence Lower Bound, which is the fundamental objective in Variational Inference. This implementation has various limitations (for example it only supports random variables with reparameterized samplers) but can be used as a template to build more sophisticated loss objectives.
For more details, refer to http://pyro.ai/examples/svi_part_i.html.
References:
Automated Variational Inference in Probabilistic Programming, David Wingate, Theo Weber
Black Box Variational Inference, Rajesh Ranganath, Sean Gerrish, David M. Blei
- Parameters
num_particles – The number of particles/samples used to form the ELBO (gradient) estimators.
- loss_with_mutable_state(rng_key, param_map, model, guide, *args, **kwargs)[source]¶
Likes
loss()
but also update and return the mutable state, which stores the values atmutable()
sites.- Parameters
rng_key (jax.random.PRNGKey) – random number generator seed.
param_map (dict) – dictionary of current parameter values keyed by site name.
model – Python callable with NumPyro primitives for the model.
guide – Python callable with NumPyro primitives for the guide.
args – arguments to the model / guide (these can possibly vary during the course of fitting).
kwargs – keyword arguments to the model / guide (these can possibly vary during the course of fitting).
- Returns
a tuple of ELBO loss and the mutable state
TraceGraph_ELBO¶
- class TraceGraph_ELBO(num_particles=1)[source]¶
Bases:
numpyro.infer.elbo.ELBO
A TraceGraph implementation of ELBO-based SVI. The gradient estimator is constructed along the lines of reference [1] specialized to the case of the ELBO. It supports arbitrary dependency structure for the model and guide. Where possible, conditional dependency information as recorded in the trace is used to reduce the variance of the gradient estimator. In particular two kinds of conditional dependency information are used to reduce variance:
the sequential order of samples (z is sampled after y => y does not depend on z)
plate
generators
References
- [1] Gradient Estimation Using Stochastic Computation Graphs,
John Schulman, Nicolas Heess, Theophane Weber, Pieter Abbeel
- can_infer_discrete = True¶
- loss(rng_key, param_map, model, guide, *args, **kwargs)[source]¶
Evaluates the ELBO with an estimator that uses num_particles many samples/particles.
- Parameters
rng_key (jax.random.PRNGKey) – random number generator seed.
param_map (dict) – dictionary of current parameter values keyed by site name.
model – Python callable with NumPyro primitives for the model.
guide – Python callable with NumPyro primitives for the guide.
args – arguments to the model / guide (these can possibly vary during the course of fitting).
kwargs – keyword arguments to the model / guide (these can possibly vary during the course of fitting).
- Returns
negative of the Evidence Lower Bound (ELBO) to be minimized.
TraceMeanField_ELBO¶
- class TraceMeanField_ELBO(num_particles=1)[source]¶
Bases:
numpyro.infer.elbo.ELBO
A trace implementation of ELBO-based SVI. This is currently the only ELBO estimator in NumPyro that uses analytic KL divergences when those are available.
Warning
This estimator may give incorrect results if the mean-field condition is not satisfied. The mean field condition is a sufficient but not necessary condition for this estimator to be correct. The precise condition is that for every latent variable z in the guide, its parents in the model must not include any latent variables that are descendants of z in the guide. Here ‘parents in the model’ and ‘descendants in the guide’ is with respect to the corresponding (statistical) dependency structure. For example, this condition is always satisfied if the model and guide have identical dependency structures.
- loss_with_mutable_state(rng_key, param_map, model, guide, *args, **kwargs)[source]¶
Likes
loss()
but also update and return the mutable state, which stores the values atmutable()
sites.- Parameters
rng_key (jax.random.PRNGKey) – random number generator seed.
param_map (dict) – dictionary of current parameter values keyed by site name.
model – Python callable with NumPyro primitives for the model.
guide – Python callable with NumPyro primitives for the guide.
args – arguments to the model / guide (these can possibly vary during the course of fitting).
kwargs – keyword arguments to the model / guide (these can possibly vary during the course of fitting).
- Returns
a tuple of ELBO loss and the mutable state
RenyiELBO¶
- class RenyiELBO(alpha=0, num_particles=2)[source]¶
Bases:
numpyro.infer.elbo.ELBO
An implementation of Renyi’s \(\alpha\)-divergence variational inference following reference [1]. In order for the objective to be a strict lower bound, we require \(\alpha \ge 0\). Note, however, that according to reference [1], depending on the dataset \(\alpha < 0\) might give better results. In the special case \(\alpha = 0\), the objective function is that of the important weighted autoencoder derived in reference [2].
Note
Setting \(\alpha < 1\) gives a better bound than the usual ELBO.
- Parameters
alpha (float) – The order of \(\alpha\)-divergence. Here \(\alpha \neq 1\). Default is 0.
num_particles – The number of particles/samples used to form the objective (gradient) estimator. Default is 2.
References:
Renyi Divergence Variational Inference, Yingzhen Li, Richard E. Turner
Importance Weighted Autoencoders, Yuri Burda, Roger Grosse, Ruslan Salakhutdinov
- loss(rng_key, param_map, model, guide, *args, **kwargs)[source]¶
Evaluates the ELBO with an estimator that uses num_particles many samples/particles.
- Parameters
rng_key (jax.random.PRNGKey) – random number generator seed.
param_map (dict) – dictionary of current parameter values keyed by site name.
model – Python callable with NumPyro primitives for the model.
guide – Python callable with NumPyro primitives for the guide.
args – arguments to the model / guide (these can possibly vary during the course of fitting).
kwargs – keyword arguments to the model / guide (these can possibly vary during the course of fitting).
- Returns
negative of the Evidence Lower Bound (ELBO) to be minimized.