Reparameterizers
The numpyro.infer.reparam
module contains reparameterization strategies for
the numpyro.handlers.reparam
effect. These are useful for altering
geometry of a poorly-conditioned parameter space to make the posterior better
shaped. These can be used with a variety of inference algorithms, e.g.
Auto*Normal
guides and MCMC.
Loc-Scale Decentering
- class LocScaleReparam(centered=None, shape_params=())[source]
Bases:
Reparam
Generic decentering reparameterizer [1] for latent variables parameterized by
loc
andscale
(and possibly additionalshape_params
).This reparameterization works only for latent variables, not likelihoods.
References:
Automatic Reparameterisation of Probabilistic Programs, Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019)
- Parameters:
centered (float) – optional centered parameter. If None (default) learn a per-site per-element centering parameter in
[0,1]
initialized at value 0.5. To sample the parameter, consider usinglift
handler with a prior likeUniform(0, 1)
to cast the parameter to a latent variable. If 0, fully decenter the distribution; if 1, preserve the centered distribution unchanged.shape_params (tuple or list) – list of additional parameter names to copy unchanged from the centered to decentered distribution.
- __call__(name, fn, obs)[source]
- Parameters:
name (str) – A sample site name.
fn (Distribution) – A distribution.
obs (numpy.ndarray) – Observed value or None.
- Returns:
A pair (
new_fn
,value
).
Neural Transport
- class NeuTraReparam(guide, params)[source]
Bases:
Reparam
Neural Transport reparameterizer [1] of multiple latent variables.
This uses a trained
AutoContinuous
guide to alter the geometry of a model, typically for use e.g. in MCMC. Example usage:# Step 1. Train a guide guide = AutoIAFNormal(model) svi = SVI(model, guide, ...) # ...train the guide... # Step 2. Use trained guide in NeuTra MCMC neutra = NeuTraReparam(guide) model = neutra.reparam(model) nuts = NUTS(model) # ...now use the model in HMC or NUTS...
This reparameterization works only for latent variables, not likelihoods. Note that all sites must share a single common
NeuTraReparam
instance, and that the model must have static structure.- [1] Hoffman, M. et al. (2019)
“NeuTra-lizing Bad Geometry in Hamiltonian Monte Carlo Using Neural Transport” https://arxiv.org/abs/1903.03704
- Parameters:
guide (AutoContinuous) – A guide.
params – trained parameters of the guide.
- __call__(name, fn, obs)[source]
- Parameters:
name (str) – A sample site name.
fn (Distribution) – A distribution.
obs (numpy.ndarray) – Observed value or None.
- Returns:
A pair (
new_fn
,value
).
- transform_sample(latent)[source]
Given latent samples from the warped posterior (with possible batch dimensions), return a dict of samples from the latent sites in the model.
- Parameters:
latent – sample from the warped posterior (possibly batched).
- Returns:
a dict of samples keyed by latent sites in the model.
- Return type:
Transformed Distributions
- class TransformReparam[source]
Bases:
Reparam
Reparameterizer for
TransformedDistribution
latent variables.This is useful for transformed distributions with complex, geometry-changing transforms, where the posterior has simple shape in the space of
base_dist
.This reparameterization works only for latent variables, not likelihoods.
- __call__(name, fn, obs)[source]
- Parameters:
name (str) – A sample site name.
fn (Distribution) – A distribution.
obs (numpy.ndarray) – Observed value or None.
- Returns:
A pair (
new_fn
,value
).
Projected Normal Distributions
- class ProjectedNormalReparam[source]
Bases:
Reparam
Reparametrizer for
ProjectedNormal
latent variables.This reparameterization works only for latent variables, not likelihoods.
- __call__(name, fn, obs)[source]
- Parameters:
name (str) – A sample site name.
fn (Distribution) – A distribution.
obs (numpy.ndarray) – Observed value or None.
- Returns:
A pair (
new_fn
,value
).
Circular Distributions
Explicit Reparameterization
- class ExplicitReparam(transform)[source]
Bases:
Reparam
Explicit reparametrizer of a latent variable
x
to a transformed spacey = transform(x)
with more amenable geometry. This reparametrizer is similar toTransformReparam
but allows reparametrizations to be decoupled from the model declaration.- Parameters:
transform – Bijective transform to the reparameterized space.
Example:
>>> from jax import random >>> from jax import numpy as jnp >>> import numpyro >>> from numpyro import handlers, distributions as dist >>> from numpyro.infer import MCMC, NUTS >>> from numpyro.infer.reparam import ExplicitReparam >>> >>> def model(): ... numpyro.sample("x", dist.Gamma(4, 4)) >>> >>> # Sample in unconstrained space using a soft-plus instead of exp transform. >>> reparam = ExplicitReparam(dist.transforms.SoftplusTransform().inv) >>> reparametrized = handlers.reparam(model, {"x": reparam}) >>> kernel = NUTS(model=reparametrized) >>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, num_chains=1) >>> mcmc.run(random.PRNGKey(2)) sample: 100%|██████████| 2000/2000 [00:00<00:00, 2306.47it/s, 3 steps of size 9.65e-01. acc. prob=0.93]
- __call__(name, fn, obs)[source]
- Parameters:
name (str) – A sample site name.
fn (Distribution) – A distribution.
obs (numpy.ndarray) – Observed value or None.
- Returns:
A pair (
new_fn
,value
).