Automatic Guide Generation
We provide a brief overview of the automatically generated guides available in NumPyro:
AutoNormal and AutoDiagonalNormal are our basic mean-field guides. If the latent space is non-euclidean (due to e.g. a positivity constraint on one of the sample sites) an appropriate bijective transformation is automatically used under the hood to map between the unconstrained space (where the Normal variational distribution is defined) to the corresponding constrained space (note this is true for all automatic guides). These guides are a great place to start when trying to get variational inference to work on a model you are developing.
AutoMultivariateNormal and AutoLowRankMultivariateNormal also construct Normal variational distributions but offer more flexibility, as they can capture correlations in the posterior. Note that these guides may be difficult to fit in the high-dimensional setting.
AutoDelta is used for computing point estimates via MAP (maximum a posteriori estimation). See here for example usage.
AutoBNAFNormal and AutoIAFNormal offer flexible variational distributions parameterized by normalizing flows.
AutoDAIS is a powerful variational inference algorithm that leverages HMC. It can be a good choice for dealing with highly correlated posteriors but may be computationally expensive depending on the nature of the model.
AutoSurrogateLikelihoodDAIS is a powerful variational inference algorithm that leverages HMC and that supports data subsampling.
AutoSemiDAIS constructs a posterior approximation like AutoDAIS for local latent variables but provides support for data subsampling during ELBO training by utilizing a parametric guide for global latent variables.
AutoLaplaceApproximation can be used to compute a Laplace approximation.
AutoGuideList can be used to combine multiple automatic guides.
AutoGuide
- class AutoGuide(model, *, prefix='auto', init_loc_fn=<function init_to_uniform>, create_plates=None)[source]
Bases:
ABC
Base class for automatic guides.
Derived classes must implement the
__call__()
method.- Parameters:
model (callable) – a pyro model
prefix (str) – a prefix that will be prefixed to all param internal sites
init_loc_fn (callable) – A per-site initialization function. See Initialization Strategies section for available functions.
create_plates (callable) – An optional function inputing the same
*args,**kwargs
asmodel()
and returning anumpyro.plate
or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling.
- abstract sample_posterior(rng_key, params, *, sample_shape=())[source]
Generate samples from the approximate posterior over the latent sites in the model.
- Parameters:
rng_key (jax.random.PRNGKey) – random key to be used draw samples.
params (dict) – Current parameters of model and autoguide. The parameters can be obtained using
get_params()
method fromSVI
.sample_shape (tuple) – sample shape of each latent site, defaults to ().
- Returns:
a dict containing samples drawn the this guide.
- Return type:
- median(params)[source]
Returns the posterior median value of each latent variable.
- Parameters:
params (dict) – A dict containing parameter values. The parameters can be obtained using
get_params()
method fromSVI
.- Returns:
A dict mapping sample site name to median value.
- Return type:
- quantiles(params, quantiles)[source]
Returns posterior quantiles each latent variable. Example:
print(guide.quantiles(params, [0.05, 0.5, 0.95]))
- Parameters:
params (dict) – A dict containing parameter values. The parameters can be obtained using
get_params()
method fromSVI
.quantiles (list) – A list of requested quantiles between 0 and 1.
- Returns:
A dict mapping sample site name to an array of quantile values.
- Return type:
AutoGuideList
- class AutoGuideList(model, *, prefix='auto', init_loc_fn=<function init_to_uniform>, create_plates=None)[source]
Bases:
AutoGuide
Container class to combine multiple automatic guides.
Example usage:
rng_key_init = random.PRNGKey(0) guide = AutoGuideList(my_model) guide.append( AutoNormal( numpyro.handlers.block(model, hide=["coefs"]) ) ) guide.append( AutoDelta( numpyro.handlers.block(model, expose=["coefs"]) ) ) svi = SVI(model, guide, optim, Trace_ELBO()) svi_state = svi.init(rng_key_init, data, labels) params = svi.get_params(svi_state)
- Parameters:
model (callable) – a NumPyro model
- append(part)[source]
Add an automatic or custom guide for part of the model. The guide should have been created by blocking the model to restrict to a subset of sample sites. No two parts should operate on any one sample site.
- Parameters:
part (AutoGuide) – a partial guide to add
- sample_posterior(rng_key, params, *args, sample_shape=(), **kwargs)[source]
Generate samples from the approximate posterior over the latent sites in the model.
- Parameters:
rng_key (jax.random.PRNGKey) – random key to be used draw samples.
params (dict) – Current parameters of model and autoguide. The parameters can be obtained using
get_params()
method fromSVI
.sample_shape (tuple) – sample shape of each latent site, defaults to ().
- Returns:
a dict containing samples drawn the this guide.
- Return type:
- median(params)[source]
Returns the posterior median value of each latent variable.
- Parameters:
params (dict) – A dict containing parameter values. The parameters can be obtained using
get_params()
method fromSVI
.- Returns:
A dict mapping sample site name to median value.
- Return type:
- quantiles(params, quantiles)[source]
Returns posterior quantiles each latent variable. Example:
print(guide.quantiles(params, [0.05, 0.5, 0.95]))
- Parameters:
params (dict) – A dict containing parameter values. The parameters can be obtained using
get_params()
method fromSVI
.quantiles (list) – A list of requested quantiles between 0 and 1.
- Returns:
A dict mapping sample site name to an array of quantile values.
- Return type:
AutoContinuous
- class AutoContinuous(model, *, prefix='auto', init_loc_fn=<function init_to_uniform>, create_plates=None)[source]
Bases:
AutoGuide
Base class for implementations of continuous-valued Automatic Differentiation Variational Inference [1].
Each derived class implements its own
_get_posterior()
method.Assumes model structure and latent dimension are fixed, and all latent variables are continuous.
Reference:
Automatic Differentiation Variational Inference, Alp Kucukelbir, Dustin Tran, Rajesh Ranganath, Andrew Gelman, David M. Blei
- Parameters:
model (callable) – A NumPyro model.
prefix (str) – a prefix that will be prefixed to all param internal sites.
init_loc_fn (callable) – A per-site initialization function. See Initialization Strategies section for available functions.
- get_base_dist()[source]
Returns the base distribution of the posterior when reparameterized as a
TransformedDistribution
. This should not depend on the model’s *args, **kwargs.
- get_transform(params)[source]
Returns the transformation learned by the guide to generate samples from the unconstrained (approximate) posterior.
- Parameters:
params (dict) – Current parameters of model and autoguide. The parameters can be obtained using
get_params()
method fromSVI
.- Returns:
the transform of posterior distribution
- Return type:
- get_posterior(params)[source]
Returns the posterior distribution.
- Parameters:
params (dict) – Current parameters of model and autoguide. The parameters can be obtained using
get_params()
method fromSVI
.
- sample_posterior(rng_key, params, *, sample_shape=())[source]
Generate samples from the approximate posterior over the latent sites in the model.
- Parameters:
rng_key (jax.random.PRNGKey) – random key to be used draw samples.
params (dict) – Current parameters of model and autoguide. The parameters can be obtained using
get_params()
method fromSVI
.sample_shape (tuple) – sample shape of each latent site, defaults to ().
- Returns:
a dict containing samples drawn the this guide.
- Return type:
AutoBNAFNormal
- class AutoBNAFNormal(model, *, prefix='auto', init_loc_fn=<function init_to_uniform>, num_flows=1, hidden_factors=[8, 8])[source]
Bases:
AutoContinuous
This implementation of
AutoContinuous
uses a Diagonal Normal distribution transformed via aBlockNeuralAutoregressiveTransform
to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Usage:
guide = AutoBNAFNormal(model, num_flows=1, hidden_factors=[50, 50], ...) svi = SVI(model, guide, ...)
References
Block Neural Autoregressive Flow, Nicola De Cao, Ivan Titov, Wilker Aziz
- Parameters:
model (callable) – a generative model.
prefix (str) – a prefix that will be prefixed to all param internal sites.
init_loc_fn (callable) – A per-site initialization function.
num_flows (int) – the number of flows to be used, defaults to 1.
hidden_factors (list) – Hidden layer i has
hidden_factors[i]
hidden units per input dimension. This corresponds to both \(a\) and \(b\) in reference [1]. The elements of hidden_factors must be integers.
- get_base_dist()[source]
Returns the base distribution of the posterior when reparameterized as a
TransformedDistribution
. This should not depend on the model’s *args, **kwargs.
AutoDiagonalNormal
- class AutoDiagonalNormal(model, *, prefix='auto', init_loc_fn=<function init_to_uniform>, init_scale=0.1)[source]
Bases:
AutoContinuous
This implementation of
AutoContinuous
uses a Normal distribution with a diagonal covariance matrix to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Usage:
guide = AutoDiagonalNormal(model, ...) svi = SVI(model, guide, ...)
- scale_constraint = SoftplusPositive(lower_bound=0.0)
- get_base_dist()[source]
Returns the base distribution of the posterior when reparameterized as a
TransformedDistribution
. This should not depend on the model’s *args, **kwargs.
- get_transform(params)[source]
Returns the transformation learned by the guide to generate samples from the unconstrained (approximate) posterior.
- Parameters:
params (dict) – Current parameters of model and autoguide. The parameters can be obtained using
get_params()
method fromSVI
.- Returns:
the transform of posterior distribution
- Return type:
- median(params)[source]
Returns the posterior median value of each latent variable.
- Parameters:
params (dict) – A dict containing parameter values. The parameters can be obtained using
get_params()
method fromSVI
.- Returns:
A dict mapping sample site name to median value.
- Return type:
- quantiles(params, quantiles)[source]
Returns posterior quantiles each latent variable. Example:
print(guide.quantiles(params, [0.05, 0.5, 0.95]))
- Parameters:
params (dict) – A dict containing parameter values. The parameters can be obtained using
get_params()
method fromSVI
.quantiles (list) – A list of requested quantiles between 0 and 1.
- Returns:
A dict mapping sample site name to an array of quantile values.
- Return type:
AutoMultivariateNormal
- class AutoMultivariateNormal(model, *, prefix='auto', init_loc_fn=<function init_to_uniform>, init_scale=0.1)[source]
Bases:
AutoContinuous
This implementation of
AutoContinuous
uses a MultivariateNormal distribution to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Usage:
guide = AutoMultivariateNormal(model, ...) svi = SVI(model, guide, ...)
- scale_tril_constraint = ScaledUnitLowerCholesky()
- get_base_dist()[source]
Returns the base distribution of the posterior when reparameterized as a
TransformedDistribution
. This should not depend on the model’s *args, **kwargs.
- get_transform(params)[source]
Returns the transformation learned by the guide to generate samples from the unconstrained (approximate) posterior.
- Parameters:
params (dict) – Current parameters of model and autoguide. The parameters can be obtained using
get_params()
method fromSVI
.- Returns:
the transform of posterior distribution
- Return type:
- median(params)[source]
Returns the posterior median value of each latent variable.
- Parameters:
params (dict) – A dict containing parameter values. The parameters can be obtained using
get_params()
method fromSVI
.- Returns:
A dict mapping sample site name to median value.
- Return type:
- quantiles(params, quantiles)[source]
Returns posterior quantiles each latent variable. Example:
print(guide.quantiles(params, [0.05, 0.5, 0.95]))
- Parameters:
params (dict) – A dict containing parameter values. The parameters can be obtained using
get_params()
method fromSVI
.quantiles (list) – A list of requested quantiles between 0 and 1.
- Returns:
A dict mapping sample site name to an array of quantile values.
- Return type:
AutoIAFNormal
- class AutoIAFNormal(model, *, prefix='auto', init_loc_fn=<function init_to_uniform>, num_flows=3, hidden_dims=None, skip_connections=False, nonlinearity=(<function elementwise.<locals>.<lambda>>, <function elementwise.<locals>.<lambda>>))[source]
Bases:
AutoContinuous
This implementation of
AutoContinuous
uses a Diagonal Normal distribution transformed via aInverseAutoregressiveTransform
to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Usage:
guide = AutoIAFNormal(model, hidden_dims=[20], skip_connections=True, ...) svi = SVI(model, guide, ...)
- Parameters:
model (callable) – a generative model.
prefix (str) – a prefix that will be prefixed to all param internal sites.
init_loc_fn (callable) – A per-site initialization function.
num_flows (int) – the number of flows to be used, defaults to 3.
hidden_dims (list) – the dimensionality of the hidden units per layer. Defaults to
[latent_dim, latent_dim]
.skip_connections (bool) – whether to add skip connections from the input to the output of each flow. Defaults to False.
nonlinearity (callable) – the nonlinearity to use in the feedforward network. Defaults to
jax.example_libraries.stax.Elu()
.
- get_base_dist()[source]
Returns the base distribution of the posterior when reparameterized as a
TransformedDistribution
. This should not depend on the model’s *args, **kwargs.
AutoLaplaceApproximation
- class AutoLaplaceApproximation(model, *, prefix='auto', init_loc_fn=<function init_to_uniform>, create_plates=None, hessian_fn=None)[source]
Bases:
AutoContinuous
Laplace approximation (quadratic approximation) approximates the posterior \(\log p(z | x)\) by a multivariate normal distribution in the unconstrained space. Under the hood, it uses Delta distributions to construct a MAP (i.e. point estimate) guide over the entire (unconstrained) latent space. Its covariance is given by the inverse of the hessian of \(-\log p(x, z)\) at the MAP point of z.
Usage:
guide = AutoLaplaceApproximation(model, ...) svi = SVI(model, guide, ...)
- Parameters:
hessian_fn (callable) – EXPERIMENTAL a function that takes a function f and a vector x`and returns the hessian of `f at x. By default, we use
lambda f, x: jax.hessian(f)(x)
. Other alternatives can belambda f, x: jax.jacobian(jax.jacobian(f))(x)
orlambda f, x: jax.hessian(f)(x) + 1e-3 * jnp.eye(x.shape[0])
. The later example is helpful when the hessian of f at x is not positive definite. Note that the output hessian is the precision matrix of the laplace approximation.
- get_base_dist()[source]
Returns the base distribution of the posterior when reparameterized as a
TransformedDistribution
. This should not depend on the model’s *args, **kwargs.
- get_transform(params)[source]
Returns the transformation learned by the guide to generate samples from the unconstrained (approximate) posterior.
- Parameters:
params (dict) – Current parameters of model and autoguide. The parameters can be obtained using
get_params()
method fromSVI
.- Returns:
the transform of posterior distribution
- Return type:
- sample_posterior(rng_key, params, *, sample_shape=())[source]
Generate samples from the approximate posterior over the latent sites in the model.
- Parameters:
rng_key (jax.random.PRNGKey) – random key to be used draw samples.
params (dict) – Current parameters of model and autoguide. The parameters can be obtained using
get_params()
method fromSVI
.sample_shape (tuple) – sample shape of each latent site, defaults to ().
- Returns:
a dict containing samples drawn the this guide.
- Return type:
- median(params)[source]
Returns the posterior median value of each latent variable.
- Parameters:
params (dict) – A dict containing parameter values. The parameters can be obtained using
get_params()
method fromSVI
.- Returns:
A dict mapping sample site name to median value.
- Return type:
- quantiles(params, quantiles)[source]
Returns posterior quantiles each latent variable. Example:
print(guide.quantiles(params, [0.05, 0.5, 0.95]))
- Parameters:
params (dict) – A dict containing parameter values. The parameters can be obtained using
get_params()
method fromSVI
.quantiles (list) – A list of requested quantiles between 0 and 1.
- Returns:
A dict mapping sample site name to an array of quantile values.
- Return type:
AutoLowRankMultivariateNormal
- class AutoLowRankMultivariateNormal(model, *, prefix='auto', init_loc_fn=<function init_to_uniform>, init_scale=0.1, rank=None)[source]
Bases:
AutoContinuous
This implementation of
AutoContinuous
uses a LowRankMultivariateNormal distribution to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Usage:
guide = AutoLowRankMultivariateNormal(model, rank=2, ...) svi = SVI(model, guide, ...)
- scale_constraint = SoftplusPositive(lower_bound=0.0)
- get_base_dist()[source]
Returns the base distribution of the posterior when reparameterized as a
TransformedDistribution
. This should not depend on the model’s *args, **kwargs.
- get_transform(params)[source]
Returns the transformation learned by the guide to generate samples from the unconstrained (approximate) posterior.
- Parameters:
params (dict) – Current parameters of model and autoguide. The parameters can be obtained using
get_params()
method fromSVI
.- Returns:
the transform of posterior distribution
- Return type:
- median(params)[source]
Returns the posterior median value of each latent variable.
- Parameters:
params (dict) – A dict containing parameter values. The parameters can be obtained using
get_params()
method fromSVI
.- Returns:
A dict mapping sample site name to median value.
- Return type:
- quantiles(params, quantiles)[source]
Returns posterior quantiles each latent variable. Example:
print(guide.quantiles(params, [0.05, 0.5, 0.95]))
- Parameters:
params (dict) – A dict containing parameter values. The parameters can be obtained using
get_params()
method fromSVI
.quantiles (list) – A list of requested quantiles between 0 and 1.
- Returns:
A dict mapping sample site name to an array of quantile values.
- Return type:
AutoNormal
- class AutoNormal(model, *, prefix='auto', init_loc_fn=<function init_to_uniform>, init_scale=0.1, create_plates=None)[source]
Bases:
AutoGuide
This implementation of
AutoGuide
uses Normal distributions to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.This should be equivalent to
AutoDiagonalNormal
, but with more convenient site names and with better support for mean field ELBO.Usage:
guide = AutoNormal(model) svi = SVI(model, guide, ...)
- Parameters:
model (callable) – A NumPyro model.
prefix (str) – a prefix that will be prefixed to all param internal sites.
init_loc_fn (callable) – A per-site initialization function. See Initialization Strategies section for available functions.
init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.
create_plates (callable) – An optional function inputing the same
*args,**kwargs
asmodel()
and returning anumpyro.plate
or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling.
- scale_constraint = SoftplusPositive(lower_bound=0.0)
- sample_posterior(rng_key, params, *, sample_shape=())[source]
Generate samples from the approximate posterior over the latent sites in the model.
- Parameters:
rng_key (jax.random.PRNGKey) – random key to be used draw samples.
params (dict) – Current parameters of model and autoguide. The parameters can be obtained using
get_params()
method fromSVI
.sample_shape (tuple) – sample shape of each latent site, defaults to ().
- Returns:
a dict containing samples drawn the this guide.
- Return type:
- median(params)[source]
Returns the posterior median value of each latent variable.
- Parameters:
params (dict) – A dict containing parameter values. The parameters can be obtained using
get_params()
method fromSVI
.- Returns:
A dict mapping sample site name to median value.
- Return type:
- quantiles(params, quantiles)[source]
Returns posterior quantiles each latent variable. Example:
print(guide.quantiles(params, [0.05, 0.5, 0.95]))
- Parameters:
params (dict) – A dict containing parameter values. The parameters can be obtained using
get_params()
method fromSVI
.quantiles (list) – A list of requested quantiles between 0 and 1.
- Returns:
A dict mapping sample site name to an array of quantile values.
- Return type:
AutoDelta
- class AutoDelta(model, *, prefix='auto', init_loc_fn=<function init_to_median>, create_plates=None)[source]
Bases:
AutoGuide
This implementation of
AutoGuide
uses Delta distributions to construct a MAP guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Note
This class does MAP inference in constrained space.
Usage:
guide = AutoDelta(model) svi = SVI(model, guide, ...)
- Parameters:
model (callable) – A NumPyro model.
prefix (str) – a prefix that will be prefixed to all param internal sites.
init_loc_fn (callable) – A per-site initialization function. See Initialization Strategies section for available functions.
create_plates (callable) – An optional function inputing the same
*args,**kwargs
asmodel()
and returning anumpyro.plate
or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling.
- sample_posterior(rng_key, params, *args, sample_shape=(), **kwargs)[source]
Generate samples from the approximate posterior over the latent sites in the model.
- Parameters:
rng_key (jax.random.PRNGKey) – random key to be used draw samples.
params (dict) – Current parameters of model and autoguide. The parameters can be obtained using
get_params()
method fromSVI
.sample_shape (tuple) – sample shape of each latent site, defaults to ().
- Returns:
a dict containing samples drawn the this guide.
- Return type:
AutoDAIS
- class AutoDAIS(model, *, K=4, base_dist='diagonal', eta_init=0.01, eta_max=0.1, gamma_init=0.9, prefix='auto', init_loc_fn=<function init_to_uniform>, init_scale=0.1)[source]
Bases:
AutoContinuous
This implementation of
AutoDAIS
uses Differentiable Annealed Importance Sampling (DAIS) [1, 2] to construct a guide over the entire latent space. Samples from the variational distribution (i.e. guide) are generated using a combination of (uncorrected) Hamiltonian Monte Carlo and Annealed Importance Sampling. The same algorithm is called Uncorrected Hamiltonian Annealing in [1].Note that AutoDAIS cannot be used in conjunction with data subsampling.
Reference:
MCMC Variational Inference via Uncorrected Hamiltonian Annealing, Tomas Geffner, Justin Domke
Differentiable Annealed Importance Sampling and the Perils of Gradient Noise, Guodong Zhang, Kyle Hsu, Jianing Li, Chelsea Finn, Roger Grosse
Usage:
guide = AutoDAIS(model) svi = SVI(model, guide, ...)
- Parameters:
model (callable) – A NumPyro model.
prefix (str) – A prefix that will be prefixed to all param internal sites.
K (int) – A positive integer that controls the number of HMC steps used. Defaults to 4.
base_dist (str) – Controls whether the base Normal variational distribution is parameterized by a “diagonal” covariance matrix or a full-rank covariance matrix parameterized by a lower-triangular “cholesky” factor. Defaults to “diagonal”.
eta_init (float) – The initial value of the step size used in HMC. Defaults to 0.01.
eta_max (float) – The maximum value of the learnable step size used in HMC. Defaults to 0.1.
gamma_init (float) – The initial value of the learnable damping factor used during partial momentum refreshments in HMC. Defaults to 0.9.
init_loc_fn (callable) – A per-site initialization function. See Initialization Strategies section for available functions.
init_scale (float) – Initial scale for the standard deviation of the base variational distribution for each (unconstrained transformed) latent variable. Defaults to 0.1.
- sample_posterior(rng_key, params, *, sample_shape=())[source]
Generate samples from the approximate posterior over the latent sites in the model.
- Parameters:
rng_key (jax.random.PRNGKey) – random key to be used draw samples.
params (dict) – Current parameters of model and autoguide. The parameters can be obtained using
get_params()
method fromSVI
.sample_shape (tuple) – sample shape of each latent site, defaults to ().
- Returns:
a dict containing samples drawn the this guide.
- Return type:
AutoSemiDAIS
- class AutoSemiDAIS(model, local_model, global_guide=None, local_guide=None, *, prefix='auto', K=4, eta_init=0.01, eta_max=0.1, gamma_init=0.9, init_scale=0.1, subsample_plate=None, use_global_dais_params=False)[source]
Bases:
AutoGuide
This implementation of
AutoSemiDAIS
[1] combines a parametric variational distribution over global latent variables with Differentiable Annealed Importance Sampling (DAIS) [2, 3] to infer local latent variables. UnlikeAutoDAIS
this guide can be used in conjunction with data subsampling. Note that the resulting ELBO can be understood as a particular realization of a ‘locally enhanced bound’ as described in reference [4].References:
Surrogate Likelihoods for Variational Annealed Importance Sampling, Martin Jankowiak, Du Phan
MCMC Variational Inference via Uncorrected Hamiltonian Annealing, Tomas Geffner, Justin Domke
Differentiable Annealed Importance Sampling and the Perils of Gradient Noise, Guodong Zhang, Kyle Hsu, Jianing Li, Chelsea Finn, Roger Grosse
Variational Inference with Locally Enhanced Bounds for Hierarchical Models, Tomas Geffner, Justin Domke
Usage:
def global_model(): return numpyro.sample("theta", dist.Normal(0, 1)) def local_model(theta): with numpyro.plate("data", 8, subsample_size=2): tau = numpyro.sample("tau", dist.Gamma(5.0, 5.0)) numpyro.sample("obs", dist.Normal(0.0, tau), obs=jnp.ones(2)) model = lambda: local_model(global_model()) global_guide = AutoNormal(global_model) guide = AutoSemiDAIS(model, local_model, global_guide, K=4) svi = SVI(model, guide, ...) # sample posterior for particular data subset {3, 7} with handlers.substitute(data={"data": jnp.array([3, 7])}): samples = guide.sample_posterior(random.PRNGKey(1), params)
- Parameters:
model (callable) – A NumPyro model with global and local latent variables.
local_model (callable) – The portion of model that includes the local latent variables only. The signature of local_model should be the return type of the global model with global latent variables only.
global_guide (callable) – A guide for the global latent variables, e.g. an autoguide. The return type should be a dictionary of latent sample sites names and corresponding samples. If there is no global variable in the model, we can set this to None.
local_guide (callable) – An optional guide for specifying the DAIS base distribution for local latent variables.
prefix (str) – A prefix that will be prefixed to all internal sites.
K (int) – A positive integer that controls the number of HMC steps used. Defaults to 4.
eta_init (float) – The initial value of the step size used in HMC. Defaults to 0.01.
eta_max (float) – The maximum value of the learnable step size used in HMC. Defaults to 0.1.
gamma_init (float) – The initial value of the learnable damping factor used during partial momentum refreshments in HMC. Defaults to 0.9.
init_scale (float) – Initial scale for the standard deviation of the variational distribution for each (unconstrained transformed) local latent variable. Defaults to 0.1.
subsample_plate (str) – Optional name of the subsample plate site. This is required when the model has a subsample plate without subsample_size specified or the model has a subsample plate with subsample_size equal to the plate size.
use_global_dais_params (bool) – Whether parameters controlling DAIS dynamic (HMC step size, HMC mass matrix, etc.) should be global (i.e. common to all data points in the subsample plate) or local (i.e. each data point in the subsample plate has individual parameters). Note that we do not use global parameters for the base distribution.
- sample_posterior(rng_key, params, *args, sample_shape=(), **kwargs)[source]
Generate samples from the approximate posterior over the latent sites in the model.
- Parameters:
rng_key (jax.random.PRNGKey) – random key to be used draw samples.
params (dict) – Current parameters of model and autoguide. The parameters can be obtained using
get_params()
method fromSVI
.sample_shape (tuple) – sample shape of each latent site, defaults to ().
- Returns:
a dict containing samples drawn the this guide.
- Return type:
AutoSurrogateLikelihoodDAIS
- class AutoSurrogateLikelihoodDAIS(model, surrogate_model, *, K=4, eta_init=0.01, eta_max=0.1, gamma_init=0.9, prefix='auto', base_dist='diagonal', init_loc_fn=<function init_to_uniform>, init_scale=0.1)[source]
Bases:
AutoDAIS
This implementation of
AutoSurrogateLikelihoodDAIS
provides a mini-batchable family of variational distributions as described in [1]. It combines a user-provided surrogate likelihood with Differentiable Annealed Importance Sampling (DAIS) [2, 3]. It is not applicable to models with local latent variables (seeAutoSemiDAIS
), but unlikeAutoDAIS
, it can be used in conjunction with data subsampling.Reference:
Surrogate likelihoods for variational annealed importance sampling, Martin Jankowiak, Du Phan
MCMC Variational Inference via Uncorrected Hamiltonian Annealing, Tomas Geffner, Justin Domke
Differentiable Annealed Importance Sampling and the Perils of Gradient Noise, Guodong Zhang, Kyle Hsu, Jianing Li, Chelsea Finn, Roger Grosse
Usage:
# logistic regression model for data {X, Y} def model(X, Y): theta = numpyro.sample( "theta", dist.Normal(jnp.zeros(2), jnp.ones(2)).to_event(1) ) with numpyro.plate("N", 100, subsample_size=10): X_batch = numpyro.subsample(X, event_dim=1) Y_batch = numpyro.subsample(Y, event_dim=0) numpyro.sample("obs", dist.Bernoulli(logits=theta @ X_batch.T), obs=Y_batch) # surrogate model defined by prior and surrogate likelihood. # a convenient choice for specifying the latter is to compute the likelihood on # a randomly chosen data subset (here {X_surr, Y_surr} of size 20) and then use # handlers.scale to scale the log likelihood by a vector of learnable weights. def surrogate_model(X_surr, Y_surr): theta = numpyro.sample( "theta", dist.Normal(jnp.zeros(2), jnp.ones(2)).to_event(1) ) omegas = numpyro.param( "omegas", 5.0 * jnp.ones(20), constraint=dist.constraints.positive ) with numpyro.plate("N", 20), numpyro.handlers.scale(scale=omegas): numpyro.sample("obs", dist.Bernoulli(logits=theta @ X_surr.T), obs=Y_surr) guide = AutoSurrogateLikelihoodDAIS(model, surrogate_model) svi = SVI(model, guide, ...)
- Parameters:
model (callable) – A NumPyro model.
surrogate_model (callable) – A NumPyro model that is used as a surrogate model for guiding the HMC dynamics that define the variational distribution. In particular surrogate_model should contain the same prior as model but should contain a cheap-to-evaluate parametric ansatz for the likelihood. A simple ansatz for the latter involves computing the likelihood for a fixed subset of the data and scaling the resulting log likelihood by a learnable vector of positive weights. See the usage example above.
prefix (str) – A prefix that will be prefixed to all param internal sites.
K (int) – A positive integer that controls the number of HMC steps used. Defaults to 4.
base_dist (str) – Controls whether the base Normal variational distribution is parameterized by a “diagonal” covariance matrix or a full-rank covariance matrix parameterized by a lower-triangular “cholesky” factor. Defaults to “diagonal”.
eta_init (float) – The initial value of the step size used in HMC. Defaults to 0.01.
eta_max (float) – The maximum value of the learnable step size used in HMC. Defaults to 0.1.
gamma_init (float) – The initial value of the learnable damping factor used during partial momentum refreshments in HMC. Defaults to 0.9.
init_loc_fn (callable) – A per-site initialization function. See Initialization Strategies section for available functions.
init_scale (float) – Initial scale for the standard deviation of the base variational distribution for each (unconstrained transformed) latent variable. Defaults to 0.1.