Reparameterizers

The numpyro.infer.reparam module contains reparameterization strategies for the numpyro.handlers.reparam effect. These are useful for altering geometry of a poorly-conditioned parameter space to make the posterior better shaped. These can be used with a variety of inference algorithms, e.g. Auto*Normal guides and MCMC.

class Reparam[source]

Bases: ABC

Base class for reparameterizers.

Loc-Scale Decentering

class LocScaleReparam(centered=None, shape_params=())[source]

Bases: Reparam

Generic decentering reparameterizer [1] for latent variables parameterized by loc and scale (and possibly additional shape_params).

This reparameterization works only for latent variables, not likelihoods.

References:

  1. Automatic Reparameterisation of Probabilistic Programs, Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019)

Parameters:
  • centered (float) – optional centered parameter. If None (default) learn a per-site per-element centering parameter in [0,1] initialized at value 0.5. To sample the parameter, consider using lift handler with a prior like Uniform(0, 1) to cast the parameter to a latent variable. If 0, fully decenter the distribution; if 1, preserve the centered distribution unchanged.

  • shape_params (tuple or list) – list of additional parameter names to copy unchanged from the centered to decentered distribution.

__call__(name, fn, obs)[source]
Parameters:
  • name (str) – A sample site name.

  • fn (Distribution) – A distribution.

  • obs (numpy.ndarray) – Observed value or None.

Returns:

A pair (new_fn, value).

Neural Transport

class NeuTraReparam(guide, params)[source]

Bases: Reparam

Neural Transport reparameterizer [1] of multiple latent variables.

This uses a trained AutoContinuous guide to alter the geometry of a model, typically for use e.g. in MCMC. Example usage:

# Step 1. Train a guide
guide = AutoIAFNormal(model)
svi = SVI(model, guide, ...)
# ...train the guide...

# Step 2. Use trained guide in NeuTra MCMC
neutra = NeuTraReparam(guide)
model = netra.reparam(model)
nuts = NUTS(model)
# ...now use the model in HMC or NUTS...

This reparameterization works only for latent variables, not likelihoods. Note that all sites must share a single common NeuTraReparam instance, and that the model must have static structure.

[1] Hoffman, M. et al. (2019)

“NeuTra-lizing Bad Geometry in Hamiltonian Monte Carlo Using Neural Transport” https://arxiv.org/abs/1903.03704

Parameters:
  • guide (AutoContinuous) – A guide.

  • params – trained parameters of the guide.

reparam(fn=None)[source]
__call__(name, fn, obs)[source]
Parameters:
  • name (str) – A sample site name.

  • fn (Distribution) – A distribution.

  • obs (numpy.ndarray) – Observed value or None.

Returns:

A pair (new_fn, value).

transform_sample(latent)[source]

Given latent samples from the warped posterior (with possible batch dimensions), return a dict of samples from the latent sites in the model.

Parameters:

latent – sample from the warped posterior (possibly batched).

Returns:

a dict of samples keyed by latent sites in the model.

Return type:

dict

Transformed Distributions

class TransformReparam[source]

Bases: Reparam

Reparameterizer for TransformedDistribution latent variables.

This is useful for transformed distributions with complex, geometry-changing transforms, where the posterior has simple shape in the space of base_dist.

This reparameterization works only for latent variables, not likelihoods.

__call__(name, fn, obs)[source]
Parameters:
  • name (str) – A sample site name.

  • fn (Distribution) – A distribution.

  • obs (numpy.ndarray) – Observed value or None.

Returns:

A pair (new_fn, value).

Projected Normal Distributions

class ProjectedNormalReparam[source]

Bases: Reparam

Reparametrizer for ProjectedNormal latent variables.

This reparameterization works only for latent variables, not likelihoods.

__call__(name, fn, obs)[source]
Parameters:
  • name (str) – A sample site name.

  • fn (Distribution) – A distribution.

  • obs (numpy.ndarray) – Observed value or None.

Returns:

A pair (new_fn, value).

Circular Distributions

class CircularReparam[source]

Bases: Reparam

Reparametrizer for VonMises latent variables.

__call__(name, fn, obs)[source]
Parameters:
  • name (str) – A sample site name.

  • fn (Distribution) – A distribution.

  • obs (numpy.ndarray) – Observed value or None.

Returns:

A pair (new_fn, value).

Explicit Reparameterization

class ExplicitReparam(transform)[source]

Bases: Reparam

Explicit reparametrizer of a latent variable x to a transformed space y = transform(x) with more amenable geometry. This reparametrizer is similar to TransformReparam but allows reparametrizations to be decoupled from the model declaration.

Parameters:

transform – Bijective transform to the reparameterized space.

Example:

>>> from jax import random
>>> from jax import numpy as jnp
>>> import numpyro
>>> from numpyro import handlers, distributions as dist
>>> from numpyro.infer import MCMC, NUTS
>>> from numpyro.infer.reparam import ExplicitReparam
>>>
>>> def model():
...    numpyro.sample("x", dist.Gamma(4, 4))
>>>
>>> # Sample in unconstrained space using a soft-plus instead of exp transform.
>>> reparam = ExplicitReparam(dist.transforms.SoftplusTransform().inv)
>>> reparametrized = handlers.reparam(model, {"x": reparam})
>>> kernel = NUTS(model=reparametrized)
>>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, num_chains=1)
>>> mcmc.run(random.PRNGKey(2))  
sample: 100%|██████████| 2000/2000 [00:00<00:00, 2306.47it/s, 3 steps of size 9.65e-01. acc. prob=0.93]
__call__(name, fn, obs)[source]
Parameters:
  • name (str) – A sample site name.

  • fn (Distribution) – A distribution.

  • obs (numpy.ndarray) – Observed value or None.

Returns:

A pair (new_fn, value).