Automatic Guide Generation

We provide a brief overview of the automatically generated guides available in NumPyro:

  • AutoNormal and AutoDiagonalNormal are our basic mean-field guides. If the latent space is non-euclidean (due to e.g. a positivity constraint on one of the sample sites) an appropriate bijective transformation is automatically used under the hood to map between the unconstrained space (where the Normal variational distribution is defined) to the corresponding constrained space (note this is true for all automatic guides). These guides are a great place to start when trying to get variational inference to work on a model you are developing.

  • AutoMultivariateNormal and AutoLowRankMultivariateNormal also construct Normal variational distributions but offer more flexibility, as they can capture correlations in the posterior. Note that these guides may be difficult to fit in the high-dimensional setting.

  • AutoDelta is used for computing point estimates via MAP (maximum a posteriori estimation). See here for example usage.

  • AutoBNAFNormal and AutoIAFNormal offer flexible variational distributions parameterized by normalizing flows.

  • AutoDAIS is a powerful variational inference algorithm that leverages HMC. It can be a good choice for dealing with highly correlated posteriors but may be computationally expensive depending on the nature of the model.

  • AutoSurrogateLikelihoodDAIS is a powerful variational inference algorithm that leverages HMC and that supports data subsampling.

  • AutoSemiDAIS constructs a posterior approximation like AutoDAIS for local latent variables but provides support for data subsampling during ELBO training by utilizing a parametric guide for global latent variables.

  • AutoLaplaceApproximation can be used to compute a Laplace approximation.

  • AutoGuideList can be used to combine multiple automatic guides.

AutoGuide

class AutoGuide(model, *, prefix='auto', init_loc_fn=<function init_to_uniform>, create_plates=None)[source]

Bases: ABC

Base class for automatic guides.

Derived classes must implement the __call__() method.

Parameters:
  • model (callable) – a pyro model

  • prefix (str) – a prefix that will be prefixed to all param internal sites

  • init_loc_fn (callable) – A per-site initialization function. See Initialization Strategies section for available functions.

  • create_plates (callable) – An optional function inputing the same *args,**kwargs as model() and returning a numpyro.plate or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling.

abstract sample_posterior(rng_key, params, *, sample_shape=())[source]

Generate samples from the approximate posterior over the latent sites in the model.

Parameters:
  • rng_key (jax.random.PRNGKey) – random key to be used draw samples.

  • params (dict) – Current parameters of model and autoguide. The parameters can be obtained using get_params() method from SVI.

  • sample_shape (tuple) – sample shape of each latent site, defaults to ().

Returns:

a dict containing samples drawn the this guide.

Return type:

dict

median(params)[source]

Returns the posterior median value of each latent variable.

Parameters:

params (dict) – A dict containing parameter values. The parameters can be obtained using get_params() method from SVI.

Returns:

A dict mapping sample site name to median value.

Return type:

dict

quantiles(params, quantiles)[source]

Returns posterior quantiles each latent variable. Example:

print(guide.quantiles(params, [0.05, 0.5, 0.95]))
Parameters:
  • params (dict) – A dict containing parameter values. The parameters can be obtained using get_params() method from SVI.

  • quantiles (list) – A list of requested quantiles between 0 and 1.

Returns:

A dict mapping sample site name to an array of quantile values.

Return type:

dict

AutoGuideList

class AutoGuideList(model, *, prefix='auto', init_loc_fn=<function init_to_uniform>, create_plates=None)[source]

Bases: AutoGuide

Container class to combine multiple automatic guides.

Example usage:

rng_key_init = random.PRNGKey(0)
guide = AutoGuideList(my_model)
guide.append(
    AutoNormal(
        numpyro.handlers.block(numpyro.handlers.seed(model, rng_seed=0), hide=["coefs"])
    )
)
guide.append(
    AutoDelta(
        numpyro.handlers.block(numpyro.handlers.seed(model, rng_seed=1), expose=["coefs"])
    )
)
svi = SVI(model, guide, optim, Trace_ELBO())
svi_state = svi.init(rng_key_init, data, labels)
params = svi.get_params(svi_state)
Parameters:

model (callable) – a NumPyro model

append(part)[source]

Add an automatic or custom guide for part of the model. The guide should have been created by blocking the model to restrict to a subset of sample sites. No two parts should operate on any one sample site.

Parameters:

part (AutoGuide) – a partial guide to add

sample_posterior(rng_key, params, *args, sample_shape=(), **kwargs)[source]

Generate samples from the approximate posterior over the latent sites in the model.

Parameters:
  • rng_key (jax.random.PRNGKey) – random key to be used draw samples.

  • params (dict) – Current parameters of model and autoguide. The parameters can be obtained using get_params() method from SVI.

  • sample_shape (tuple) – sample shape of each latent site, defaults to ().

Returns:

a dict containing samples drawn the this guide.

Return type:

dict

median(params)[source]

Returns the posterior median value of each latent variable.

Parameters:

params (dict) – A dict containing parameter values. The parameters can be obtained using get_params() method from SVI.

Returns:

A dict mapping sample site name to median value.

Return type:

dict

quantiles(params, quantiles)[source]

Returns posterior quantiles each latent variable. Example:

print(guide.quantiles(params, [0.05, 0.5, 0.95]))
Parameters:
  • params (dict) – A dict containing parameter values. The parameters can be obtained using get_params() method from SVI.

  • quantiles (list) – A list of requested quantiles between 0 and 1.

Returns:

A dict mapping sample site name to an array of quantile values.

Return type:

dict

AutoContinuous

class AutoContinuous(model, *, prefix='auto', init_loc_fn=<function init_to_uniform>, create_plates=None)[source]

Bases: AutoGuide

Base class for implementations of continuous-valued Automatic Differentiation Variational Inference [1].

Each derived class implements its own _get_posterior() method.

Assumes model structure and latent dimension are fixed, and all latent variables are continuous.

Reference:

  1. Automatic Differentiation Variational Inference, Alp Kucukelbir, Dustin Tran, Rajesh Ranganath, Andrew Gelman, David M. Blei

Parameters:
  • model (callable) – A NumPyro model.

  • prefix (str) – a prefix that will be prefixed to all param internal sites.

  • init_loc_fn (callable) – A per-site initialization function. See Initialization Strategies section for available functions.

get_base_dist()[source]

Returns the base distribution of the posterior when reparameterized as a TransformedDistribution. This should not depend on the model’s *args, **kwargs.

get_transform(params)[source]

Returns the transformation learned by the guide to generate samples from the unconstrained (approximate) posterior.

Parameters:

params (dict) – Current parameters of model and autoguide. The parameters can be obtained using get_params() method from SVI.

Returns:

the transform of posterior distribution

Return type:

Transform

get_posterior(params)[source]

Returns the posterior distribution.

Parameters:

params (dict) – Current parameters of model and autoguide. The parameters can be obtained using get_params() method from SVI.

sample_posterior(rng_key, params, *, sample_shape=())[source]

Generate samples from the approximate posterior over the latent sites in the model.

Parameters:
  • rng_key (jax.random.PRNGKey) – random key to be used draw samples.

  • params (dict) – Current parameters of model and autoguide. The parameters can be obtained using get_params() method from SVI.

  • sample_shape (tuple) – sample shape of each latent site, defaults to ().

Returns:

a dict containing samples drawn the this guide.

Return type:

dict

AutoBNAFNormal

class AutoBNAFNormal(model, *, prefix='auto', init_loc_fn=<function init_to_uniform>, num_flows=1, hidden_factors=[8, 8])[source]

Bases: AutoContinuous

This implementation of AutoContinuous uses a Diagonal Normal distribution transformed via a BlockNeuralAutoregressiveTransform to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

Usage:

guide = AutoBNAFNormal(model, num_flows=1, hidden_factors=[50, 50], ...)
svi = SVI(model, guide, ...)

References

  1. Block Neural Autoregressive Flow, Nicola De Cao, Ivan Titov, Wilker Aziz

Parameters:
  • model (callable) – a generative model.

  • prefix (str) – a prefix that will be prefixed to all param internal sites.

  • init_loc_fn (callable) – A per-site initialization function.

  • num_flows (int) – the number of flows to be used, defaults to 1.

  • hidden_factors (list) – Hidden layer i has hidden_factors[i] hidden units per input dimension. This corresponds to both \(a\) and \(b\) in reference [1]. The elements of hidden_factors must be integers.

get_base_dist()[source]

Returns the base distribution of the posterior when reparameterized as a TransformedDistribution. This should not depend on the model’s *args, **kwargs.

AutoDiagonalNormal

class AutoDiagonalNormal(model, *, prefix='auto', init_loc_fn=<function init_to_uniform>, init_scale=0.1)[source]

Bases: AutoContinuous

This implementation of AutoContinuous uses a Normal distribution with a diagonal covariance matrix to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

Usage:

guide = AutoDiagonalNormal(model, ...)
svi = SVI(model, guide, ...)
scale_constraint = SoftplusPositive(lower_bound=0.0)
get_base_dist()[source]

Returns the base distribution of the posterior when reparameterized as a TransformedDistribution. This should not depend on the model’s *args, **kwargs.

get_transform(params)[source]

Returns the transformation learned by the guide to generate samples from the unconstrained (approximate) posterior.

Parameters:

params (dict) – Current parameters of model and autoguide. The parameters can be obtained using get_params() method from SVI.

Returns:

the transform of posterior distribution

Return type:

Transform

get_posterior(params)[source]

Returns a diagonal Normal posterior distribution.

median(params)[source]

Returns the posterior median value of each latent variable.

Parameters:

params (dict) – A dict containing parameter values. The parameters can be obtained using get_params() method from SVI.

Returns:

A dict mapping sample site name to median value.

Return type:

dict

quantiles(params, quantiles)[source]

Returns posterior quantiles each latent variable. Example:

print(guide.quantiles(params, [0.05, 0.5, 0.95]))
Parameters:
  • params (dict) – A dict containing parameter values. The parameters can be obtained using get_params() method from SVI.

  • quantiles (list) – A list of requested quantiles between 0 and 1.

Returns:

A dict mapping sample site name to an array of quantile values.

Return type:

dict

AutoMultivariateNormal

class AutoMultivariateNormal(model, *, prefix='auto', init_loc_fn=<function init_to_uniform>, init_scale=0.1)[source]

Bases: AutoContinuous

This implementation of AutoContinuous uses a MultivariateNormal distribution to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

Usage:

guide = AutoMultivariateNormal(model, ...)
svi = SVI(model, guide, ...)
scale_tril_constraint = ScaledUnitLowerCholesky()
get_base_dist()[source]

Returns the base distribution of the posterior when reparameterized as a TransformedDistribution. This should not depend on the model’s *args, **kwargs.

get_transform(params)[source]

Returns the transformation learned by the guide to generate samples from the unconstrained (approximate) posterior.

Parameters:

params (dict) – Current parameters of model and autoguide. The parameters can be obtained using get_params() method from SVI.

Returns:

the transform of posterior distribution

Return type:

Transform

get_posterior(params)[source]

Returns a multivariate Normal posterior distribution.

median(params)[source]

Returns the posterior median value of each latent variable.

Parameters:

params (dict) – A dict containing parameter values. The parameters can be obtained using get_params() method from SVI.

Returns:

A dict mapping sample site name to median value.

Return type:

dict

quantiles(params, quantiles)[source]

Returns posterior quantiles each latent variable. Example:

print(guide.quantiles(params, [0.05, 0.5, 0.95]))
Parameters:
  • params (dict) – A dict containing parameter values. The parameters can be obtained using get_params() method from SVI.

  • quantiles (list) – A list of requested quantiles between 0 and 1.

Returns:

A dict mapping sample site name to an array of quantile values.

Return type:

dict

AutoIAFNormal

class AutoIAFNormal(model, *, prefix='auto', init_loc_fn=<function init_to_uniform>, num_flows=3, hidden_dims=None, skip_connections=False, nonlinearity=(<function elementwise.<locals>.<lambda>>, <function elementwise.<locals>.<lambda>>))[source]

Bases: AutoContinuous

This implementation of AutoContinuous uses a Diagonal Normal distribution transformed via a InverseAutoregressiveTransform to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

Usage:

guide = AutoIAFNormal(model, hidden_dims=[20], skip_connections=True, ...)
svi = SVI(model, guide, ...)
Parameters:
  • model (callable) – a generative model.

  • prefix (str) – a prefix that will be prefixed to all param internal sites.

  • init_loc_fn (callable) – A per-site initialization function.

  • num_flows (int) – the number of flows to be used, defaults to 3.

  • hidden_dims (list) – the dimensionality of the hidden units per layer. Defaults to [latent_dim, latent_dim].

  • skip_connections (bool) – whether to add skip connections from the input to the output of each flow. Defaults to False.

  • nonlinearity (callable) – the nonlinearity to use in the feedforward network. Defaults to jax.example_libraries.stax.Elu().

get_base_dist()[source]

Returns the base distribution of the posterior when reparameterized as a TransformedDistribution. This should not depend on the model’s *args, **kwargs.

AutoLaplaceApproximation

class AutoLaplaceApproximation(model, *, prefix='auto', init_loc_fn=<function init_to_uniform>, create_plates=None, hessian_fn=None)[source]

Bases: AutoContinuous

Laplace approximation (quadratic approximation) approximates the posterior \(\log p(z | x)\) by a multivariate normal distribution in the unconstrained space. Under the hood, it uses Delta distributions to construct a MAP (i.e. point estimate) guide over the entire (unconstrained) latent space. Its covariance is given by the inverse of the hessian of \(-\log p(x, z)\) at the MAP point of z.

Usage:

guide = AutoLaplaceApproximation(model, ...)
svi = SVI(model, guide, ...)
Parameters:

hessian_fn (callable) – EXPERIMENTAL a function that takes a function f and a vector x`and returns the hessian of `f at x. By default, we use lambda f, x: jax.hessian(f)(x). Other alternatives can be lambda f, x: jax.jacobian(jax.jacobian(f))(x) or lambda f, x: jax.hessian(f)(x) + 1e-3 * jnp.eye(x.shape[0]). The later example is helpful when the hessian of f at x is not positive definite. Note that the output hessian is the precision matrix of the laplace approximation.

get_base_dist()[source]

Returns the base distribution of the posterior when reparameterized as a TransformedDistribution. This should not depend on the model’s *args, **kwargs.

get_transform(params)[source]

Returns the transformation learned by the guide to generate samples from the unconstrained (approximate) posterior.

Parameters:

params (dict) – Current parameters of model and autoguide. The parameters can be obtained using get_params() method from SVI.

Returns:

the transform of posterior distribution

Return type:

Transform

get_posterior(params)[source]

Returns a multivariate Normal posterior distribution.

sample_posterior(rng_key, params, *, sample_shape=())[source]

Generate samples from the approximate posterior over the latent sites in the model.

Parameters:
  • rng_key (jax.random.PRNGKey) – random key to be used draw samples.

  • params (dict) – Current parameters of model and autoguide. The parameters can be obtained using get_params() method from SVI.

  • sample_shape (tuple) – sample shape of each latent site, defaults to ().

Returns:

a dict containing samples drawn the this guide.

Return type:

dict

median(params)[source]

Returns the posterior median value of each latent variable.

Parameters:

params (dict) – A dict containing parameter values. The parameters can be obtained using get_params() method from SVI.

Returns:

A dict mapping sample site name to median value.

Return type:

dict

quantiles(params, quantiles)[source]

Returns posterior quantiles each latent variable. Example:

print(guide.quantiles(params, [0.05, 0.5, 0.95]))
Parameters:
  • params (dict) – A dict containing parameter values. The parameters can be obtained using get_params() method from SVI.

  • quantiles (list) – A list of requested quantiles between 0 and 1.

Returns:

A dict mapping sample site name to an array of quantile values.

Return type:

dict

AutoLowRankMultivariateNormal

class AutoLowRankMultivariateNormal(model, *, prefix='auto', init_loc_fn=<function init_to_uniform>, init_scale=0.1, rank=None)[source]

Bases: AutoContinuous

This implementation of AutoContinuous uses a LowRankMultivariateNormal distribution to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

Usage:

guide = AutoLowRankMultivariateNormal(model, rank=2, ...)
svi = SVI(model, guide, ...)
scale_constraint = SoftplusPositive(lower_bound=0.0)
get_base_dist()[source]

Returns the base distribution of the posterior when reparameterized as a TransformedDistribution. This should not depend on the model’s *args, **kwargs.

get_transform(params)[source]

Returns the transformation learned by the guide to generate samples from the unconstrained (approximate) posterior.

Parameters:

params (dict) – Current parameters of model and autoguide. The parameters can be obtained using get_params() method from SVI.

Returns:

the transform of posterior distribution

Return type:

Transform

get_posterior(params)[source]

Returns a lowrank multivariate Normal posterior distribution.

median(params)[source]

Returns the posterior median value of each latent variable.

Parameters:

params (dict) – A dict containing parameter values. The parameters can be obtained using get_params() method from SVI.

Returns:

A dict mapping sample site name to median value.

Return type:

dict

quantiles(params, quantiles)[source]

Returns posterior quantiles each latent variable. Example:

print(guide.quantiles(params, [0.05, 0.5, 0.95]))
Parameters:
  • params (dict) – A dict containing parameter values. The parameters can be obtained using get_params() method from SVI.

  • quantiles (list) – A list of requested quantiles between 0 and 1.

Returns:

A dict mapping sample site name to an array of quantile values.

Return type:

dict

AutoNormal

class AutoNormal(model, *, prefix='auto', init_loc_fn=<function init_to_uniform>, init_scale=0.1, create_plates=None)[source]

Bases: AutoGuide

This implementation of AutoGuide uses Normal distributions to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

This should be equivalent to AutoDiagonalNormal , but with more convenient site names and with better support for mean field ELBO.

Usage:

guide = AutoNormal(model)
svi = SVI(model, guide, ...)
Parameters:
  • model (callable) – A NumPyro model.

  • prefix (str) – a prefix that will be prefixed to all param internal sites.

  • init_loc_fn (callable) – A per-site initialization function. See Initialization Strategies section for available functions.

  • init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.

  • create_plates (callable) – An optional function inputing the same *args,**kwargs as model() and returning a numpyro.plate or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling.

scale_constraint = SoftplusPositive(lower_bound=0.0)
sample_posterior(rng_key, params, *, sample_shape=())[source]

Generate samples from the approximate posterior over the latent sites in the model.

Parameters:
  • rng_key (jax.random.PRNGKey) – random key to be used draw samples.

  • params (dict) – Current parameters of model and autoguide. The parameters can be obtained using get_params() method from SVI.

  • sample_shape (tuple) – sample shape of each latent site, defaults to ().

Returns:

a dict containing samples drawn the this guide.

Return type:

dict

median(params)[source]

Returns the posterior median value of each latent variable.

Parameters:

params (dict) – A dict containing parameter values. The parameters can be obtained using get_params() method from SVI.

Returns:

A dict mapping sample site name to median value.

Return type:

dict

quantiles(params, quantiles)[source]

Returns posterior quantiles each latent variable. Example:

print(guide.quantiles(params, [0.05, 0.5, 0.95]))
Parameters:
  • params (dict) – A dict containing parameter values. The parameters can be obtained using get_params() method from SVI.

  • quantiles (list) – A list of requested quantiles between 0 and 1.

Returns:

A dict mapping sample site name to an array of quantile values.

Return type:

dict

AutoDelta

class AutoDelta(model, *, prefix='auto', init_loc_fn=<function init_to_median>, create_plates=None)[source]

Bases: AutoGuide

This implementation of AutoGuide uses Delta distributions to construct a MAP guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

Note

This class does MAP inference in constrained space.

Usage:

guide = AutoDelta(model)
svi = SVI(model, guide, ...)
Parameters:
  • model (callable) – A NumPyro model.

  • prefix (str) – a prefix that will be prefixed to all param internal sites.

  • init_loc_fn (callable) – A per-site initialization function. See Initialization Strategies section for available functions.

  • create_plates (callable) – An optional function inputing the same *args,**kwargs as model() and returning a numpyro.plate or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling.

sample_posterior(rng_key, params, *args, sample_shape=(), **kwargs)[source]

Generate samples from the approximate posterior over the latent sites in the model.

Parameters:
  • rng_key (jax.random.PRNGKey) – random key to be used draw samples.

  • params (dict) – Current parameters of model and autoguide. The parameters can be obtained using get_params() method from SVI.

  • sample_shape (tuple) – sample shape of each latent site, defaults to ().

Returns:

a dict containing samples drawn the this guide.

Return type:

dict

median(params)[source]

Returns the posterior median value of each latent variable.

Parameters:

params (dict) – A dict containing parameter values. The parameters can be obtained using get_params() method from SVI.

Returns:

A dict mapping sample site name to median value.

Return type:

dict

AutoDAIS

class AutoDAIS(model, *, K=4, base_dist='diagonal', eta_init=0.01, eta_max=0.1, gamma_init=0.9, prefix='auto', init_loc_fn=<function init_to_uniform>, init_scale=0.1)[source]

Bases: AutoContinuous

This implementation of AutoDAIS uses Differentiable Annealed Importance Sampling (DAIS) [1, 2] to construct a guide over the entire latent space. Samples from the variational distribution (i.e. guide) are generated using a combination of (uncorrected) Hamiltonian Monte Carlo and Annealed Importance Sampling. The same algorithm is called Uncorrected Hamiltonian Annealing in [1].

Note that AutoDAIS cannot be used in conjunction with data subsampling.

Reference:

  1. MCMC Variational Inference via Uncorrected Hamiltonian Annealing, Tomas Geffner, Justin Domke

  2. Differentiable Annealed Importance Sampling and the Perils of Gradient Noise, Guodong Zhang, Kyle Hsu, Jianing Li, Chelsea Finn, Roger Grosse

Usage:

guide = AutoDAIS(model)
svi = SVI(model, guide, ...)
Parameters:
  • model (callable) – A NumPyro model.

  • prefix (str) – A prefix that will be prefixed to all param internal sites.

  • K (int) – A positive integer that controls the number of HMC steps used. Defaults to 4.

  • base_dist (str) – Controls whether the base Normal variational distribution is parameterized by a “diagonal” covariance matrix or a full-rank covariance matrix parameterized by a lower-triangular “cholesky” factor. Defaults to “diagonal”.

  • eta_init (float) – The initial value of the step size used in HMC. Defaults to 0.01.

  • eta_max (float) – The maximum value of the learnable step size used in HMC. Defaults to 0.1.

  • gamma_init (float) – The initial value of the learnable damping factor used during partial momentum refreshments in HMC. Defaults to 0.9.

  • init_loc_fn (callable) – A per-site initialization function. See Initialization Strategies section for available functions.

  • init_scale (float) – Initial scale for the standard deviation of the base variational distribution for each (unconstrained transformed) latent variable. Defaults to 0.1.

sample_posterior(rng_key, params, *, sample_shape=())[source]

Generate samples from the approximate posterior over the latent sites in the model.

Parameters:
  • rng_key (jax.random.PRNGKey) – random key to be used draw samples.

  • params (dict) – Current parameters of model and autoguide. The parameters can be obtained using get_params() method from SVI.

  • sample_shape (tuple) – sample shape of each latent site, defaults to ().

Returns:

a dict containing samples drawn the this guide.

Return type:

dict

AutoSemiDAIS

class AutoSemiDAIS(model, local_model, global_guide=None, local_guide=None, *, prefix='auto', K=4, eta_init=0.01, eta_max=0.1, gamma_init=0.9, init_scale=0.1, subsample_plate=None, use_global_dais_params=False)[source]

Bases: AutoGuide

This implementation of AutoSemiDAIS [1] combines a parametric variational distribution over global latent variables with Differentiable Annealed Importance Sampling (DAIS) [2, 3] to infer local latent variables. Unlike AutoDAIS this guide can be used in conjunction with data subsampling. Note that the resulting ELBO can be understood as a particular realization of a ‘locally enhanced bound’ as described in reference [4].

References:

  1. Surrogate Likelihoods for Variational Annealed Importance Sampling, Martin Jankowiak, Du Phan

  2. MCMC Variational Inference via Uncorrected Hamiltonian Annealing, Tomas Geffner, Justin Domke

  3. Differentiable Annealed Importance Sampling and the Perils of Gradient Noise, Guodong Zhang, Kyle Hsu, Jianing Li, Chelsea Finn, Roger Grosse

  4. Variational Inference with Locally Enhanced Bounds for Hierarchical Models, Tomas Geffner, Justin Domke

Usage:

def global_model():
    return numpyro.sample("theta", dist.Normal(0, 1))

def local_model(theta):
    with numpyro.plate("data", 8, subsample_size=2):
        tau = numpyro.sample("tau", dist.Gamma(5.0, 5.0))
        numpyro.sample("obs", dist.Normal(0.0, tau), obs=jnp.ones(2))

model = lambda: local_model(global_model())
global_guide = AutoNormal(global_model)
guide = AutoSemiDAIS(model, local_model, global_guide, K=4)
svi = SVI(model, guide, ...)

# sample posterior for particular data subset {3, 7}
with handlers.substitute(data={"data": jnp.array([3, 7])}):
    samples = guide.sample_posterior(random.PRNGKey(1), params)
Parameters:
  • model (callable) – A NumPyro model with global and local latent variables.

  • local_model (callable) – The portion of model that includes the local latent variables only. The signature of local_model should be the return type of the global model with global latent variables only.

  • global_guide (callable) – A guide for the global latent variables, e.g. an autoguide. The return type should be a dictionary of latent sample sites names and corresponding samples. If there is no global variable in the model, we can set this to None.

  • local_guide (callable) – An optional guide for specifying the DAIS base distribution for local latent variables.

  • prefix (str) – A prefix that will be prefixed to all internal sites.

  • K (int) – A positive integer that controls the number of HMC steps used. Defaults to 4.

  • eta_init (float) – The initial value of the step size used in HMC. Defaults to 0.01.

  • eta_max (float) – The maximum value of the learnable step size used in HMC. Defaults to 0.1.

  • gamma_init (float) – The initial value of the learnable damping factor used during partial momentum refreshments in HMC. Defaults to 0.9.

  • init_scale (float) – Initial scale for the standard deviation of the variational distribution for each (unconstrained transformed) local latent variable. Defaults to 0.1.

  • subsample_plate (str) – Optional name of the subsample plate site. This is required when the model has a subsample plate without subsample_size specified or the model has a subsample plate with subsample_size equal to the plate size.

  • use_global_dais_params (bool) – Whether parameters controlling DAIS dynamic (HMC step size, HMC mass matrix, etc.) should be global (i.e. common to all data points in the subsample plate) or local (i.e. each data point in the subsample plate has individual parameters). Note that we do not use global parameters for the base distribution.

sample_posterior(rng_key, params, *args, sample_shape=(), **kwargs)[source]

Generate samples from the approximate posterior over the latent sites in the model.

Parameters:
  • rng_key (jax.random.PRNGKey) – random key to be used draw samples.

  • params (dict) – Current parameters of model and autoguide. The parameters can be obtained using get_params() method from SVI.

  • sample_shape (tuple) – sample shape of each latent site, defaults to ().

Returns:

a dict containing samples drawn the this guide.

Return type:

dict

AutoSurrogateLikelihoodDAIS

class AutoSurrogateLikelihoodDAIS(model, surrogate_model, *, K=4, eta_init=0.01, eta_max=0.1, gamma_init=0.9, prefix='auto', base_dist='diagonal', init_loc_fn=<function init_to_uniform>, init_scale=0.1)[source]

Bases: AutoDAIS

This implementation of AutoSurrogateLikelihoodDAIS provides a mini-batchable family of variational distributions as described in [1]. It combines a user-provided surrogate likelihood with Differentiable Annealed Importance Sampling (DAIS) [2, 3]. It is not applicable to models with local latent variables (see AutoSemiDAIS), but unlike AutoDAIS, it can be used in conjunction with data subsampling.

Reference:

  1. Surrogate likelihoods for variational annealed importance sampling, Martin Jankowiak, Du Phan

  2. MCMC Variational Inference via Uncorrected Hamiltonian Annealing, Tomas Geffner, Justin Domke

  3. Differentiable Annealed Importance Sampling and the Perils of Gradient Noise, Guodong Zhang, Kyle Hsu, Jianing Li, Chelsea Finn, Roger Grosse

Usage:

# logistic regression model for data {X, Y}
def model(X, Y):
    theta = numpyro.sample(
        "theta", dist.Normal(jnp.zeros(2), jnp.ones(2)).to_event(1)
    )
    with numpyro.plate("N", 100, subsample_size=10):
        X_batch = numpyro.subsample(X, event_dim=1)
        Y_batch = numpyro.subsample(Y, event_dim=0)
        numpyro.sample("obs", dist.Bernoulli(logits=theta @ X_batch.T), obs=Y_batch)

# surrogate model defined by prior and surrogate likelihood.
# a convenient choice for specifying the latter is to compute the likelihood on
# a randomly chosen data subset (here {X_surr, Y_surr} of size 20) and then use
# handlers.scale to scale the log likelihood by a vector of learnable weights.
def surrogate_model(X_surr, Y_surr):
    theta = numpyro.sample(
        "theta", dist.Normal(jnp.zeros(2), jnp.ones(2)).to_event(1)
    )
    omegas = numpyro.param(
        "omegas", 5.0 * jnp.ones(20), constraint=dist.constraints.positive
    )
    with numpyro.plate("N", 20), numpyro.handlers.scale(scale=omegas):
        numpyro.sample("obs", dist.Bernoulli(logits=theta @ X_surr.T), obs=Y_surr)

guide = AutoSurrogateLikelihoodDAIS(model, surrogate_model)
svi = SVI(model, guide, ...)
Parameters:
  • model (callable) – A NumPyro model.

  • surrogate_model (callable) – A NumPyro model that is used as a surrogate model for guiding the HMC dynamics that define the variational distribution. In particular surrogate_model should contain the same prior as model but should contain a cheap-to-evaluate parametric ansatz for the likelihood. A simple ansatz for the latter involves computing the likelihood for a fixed subset of the data and scaling the resulting log likelihood by a learnable vector of positive weights. See the usage example above.

  • prefix (str) – A prefix that will be prefixed to all param internal sites.

  • K (int) – A positive integer that controls the number of HMC steps used. Defaults to 4.

  • base_dist (str) – Controls whether the base Normal variational distribution is parameterized by a “diagonal” covariance matrix or a full-rank covariance matrix parameterized by a lower-triangular “cholesky” factor. Defaults to “diagonal”.

  • eta_init (float) – The initial value of the step size used in HMC. Defaults to 0.01.

  • eta_max (float) – The maximum value of the learnable step size used in HMC. Defaults to 0.1.

  • gamma_init (float) – The initial value of the learnable damping factor used during partial momentum refreshments in HMC. Defaults to 0.9.

  • init_loc_fn (callable) – A per-site initialization function. See Initialization Strategies section for available functions.

  • init_scale (float) – Initial scale for the standard deviation of the base variational distribution for each (unconstrained transformed) latent variable. Defaults to 0.1.