Source code for numpyro.infer.autoguide

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

# Adapted from pyro.infer.autoguide
from abc import ABC, abstractmethod
from contextlib import ExitStack
from functools import partial
import warnings

import numpy as np

import jax
from jax import grad, hessian, lax, random
from jax.tree_util import tree_map

from numpyro.util import _versiontuple, find_stack_level

if _versiontuple(jax.__version__) >= (0, 2, 25):
    from jax.example_libraries import stax
else:
    from jax.experimental import stax

import jax.numpy as jnp

import numpyro
from numpyro import handlers
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.distributions.flows import (
    BlockNeuralAutoregressiveTransform,
    InverseAutoregressiveTransform,
)
from numpyro.distributions.transforms import (
    AffineTransform,
    ComposeTransform,
    IndependentTransform,
    LowerCholeskyAffine,
    PermuteTransform,
    UnpackTransform,
    biject_to,
)
from numpyro.distributions.util import (
    cholesky_of_inverse,
    periodic_repeat,
    sum_rightmost,
)
from numpyro.infer.elbo import Trace_ELBO
from numpyro.infer.initialization import init_to_median, init_to_uniform
from numpyro.infer.util import helpful_support_errors, initialize_model
from numpyro.nn.auto_reg_nn import AutoregressiveNN
from numpyro.nn.block_neural_arn import BlockNeuralAutoregressiveNN
from numpyro.util import not_jax_tracer

__all__ = [
    "AutoContinuous",
    "AutoGuide",
    "AutoDAIS",
    "AutoDiagonalNormal",
    "AutoLaplaceApproximation",
    "AutoLowRankMultivariateNormal",
    "AutoNormal",
    "AutoMultivariateNormal",
    "AutoBNAFNormal",
    "AutoIAFNormal",
    "AutoDelta",
    "AutoSemiDAIS",
    "AutoSurrogateLikelihoodDAIS",
]


[docs]class AutoGuide(ABC): """ Base class for automatic guides. Derived classes must implement the :meth:`__call__` method. :param callable model: a pyro model :param str prefix: a prefix that will be prefixed to all param internal sites :param callable init_loc_fn: A per-site initialization function. See :ref:`init_strategy` section for available functions. :param callable create_plates: An optional function inputing the same ``*args,**kwargs`` as ``model()`` and returning a :class:`numpyro.plate` or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling. """ def __init__( self, model, *, prefix="auto", init_loc_fn=init_to_uniform, create_plates=None ): self.model = model self.prefix = prefix self.init_loc_fn = init_loc_fn self.create_plates = create_plates self.prototype_trace = None self._prototype_frames = {} self._prototype_frame_full_sizes = {} def _create_plates(self, *args, **kwargs): if self.create_plates is None: self.plates = {} else: plates = self.create_plates(*args, **kwargs) if isinstance(plates, numpyro.plate): plates = [plates] assert all( isinstance(p, numpyro.plate) for p in plates ), "create_plates() returned a non-plate" self.plates = {p.name: p for p in plates} for name, frame in sorted(self._prototype_frames.items()): if name not in self.plates: full_size = self._prototype_frame_full_sizes[name] self.plates[name] = numpyro.plate( name, full_size, dim=frame.dim, subsample_size=frame.size ) return self.plates def __getstate__(self): state = self.__dict__.copy() state.pop("plates", None) return state @abstractmethod def __call__(self, *args, **kwargs): """ A guide with the same ``*args, **kwargs`` as the base ``model``. :return: A dict mapping sample site name to sampled value. :rtype: dict """ raise NotImplementedError
[docs] @abstractmethod def sample_posterior(self, rng_key, params, sample_shape=()): """ Generate samples from the approximate posterior over the latent sites in the model. :param jax.random.PRNGKey rng_key: random key to be used draw samples. :param dict params: Current parameters of model and autoguide. The parameters can be obtained using :meth:`~numpyro.infer.svi.SVI.get_params` method from :class:`~numpyro.infer.svi.SVI`. :param tuple sample_shape: sample shape of each latent site, defaults to (). :return: a dict containing samples drawn the this guide. :rtype: dict """ raise NotImplementedError
def _setup_prototype(self, *args, **kwargs): rng_key = numpyro.prng_key() with handlers.block(): ( init_params, self._potential_fn_gen, postprocess_fn_gen, self.prototype_trace, ) = initialize_model( rng_key, self.model, init_strategy=self.init_loc_fn, dynamic_args=True, model_args=args, model_kwargs=kwargs, ) self._potential_fn = self._potential_fn_gen(*args, **kwargs) postprocess_fn = postprocess_fn_gen(*args, **kwargs) # We apply a fixed seed just in case postprocess_fn requires # a random key to generate subsample indices. It does not matter # because we only collect deterministic sites. self._postprocess_fn = handlers.seed(postprocess_fn, rng_seed=0) self._init_locs = init_params[0] self._prototype_frames = {} self._prototype_plate_sizes = {} for name, site in self.prototype_trace.items(): if site["type"] == "sample": if not site["is_observed"] and site["fn"].support.is_discrete: # raise support errors early for discrete sites with helpful_support_errors(site): biject_to(site["fn"].support) for frame in site["cond_indep_stack"]: if frame.name in self._prototype_frames: assert ( frame == self._prototype_frames[frame.name] ), f"The plate {frame.name} has inconsistent dim or size. Please check your model again." else: self._prototype_frames[frame.name] = frame elif site["type"] == "plate": self._prototype_frame_full_sizes[name] = site["args"][0]
[docs] def median(self, params): """ Returns the posterior median value of each latent variable. :param dict params: A dict containing parameter values. The parameters can be obtained using :meth:`~numpyro.infer.svi.SVI.get_params` method from :class:`~numpyro.infer.svi.SVI`. :return: A dict mapping sample site name to median value. :rtype: dict """ raise NotImplementedError
[docs] def quantiles(self, params, quantiles): """ Returns posterior quantiles each latent variable. Example:: print(guide.quantiles(params, [0.05, 0.5, 0.95])) :param dict params: A dict containing parameter values. The parameters can be obtained using :meth:`~numpyro.infer.svi.SVI.get_params` method from :class:`~numpyro.infer.svi.SVI`. :param list quantiles: A list of requested quantiles between 0 and 1. :return: A dict mapping sample site name to an array of quantile values. :rtype: dict """ raise NotImplementedError
[docs]class AutoNormal(AutoGuide): """ This implementation of :class:`AutoGuide` uses Normal distributions to construct a guide over the entire latent space. The guide does not depend on the model's ``*args, **kwargs``. This should be equivalent to :class:`AutoDiagonalNormal` , but with more convenient site names and with better support for mean field ELBO. Usage:: guide = AutoNormal(model) svi = SVI(model, guide, ...) :param callable model: A NumPyro model. :param str prefix: a prefix that will be prefixed to all param internal sites. :param callable init_loc_fn: A per-site initialization function. See :ref:`init_strategy` section for available functions. :param float init_scale: Initial scale for the standard deviation of each (unconstrained transformed) latent variable. :param callable create_plates: An optional function inputing the same ``*args,**kwargs`` as ``model()`` and returning a :class:`numpyro.plate` or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling. """ scale_constraint = constraints.softplus_positive def __init__( self, model, *, prefix="auto", init_loc_fn=init_to_uniform, init_scale=0.1, create_plates=None, ): self._init_scale = init_scale self._event_dims = {} super().__init__( model, prefix=prefix, init_loc_fn=init_loc_fn, create_plates=create_plates ) def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) for name, site in self.prototype_trace.items(): if site["type"] != "sample" or site["is_observed"]: continue event_dim = ( site["fn"].event_dim + jnp.ndim(self._init_locs[name]) - jnp.ndim(site["value"]) ) self._event_dims[name] = event_dim # If subsampling, repeat init_value to full size. for frame in site["cond_indep_stack"]: full_size = self._prototype_frame_full_sizes[frame.name] if full_size != frame.size: dim = frame.dim - event_dim self._init_locs[name] = periodic_repeat( self._init_locs[name], full_size, dim ) def __call__(self, *args, **kwargs): if self.prototype_trace is None: # run model to inspect the model structure self._setup_prototype(*args, **kwargs) plates = self._create_plates(*args, **kwargs) result = {} for name, site in self.prototype_trace.items(): if site["type"] != "sample" or site["is_observed"]: continue event_dim = self._event_dims[name] init_loc = self._init_locs[name] with ExitStack() as stack: for frame in site["cond_indep_stack"]: stack.enter_context(plates[frame.name]) site_loc = numpyro.param( "{}_{}_loc".format(name, self.prefix), init_loc, event_dim=event_dim ) site_scale = numpyro.param( "{}_{}_scale".format(name, self.prefix), jnp.full(jnp.shape(init_loc), self._init_scale), constraint=self.scale_constraint, event_dim=event_dim, ) site_fn = dist.Normal(site_loc, site_scale).to_event(event_dim) if site["fn"].support is constraints.real or ( isinstance(site["fn"].support, constraints.independent) and site["fn"].support.base_constraint is constraints.real ): result[name] = numpyro.sample(name, site_fn) else: with helpful_support_errors(site): transform = biject_to(site["fn"].support) guide_dist = dist.TransformedDistribution(site_fn, transform) result[name] = numpyro.sample(name, guide_dist) return result def _constrain(self, latent_samples): name = list(latent_samples)[0] sample_shape = jnp.shape(latent_samples[name])[ : jnp.ndim(latent_samples[name]) - jnp.ndim(self._init_locs[name]) ] if sample_shape: flatten_samples = tree_map( lambda x: jnp.reshape(x, (-1,) + jnp.shape(x)[len(sample_shape) :]), latent_samples, ) contrained_samples = lax.map(self._postprocess_fn, flatten_samples) return tree_map( lambda x: jnp.reshape(x, sample_shape + jnp.shape(x)[1:]), contrained_samples, ) else: return self._postprocess_fn(latent_samples)
[docs] def sample_posterior(self, rng_key, params, sample_shape=()): locs = {k: params["{}_{}_loc".format(k, self.prefix)] for k in self._init_locs} scales = {k: params["{}_{}_scale".format(k, self.prefix)] for k in locs} with handlers.seed(rng_seed=rng_key): latent_samples = {} for k in locs: latent_samples[k] = numpyro.sample( k, dist.Normal(locs[k], scales[k]).expand_by(sample_shape) ) return self._constrain(latent_samples)
[docs] def median(self, params): locs = { k: params["{}_{}_loc".format(k, self.prefix)] for k, v in self._init_locs.items() } return self._constrain(locs)
[docs] def quantiles(self, params, quantiles): quantiles = jnp.array(quantiles) locs = {k: params["{}_{}_loc".format(k, self.prefix)] for k in self._init_locs} scales = {k: params["{}_{}_scale".format(k, self.prefix)] for k in locs} latent = { k: dist.Normal(locs[k], scales[k]).icdf( quantiles.reshape((-1,) + (1,) * jnp.ndim(locs[k])) ) for k in locs } return self._constrain(latent)
[docs]class AutoDelta(AutoGuide): """ This implementation of :class:`AutoGuide` uses Delta distributions to construct a MAP guide over the entire latent space. The guide does not depend on the model's ``*args, **kwargs``. .. note:: This class does MAP inference in constrained space. Usage:: guide = AutoDelta(model) svi = SVI(model, guide, ...) :param callable model: A NumPyro model. :param str prefix: a prefix that will be prefixed to all param internal sites. :param callable init_loc_fn: A per-site initialization function. See :ref:`init_strategy` section for available functions. :param callable create_plates: An optional function inputing the same ``*args,**kwargs`` as ``model()`` and returning a :class:`numpyro.plate` or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling. """ def __init__( self, model, *, prefix="auto", init_loc_fn=init_to_median, create_plates=None ): self._event_dims = {} super().__init__( model, prefix=prefix, init_loc_fn=init_loc_fn, create_plates=create_plates ) def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) with numpyro.handlers.block(): self._init_locs = { k: v for k, v in self._postprocess_fn(self._init_locs).items() if k in self._init_locs } for name, site in self.prototype_trace.items(): if site["type"] != "sample" or site["is_observed"]: continue event_dim = site["fn"].event_dim self._event_dims[name] = event_dim # If subsampling, repeat init_value to full size. for frame in site["cond_indep_stack"]: full_size = self._prototype_frame_full_sizes[frame.name] if full_size != frame.size: dim = frame.dim - event_dim self._init_locs[name] = periodic_repeat( self._init_locs[name], full_size, dim ) def __call__(self, *args, **kwargs): if self.prototype_trace is None: # run model to inspect the model structure self._setup_prototype(*args, **kwargs) plates = self._create_plates(*args, **kwargs) result = {} for name, site in self.prototype_trace.items(): if site["type"] != "sample" or site["is_observed"]: continue event_dim = self._event_dims[name] init_loc = self._init_locs[name] with ExitStack() as stack: for frame in site["cond_indep_stack"]: stack.enter_context(plates[frame.name]) site_loc = numpyro.param( "{}_{}_loc".format(name, self.prefix), init_loc, constraint=site["fn"].support, event_dim=event_dim, ) site_fn = dist.Delta(site_loc).to_event(event_dim) result[name] = numpyro.sample(name, site_fn) return result
[docs] def sample_posterior(self, rng_key, params, sample_shape=()): locs = {k: params["{}_{}_loc".format(k, self.prefix)] for k in self._init_locs} latent_samples = { k: jnp.broadcast_to(v, sample_shape + jnp.shape(v)) for k, v in locs.items() } return latent_samples
[docs] def median(self, params): locs = {k: params["{}_{}_loc".format(k, self.prefix)] for k in self._init_locs} return locs
def _unravel_dict(x_flat, shape_dict): """Return `x` from the flatten version `x_flat`. Shape information of each item in `x` is defined in `shape_dict`. """ assert jnp.ndim(x_flat) == 1 assert isinstance(shape_dict, dict) x = {} curr_pos = next_pos = 0 for name, shape in shape_dict.items(): next_pos = curr_pos + int(np.prod(shape)) x[name] = x_flat[curr_pos:next_pos].reshape(shape) curr_pos = next_pos assert next_pos == x_flat.shape[0] return x def _ravel_dict(x): """Return the flatten version of `x` and shapes of each item in `x`.""" assert isinstance(x, dict) shape_dict = {} x_flat = [] for name, value in x.items(): shape_dict[name] = jnp.shape(value) x_flat.append(jnp.reshape(value, -1)) x_flat = jnp.concatenate(x_flat) if x_flat else jnp.zeros((0,)) return x_flat, shape_dict
[docs]class AutoContinuous(AutoGuide): """ Base class for implementations of continuous-valued Automatic Differentiation Variational Inference [1]. Each derived class implements its own :meth:`_get_posterior` method. Assumes model structure and latent dimension are fixed, and all latent variables are continuous. **Reference:** 1. *Automatic Differentiation Variational Inference*, Alp Kucukelbir, Dustin Tran, Rajesh Ranganath, Andrew Gelman, David M. Blei :param callable model: A NumPyro model. :param str prefix: a prefix that will be prefixed to all param internal sites. :param callable init_loc_fn: A per-site initialization function. See :ref:`init_strategy` section for available functions. """ def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) self._init_latent, shape_dict = _ravel_dict(self._init_locs) unpack_latent = partial(_unravel_dict, shape_dict=shape_dict) # this is to match the behavior of Pyro, where we can apply # unpack_latent for a batch of samples self._unpack_latent = UnpackTransform(unpack_latent) self.latent_dim = jnp.size(self._init_latent) if self.latent_dim == 0: raise RuntimeError( "{} found no latent variables; Use an empty guide instead".format( type(self).__name__ ) ) for site in self.prototype_trace.values(): if site["type"] == "sample" and not site["is_observed"]: for frame in site["cond_indep_stack"]: if frame.size != self._prototype_frame_full_sizes[frame.name]: raise ValueError( "AutoContinuous guide does not support" " local latent variables." ) @abstractmethod def _get_posterior(self): raise NotImplementedError def _sample_latent(self, *args, **kwargs): sample_shape = kwargs.pop("sample_shape", ()) posterior = self._get_posterior() return numpyro.sample( "_{}_latent".format(self.prefix), posterior.expand_by(sample_shape), infer={"is_auxiliary": True}, ) def __call__(self, *args, **kwargs): if self.prototype_trace is None: # run model to inspect the model structure self._setup_prototype(*args, **kwargs) latent = self._sample_latent(*args, **kwargs) # unpack continuous latent samples result = {} for name, unconstrained_value in self._unpack_latent(latent).items(): site = self.prototype_trace[name] with helpful_support_errors(site): transform = biject_to(site["fn"].support) value = transform(unconstrained_value) event_ndim = site["fn"].event_dim if numpyro.get_mask() is False: log_density = 0.0 else: log_density = -transform.log_abs_det_jacobian( unconstrained_value, value ) log_density = sum_rightmost( log_density, jnp.ndim(log_density) - jnp.ndim(value) + event_ndim ) delta_dist = dist.Delta( value, log_density=log_density, event_dim=event_ndim ) result[name] = numpyro.sample(name, delta_dist) return result def _unpack_and_constrain(self, latent_sample, params): def unpack_single_latent(latent): unpacked_samples = self._unpack_latent(latent) # XXX: we need to add param here to be able to replay model unpacked_samples.update( { k: v for k, v in params.items() if k in self.prototype_trace and self.prototype_trace[k]["type"] == "param" } ) samples = self._postprocess_fn(unpacked_samples) # filter out param sites return { k: v for k, v in samples.items() if k in self.prototype_trace and self.prototype_trace[k]["type"] != "param" } sample_shape = jnp.shape(latent_sample)[:-1] if sample_shape: latent_sample = jnp.reshape( latent_sample, (-1, jnp.shape(latent_sample)[-1]) ) unpacked_samples = lax.map(unpack_single_latent, latent_sample) return tree_map( lambda x: jnp.reshape(x, sample_shape + jnp.shape(x)[1:]), unpacked_samples, ) else: return unpack_single_latent(latent_sample)
[docs] def get_base_dist(self): """ Returns the base distribution of the posterior when reparameterized as a :class:`~numpyro.distributions.distribution.TransformedDistribution`. This should not depend on the model's `*args, **kwargs`. """ raise NotImplementedError
[docs] def get_transform(self, params): """ Returns the transformation learned by the guide to generate samples from the unconstrained (approximate) posterior. :param dict params: Current parameters of model and autoguide. The parameters can be obtained using :meth:`~numpyro.infer.svi.SVI.get_params` method from :class:`~numpyro.infer.svi.SVI`. :return: the transform of posterior distribution :rtype: :class:`~numpyro.distributions.transforms.Transform` """ posterior = handlers.substitute(self._get_posterior, params)() assert isinstance( posterior, dist.TransformedDistribution ), "posterior is not a transformed distribution" if len(posterior.transforms) > 0: return ComposeTransform(posterior.transforms) else: return posterior.transforms[0]
[docs] def get_posterior(self, params): """ Returns the posterior distribution. :param dict params: Current parameters of model and autoguide. The parameters can be obtained using :meth:`~numpyro.infer.svi.SVI.get_params` method from :class:`~numpyro.infer.svi.SVI`. """ base_dist = self.get_base_dist() transform = self.get_transform(params) return dist.TransformedDistribution(base_dist, transform)
[docs] def sample_posterior(self, rng_key, params, sample_shape=()): latent_sample = handlers.substitute( handlers.seed(self._sample_latent, rng_key), params )(sample_shape=sample_shape) return self._unpack_and_constrain(latent_sample, params)
[docs]class AutoDAIS(AutoContinuous): """ This implementation of :class:`AutoDAIS` uses Differentiable Annealed Importance Sampling (DAIS) [1, 2] to construct a guide over the entire latent space. Samples from the variational distribution (i.e. guide) are generated using a combination of (uncorrected) Hamiltonian Monte Carlo and Annealed Importance Sampling. The same algorithm is called Uncorrected Hamiltonian Annealing in [1]. Note that AutoDAIS cannot be used in conjunction with data subsampling. **Reference:** 1. *MCMC Variational Inference via Uncorrected Hamiltonian Annealing*, Tomas Geffner, Justin Domke 2. *Differentiable Annealed Importance Sampling and the Perils of Gradient Noise*, Guodong Zhang, Kyle Hsu, Jianing Li, Chelsea Finn, Roger Grosse Usage:: guide = AutoDAIS(model) svi = SVI(model, guide, ...) :param callable model: A NumPyro model. :param str prefix: A prefix that will be prefixed to all param internal sites. :param int K: A positive integer that controls the number of HMC steps used. Defaults to 4. :param str base_dist: Controls whether the base Normal variational distribution is parameterized by a "diagonal" covariance matrix or a full-rank covariance matrix parameterized by a lower-triangular "cholesky" factor. Defaults to "diagonal". :param float eta_init: The initial value of the step size used in HMC. Defaults to 0.01. :param float eta_max: The maximum value of the learnable step size used in HMC. Defaults to 0.1. :param float gamma_init: The initial value of the learnable damping factor used during partial momentum refreshments in HMC. Defaults to 0.9. :param callable init_loc_fn: A per-site initialization function. See :ref:`init_strategy` section for available functions. :param float init_scale: Initial scale for the standard deviation of the base variational distribution for each (unconstrained transformed) latent variable. Defaults to 0.1. """ def __init__( self, model, *, K=4, base_dist="diagonal", eta_init=0.01, eta_max=0.1, gamma_init=0.9, prefix="auto", init_loc_fn=init_to_uniform, init_scale=0.1, ): if K < 1: raise ValueError("K must satisfy K >= 1 (got K = {})".format(K)) if base_dist not in ["diagonal", "cholesky"]: raise ValueError('base_dist must be one of "diagonal" or "cholesky".') if eta_init <= 0.0 or eta_init >= eta_max: raise ValueError( "eta_init must be positive and satisfy eta_init < eta_max." ) if eta_max <= 0.0: raise ValueError("eta_max must be positive.") if gamma_init <= 0.0 or gamma_init >= 1.0: raise ValueError("gamma_init must be in the open interval (0, 1).") if init_scale <= 0.0: raise ValueError("init_scale must be positive.") self.eta_init = eta_init self.eta_max = eta_max self.gamma_init = gamma_init self.K = K self.base_dist = base_dist self._init_scale = init_scale super().__init__(model, prefix=prefix, init_loc_fn=init_loc_fn) def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) for name, site in self.prototype_trace.items(): if ( site["type"] == "plate" and isinstance(site["args"][1], int) and site["args"][0] > site["args"][1] ): raise NotImplementedError( "AutoDAIS cannot be used in conjunction with data subsampling." ) def _get_posterior(self): raise NotImplementedError def _sample_latent(self, *args, **kwargs): def log_density(x): x_unpack = self._unpack_latent(x) with numpyro.handlers.block(): return -self._potential_fn(x_unpack) eta0 = numpyro.param( "{}_eta0".format(self.prefix), self.eta_init, constraint=constraints.interval(0, self.eta_max), ) eta_coeff = numpyro.param("{}_eta_coeff".format(self.prefix), 0.00) gamma = numpyro.param( "{}_gamma".format(self.prefix), self.gamma_init, constraint=constraints.interval(0, 1), ) betas = numpyro.param( "{}_beta_increments".format(self.prefix), jnp.ones(self.K), constraint=constraints.positive, ) betas = jnp.cumsum(betas) betas = betas / betas[-1] # K-dimensional with betas[-1] = 1 mass_matrix = numpyro.param( "{}_mass_matrix".format(self.prefix), jnp.ones(self.latent_dim), constraint=constraints.positive, ) inv_mass_matrix = 0.5 / mass_matrix init_z_loc = numpyro.param( "{}_z_0_loc".format(self.prefix), self._init_latent, ) if self.base_dist == "diagonal": init_z_scale = numpyro.param( "{}_z_0_scale".format(self.prefix), jnp.full(self.latent_dim, self._init_scale), constraint=constraints.positive, ) base_z_dist = dist.Normal(init_z_loc, init_z_scale).to_event() elif self.base_dist == "cholesky": scale_tril = numpyro.param( "{}_z_0_scale_tril".format(self.prefix), jnp.identity(self.latent_dim) * self._init_scale, constraint=constraints.scaled_unit_lower_cholesky, ) base_z_dist = dist.MultivariateNormal(init_z_loc, scale_tril=scale_tril) z_0 = numpyro.sample( "{}_z_0".format(self.prefix), base_z_dist, infer={"is_auxiliary": True}, ) momentum_dist = dist.Normal(0, mass_matrix).to_event() eps = numpyro.sample( "{}_momentum".format(self.prefix), momentum_dist.expand((self.K,)).to_event().mask(False), infer={"is_auxiliary": True}, ) def scan_body(carry, eps_beta): eps, beta = eps_beta eta = eta0 + eta_coeff * beta eta = jnp.clip(eta, a_min=0.0, a_max=self.eta_max) z_prev, v_prev, log_factor = carry z_half = z_prev + v_prev * eta * inv_mass_matrix q_grad = (1.0 - beta) * grad(base_z_dist.log_prob)(z_half) p_grad = beta * grad(log_density)(z_half) v_hat = v_prev + eta * (q_grad + p_grad) z = z_half + v_hat * eta * inv_mass_matrix v = gamma * v_hat + jnp.sqrt(1 - gamma**2) * eps delta_ke = momentum_dist.log_prob(v_prev) - momentum_dist.log_prob(v_hat) log_factor = log_factor + delta_ke return (z, v, log_factor), None v_0 = eps[-1] # note the return value of scan doesn't depend on eps[-1] (z, _, log_factor), _ = jax.lax.scan(scan_body, (z_0, v_0, 0.0), (eps, betas)) numpyro.factor("{}_factor".format(self.prefix), log_factor) return z
[docs] def sample_posterior(self, rng_key, params, sample_shape=()): def _single_sample(_rng_key): latent_sample = handlers.substitute( handlers.seed(self._sample_latent, _rng_key), params )(sample_shape=()) return self._unpack_and_constrain(latent_sample, params) if sample_shape: rng_key = random.split(rng_key, int(np.prod(sample_shape))) samples = lax.map(_single_sample, rng_key) return tree_map( lambda x: jnp.reshape(x, sample_shape + jnp.shape(x)[1:]), samples, ) else: return _single_sample(rng_key)
[docs]class AutoSurrogateLikelihoodDAIS(AutoDAIS): """ This implementation of :class:`AutoSurrogateLikelihoodDAIS` provides a mini-batchable family of variational distributions as described in [1]. It combines a user-provided surrogate likelihood with Differentiable Annealed Importance Sampling (DAIS) [2, 3]. It is not applicable to models with local latent variables (see :class:`AutoSemiDAIS`), but unlike :class:`AutoDAIS`, it *can* be used in conjunction with data subsampling. **Reference:** 1. *Surrogate likelihoods for variational annealed importance sampling*, Martin Jankowiak, Du Phan 2. *MCMC Variational Inference via Uncorrected Hamiltonian Annealing*, Tomas Geffner, Justin Domke 3. *Differentiable Annealed Importance Sampling and the Perils of Gradient Noise*, Guodong Zhang, Kyle Hsu, Jianing Li, Chelsea Finn, Roger Grosse Usage:: # logistic regression model for data {X, Y} def model(X, Y): theta = numpyro.sample( "theta", dist.Normal(jnp.zeros(2), jnp.ones(2)).to_event(1) ) with numpyro.plate("N", 100, subsample_size=10): X_batch = numpyro.subsample(X, event_dim=1) Y_batch = numpyro.subsample(Y, event_dim=0) numpyro.sample("obs", dist.Bernoulli(logits=theta @ X_batch.T), obs=Y_batch) # surrogate model defined by prior and surrogate likelihood. # a convenient choice for specifying the latter is to compute the likelihood on # a randomly chosen data subset (here {X_surr, Y_surr} of size 20) and then use # handlers.scale to scale the log likelihood by a vector of learnable weights. def surrogate_model(X_surr, Y_surr): theta = numpyro.sample( "theta", dist.Normal(jnp.zeros(2), jnp.ones(2)).to_event(1) ) omegas = numpyro.param( "omegas", 5.0 * jnp.ones(20), constraint=dist.constraints.positive ) with numpyro.plate("N", 20), numpyro.handlers.scale(scale=omegas): numpyro.sample("obs", dist.Bernoulli(logits=theta @ X_surr.T), obs=Y_surr) guide = AutoSurrogateLikelihoodDAIS(model, surrogate_model) svi = SVI(model, guide, ...) :param callable model: A NumPyro model. :param callable surrogate_model: A NumPyro model that is used as a surrogate model for guiding the HMC dynamics that define the variational distribution. In particular `surrogate_model` should contain the same prior as `model` but should contain a cheap-to-evaluate parametric ansatz for the likelihood. A simple ansatz for the latter involves computing the likelihood for a fixed subset of the data and scaling the resulting log likelihood by a learnable vector of positive weights. See the usage example above. :param str prefix: A prefix that will be prefixed to all param internal sites. :param int K: A positive integer that controls the number of HMC steps used. Defaults to 4. :param str base_dist: Controls whether the base Normal variational distribution is parameterized by a "diagonal" covariance matrix or a full-rank covariance matrix parameterized by a lower-triangular "cholesky" factor. Defaults to "diagonal". :param float eta_init: The initial value of the step size used in HMC. Defaults to 0.01. :param float eta_max: The maximum value of the learnable step size used in HMC. Defaults to 0.1. :param float gamma_init: The initial value of the learnable damping factor used during partial momentum refreshments in HMC. Defaults to 0.9. :param callable init_loc_fn: A per-site initialization function. See :ref:`init_strategy` section for available functions. :param float init_scale: Initial scale for the standard deviation of the base variational distribution for each (unconstrained transformed) latent variable. Defaults to 0.1. """ def __init__( self, model, surrogate_model, *, K=4, eta_init=0.01, eta_max=0.1, gamma_init=0.9, prefix="auto", base_dist="diagonal", init_loc_fn=init_to_uniform, init_scale=0.1, ): super().__init__( model, K=K, eta_init=eta_init, eta_max=eta_max, gamma_init=gamma_init, prefix=prefix, init_loc_fn=init_loc_fn, init_scale=init_scale, base_dist=base_dist, ) self.surrogate_model = surrogate_model def _setup_prototype(self, *args, **kwargs): AutoContinuous._setup_prototype(self, *args, **kwargs) rng_key = numpyro.prng_key() with numpyro.handlers.block(): (_, self._surrogate_potential_fn, _, _) = initialize_model( rng_key, self.surrogate_model, init_strategy=self.init_loc_fn, dynamic_args=False, model_args=(), model_kwargs={}, ) def _sample_latent(self, *args, **kwargs): def blocked_surrogate_model(x): x_unpack = self._unpack_latent(x) with numpyro.handlers.block(expose_types=["param"]): return -self._surrogate_potential_fn(x_unpack) eta0 = numpyro.param( "{}_eta0".format(self.prefix), self.eta_init, constraint=constraints.interval(0, self.eta_max), ) eta_coeff = numpyro.param("{}_eta_coeff".format(self.prefix), 0.0) gamma = numpyro.param( "{}_gamma".format(self.prefix), self.gamma_init, constraint=constraints.interval(0, 1), ) betas = numpyro.param( "{}_beta_increments".format(self.prefix), jnp.ones(self.K), constraint=constraints.positive, ) betas = jnp.cumsum(betas) betas = betas / betas[-1] # K-dimensional with betas[-1] = 1 mass_matrix = numpyro.param( "{}_mass_matrix".format(self.prefix), jnp.ones(self.latent_dim), constraint=constraints.positive, ) inv_mass_matrix = 0.5 / mass_matrix init_z_loc = numpyro.param("{}_z_0_loc".format(self.prefix), self._init_latent) if self.base_dist == "diagonal": init_z_scale = numpyro.param( "{}_z_0_scale".format(self.prefix), jnp.full(self.latent_dim, self._init_scale), constraint=constraints.positive, ) base_z_dist = dist.Normal(init_z_loc, init_z_scale).to_event() else: scale_tril = numpyro.param( "{}_scale_tril".format(self.prefix), jnp.identity(self.latent_dim) * self._init_scale, constraint=constraints.scaled_unit_lower_cholesky, ) base_z_dist = dist.MultivariateNormal(init_z_loc, scale_tril=scale_tril) z_0 = numpyro.sample( "{}_z_0".format(self.prefix), base_z_dist, infer={"is_auxiliary": True} ) base_z_dist_log_prob = base_z_dist.log_prob momentum_dist = dist.Normal(0, mass_matrix).to_event() eps = numpyro.sample( "{}_momentum".format(self.prefix), momentum_dist.expand((self.K,)).to_event().mask(False), infer={"is_auxiliary": True}, ) def scan_body(carry, eps_beta): eps, beta = eps_beta eta = eta0 + eta_coeff * beta eta = jnp.clip(eta, a_min=0.0, a_max=self.eta_max) z_prev, v_prev, log_factor = carry z_half = z_prev + v_prev * eta * inv_mass_matrix q_grad = (1.0 - beta) * grad(base_z_dist_log_prob)(z_half) p_grad = beta * grad(blocked_surrogate_model)(z_half) v_hat = v_prev + eta * (q_grad + p_grad) z = z_half + v_hat * eta * inv_mass_matrix v = gamma * v_hat + jnp.sqrt(1 - gamma**2) * eps delta_ke = momentum_dist.log_prob(v_prev) - momentum_dist.log_prob(v_hat) log_factor = log_factor + delta_ke return (z, v, log_factor), None v_0 = eps[-1] # note the return value of scan doesn't depend on eps[-1] (z, _, log_factor), _ = jax.lax.scan(scan_body, (z_0, v_0, 0.0), (eps, betas)) numpyro.factor("{}_factor".format(self.prefix), log_factor) return z
def _subsample_model(model, *args, **kwargs): data = kwargs.pop("_subsample_idx", {}) with handlers.substitute(data=data): return model(*args, **kwargs)
[docs]class AutoSemiDAIS(AutoGuide): r""" This implementation of :class:`AutoSemiDAIS` [1] combines a parametric variational distribution over global latent variables with Differentiable Annealed Importance Sampling (DAIS) [2, 3] to infer local latent variables. Unlike :class:`AutoDAIS` this guide can be used in conjunction with data subsampling. Note that the resulting ELBO can be understood as a particular realization of a 'locally enhanced bound' as described in reference [4]. **References:** 1. *Surrogate Likelihoods for Variational Annealed Importance Sampling*, Martin Jankowiak, Du Phan 2. *MCMC Variational Inference via Uncorrected Hamiltonian Annealing*, Tomas Geffner, Justin Domke 3. *Differentiable Annealed Importance Sampling and the Perils of Gradient Noise*, Guodong Zhang, Kyle Hsu, Jianing Li, Chelsea Finn, Roger Grosse 4. *Variational Inference with Locally Enhanced Bounds for Hierarchical Models*, Tomas Geffner, Justin Domke Usage:: def global_model(): return numpyro.sample("theta", dist.Normal(0, 1)) def local_model(theta): with numpyro.plate("data", 8, subsample_size=2): tau = numpyro.sample("tau", dist.Gamma(5.0, 5.0)) numpyro.sample("obs", dist.Normal(0.0, tau), obs=jnp.ones(2)) model = lambda: local_model(global_model()) base_guide = AutoNormal(global_model) guide = AutoSemiDAIS(model, local_model, base_guide, K=4) svi = SVI(model, guide, ...) # sample posterior for particular data subset {3, 7} with handlers.substitute(data={"data": jnp.array([3, 7])}): samples = guide.sample_posterior(random.PRNGKey(1), params) :param callable model: A NumPyro model with global and local latent variables. :param callable local_model: The portion of `model` that includes the local latent variables only. The signature of `local_model` should be the return type of the global model with global latent variables only. :param callable base_guide: A guide for the global latent variables, e.g. an autoguide. The return type should be a dictionary of latent sample sites names and corresponding samples. :param str prefix: A prefix that will be prefixed to all internal sites. :param int K: A positive integer that controls the number of HMC steps used. Defaults to 4. :param float eta_init: The initial value of the step size used in HMC. Defaults to 0.01. :param float eta_max: The maximum value of the learnable step size used in HMC. Defaults to 0.1. :param float gamma_init: The initial value of the learnable damping factor used during partial momentum refreshments in HMC. Defaults to 0.9. :param float init_scale: Initial scale for the standard deviation of the variational distribution for each (unconstrained transformed) local latent variable. Defaults to 0.1. """ def __init__( self, model, local_model, base_guide, *, prefix="auto", K=4, eta_init=0.01, eta_max=0.1, gamma_init=0.9, init_scale=0.1, ): # init_loc_fn is only used to inspect the model. super().__init__(model, prefix=prefix, init_loc_fn=init_to_uniform) if K < 1: raise ValueError("K must satisfy K >= 1 (got K = {})".format(K)) if eta_init <= 0.0 or eta_init >= eta_max: raise ValueError( "eta_init must be positive and satisfy eta_init < eta_max." ) if eta_max <= 0.0: raise ValueError("eta_max must be positive.") if gamma_init <= 0.0 or gamma_init >= 1.0: raise ValueError("gamma_init must be in the open interval (0, 1).") if init_scale <= 0.0: raise ValueError("init_scale must be positive.") self.local_model = local_model self.base_guide = base_guide self.eta_init = eta_init self.eta_max = eta_max self.gamma_init = gamma_init self.K = K self.init_scale = init_scale def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) # extract global/local/local_dim/plates subsample_plates = { name: site for name, site in self.prototype_trace.items() if site["type"] == "plate" and isinstance(site["args"][1], int) and site["args"][0] > site["args"][1] } num_plates = len(subsample_plates) assert ( num_plates == 1 ), f"AutoSemiDAIS assumes that the model contains exactly 1 plate with data subsampling but got {num_plates}." plate_name = list(subsample_plates.keys())[0] local_vars = [] subsample_axes = {} plate_dim = None for name, site in self.prototype_trace.items(): if site["type"] == "sample" and not site["is_observed"]: for frame in site["cond_indep_stack"]: if frame.name == plate_name: if plate_dim is None: plate_dim = frame.dim local_vars.append(name) subsample_axes[name] = plate_dim - site["fn"].event_dim break if len(local_vars) == 0: raise RuntimeError( "There are no local variables in the `{plate_name}` plate." " AutoSemiDAIS is appropriate for models with local variables." ) local_init_locs = { name: value for name, value in self._init_locs.items() if name in local_vars } one_sample = { k: jnp.take(v, 0, axis=subsample_axes[k]) for k, v in local_init_locs.items() } _, shape_dict = _ravel_dict(one_sample) local_init_latent = jax.vmap( lambda x: _ravel_dict(x)[0], in_axes=(subsample_axes,) )(local_init_locs) unpack_latent = partial(_unravel_dict, shape_dict=shape_dict) # this is to match the behavior of Pyro, where we can apply # unpack_latent for a batch of samples self._unpack_local_latent = jax.vmap( UnpackTransform(unpack_latent), out_axes=subsample_axes ) plate_full_size, plate_subsample_size = subsample_plates[plate_name]["args"] self._local_latent_dim = jnp.size(local_init_latent) // plate_subsample_size self._local_plate = (plate_name, plate_full_size, plate_subsample_size) rng_key = numpyro.prng_key() with handlers.block(), handlers.seed(rng_seed=rng_key): global_output = self.base_guide.model(*args, **kwargs) ( _, self._local_potential_fn_gen, self._local_postprecess_fn, _, ) = initialize_model( numpyro.prng_key(), partial(_subsample_model, self.local_model), init_strategy=self.init_loc_fn, dynamic_args=True, model_args=(global_output,), model_kwargs={ "_subsample_idx": { plate_name: subsample_plates[plate_name]["value"] } }, ) def __call__(self, *args, **kwargs): if self.prototype_trace is None: # run model to inspect the model structure self._setup_prototype(*args, **kwargs) global_latents, local_latent_flat = self._sample_latent(*args, **kwargs) # unpack continuous latent samples result = global_latents.copy() _, N, subsample_size = self._local_plate for name, unconstrained_value in self._unpack_local_latent( local_latent_flat ).items(): site = self.prototype_trace[name] with helpful_support_errors(site): transform = biject_to(site["fn"].support) value = transform(unconstrained_value) event_ndim = site["fn"].event_dim if numpyro.get_mask() is False: log_density = 0.0 else: log_density = -transform.log_abs_det_jacobian( unconstrained_value, value ) log_density = (N / subsample_size) * sum_rightmost( log_density, jnp.ndim(log_density) - jnp.ndim(value) + event_ndim ) delta_dist = dist.Delta( value, log_density=log_density, event_dim=event_ndim ) result[name] = numpyro.sample(name, delta_dist) return result def _get_posterior(self): raise NotImplementedError def _sample_latent(self, *args, **kwargs): kwargs.pop("sample_shape", ()) def make_local_log_density(*local_args, **local_kwargs): def fn(x): x_unpack = self._unpack_local_latent(x) with numpyro.handlers.block(): return -self._local_potential_fn_gen(*local_args, **local_kwargs)( x_unpack ) return fn global_latents = self.base_guide(*args, **kwargs) rng_key = numpyro.prng_key() with handlers.block(), handlers.seed(rng_seed=rng_key), handlers.substitute( data=global_latents ): global_output = self.base_guide.model(*args, **kwargs) plate_name, N, subsample_size = self._local_plate D, K = self._local_latent_dim, self.K with numpyro.plate(plate_name, N, subsample_size=subsample_size) as idx: eta0 = numpyro.param( "{}_eta0".format(self.prefix), jnp.ones(N) * self.eta_init, constraint=constraints.interval(0, self.eta_max), event_dim=0, ) eta_coeff = numpyro.param( "{}_eta_coeff".format(self.prefix), jnp.zeros(N), event_dim=0 ) gamma = numpyro.param( "{}_gamma".format(self.prefix), jnp.ones(N) * 0.9, constraint=constraints.interval(0, 1), event_dim=0, ) betas = numpyro.param( "{}_beta_increments".format(self.prefix), jnp.ones((N, K)), constraint=constraints.positive, event_dim=1, ) betas = jnp.cumsum(betas, axis=-1) betas = betas / betas[..., -1:] mass_matrix = numpyro.param( "{}_mass_matrix".format(self.prefix), jnp.ones((N, D)), constraint=constraints.positive, event_dim=1, ) inv_mass_matrix = 0.5 / mass_matrix assert inv_mass_matrix.shape == (subsample_size, D) z_0_loc_init = jnp.zeros((N, D)) z_0_loc = numpyro.param( "{}_z_0_loc".format(self.prefix), z_0_loc_init, event_dim=1 ) z_0_scale_init = jnp.ones((N, D)) * self.init_scale z_0_scale = numpyro.param( "{}_z_0_scale".format(self.prefix), z_0_scale_init, constraint=constraints.positive, event_dim=1, ) base_z_dist = dist.Normal(z_0_loc, z_0_scale).to_event(1) assert base_z_dist.shape() == (subsample_size, D) z_0 = numpyro.sample( "{}_z_0".format(self.prefix), base_z_dist, infer={"is_auxiliary": True} ) def base_z_dist_log_prob(x): return base_z_dist.log_prob(x).sum() momentum_dist = dist.Normal(0, mass_matrix).to_event(1) eps = numpyro.sample( "{}_momentum".format(self.prefix), dist.Normal(0, mass_matrix[..., None]) .expand([subsample_size, D, K]) .to_event(2) .mask(False), infer={"is_auxiliary": True}, ) local_log_density = make_local_log_density( global_output, _subsample_idx={plate_name: idx} ) def scan_body(carry, eps_beta): eps, beta = eps_beta eta = eta0 + eta_coeff * beta eta = jnp.clip(eta, a_min=0.0, a_max=self.eta_max) assert eps.shape == (subsample_size, D) assert eta.shape == beta.shape == (subsample_size,) z_prev, v_prev, log_factor = carry z_half = z_prev + v_prev * eta[:, None] * inv_mass_matrix q_grad = (1.0 - beta[:, None]) * grad(base_z_dist_log_prob)(z_half) p_grad = ( beta[:, None] * (subsample_size / N) * grad(local_log_density)(z_half) ) assert q_grad.shape == p_grad.shape == (subsample_size, D) v_hat = v_prev + eta[:, None] * (q_grad + p_grad) z = z_half + v_hat * eta[:, None] * inv_mass_matrix v = gamma[:, None] * v_hat + jnp.sqrt(1 - gamma[:, None] ** 2) * eps delta_ke = momentum_dist.log_prob(v_prev) - momentum_dist.log_prob( v_hat ) assert delta_ke.shape == (subsample_size,) log_factor = log_factor + delta_ke return (z, v, log_factor), None v_0 = eps[ :, :, -1 ] # note the return value of scan doesn't depend on eps[:, :, -1] assert eps.shape == (subsample_size, D, K) assert betas.shape == (subsample_size, K) eps_T = jnp.moveaxis(eps, -1, 0) (z, _, log_factor), _ = jax.lax.scan( scan_body, (z_0, v_0, jnp.zeros(subsample_size)), (eps_T, betas.T) ) assert log_factor.shape == (subsample_size,) numpyro.factor("{}_local_dais_factor".format(self.prefix), log_factor) return global_latents, z
[docs] def sample_posterior(self, rng_key, params, *args, sample_shape=(), **kwargs): def _single_sample(_rng_key): global_latents, local_flat = handlers.substitute( handlers.seed(self._sample_latent, _rng_key), params )(*args, **kwargs) results = global_latents.copy() for name, unconstrained_value in self._unpack_local_latent( local_flat ).items(): site = self.prototype_trace[name] transform = biject_to(site["fn"].support) value = transform(unconstrained_value) results[name] = value return results if sample_shape: rng_key = random.split(rng_key, int(np.prod(sample_shape))) samples = lax.map(_single_sample, rng_key) return tree_map( lambda x: jnp.reshape(x, sample_shape + jnp.shape(x)[1:]), samples, ) else: return _single_sample(rng_key)
[docs]class AutoDiagonalNormal(AutoContinuous): """ This implementation of :class:`AutoContinuous` uses a Normal distribution with a diagonal covariance matrix to construct a guide over the entire latent space. The guide does not depend on the model's ``*args, **kwargs``. Usage:: guide = AutoDiagonalNormal(model, ...) svi = SVI(model, guide, ...) """ scale_constraint = constraints.softplus_positive def __init__( self, model, *, prefix="auto", init_loc_fn=init_to_uniform, init_scale=0.1, ): if init_scale <= 0: raise ValueError("Expected init_scale > 0. but got {}".format(init_scale)) self._init_scale = init_scale super().__init__(model, prefix=prefix, init_loc_fn=init_loc_fn) def _get_posterior(self): loc = numpyro.param("{}_loc".format(self.prefix), self._init_latent) scale = numpyro.param( "{}_scale".format(self.prefix), jnp.full(self.latent_dim, self._init_scale), constraint=self.scale_constraint, ) return dist.Normal(loc, scale)
[docs] def get_base_dist(self): return dist.Normal(jnp.zeros(self.latent_dim), 1).to_event(1)
[docs] def get_transform(self, params): loc = params["{}_loc".format(self.prefix)] scale = params["{}_scale".format(self.prefix)] return IndependentTransform(AffineTransform(loc, scale), 1)
[docs] def get_posterior(self, params): """ Returns a diagonal Normal posterior distribution. """ transform = self.get_transform(params).base_transform return dist.Normal(transform.loc, transform.scale)
[docs] def median(self, params): loc = params["{}_loc".format(self.prefix)] return self._unpack_and_constrain(loc, params)
[docs] def quantiles(self, params, quantiles): quantiles = jnp.array(quantiles)[..., None] latent = self.get_posterior(params).icdf(quantiles) return self._unpack_and_constrain(latent, params)
[docs]class AutoMultivariateNormal(AutoContinuous): """ This implementation of :class:`AutoContinuous` uses a MultivariateNormal distribution to construct a guide over the entire latent space. The guide does not depend on the model's ``*args, **kwargs``. Usage:: guide = AutoMultivariateNormal(model, ...) svi = SVI(model, guide, ...) """ scale_tril_constraint = constraints.scaled_unit_lower_cholesky def __init__( self, model, *, prefix="auto", init_loc_fn=init_to_uniform, init_scale=0.1, ): if init_scale <= 0: raise ValueError("Expected init_scale > 0. but got {}".format(init_scale)) self._init_scale = init_scale super().__init__(model, prefix=prefix, init_loc_fn=init_loc_fn) def _get_posterior(self): loc = numpyro.param("{}_loc".format(self.prefix), self._init_latent) scale_tril = numpyro.param( "{}_scale_tril".format(self.prefix), jnp.identity(self.latent_dim) * self._init_scale, constraint=self.scale_tril_constraint, ) return dist.MultivariateNormal(loc, scale_tril=scale_tril)
[docs] def get_base_dist(self): return dist.Normal(jnp.zeros(self.latent_dim), 1).to_event(1)
[docs] def get_transform(self, params): loc = params["{}_loc".format(self.prefix)] scale_tril = params["{}_scale_tril".format(self.prefix)] return LowerCholeskyAffine(loc, scale_tril)
[docs] def get_posterior(self, params): """ Returns a multivariate Normal posterior distribution. """ transform = self.get_transform(params) return dist.MultivariateNormal(transform.loc, scale_tril=transform.scale_tril)
[docs] def median(self, params): loc = params["{}_loc".format(self.prefix)] return self._unpack_and_constrain(loc, params)
[docs] def quantiles(self, params, quantiles): transform = self.get_transform(params) quantiles = jnp.array(quantiles)[..., None] latent = dist.Normal(transform.loc, jnp.diagonal(transform.scale_tril)).icdf( quantiles ) return self._unpack_and_constrain(latent, params)
[docs]class AutoLowRankMultivariateNormal(AutoContinuous): """ This implementation of :class:`AutoContinuous` uses a LowRankMultivariateNormal distribution to construct a guide over the entire latent space. The guide does not depend on the model's ``*args, **kwargs``. Usage:: guide = AutoLowRankMultivariateNormal(model, rank=2, ...) svi = SVI(model, guide, ...) """ scale_constraint = constraints.softplus_positive def __init__( self, model, *, prefix="auto", init_loc_fn=init_to_uniform, init_scale=0.1, rank=None, ): if init_scale <= 0: raise ValueError("Expected init_scale > 0. but got {}".format(init_scale)) self._init_scale = init_scale self.rank = rank super(AutoLowRankMultivariateNormal, self).__init__( model, prefix=prefix, init_loc_fn=init_loc_fn ) def _get_posterior(self, *args, **kwargs): rank = int(round(self.latent_dim**0.5)) if self.rank is None else self.rank loc = numpyro.param("{}_loc".format(self.prefix), self._init_latent) cov_factor = numpyro.param( "{}_cov_factor".format(self.prefix), jnp.zeros((self.latent_dim, rank)) ) scale = numpyro.param( "{}_scale".format(self.prefix), jnp.full(self.latent_dim, self._init_scale), constraint=self.scale_constraint, ) cov_diag = scale * scale cov_factor = cov_factor * scale[..., None] return dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag)
[docs] def get_base_dist(self): return dist.Normal(jnp.zeros(self.latent_dim), 1).to_event(1)
[docs] def get_transform(self, params): posterior = self.get_posterior(params) return LowerCholeskyAffine(posterior.loc, posterior.scale_tril)
[docs] def get_posterior(self, params): """ Returns a lowrank multivariate Normal posterior distribution. """ loc = params["{}_loc".format(self.prefix)] cov_factor = params["{}_cov_factor".format(self.prefix)] scale = params["{}_scale".format(self.prefix)] cov_diag = scale * scale cov_factor = cov_factor * scale[..., None] return dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag)
[docs] def median(self, params): loc = params["{}_loc".format(self.prefix)] return self._unpack_and_constrain(loc, params)
[docs] def quantiles(self, params, quantiles): loc = params[f"{self.prefix}_loc"] cov_factor = params[f"{self.prefix}_cov_factor"] scale = params[f"{self.prefix}_scale"] scale = scale * jnp.sqrt(jnp.square(cov_factor).sum(-1) + 1) quantiles = jnp.array(quantiles)[..., None] latent = dist.Normal(loc, scale).icdf(quantiles) return self._unpack_and_constrain(latent, params)
[docs]class AutoLaplaceApproximation(AutoContinuous): r""" Laplace approximation (quadratic approximation) approximates the posterior :math:`\log p(z | x)` by a multivariate normal distribution in the unconstrained space. Under the hood, it uses Delta distributions to construct a MAP guide over the entire (unconstrained) latent space. Its covariance is given by the inverse of the hessian of :math:`-\log p(x, z)` at the MAP point of `z`. Usage:: guide = AutoLaplaceApproximation(model, ...) svi = SVI(model, guide, ...) :param callable hessian_fn: EXPERIMENTAL a function that takes a function `f` and a vector `x`and returns the hessian of `f` at `x`. By default, we use ``lambda f, x: jax.hessian(f)(x)``. Other alternatives can be ``lambda f, x: jax.jacobian(jax.jacobian(f))(x)`` or ``lambda f, x: jax.hessian(f)(x) + 1e-3 * jnp.eye(x.shape[0])``. The later example is helpful when the hessian of `f` at `x` is not positive definite. Note that the output hessian is the precision matrix of the laplace approximation. """ def __init__( self, model, *, prefix="auto", init_loc_fn=init_to_uniform, create_plates=None, hessian_fn=None, ): super().__init__( model, prefix=prefix, init_loc_fn=init_loc_fn, create_plates=create_plates ) self._hessian_fn = ( hessian_fn if hessian_fn is not None else (lambda f, x: hessian(f)(x)) ) def _setup_prototype(self, *args, **kwargs): super(AutoLaplaceApproximation, self)._setup_prototype(*args, **kwargs) def loss_fn(params): # we are doing maximum likelihood, so only require `num_particles=1` and an arbitrary rng_key. return Trace_ELBO().loss( random.PRNGKey(0), params, self.model, self, *args, **kwargs ) self._loss_fn = loss_fn def _get_posterior(self, *args, **kwargs): # sample from Delta guide loc = numpyro.param("{}_loc".format(self.prefix), self._init_latent) return dist.Delta(loc, event_dim=1)
[docs] def get_base_dist(self): return dist.Normal(jnp.zeros(self.latent_dim), 1).to_event(1)
[docs] def get_transform(self, params): def loss_fn(z): params1 = params.copy() params1["{}_loc".format(self.prefix)] = z return self._loss_fn(params1) loc = params["{}_loc".format(self.prefix)] precision = self._hessian_fn(loss_fn, loc) scale_tril = cholesky_of_inverse(precision) if not_jax_tracer(scale_tril): if np.any(np.isnan(scale_tril)): warnings.warn( "Hessian of log posterior at the MAP point is singular. Posterior" " samples from AutoLaplaceApproxmiation will be constant (equal to" " the MAP point). Please consider using an AutoNormal guide.", stacklevel=find_stack_level(), ) scale_tril = jnp.where(jnp.isnan(scale_tril), 0.0, scale_tril) return LowerCholeskyAffine(loc, scale_tril)
[docs] def get_posterior(self, params): """ Returns a multivariate Normal posterior distribution. """ transform = self.get_transform(params) return dist.MultivariateNormal(transform.loc, scale_tril=transform.scale_tril)
[docs] def sample_posterior(self, rng_key, params, sample_shape=()): latent_sample = self.get_posterior(params).sample(rng_key, sample_shape) return self._unpack_and_constrain(latent_sample, params)
[docs] def median(self, params): loc = params["{}_loc".format(self.prefix)] return self._unpack_and_constrain(loc, params)
[docs] def quantiles(self, params, quantiles): transform = self.get_transform(params) quantiles = jnp.array(quantiles)[..., None] latent = dist.Normal(transform.loc, jnp.diagonal(transform.scale_tril)).icdf( quantiles ) return self._unpack_and_constrain(latent, params)
[docs]class AutoIAFNormal(AutoContinuous): """ This implementation of :class:`AutoContinuous` uses a Diagonal Normal distribution transformed via a :class:`~numpyro.distributions.flows.InverseAutoregressiveTransform` to construct a guide over the entire latent space. The guide does not depend on the model's ``*args, **kwargs``. Usage:: guide = AutoIAFNormal(model, hidden_dims=[20], skip_connections=True, ...) svi = SVI(model, guide, ...) :param callable model: a generative model. :param str prefix: a prefix that will be prefixed to all param internal sites. :param callable init_loc_fn: A per-site initialization function. :param int num_flows: the number of flows to be used, defaults to 3. :param list hidden_dims: the dimensionality of the hidden units per layer. Defaults to ``[latent_dim, latent_dim]``. :param bool skip_connections: whether to add skip connections from the input to the output of each flow. Defaults to False. :param callable nonlinearity: the nonlinearity to use in the feedforward network. Defaults to :func:`jax.example_libraries.stax.Elu`. """ def __init__( self, model, *, prefix="auto", init_loc_fn=init_to_uniform, num_flows=3, hidden_dims=None, skip_connections=False, nonlinearity=stax.Elu, ): self.num_flows = num_flows # 2-layer, stax.Elu, skip_connections=False by default following the experiments in # IAF paper (https://arxiv.org/abs/1606.04934) # and Neutra paper (https://arxiv.org/abs/1903.03704) self._hidden_dims = hidden_dims self._skip_connections = skip_connections self._nonlinearity = nonlinearity super(AutoIAFNormal, self).__init__( model, prefix=prefix, init_loc_fn=init_loc_fn ) def _get_posterior(self): if self.latent_dim == 1: raise ValueError( "latent dim = 1. Consider using AutoDiagonalNormal instead" ) hidden_dims = ( [self.latent_dim, self.latent_dim] if self._hidden_dims is None else self._hidden_dims ) flows = [] for i in range(self.num_flows): if i > 0: flows.append(PermuteTransform(jnp.arange(self.latent_dim)[::-1])) arn = AutoregressiveNN( self.latent_dim, hidden_dims, permutation=jnp.arange(self.latent_dim), skip_connections=self._skip_connections, nonlinearity=self._nonlinearity, ) arnn = numpyro.module( "{}_arn__{}".format(self.prefix, i), arn, (self.latent_dim,) ) flows.append(InverseAutoregressiveTransform(arnn)) return dist.TransformedDistribution(self.get_base_dist(), flows)
[docs] def get_base_dist(self): return dist.Normal(jnp.zeros(self.latent_dim), 1).to_event(1)
[docs]class AutoBNAFNormal(AutoContinuous): """ This implementation of :class:`AutoContinuous` uses a Diagonal Normal distribution transformed via a :class:`~numpyro.distributions.flows.BlockNeuralAutoregressiveTransform` to construct a guide over the entire latent space. The guide does not depend on the model's ``*args, **kwargs``. Usage:: guide = AutoBNAFNormal(model, num_flows=1, hidden_factors=[50, 50], ...) svi = SVI(model, guide, ...) **References** 1. *Block Neural Autoregressive Flow*, Nicola De Cao, Ivan Titov, Wilker Aziz :param callable model: a generative model. :param str prefix: a prefix that will be prefixed to all param internal sites. :param callable init_loc_fn: A per-site initialization function. :param int num_flows: the number of flows to be used, defaults to 1. :param list hidden_factors: Hidden layer i has ``hidden_factors[i]`` hidden units per input dimension. This corresponds to both :math:`a` and :math:`b` in reference [1]. The elements of hidden_factors must be integers. """ def __init__( self, model, *, prefix="auto", init_loc_fn=init_to_uniform, num_flows=1, hidden_factors=[8, 8], ): self.num_flows = num_flows self._hidden_factors = hidden_factors super(AutoBNAFNormal, self).__init__( model, prefix=prefix, init_loc_fn=init_loc_fn ) def _get_posterior(self): if self.latent_dim == 1: raise ValueError( "latent dim = 1. Consider using AutoDiagonalNormal instead" ) flows = [] for i in range(self.num_flows): if i > 0: flows.append(PermuteTransform(jnp.arange(self.latent_dim)[::-1])) residual = "gated" if i < (self.num_flows - 1) else None arn = BlockNeuralAutoregressiveNN( self.latent_dim, self._hidden_factors, residual ) arnn = numpyro.module( "{}_arn__{}".format(self.prefix, i), arn, (self.latent_dim,) ) flows.append(BlockNeuralAutoregressiveTransform(arnn)) return dist.TransformedDistribution(self.get_base_dist(), flows)
[docs] def get_base_dist(self): return dist.Normal(jnp.zeros(self.latent_dim), 1).to_event(1)