Source code for numpyro.infer.reparam

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from abc import ABC, abstractmethod
import math

import numpy as np

import jax
import jax.numpy as jnp

import numpyro
import numpyro.distributions as dist
from numpyro.distributions import biject_to, constraints
from numpyro.distributions.util import is_identically_one, safe_normalize, sum_rightmost
from numpyro.infer.autoguide import AutoContinuous
from numpyro.util import not_jax_tracer

[docs] class Reparam(ABC): """ Base class for reparameterizers. """ @abstractmethod def __call__(self, name, fn, obs): """ :param str name: A sample site name. :param ~numpyro.distributions.Distribution fn: A distribution. :param numpy.ndarray obs: Observed value or None. :return: A pair (``new_fn``, ``value``). """ return fn, obs def _unwrap(self, fn): """ Unwrap Independent(...) and ExpandedDistribution(...) distributions. We can recover the input `fn` from the result triple `(fn, expand_shape, event_dim)` with `fn.expand(expand_shape).to_event(event_dim - fn.event_dim)`. """ shape = fn.shape() event_dim = fn.event_dim while isinstance(fn, (dist.Independent, dist.ExpandedDistribution)): fn = fn.base_dist expand_shape = shape[: len(shape) - fn.event_dim] return fn, expand_shape, event_dim def _wrap(self, fn, expand_shape, event_dim): """ Wrap in Independent and ExpandedDistribution distributions. """ # Match batch_shape. assert fn.event_dim <= event_dim fn = fn.expand(expand_shape) # no-op if expand_shape == fn.batch_shape # Match event_dim. if fn.event_dim < event_dim: fn = fn.to_event(event_dim - fn.event_dim) assert fn.event_dim == event_dim return fn
[docs] class LocScaleReparam(Reparam): """ Generic decentering reparameterizer [1] for latent variables parameterized by ``loc`` and ``scale`` (and possibly additional ``shape_params``). This reparameterization works only for latent variables, not likelihoods. **References:** 1. *Automatic Reparameterisation of Probabilistic Programs*, Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019) :param float centered: optional centered parameter. If None (default) learn a per-site per-element centering parameter in ``[0,1]`` initialized at value 0.5. To sample the parameter, consider using :class:`~numpyro.handlers.lift` handler with a prior like ``Uniform(0, 1)`` to cast the parameter to a latent variable. If 0, fully decenter the distribution; if 1, preserve the centered distribution unchanged. :param shape_params: list of additional parameter names to copy unchanged from the centered to decentered distribution. :type shape_params: tuple or list """ def __init__(self, centered=None, shape_params=()): assert centered is None or isinstance( centered, (int, float, np.generic, np.ndarray, jnp.ndarray, jax.core.Tracer) ) assert isinstance(shape_params, (tuple, list)) assert all(isinstance(name, str) for name in shape_params) if centered is not None: is_valid = constraints.unit_interval.check(centered) if not_jax_tracer(is_valid): if not np.all(is_valid): raise ValueError( "`centered` argument does not satisfy `0 <= centered <= 1`." ) self.centered = centered self.shape_params = shape_params
[docs] def __call__(self, name, fn, obs): assert obs is None, "LocScaleReparam does not support observe statements" support = if isinstance(support, constraints.independent): support = if support is not constraints.real: raise ValueError( "LocScaleReparam only supports distributions with real " f"support, but got {support} support at site {name}." ) centered = self.centered if is_identically_one(centered): return fn, obs event_shape = fn.event_shape fn, expand_shape, event_dim = self._unwrap(fn) # Apply a partial decentering transform. params = {key: getattr(fn, key) for key in self.shape_params} if self.centered is None: centered = numpyro.param( "{}_centered".format(name), jnp.full(event_shape, 0.5), constraint=constraints.unit_interval, ) if isinstance(centered, (int, float, np.generic)) and centered == 0.0: params["loc"] = jnp.zeros_like(fn.loc) params["scale"] = jnp.ones_like(fn.scale) else: params["loc"] = fn.loc * centered params["scale"] = fn.scale**centered decentered_fn = self._wrap(type(fn)(**params), expand_shape, event_dim) # Draw decentered noise. decentered_value = numpyro.sample("{}_decentered".format(name), decentered_fn) # Differentiably transform. delta = decentered_value - centered * fn.loc value = fn.loc + jnp.power(fn.scale, 1 - centered) * delta # Simulate a pyro.deterministic() site. return None, value
[docs] class TransformReparam(Reparam): """ Reparameterizer for :class:`~numpyro.distributions.TransformedDistribution` latent variables. This is useful for transformed distributions with complex, geometry-changing transforms, where the posterior has simple shape in the space of ``base_dist``. This reparameterization works only for latent variables, not likelihoods. """
[docs] def __call__(self, name, fn, obs): assert obs is None, "TransformReparam does not support observe statements" fn, expand_shape, event_dim = self._unwrap(fn) if not isinstance(fn, dist.TransformedDistribution): raise ValueError( "TransformReparam does not automatically work with {}" " distribution anymore. Please explicitly using" " TransformedDistribution(base_dist, AffineTransform(...)) pattern" " with TransformReparam.".format(type(fn).__name__) ) # Draw noise from the base distribution. base_event_dim = event_dim for t in reversed(fn.transforms): base_event_dim += t.domain.event_dim - t.codomain.event_dim x = numpyro.sample( "{}_base".format(name), self._wrap(fn.base_dist, expand_shape, base_event_dim), ) # Differentiably transform. for t in fn.transforms: x = t(x) # Simulate a pyro.deterministic() site. return None, x
[docs] class ProjectedNormalReparam(Reparam): """ Reparametrizer for :class:`~numpyro.distributions.ProjectedNormal` latent variables. This reparameterization works only for latent variables, not likelihoods. """
[docs] def __call__(self, name, fn, obs): assert obs is None, "ProjectedNormalReparam does not support observe statements" fn, expand_shape, event_dim = self._unwrap(fn) assert isinstance(fn, dist.ProjectedNormal) # Draw parameter-free noise. new_fn = dist.Normal(jnp.zeros(fn.concentration.shape), 1).to_event(1) x = numpyro.sample( "{}_normal".format(name), self._wrap(new_fn, expand_shape, event_dim) ) # Differentiably transform. value = safe_normalize(x + fn.concentration) # Simulate a pyro.deterministic() site. return None, value
[docs] class NeuTraReparam(Reparam): """ Neural Transport reparameterizer [1] of multiple latent variables. This uses a trained :class:`~numpyro.infer.autoguide.AutoContinuous` guide to alter the geometry of a model, typically for use e.g. in MCMC. Example usage:: # Step 1. Train a guide guide = AutoIAFNormal(model) svi = SVI(model, guide, ...) # ...train the guide... # Step 2. Use trained guide in NeuTra MCMC neutra = NeuTraReparam(guide) model = netra.reparam(model) nuts = NUTS(model) # use the model in HMC or NUTS... This reparameterization works only for latent variables, not likelihoods. Note that all sites must share a single common :class:`NeuTraReparam` instance, and that the model must have static structure. [1] Hoffman, M. et al. (2019) "NeuTra-lizing Bad Geometry in Hamiltonian Monte Carlo Using Neural Transport" :param ~numpyro.infer.autoguide.AutoContinuous guide: A guide. :param params: trained parameters of the guide. """ def __init__(self, guide, params): if not isinstance(guide, AutoContinuous): raise TypeError( "NeuTraReparam expected an AutoContinuous guide, but got {}".format( type(guide) ) ) = guide self.params = params try: self.transform = except (NotImplementedError, TypeError) as e: raise ValueError( "NeuTraReparam only supports guides that implement " "`get_transform` method that does not depend on the " "model's `*args, **kwargs`" ) from e self._x_unconstrained = {} def _reparam_config(self, site): if site["name"] in # We only reparam if this is an unobserved site in the guide # prototype trace. guide_site =[site["name"]] if not guide_site.get("is_observed", False): return self
[docs] def reparam(self, fn=None): return numpyro.handlers.reparam(fn, config=self._reparam_config)
[docs] def __call__(self, name, fn, obs): if name not in return fn, obs assert obs is None, "NeuTraReparam does not support observe statements" log_density = 0.0 compute_density = numpyro.get_mask() is not False if not self._x_unconstrained: # On first sample site. # Sample a shared latent. z_unconstrained = numpyro.sample( "{}_shared_latent".format(,, ) # Differentiably transform. x_unconstrained = self.transform(z_unconstrained) if compute_density: log_density = self.transform.log_abs_det_jacobian( z_unconstrained, x_unconstrained ) self._x_unconstrained = # Extract a single site's value from the shared latent. unconstrained_value = self._x_unconstrained.pop(name) transform = biject_to( value = transform(unconstrained_value) if compute_density: logdet = transform.log_abs_det_jacobian(unconstrained_value, value) logdet = sum_rightmost( logdet, jnp.ndim(logdet) - jnp.ndim(value) + len(fn.event_shape) ) log_density = log_density + fn.log_prob(value) + logdet numpyro.factor("_{}_log_prob".format(name), log_density) return None, value
[docs] def transform_sample(self, latent): """ Given latent samples from the warped posterior (with possible batch dimensions), return a `dict` of samples from the latent sites in the model. :param latent: sample from the warped posterior (possibly batched). :return: a `dict` of samples keyed by latent sites in the model. :rtype: dict """ x_unconstrained = self.transform(latent) return, self.params)
[docs] class CircularReparam(Reparam): """ Reparametrizer for :class:`~numpyro.distributions.VonMises` latent variables. """
[docs] def __call__(self, name, fn, obs): # Support must be circular support = if isinstance(support, constraints.independent): support = assert support is constraints.circular # Draw parameter-free noise. new_fn = dist.ImproperUniform(constraints.real, fn.batch_shape, fn.event_shape) value = numpyro.sample( f"{name}_unwrapped", new_fn, obs=obs, ) # Differentiably transform. value = jnp.remainder(value + math.pi, 2 * math.pi) - math.pi # Simulate a pyro.deterministic() site. numpyro.factor(f"{name}_factor", fn.log_prob(value)) return None, value