Source code for numpyro.infer.reparam

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from abc import ABC, abstractmethod
import math
from typing import Iterable

import numpy as np

import jax
import jax.numpy as jnp

import numpyro
import numpyro.distributions as dist
from numpyro.distributions import biject_to, constraints
from numpyro.distributions.util import is_identically_one, safe_normalize, sum_rightmost
from numpyro.infer.autoguide import AutoContinuous
from numpyro.util import not_jax_tracer


[docs] class Reparam(ABC): """ Base class for reparameterizers. """ @abstractmethod def __call__(self, name, fn, obs): """ :param str name: A sample site name. :param ~numpyro.distributions.Distribution fn: A distribution. :param numpy.ndarray obs: Observed value or None. :return: A pair (``new_fn``, ``value``). """ return fn, obs def _unwrap(self, fn): """ Unwrap Independent(...) and ExpandedDistribution(...) distributions. We can recover the input `fn` from the result triple `(fn, expand_shape, event_dim)` with `fn.expand(expand_shape).to_event(event_dim - fn.event_dim)`. """ shape = fn.shape() event_dim = fn.event_dim while isinstance(fn, (dist.Independent, dist.ExpandedDistribution)): fn = fn.base_dist expand_shape = shape[: len(shape) - fn.event_dim] return fn, expand_shape, event_dim def _wrap(self, fn, expand_shape, event_dim): """ Wrap in Independent and ExpandedDistribution distributions. """ # Match batch_shape. assert fn.event_dim <= event_dim fn = fn.expand(expand_shape) # no-op if expand_shape == fn.batch_shape # Match event_dim. if fn.event_dim < event_dim: fn = fn.to_event(event_dim - fn.event_dim) assert fn.event_dim == event_dim return fn
[docs] class LocScaleReparam(Reparam): """ Generic decentering reparameterizer [1] for latent variables parameterized by ``loc`` and ``scale`` (and possibly additional ``shape_params``). This reparameterization works only for latent variables, not likelihoods. **References:** 1. *Automatic Reparameterisation of Probabilistic Programs*, Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019) :param float centered: optional centered parameter. If None (default) learn a per-site per-element centering parameter in ``[0,1]`` initialized at value 0.5. To sample the parameter, consider using :class:`~numpyro.handlers.lift` handler with a prior like ``Uniform(0, 1)`` to cast the parameter to a latent variable. If 0, fully decenter the distribution; if 1, preserve the centered distribution unchanged. :param shape_params: list of additional parameter names to copy unchanged from the centered to decentered distribution. :type shape_params: tuple or list """ def __init__(self, centered=None, shape_params=()): assert centered is None or isinstance( centered, (int, float, np.generic, np.ndarray, jnp.ndarray, jax.core.Tracer) ) assert isinstance(shape_params, (tuple, list)) assert all(isinstance(name, str) for name in shape_params) if centered is not None: is_valid = constraints.unit_interval.check(centered) if not_jax_tracer(is_valid): if not np.all(is_valid): raise ValueError( "`centered` argument does not satisfy `0 <= centered <= 1`." ) self.centered = centered self.shape_params = shape_params
[docs] def __call__(self, name, fn, obs): assert obs is None, "LocScaleReparam does not support observe statements" support = fn.support if isinstance(support, constraints.independent): support = fn.support.base_constraint if support is not constraints.real: raise ValueError( "LocScaleReparam only supports distributions with real " f"support, but got {support} support at site {name}." ) centered = self.centered if is_identically_one(centered): return fn, obs event_shape = fn.event_shape fn, expand_shape, event_dim = self._unwrap(fn) # Apply a partial decentering transform. params = {key: getattr(fn, key) for key in self.shape_params} if self.centered is None: centered = numpyro.param( "{}_centered".format(name), jnp.full(event_shape, 0.5), constraint=constraints.unit_interval, ) if isinstance(centered, (int, float, np.generic)) and centered == 0.0: params["loc"] = jnp.zeros_like(fn.loc) params["scale"] = jnp.ones_like(fn.scale) else: params["loc"] = fn.loc * centered params["scale"] = fn.scale**centered decentered_fn = self._wrap(type(fn)(**params), expand_shape, event_dim) # Draw decentered noise. decentered_value = numpyro.sample("{}_decentered".format(name), decentered_fn) # Differentiably transform. delta = decentered_value - centered * fn.loc value = fn.loc + jnp.power(fn.scale, 1 - centered) * delta # Simulate a pyro.deterministic() site. return None, value
[docs] class TransformReparam(Reparam): """ Reparameterizer for :class:`~numpyro.distributions.TransformedDistribution` latent variables. This is useful for transformed distributions with complex, geometry-changing transforms, where the posterior has simple shape in the space of ``base_dist``. This reparameterization works only for latent variables, not likelihoods. """
[docs] def __call__(self, name, fn, obs): assert obs is None, "TransformReparam does not support observe statements" fn, expand_shape, event_dim = self._unwrap(fn) if not isinstance(fn, dist.TransformedDistribution): raise ValueError( "TransformReparam does not automatically work with {}" " distribution anymore. Please explicitly using" " TransformedDistribution(base_dist, AffineTransform(...)) pattern" " with TransformReparam.".format(type(fn).__name__) ) # Draw noise from the base distribution. base_event_dim = event_dim for t in reversed(fn.transforms): base_event_dim += t.domain.event_dim - t.codomain.event_dim x = numpyro.sample( "{}_base".format(name), self._wrap(fn.base_dist, expand_shape, base_event_dim), ) # Differentiably transform. for t in fn.transforms: x = t(x) # Simulate a pyro.deterministic() site. return None, x
[docs] class ProjectedNormalReparam(Reparam): """ Reparametrizer for :class:`~numpyro.distributions.ProjectedNormal` latent variables. This reparameterization works only for latent variables, not likelihoods. """
[docs] def __call__(self, name, fn, obs): assert obs is None, "ProjectedNormalReparam does not support observe statements" fn, expand_shape, event_dim = self._unwrap(fn) assert isinstance(fn, dist.ProjectedNormal) # Draw parameter-free noise. new_fn = dist.Normal(jnp.zeros(fn.concentration.shape), 1).to_event(1) x = numpyro.sample( "{}_normal".format(name), self._wrap(new_fn, expand_shape, event_dim) ) # Differentiably transform. value = safe_normalize(x + fn.concentration) # Simulate a pyro.deterministic() site. return None, value
[docs] class NeuTraReparam(Reparam): """ Neural Transport reparameterizer [1] of multiple latent variables. This uses a trained :class:`~numpyro.infer.autoguide.AutoContinuous` guide to alter the geometry of a model, typically for use e.g. in MCMC. Example usage:: # Step 1. Train a guide guide = AutoIAFNormal(model) svi = SVI(model, guide, ...) # ...train the guide... # Step 2. Use trained guide in NeuTra MCMC neutra = NeuTraReparam(guide) model = netra.reparam(model) nuts = NUTS(model) # ...now use the model in HMC or NUTS... This reparameterization works only for latent variables, not likelihoods. Note that all sites must share a single common :class:`NeuTraReparam` instance, and that the model must have static structure. [1] Hoffman, M. et al. (2019) "NeuTra-lizing Bad Geometry in Hamiltonian Monte Carlo Using Neural Transport" https://arxiv.org/abs/1903.03704 :param ~numpyro.infer.autoguide.AutoContinuous guide: A guide. :param params: trained parameters of the guide. """ def __init__(self, guide, params): if not isinstance(guide, AutoContinuous): raise TypeError( "NeuTraReparam expected an AutoContinuous guide, but got {}".format( type(guide) ) ) self.guide = guide self.params = params try: self.transform = self.guide.get_transform(params) except (NotImplementedError, TypeError) as e: raise ValueError( "NeuTraReparam only supports guides that implement " "`get_transform` method that does not depend on the " "model's `*args, **kwargs`" ) from e self._x_unconstrained = {} def _reparam_config(self, site): if site["name"] in self.guide.prototype_trace: # We only reparam if this is an unobserved site in the guide # prototype trace. guide_site = self.guide.prototype_trace[site["name"]] if not guide_site.get("is_observed", False): return self
[docs] def reparam(self, fn=None): return numpyro.handlers.reparam(fn, config=self._reparam_config)
[docs] def __call__(self, name, fn, obs): if name not in self.guide.prototype_trace: return fn, obs assert obs is None, "NeuTraReparam does not support observe statements" log_density = 0.0 compute_density = numpyro.get_mask() is not False if not self._x_unconstrained: # On first sample site. # Sample a shared latent. z_unconstrained = numpyro.sample( "{}_shared_latent".format(self.guide.prefix), self.guide.get_base_dist().mask(False), ) # Differentiably transform. x_unconstrained = self.transform(z_unconstrained) if compute_density: log_density = self.transform.log_abs_det_jacobian( z_unconstrained, x_unconstrained ) self._x_unconstrained = self.guide._unpack_latent(x_unconstrained) # Extract a single site's value from the shared latent. unconstrained_value = self._x_unconstrained.pop(name) transform = biject_to(fn.support) value = transform(unconstrained_value) if compute_density: logdet = transform.log_abs_det_jacobian(unconstrained_value, value) logdet = sum_rightmost( logdet, jnp.ndim(logdet) - jnp.ndim(value) + len(fn.event_shape) ) log_density = log_density + fn.log_prob(value) + logdet numpyro.factor("_{}_log_prob".format(name), log_density) return None, value
[docs] def transform_sample(self, latent): """ Given latent samples from the warped posterior (with possible batch dimensions), return a `dict` of samples from the latent sites in the model. :param latent: sample from the warped posterior (possibly batched). :return: a `dict` of samples keyed by latent sites in the model. :rtype: dict """ x_unconstrained = self.transform(latent) return self.guide._unpack_and_constrain(x_unconstrained, self.params)
[docs] class CircularReparam(Reparam): """ Reparametrizer for :class:`~numpyro.distributions.VonMises` latent variables. """
[docs] def __call__(self, name, fn, obs): # Support must be circular support = fn.support if isinstance(support, constraints.independent): support = fn.support.base_constraint assert support is constraints.circular # Draw parameter-free noise. new_fn = dist.ImproperUniform(constraints.real, fn.batch_shape, fn.event_shape) value = numpyro.sample( f"{name}_unwrapped", new_fn, obs=obs, ) # Differentiably transform. value = jnp.remainder(value + math.pi, 2 * math.pi) - math.pi # Simulate a pyro.deterministic() site. numpyro.factor(f"{name}_factor", fn.log_prob(value)) return None, value
[docs] class ExplicitReparam(Reparam): """ Explicit reparametrizer of a latent variable :code:`x` to a transformed space :code:`y = transform(x)` with more amenable geometry. This reparametrizer is similar to :class:`.TransformReparam` but allows reparametrizations to be decoupled from the model declaration. :param transform: Bijective transform to the reparameterized space. **Example:** .. doctest:: >>> from jax import random >>> from jax import numpy as jnp >>> import numpyro >>> from numpyro import handlers, distributions as dist >>> from numpyro.infer import MCMC, NUTS >>> from numpyro.infer.reparam import ExplicitReparam >>> >>> def model(): ... numpyro.sample("x", dist.Gamma(4, 4)) >>> >>> # Sample in unconstrained space using a soft-plus instead of exp transform. >>> reparam = ExplicitReparam(dist.transforms.SoftplusTransform().inv) >>> reparametrized = handlers.reparam(model, {"x": reparam}) >>> kernel = NUTS(model=reparametrized) >>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, num_chains=1) >>> mcmc.run(random.PRNGKey(2)) # doctest: +SKIP sample: 100%|██████████| 2000/2000 [00:00<00:00, 2306.47it/s, 3 steps of size 9.65e-01. acc. prob=0.93] """ def __init__(self, transform): if isinstance(transform, Iterable) and all( isinstance(t, dist.transforms.Transform) for t in transform ): transform = dist.transforms.ComposeTransform(transform) self.transform = transform
[docs] def __call__(self, name, fn, obs): assert obs is None, "ExplicitReparam does not support observe statements" transformed = dist.TransformedDistribution(fn, self.transform) x = numpyro.sample(f"{name}_base", transformed) return None, self.transform.inv(x)