# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
import math
from typing import Iterable
import numpy as np
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.distributions import biject_to, constraints
from numpyro.distributions.util import is_identically_one, safe_normalize, sum_rightmost
from numpyro.infer.autoguide import AutoContinuous
from numpyro.util import not_jax_tracer
[docs]
class Reparam(ABC):
"""
Base class for reparameterizers.
"""
@abstractmethod
def __call__(self, name, fn, obs):
"""
:param str name: A sample site name.
:param ~numpyro.distributions.Distribution fn: A distribution.
:param numpy.ndarray obs: Observed value or None.
:return: A pair (``new_fn``, ``value``).
"""
return fn, obs
def _unwrap(self, fn):
"""
Unwrap Independent(...) and ExpandedDistribution(...) distributions.
We can recover the input `fn` from the result triple `(fn, expand_shape, event_dim)`
with `fn.expand(expand_shape).to_event(event_dim - fn.event_dim)`.
"""
shape = fn.shape()
event_dim = fn.event_dim
while isinstance(fn, (dist.Independent, dist.ExpandedDistribution)):
fn = fn.base_dist
expand_shape = shape[: len(shape) - fn.event_dim]
return fn, expand_shape, event_dim
def _wrap(self, fn, expand_shape, event_dim):
"""
Wrap in Independent and ExpandedDistribution distributions.
"""
# Match batch_shape.
assert fn.event_dim <= event_dim
fn = fn.expand(expand_shape) # no-op if expand_shape == fn.batch_shape
# Match event_dim.
if fn.event_dim < event_dim:
fn = fn.to_event(event_dim - fn.event_dim)
assert fn.event_dim == event_dim
return fn
[docs]
class LocScaleReparam(Reparam):
"""
Generic decentering reparameterizer [1] for latent variables parameterized
by ``loc`` and ``scale`` (and possibly additional ``shape_params``).
This reparameterization works only for latent variables, not likelihoods.
**References:**
1. *Automatic Reparameterisation of Probabilistic Programs*,
Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019)
:param float centered: optional centered parameter. If None (default) learn
a per-site per-element centering parameter in ``[0,1]`` initialized at value 0.5.
To sample the parameter, consider using :class:`~numpyro.handlers.lift` handler with a
prior like ``Uniform(0, 1)`` to cast the parameter to a latent variable. If 0, fully
decenter the distribution; if 1, preserve the centered distribution
unchanged.
:param shape_params: list of additional parameter names to copy unchanged from
the centered to decentered distribution.
:type shape_params: tuple or list
"""
def __init__(self, centered=None, shape_params=()):
assert centered is None or isinstance(
centered, (int, float, np.generic, np.ndarray, jnp.ndarray, jax.core.Tracer)
)
assert isinstance(shape_params, (tuple, list))
assert all(isinstance(name, str) for name in shape_params)
if centered is not None:
is_valid = constraints.unit_interval.check(centered)
if not_jax_tracer(is_valid):
if not np.all(is_valid):
raise ValueError(
"`centered` argument does not satisfy `0 <= centered <= 1`."
)
self.centered = centered
self.shape_params = shape_params
[docs]
def __call__(self, name, fn, obs):
assert obs is None, "LocScaleReparam does not support observe statements"
support = fn.support
if isinstance(support, constraints.independent):
support = fn.support.base_constraint
if support is not constraints.real:
raise ValueError(
"LocScaleReparam only supports distributions with real "
f"support, but got {support} support at site {name}."
)
centered = self.centered
if is_identically_one(centered):
return fn, obs
event_shape = fn.event_shape
fn, expand_shape, event_dim = self._unwrap(fn)
# Apply a partial decentering transform.
params = {key: getattr(fn, key) for key in self.shape_params}
if self.centered is None:
centered = numpyro.param(
"{}_centered".format(name),
jnp.full(event_shape, 0.5),
constraint=constraints.unit_interval,
)
if isinstance(centered, (int, float, np.generic)) and centered == 0.0:
params["loc"] = jnp.zeros_like(fn.loc)
params["scale"] = jnp.ones_like(fn.scale)
else:
params["loc"] = fn.loc * centered
params["scale"] = fn.scale**centered
decentered_fn = self._wrap(type(fn)(**params), expand_shape, event_dim)
# Draw decentered noise.
decentered_value = numpyro.sample("{}_decentered".format(name), decentered_fn)
# Differentiably transform.
delta = decentered_value - centered * fn.loc
value = fn.loc + jnp.power(fn.scale, 1 - centered) * delta
# Simulate a pyro.deterministic() site.
return None, value
[docs]
class ProjectedNormalReparam(Reparam):
"""
Reparametrizer for :class:`~numpyro.distributions.ProjectedNormal` latent
variables.
This reparameterization works only for latent variables, not likelihoods.
"""
[docs]
def __call__(self, name, fn, obs):
assert obs is None, "ProjectedNormalReparam does not support observe statements"
fn, expand_shape, event_dim = self._unwrap(fn)
assert isinstance(fn, dist.ProjectedNormal)
# Draw parameter-free noise.
new_fn = dist.Normal(jnp.zeros(fn.concentration.shape), 1).to_event(1)
x = numpyro.sample(
"{}_normal".format(name), self._wrap(new_fn, expand_shape, event_dim)
)
# Differentiably transform.
value = safe_normalize(x + fn.concentration)
# Simulate a pyro.deterministic() site.
return None, value
[docs]
class NeuTraReparam(Reparam):
"""
Neural Transport reparameterizer [1] of multiple latent variables.
This uses a trained :class:`~numpyro.infer.autoguide.AutoContinuous`
guide to alter the geometry of a model, typically for use e.g. in MCMC.
Example usage::
# Step 1. Train a guide
guide = AutoIAFNormal(model)
svi = SVI(model, guide, ...)
# ...train the guide...
# Step 2. Use trained guide in NeuTra MCMC
neutra = NeuTraReparam(guide)
model = netra.reparam(model)
nuts = NUTS(model)
# ...now use the model in HMC or NUTS...
This reparameterization works only for latent variables, not likelihoods.
Note that all sites must share a single common :class:`NeuTraReparam`
instance, and that the model must have static structure.
[1] Hoffman, M. et al. (2019)
"NeuTra-lizing Bad Geometry in Hamiltonian Monte Carlo Using Neural Transport"
https://arxiv.org/abs/1903.03704
:param ~numpyro.infer.autoguide.AutoContinuous guide: A guide.
:param params: trained parameters of the guide.
"""
def __init__(self, guide, params):
if not isinstance(guide, AutoContinuous):
raise TypeError(
"NeuTraReparam expected an AutoContinuous guide, but got {}".format(
type(guide)
)
)
self.guide = guide
self.params = params
try:
self.transform = self.guide.get_transform(params)
except (NotImplementedError, TypeError) as e:
raise ValueError(
"NeuTraReparam only supports guides that implement "
"`get_transform` method that does not depend on the "
"model's `*args, **kwargs`"
) from e
self._x_unconstrained = {}
def _reparam_config(self, site):
if site["name"] in self.guide.prototype_trace:
# We only reparam if this is an unobserved site in the guide
# prototype trace.
guide_site = self.guide.prototype_trace[site["name"]]
if not guide_site.get("is_observed", False):
return self
[docs]
def reparam(self, fn=None):
return numpyro.handlers.reparam(fn, config=self._reparam_config)
[docs]
def __call__(self, name, fn, obs):
if name not in self.guide.prototype_trace:
return fn, obs
assert obs is None, "NeuTraReparam does not support observe statements"
log_density = 0.0
compute_density = numpyro.get_mask() is not False
if not self._x_unconstrained: # On first sample site.
# Sample a shared latent.
z_unconstrained = numpyro.sample(
"{}_shared_latent".format(self.guide.prefix),
self.guide.get_base_dist().mask(False),
)
# Differentiably transform.
x_unconstrained = self.transform(z_unconstrained)
if compute_density:
log_density = self.transform.log_abs_det_jacobian(
z_unconstrained, x_unconstrained
)
self._x_unconstrained = self.guide._unpack_latent(x_unconstrained)
# Extract a single site's value from the shared latent.
unconstrained_value = self._x_unconstrained.pop(name)
transform = biject_to(fn.support)
value = transform(unconstrained_value)
if compute_density:
logdet = transform.log_abs_det_jacobian(unconstrained_value, value)
logdet = sum_rightmost(
logdet, jnp.ndim(logdet) - jnp.ndim(value) + len(fn.event_shape)
)
log_density = log_density + fn.log_prob(value) + logdet
numpyro.factor("_{}_log_prob".format(name), log_density)
return None, value
[docs]
class CircularReparam(Reparam):
"""
Reparametrizer for :class:`~numpyro.distributions.VonMises` latent
variables.
"""
[docs]
def __call__(self, name, fn, obs):
# Support must be circular
support = fn.support
if isinstance(support, constraints.independent):
support = fn.support.base_constraint
assert support is constraints.circular
# Draw parameter-free noise.
new_fn = dist.ImproperUniform(constraints.real, fn.batch_shape, fn.event_shape)
value = numpyro.sample(
f"{name}_unwrapped",
new_fn,
obs=obs,
)
# Differentiably transform.
value = jnp.remainder(value + math.pi, 2 * math.pi) - math.pi
# Simulate a pyro.deterministic() site.
numpyro.factor(f"{name}_factor", fn.log_prob(value))
return None, value
[docs]
class ExplicitReparam(Reparam):
"""
Explicit reparametrizer of a latent variable :code:`x` to a transformed space
:code:`y = transform(x)` with more amenable geometry. This reparametrizer is similar
to :class:`.TransformReparam` but allows reparametrizations to be decoupled from the
model declaration.
:param transform: Bijective transform to the reparameterized space.
**Example:**
.. doctest::
>>> from jax import random
>>> from jax import numpy as jnp
>>> import numpyro
>>> from numpyro import handlers, distributions as dist
>>> from numpyro.infer import MCMC, NUTS
>>> from numpyro.infer.reparam import ExplicitReparam
>>>
>>> def model():
... numpyro.sample("x", dist.Gamma(4, 4))
>>>
>>> # Sample in unconstrained space using a soft-plus instead of exp transform.
>>> reparam = ExplicitReparam(dist.transforms.SoftplusTransform().inv)
>>> reparametrized = handlers.reparam(model, {"x": reparam})
>>> kernel = NUTS(model=reparametrized)
>>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, num_chains=1)
>>> mcmc.run(random.PRNGKey(2)) # doctest: +SKIP
sample: 100%|██████████| 2000/2000 [00:00<00:00, 2306.47it/s, 3 steps of size 9.65e-01. acc. prob=0.93]
"""
def __init__(self, transform):
if isinstance(transform, Iterable) and all(
isinstance(t, dist.transforms.Transform) for t in transform
):
transform = dist.transforms.ComposeTransform(transform)
self.transform = transform
[docs]
def __call__(self, name, fn, obs):
assert obs is None, "ExplicitReparam does not support observe statements"
transformed = dist.TransformedDistribution(fn, self.transform)
x = numpyro.sample(f"{name}_base", transformed)
return None, self.transform.inv(x)