# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
import math
import numpy as np
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.distributions import biject_to, constraints
from numpyro.distributions.util import is_identically_one, safe_normalize, sum_rightmost
from numpyro.infer.autoguide import AutoContinuous
from numpyro.util import not_jax_tracer
[docs]class Reparam(ABC):
"""
Base class for reparameterizers.
"""
@abstractmethod
def __call__(self, name, fn, obs):
"""
:param str name: A sample site name.
:param ~numpyro.distributions.Distribution fn: A distribution.
:param numpy.ndarray obs: Observed value or None.
:return: A pair (``new_fn``, ``value``).
"""
return fn, obs
def _unwrap(self, fn):
"""
Unwrap Independent(...) and ExpandedDistribution(...) distributions.
We can recover the input `fn` from the result triple `(fn, expand_shape, event_dim)`
with `fn.expand(expand_shape).to_event(event_dim - fn.event_dim)`.
"""
shape = fn.shape()
event_dim = fn.event_dim
while isinstance(fn, (dist.Independent, dist.ExpandedDistribution)):
fn = fn.base_dist
expand_shape = shape[: len(shape) - fn.event_dim]
return fn, expand_shape, event_dim
def _wrap(self, fn, expand_shape, event_dim):
"""
Wrap in Independent and ExpandedDistribution distributions.
"""
# Match batch_shape.
assert fn.event_dim <= event_dim
fn = fn.expand(expand_shape) # no-op if expand_shape == fn.batch_shape
# Match event_dim.
if fn.event_dim < event_dim:
fn = fn.to_event(event_dim - fn.event_dim)
assert fn.event_dim == event_dim
return fn
[docs]class LocScaleReparam(Reparam):
"""
Generic decentering reparameterizer [1] for latent variables parameterized
by ``loc`` and ``scale`` (and possibly additional ``shape_params``).
This reparameterization works only for latent variables, not likelihoods.
**References:**
1. *Automatic Reparameterisation of Probabilistic Programs*,
Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019)
:param float centered: optional centered parameter. If None (default) learn
a per-site per-element centering parameter in ``[0,1]``. If 0, fully
decenter the distribution; if 1, preserve the centered distribution
unchanged.
:param shape_params: list of additional parameter names to copy unchanged from
the centered to decentered distribution.
:type shape_params: tuple or list
"""
def __init__(self, centered=None, shape_params=()):
assert centered is None or isinstance(
centered, (int, float, np.generic, np.ndarray, jnp.ndarray, jax.core.Tracer)
)
assert isinstance(shape_params, (tuple, list))
assert all(isinstance(name, str) for name in shape_params)
if centered is not None:
is_valid = constraints.unit_interval.check(centered)
if not_jax_tracer(is_valid):
if not np.all(is_valid):
raise ValueError(
"`centered` argument does not satisfy `0 <= centered <= 1`."
)
self.centered = centered
self.shape_params = shape_params
[docs] def __call__(self, name, fn, obs):
assert obs is None, "LocScaleReparam does not support observe statements"
support = fn.support
if isinstance(support, constraints.independent):
support = fn.support.base_constraint
if support is not constraints.real:
raise ValueError(
"LocScaleReparam only supports distributions with real "
f"support, but got {support} support at site {name}."
)
centered = self.centered
if is_identically_one(centered):
return fn, obs
event_shape = fn.event_shape
fn, expand_shape, event_dim = self._unwrap(fn)
# Apply a partial decentering transform.
params = {key: getattr(fn, key) for key in self.shape_params}
if self.centered is None:
centered = numpyro.param(
"{}_centered".format(name),
jnp.full(event_shape, 0.5),
constraint=constraints.unit_interval,
)
if isinstance(centered, (int, float, np.generic)) and centered == 0.0:
params["loc"] = jnp.zeros_like(fn.loc)
params["scale"] = jnp.ones_like(fn.scale)
else:
params["loc"] = fn.loc * centered
params["scale"] = fn.scale**centered
decentered_fn = self._wrap(type(fn)(**params), expand_shape, event_dim)
# Draw decentered noise.
decentered_value = numpyro.sample("{}_decentered".format(name), decentered_fn)
# Differentiably transform.
delta = decentered_value - centered * fn.loc
value = fn.loc + jnp.power(fn.scale, 1 - centered) * delta
# Simulate a pyro.deterministic() site.
return None, value
[docs]class ProjectedNormalReparam(Reparam):
"""
Reparametrizer for :class:`~numpyro.distributions.ProjectedNormal` latent
variables.
This reparameterization works only for latent variables, not likelihoods.
"""
[docs] def __call__(self, name, fn, obs):
assert obs is None, "ProjectedNormalReparam does not support observe statements"
fn, expand_shape, event_dim = self._unwrap(fn)
assert isinstance(fn, dist.ProjectedNormal)
# Draw parameter-free noise.
new_fn = dist.Normal(jnp.zeros(fn.concentration.shape), 1).to_event(1)
x = numpyro.sample(
"{}_normal".format(name), self._wrap(new_fn, expand_shape, event_dim)
)
# Differentiably transform.
value = safe_normalize(x + fn.concentration)
# Simulate a pyro.deterministic() site.
return None, value
[docs]class NeuTraReparam(Reparam):
"""
Neural Transport reparameterizer [1] of multiple latent variables.
This uses a trained :class:`~numpyro.infer.autoguide.AutoContinuous`
guide to alter the geometry of a model, typically for use e.g. in MCMC.
Example usage::
# Step 1. Train a guide
guide = AutoIAFNormal(model)
svi = SVI(model, guide, ...)
# ...train the guide...
# Step 2. Use trained guide in NeuTra MCMC
neutra = NeuTraReparam(guide)
model = netra.reparam(model)
nuts = NUTS(model)
# ...now use the model in HMC or NUTS...
This reparameterization works only for latent variables, not likelihoods.
Note that all sites must share a single common :class:`NeuTraReparam`
instance, and that the model must have static structure.
[1] Hoffman, M. et al. (2019)
"NeuTra-lizing Bad Geometry in Hamiltonian Monte Carlo Using Neural Transport"
https://arxiv.org/abs/1903.03704
:param ~numpyro.infer.autoguide.AutoContinuous guide: A guide.
:param params: trained parameters of the guide.
"""
def __init__(self, guide, params):
if not isinstance(guide, AutoContinuous):
raise TypeError(
"NeuTraReparam expected an AutoContinuous guide, but got {}".format(
type(guide)
)
)
self.guide = guide
self.params = params
try:
self.transform = self.guide.get_transform(params)
except (NotImplementedError, TypeError) as e:
raise ValueError(
"NeuTraReparam only supports guides that implement "
"`get_transform` method that does not depend on the "
"model's `*args, **kwargs`"
) from e
self._x_unconstrained = {}
def _reparam_config(self, site):
if site["name"] in self.guide.prototype_trace:
# We only reparam if this is an unobserved site in the guide
# prototype trace.
guide_site = self.guide.prototype_trace[site["name"]]
if not guide_site.get("is_observed", False):
return self
[docs] def reparam(self, fn=None):
return numpyro.handlers.reparam(fn, config=self._reparam_config)
[docs] def __call__(self, name, fn, obs):
if name not in self.guide.prototype_trace:
return fn, obs
assert obs is None, "NeuTraReparam does not support observe statements"
log_density = 0.0
compute_density = numpyro.get_mask() is not False
if not self._x_unconstrained: # On first sample site.
# Sample a shared latent.
z_unconstrained = numpyro.sample(
"{}_shared_latent".format(self.guide.prefix),
self.guide.get_base_dist().mask(False),
)
# Differentiably transform.
x_unconstrained = self.transform(z_unconstrained)
if compute_density:
log_density = self.transform.log_abs_det_jacobian(
z_unconstrained, x_unconstrained
)
self._x_unconstrained = self.guide._unpack_latent(x_unconstrained)
# Extract a single site's value from the shared latent.
unconstrained_value = self._x_unconstrained.pop(name)
transform = biject_to(fn.support)
value = transform(unconstrained_value)
if compute_density:
logdet = transform.log_abs_det_jacobian(unconstrained_value, value)
logdet = sum_rightmost(
logdet, jnp.ndim(logdet) - jnp.ndim(value) + len(fn.event_shape)
)
log_density = log_density + fn.log_prob(value) + logdet
numpyro.factor("_{}_log_prob".format(name), log_density)
return None, value
[docs]class CircularReparam(Reparam):
"""
Reparametrizer for :class:`~numpyro.distributions.VonMises` latent
variables.
"""
[docs] def __call__(self, name, fn, obs):
# Support must be circular
support = fn.support
if isinstance(support, constraints.independent):
support = fn.support.base_constraint
assert support is constraints.circular
# Draw parameter-free noise.
new_fn = dist.ImproperUniform(constraints.real, fn.batch_shape, fn.event_shape)
value = numpyro.sample(
f"{name}_unwrapped",
new_fn,
obs=obs,
)
# Differentiably transform.
value = jnp.remainder(value + math.pi, 2 * math.pi) - math.pi
# Simulate a pyro.deterministic() site.
numpyro.factor(f"{name}_factor", fn.log_prob(value))
return None, value