Markov Chain Monte Carlo (MCMC)

class MCMC(sampler, num_warmup, num_samples, num_chains=1, postprocess_fn=None, chain_method='parallel', progress_bar=True, jit_model_args=False)[source]

Bases: object

Provides access to Markov Chain Monte Carlo inference algorithms in NumPyro.

Note

chain_method is an experimental arg, which might be removed in a future version.

Note

Setting progress_bar=False will improve the speed for many cases.

Parameters:
  • sampler (MCMCKernel) – an instance of MCMCKernel that determines the sampler for running MCMC. Currently, only HMC and NUTS are available.
  • num_warmup (int) – Number of warmup steps.
  • num_samples (int) – Number of samples to generate from the Markov chain.
  • num_chains (int) – Number of Number of MCMC chains to run. By default, chains will be run in parallel using jax.pmap(), failing which, chains will be run in sequence.
  • postprocess_fn – Post-processing callable - used to convert a collection of unconstrained sample values returned from the sampler to constrained values that lie within the support of the sample sites. Additionally, this is used to return values at deterministic sites in the model.
  • chain_method (str) – One of ‘parallel’ (default), ‘sequential’, ‘vectorized’. The method ‘parallel’ is used to execute the drawing process in parallel on XLA devices (CPUs/GPUs/TPUs), If there are not enough devices for ‘parallel’, we fall back to ‘sequential’ method to draw chains sequentially. ‘vectorized’ method is an experimental feature which vectorizes the drawing method, hence allowing us to collect samples in parallel on a single device.
  • progress_bar (bool) – Whether to enable progress bar updates. Defaults to True.
  • jit_model_args (bool) – If set to True, this will compile the potential energy computation as a function of model arguments. As such, calling MCMC.run again on a same sized but different dataset will not result in additional compilation cost.
warmup(rng_key, *args, extra_fields=(), collect_warmup=False, init_params=None, **kwargs)[source]

Run the MCMC warmup adaptation phase. After this call, the run() method will skip the warmup adaptation phase. To run warmup again for the new data, it is required to run warmup() again.

Parameters:
  • rng_key (random.PRNGKey) – Random number generator key to be used for the sampling.
  • args – Arguments to be provided to the numpyro.infer.mcmc.MCMCKernel.init() method. These are typically the arguments needed by the model.
  • extra_fields (tuple or list) – Extra fields (aside from default_fields()) from the state object (e.g. numpyro.infer.mcmc.HMCState for HMC) to collect during the MCMC run.
  • collect_warmup (bool) – Whether to collect samples from the warmup phase. Defaults to False.
  • init_params – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
  • kwargs – Keyword arguments to be provided to the numpyro.infer.mcmc.MCMCKernel.init() method. These are typically the keyword arguments needed by the model.
run(rng_key, *args, extra_fields=(), init_params=None, **kwargs)[source]

Run the MCMC samplers and collect samples.

Parameters:
  • rng_key (random.PRNGKey) – Random number generator key to be used for the sampling. For multi-chains, a batch of num_chains keys can be supplied. If rng_key does not have batch_size, it will be split in to a batch of num_chains keys.
  • args – Arguments to be provided to the numpyro.infer.mcmc.MCMCKernel.init() method. These are typically the arguments needed by the model.
  • extra_fields (tuple or list) – Extra fields (aside from z, diverging) from numpyro.infer.mcmc.HMCState to collect during the MCMC run.
  • init_params – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
  • kwargs – Keyword arguments to be provided to the numpyro.infer.mcmc.MCMCKernel.init() method. These are typically the keyword arguments needed by the model.

Note

jax allows python code to continue even when the compiled code has not finished yet. This can cause troubles when trying to profile the code for speed. See https://jax.readthedocs.io/en/latest/async_dispatch.html and https://jax.readthedocs.io/en/latest/profiling.html for pointers on profiling jax programs.

get_samples(group_by_chain=False)[source]

Get samples from the MCMC run.

Parameters:group_by_chain (bool) – Whether to preserve the chain dimension. If True, all samples will have num_chains as the size of their leading dimension.
Returns:Samples having the same data type as init_params. The data type is a dict keyed on site names if a model containing Pyro primitives is used, but can be any jaxlib.pytree(), more generally (e.g. when defining a potential_fn for HMC that takes list args).
get_extra_fields(group_by_chain=False)[source]

Get extra fields from the MCMC run.

Parameters:group_by_chain (bool) – Whether to preserve the chain dimension. If True, all samples will have num_chains as the size of their leading dimension.
Returns:Extra fields keyed by field names which are specified in the extra_fields keyword of run().
print_summary(prob=0.9, exclude_deterministic=True)[source]

MCMC Kernels

class MCMCKernel[source]

Bases: abc.ABC

Defines the interface for the Markov transition kernel that is used for MCMC inference.

Example:

>>> from collections import namedtuple
>>> from jax import random
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.infer import MCMC

>>> MHState = namedtuple("MHState", ["z", "rng_key"])

>>> class MetropolisHastings(numpyro.infer.mcmc.MCMCKernel):
...     sample_field = "z"
...
...     def __init__(self, potential_fn, step_size=0.1):
...         self.potential_fn = potential_fn
...         self.step_size = step_size
...
...     def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
...         return MHState(init_params, rng_key)
...
...     def sample(self, state, model_args, model_kwargs):
...         z, rng_key = state
...         rng_key, key_proposal, key_accept = random.split(rng_key, 3)
...         z_proposal = dist.Normal(z, self.step_size).sample(key_proposal)
...         accept_prob = jnp.exp(self.potential_fn(z) - self.potential_fn(z_proposal))
...         z_new = jnp.where(dist.Uniform().sample(key_accept) < accept_prob, z_proposal, z)
...         return MHState(z_new, rng_key)

>>> def f(x):
...     return ((x - 2) ** 2).sum()

>>> kernel = MetropolisHastings(f)
>>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
>>> mcmc.run(random.PRNGKey(0), init_params=jnp.array([1., 2.]))
>>> samples = mcmc.get_samples()
postprocess_fn(model_args, model_kwargs)[source]

Get a function that transforms unconstrained values at sample sites to values constrained to the site’s support, in addition to returning deterministic sites in the model.

Parameters:
  • model_args – Arguments to the model.
  • model_kwargs – Keyword arguments to the model.
init(rng_key, num_warmup, init_params, model_args, model_kwargs)[source]

Initialize the MCMCKernel and return an initial state to begin sampling from.

Parameters:
  • rng_key (random.PRNGKey) – Random number generator key to initialize the kernel.
  • num_warmup (int) – Number of warmup steps. This can be useful when doing adaptation during warmup.
  • init_params (tuple) – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
  • model_args – Arguments provided to the model.
  • model_kwargs – Keyword arguments provided to the model.
Returns:

The initial state representing the state of the kernel. This can be any class that is registered as a pytree.

sample(state, model_args, model_kwargs)[source]

Given the current state, return the next state using the given transition kernel.

Parameters:
  • state

    A pytree class representing the state for the kernel. For HMC, this is given by HMCState. In general, this could be any class that supports getattr.

  • model_args – Arguments provided to the model.
  • model_kwargs – Keyword arguments provided to the model.
Returns:

Next state.

sample_field

The attribute of the state object passed to sample() that denotes the MCMC sample. This is used by postprocess_fn() and for reporting results in MCMC.print_summary().

default_fields

The attributes of the state object to be collected by default during the MCMC run (when MCMC.run() is called).

get_diagnostics_str(state)[source]

Given the current state, returns the diagnostics string to be added to progress bar for diagnostics purpose.

class HMC(model=None, potential_fn=None, kinetic_fn=None, step_size=1.0, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, trajectory_length=6.283185307179586, init_strategy=<function init_to_uniform>, find_heuristic_step_size=False)[source]

Bases: numpyro.infer.mcmc.MCMCKernel

Hamiltonian Monte Carlo inference, using fixed trajectory length, with provision for step size and mass matrix adaptation.

References:

  1. MCMC Using Hamiltonian Dynamics, Radford M. Neal
Parameters:
  • model – Python callable containing Pyro primitives. If model is provided, potential_fn will be inferred using the model.
  • potential_fn – Python callable that computes the potential energy given input parameters. The input parameters to potential_fn can be any python collection type, provided that init_params argument to init() has the same type.
  • kinetic_fn – Python callable that returns the kinetic energy given inverse mass matrix and momentum. If not provided, the default is euclidean kinetic energy.
  • step_size (float) – Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1.
  • adapt_step_size (bool) – A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme.
  • adapt_mass_matrix (bool) – A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme.
  • dense_mass (bool) – A flag to decide if mass matrix is dense or diagonal (default when dense_mass=False)
  • target_accept_prob (float) – Target acceptance probability for step size adaptation using Dual Averaging. Increasing this value will lead to a smaller step size, hence the sampling will be slower but more robust. Default to 0.8.
  • trajectory_length (float) – Length of a MCMC trajectory for HMC. Default value is \(2\pi\).
  • init_strategy (callable) – a per-site initialization function. See Initialization Strategies section for available functions.
  • find_heuristic_step_size (bool) – whether to a heuristic function to adjust the step size at the beginning of each adaptation window. Defaults to False.
model
sample_field

The attribute of the state object passed to sample() that denotes the MCMC sample. This is used by postprocess_fn() and for reporting results in MCMC.print_summary().

default_fields

The attributes of the state object to be collected by default during the MCMC run (when MCMC.run() is called).

get_diagnostics_str(state)[source]

Given the current state, returns the diagnostics string to be added to progress bar for diagnostics purpose.

init(rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={})[source]

Initialize the MCMCKernel and return an initial state to begin sampling from.

Parameters:
  • rng_key (random.PRNGKey) – Random number generator key to initialize the kernel.
  • num_warmup (int) – Number of warmup steps. This can be useful when doing adaptation during warmup.
  • init_params (tuple) – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
  • model_args – Arguments provided to the model.
  • model_kwargs – Keyword arguments provided to the model.
Returns:

The initial state representing the state of the kernel. This can be any class that is registered as a pytree.

postprocess_fn(args, kwargs)[source]

Get a function that transforms unconstrained values at sample sites to values constrained to the site’s support, in addition to returning deterministic sites in the model.

Parameters:
  • model_args – Arguments to the model.
  • model_kwargs – Keyword arguments to the model.
sample(state, model_args, model_kwargs)[source]

Run HMC from the given HMCState and return the resulting HMCState.

Parameters:
  • state (HMCState) – Represents the current state.
  • model_args – Arguments provided to the model.
  • model_kwargs – Keyword arguments provided to the model.
Returns:

Next state after running HMC.

class NUTS(model=None, potential_fn=None, kinetic_fn=None, step_size=1.0, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, trajectory_length=None, max_tree_depth=10, init_strategy=<function init_to_uniform>, find_heuristic_step_size=False)[source]

Bases: numpyro.infer.hmc.HMC

Hamiltonian Monte Carlo inference, using the No U-Turn Sampler (NUTS) with adaptive path length and mass matrix adaptation.

References:

  1. MCMC Using Hamiltonian Dynamics, Radford M. Neal
  2. The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo, Matthew D. Hoffman, and Andrew Gelman.
  3. A Conceptual Introduction to Hamiltonian Monte Carlo`, Michael Betancourt
Parameters:
  • model – Python callable containing Pyro primitives. If model is provided, potential_fn will be inferred using the model.
  • potential_fn – Python callable that computes the potential energy given input parameters. The input parameters to potential_fn can be any python collection type, provided that init_params argument to init_kernel has the same type.
  • kinetic_fn – Python callable that returns the kinetic energy given inverse mass matrix and momentum. If not provided, the default is euclidean kinetic energy.
  • step_size (float) – Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1.
  • adapt_step_size (bool) – A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme.
  • adapt_mass_matrix (bool) – A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme.
  • dense_mass (bool) – A flag to decide if mass matrix is dense or diagonal (default when dense_mass=False)
  • target_accept_prob (float) – Target acceptance probability for step size adaptation using Dual Averaging. Increasing this value will lead to a smaller step size, hence the sampling will be slower but more robust. Default to 0.8.
  • trajectory_length (float) – Length of a MCMC trajectory for HMC. This arg has no effect in NUTS sampler.
  • max_tree_depth (int) – Max depth of the binary tree created during the doubling scheme of NUTS sampler. Defaults to 10.
  • init_strategy (callable) – a per-site initialization function. See Initialization Strategies section for available functions.
  • find_heuristic_step_size (bool) – whether to a heuristic function to adjust the step size at the beginning of each adaptation window. Defaults to False.
class SA(model=None, potential_fn=None, adapt_state_size=None, dense_mass=True, init_strategy=<function init_to_uniform>)[source]

Bases: numpyro.infer.mcmc.MCMCKernel

Sample Adaptive MCMC, a gradient-free sampler.

This is a very fast (in term of n_eff / s) sampler but requires many warmup (burn-in) steps. In each MCMC step, we only need to evaluate potential function at one point.

Note that unlike in reference [1], we return a randomly selected (i.e. thinned) subset of approximate posterior samples of size num_chains x num_samples instead of num_chains x num_samples x adapt_state_size.

Note

We recommend to use this kernel with progress_bar=False in MCMC to reduce JAX’s dispatch overhead.

References:

  1. Sample Adaptive MCMC (https://papers.nips.cc/paper/9107-sample-adaptive-mcmc), Michael Zhu
Parameters:
  • model – Python callable containing Pyro primitives. If model is provided, potential_fn will be inferred using the model.
  • potential_fn – Python callable that computes the potential energy given input parameters. The input parameters to potential_fn can be any python collection type, provided that init_params argument to init() has the same type.
  • adapt_state_size (int) – The number of points to generate proposal distribution. Defaults to 2 times latent size.
  • dense_mass (bool) – A flag to decide if mass matrix is dense or diagonal (default to dense_mass=True)
  • init_strategy (callable) – a per-site initialization function. See Initialization Strategies section for available functions.
init(rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={})[source]

Initialize the MCMCKernel and return an initial state to begin sampling from.

Parameters:
  • rng_key (random.PRNGKey) – Random number generator key to initialize the kernel.
  • num_warmup (int) – Number of warmup steps. This can be useful when doing adaptation during warmup.
  • init_params (tuple) – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
  • model_args – Arguments provided to the model.
  • model_kwargs – Keyword arguments provided to the model.
Returns:

The initial state representing the state of the kernel. This can be any class that is registered as a pytree.

sample_field

The attribute of the state object passed to sample() that denotes the MCMC sample. This is used by postprocess_fn() and for reporting results in MCMC.print_summary().

default_fields

The attributes of the state object to be collected by default during the MCMC run (when MCMC.run() is called).

get_diagnostics_str(state)[source]

Given the current state, returns the diagnostics string to be added to progress bar for diagnostics purpose.

postprocess_fn(args, kwargs)[source]

Get a function that transforms unconstrained values at sample sites to values constrained to the site’s support, in addition to returning deterministic sites in the model.

Parameters:
  • model_args – Arguments to the model.
  • model_kwargs – Keyword arguments to the model.
sample(state, model_args, model_kwargs)[source]

Run SA from the given SAState and return the resulting SAState.

Parameters:
  • state (SAState) – Represents the current state.
  • model_args – Arguments provided to the model.
  • model_kwargs – Keyword arguments provided to the model.
Returns:

Next state after running SA.

hmc(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, algo='NUTS')[source]

Hamiltonian Monte Carlo inference, using either fixed number of steps or the No U-Turn Sampler (NUTS) with adaptive path length.

References:

  1. MCMC Using Hamiltonian Dynamics, Radford M. Neal
  2. The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo, Matthew D. Hoffman, and Andrew Gelman.
  3. A Conceptual Introduction to Hamiltonian Monte Carlo`, Michael Betancourt
Parameters:
  • potential_fn – Python callable that computes the potential energy given input parameters. The input parameters to potential_fn can be any python collection type, provided that init_params argument to init_kernel has the same type.
  • potential_fn_gen – Python callable that when provided with model arguments / keyword arguments returns potential_fn. This may be provided to do inference on the same model with changing data. If the data shape remains the same, we can compile sample_kernel once, and use the same for multiple inference runs.
  • kinetic_fn – Python callable that returns the kinetic energy given inverse mass matrix and momentum. If not provided, the default is euclidean kinetic energy.
  • algo (str) – Whether to run HMC with fixed number of steps or NUTS with adaptive path length. Default is NUTS.
Returns:

a tuple of callables (init_kernel, sample_kernel), the first one to initialize the sampler, and the second one to generate samples given an existing one.

Warning

Instead of using this interface directly, we would highly recommend you to use the higher level numpyro.infer.MCMC API instead.

Example

>>> import jax
>>> from jax import random
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.infer.hmc import hmc
>>> from numpyro.infer.util import initialize_model
>>> from numpyro.util import fori_collect

>>> true_coefs = jnp.array([1., 2., 3.])
>>> data = random.normal(random.PRNGKey(2), (2000, 3))
>>> dim = 3
>>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(3))
>>>
>>> def model(data, labels):
...     coefs_mean = jnp.zeros(dim)
...     coefs = numpyro.sample('beta', dist.Normal(coefs_mean, jnp.ones(3)))
...     intercept = numpyro.sample('intercept', dist.Normal(0., 10.))
...     return numpyro.sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)), obs=labels)
>>>
>>> model_info = initialize_model(random.PRNGKey(0), model, model_args=(data, labels,))
>>> init_kernel, sample_kernel = hmc(model_info.potential_fn, algo='NUTS')
>>> hmc_state = init_kernel(model_info.param_info,
...                         trajectory_length=10,
...                         num_warmup=300)
>>> samples = fori_collect(0, 500, sample_kernel, hmc_state,
...                        transform=lambda state: model_info.postprocess_fn(state.z))
>>> print(jnp.mean(samples['beta'], axis=0))  
[0.9153987 2.0754058 2.9621222]
init_kernel(init_params, num_warmup, step_size=1.0, inverse_mass_matrix=None, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, trajectory_length=6.283185307179586, max_tree_depth=10, find_heuristic_step_size=False, model_args=(), model_kwargs=None, rng_key=DeviceArray([0, 0], dtype=uint32))

Initializes the HMC sampler.

Parameters:
  • init_params – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
  • num_warmup (int) – Number of warmup steps; samples generated during warmup are discarded.
  • step_size (float) – Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1.
  • inverse_mass_matrix (numpy.ndarray) – Initial value for inverse mass matrix. This may be adapted during warmup if adapt_mass_matrix = True. If no value is specified, then it is initialized to the identity matrix.
  • adapt_step_size (bool) – A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme.
  • adapt_mass_matrix (bool) – A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme.
  • dense_mass (bool) – A flag to decide if mass matrix is dense or diagonal (default when dense_mass=False)
  • target_accept_prob (float) – Target acceptance probability for step size adaptation using Dual Averaging. Increasing this value will lead to a smaller step size, hence the sampling will be slower but more robust. Default to 0.8.
  • trajectory_length (float) – Length of a MCMC trajectory for HMC. Default value is \(2\pi\).
  • max_tree_depth (int) – Max depth of the binary tree created during the doubling scheme of NUTS sampler. Defaults to 10.
  • find_heuristic_step_size (bool) – whether to a heuristic function to adjust the step size at the beginning of each adaptation window. Defaults to False.
  • model_args (tuple) – Model arguments if potential_fn_gen is specified.
  • model_kwargs (dict) – Model keyword arguments if potential_fn_gen is specified.
  • rng_key (jax.random.PRNGKey) – random key to be used as the source of randomness.
sample_kernel(hmc_state, model_args=(), model_kwargs=None)

Given an existing HMCState, run HMC with fixed (possibly adapted) step size and return a new HMCState.

Parameters:
  • hmc_state – Current sample (and associated state).
  • model_args (tuple) – Model arguments if potential_fn_gen is specified.
  • model_kwargs (dict) – Model keyword arguments if potential_fn_gen is specified.
Returns:

new proposed HMCState from simulating Hamiltonian dynamics given existing state.

HMCState = <class 'numpyro.infer.hmc.HMCState'>

A namedtuple() consisting of the following fields:

  • i - iteration. This is reset to 0 after warmup.
  • z - Python collection representing values (unconstrained samples from the posterior) at latent sites.
  • z_grad - Gradient of potential energy w.r.t. latent sample sites.
  • potential_energy - Potential energy computed at the given value of z.
  • energy - Sum of potential energy and kinetic energy of the current state.
  • num_steps - Number of steps in the Hamiltonian trajectory (for diagnostics).
  • accept_prob - Acceptance probability of the proposal. Note that z does not correspond to the proposal if it is rejected.
  • mean_accept_prob - Mean acceptance probability until current iteration during warmup adaptation or sampling (for diagnostics).
  • diverging - A boolean value to indicate whether the current trajectory is diverging.
  • adapt_state - A HMCAdaptState namedtuple which contains adaptation information during warmup:
    • step_size - Step size to be used by the integrator in the next iteration.
    • inverse_mass_matrix - The inverse mass matrix to be used for the next iteration.
    • mass_matrix_sqrt - The square root of mass matrix to be used for the next iteration. In case of dense mass, this is the Cholesky factorization of the mass matrix.
  • rng_key - random number generator seed used for the iteration.
SAState = <class 'numpyro.infer.sa.SAState'>

A namedtuple() used in Sample Adaptive MCMC. This consists of the following fields:

  • i - iteration. This is reset to 0 after warmup.
  • z - Python collection representing values (unconstrained samples from the posterior) at latent sites.
  • potential_energy - Potential energy computed at the given value of z.
  • accept_prob - Acceptance probability of the proposal. Note that z does not correspond to the proposal if it is rejected.
  • mean_accept_prob - Mean acceptance probability until current iteration during warmup or sampling (for diagnostics).
  • diverging - A boolean value to indicate whether the new sample potential energy is diverging from the current one.
  • adapt_state - A SAAdaptState namedtuple which contains adaptation information:
    • zs - Step size to be used by the integrator in the next iteration.
    • pes - Potential energies of zs.
    • loc - Mean of those zs.
    • inv_mass_matrix_sqrt - If using dense mass matrix, this is Cholesky of the covariance of zs. Otherwise, this is standard deviation of those zs.
  • rng_key - random number generator seed used for the iteration.

TensorFlow Kernels

Thin wrappers around TensorFlow Probability (TFP) distributions. For details on the TFP distribution interface, see its TransitionKernel docs.

TFPKernel

class TFPKernel(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)[source]

A thin wrapper for TensorFlow Probability (TFP) MCMC transition kernels. The argument target_log_prob_fn in TFP is replaced by either model or potential_fn (which is the negative of target_log_prob_fn).

This class can be used to convert a TFP kernel to a NumPyro-compatible one as follows:

kernel = TFPKernel[tfp.mcmc.NoUTurnSampler](model, step_size=1.)

Note

By default, uncalibrated kernels will be inner kernels of the MetropolisHastings kernel.

Note

For ReplicaExchangeMC, TFP requires that the shape of step_size of the inner kernel must be [len(inverse_temperatures), 1] or [len(inverse_temperatures), latent_size].

Parameters:
  • model – Python callable containing Pyro primitives. If model is provided, potential_fn will be inferred using the model.
  • potential_fn – Python callable that computes the target potential energy given input parameters. The input parameters to potential_fn can be any python collection type, provided that init_params argument to init() has the same type.
  • init_strategy (callable) – a per-site initialization function. See Initialization Strategies section for available functions.
  • kernel_kwargs – other arguments to be passed to TFP kernel constructor.

HamiltonianMonteCarlo

class HamiltonianMonteCarlo(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)

Wraps tensorflow_probability.substrates.jax.mcmc.hmc.HamiltonianMonteCarlo with TFPKernel. The first argument target_log_prob_fn in TFP kernel construction is replaced by either model or potential_fn.

MetropolisAdjustedLangevinAlgorithm

class MetropolisAdjustedLangevinAlgorithm(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)

Wraps tensorflow_probability.substrates.jax.mcmc.langevin.MetropolisAdjustedLangevinAlgorithm with TFPKernel. The first argument target_log_prob_fn in TFP kernel construction is replaced by either model or potential_fn.

NoUTurnSampler

class NoUTurnSampler(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)

Wraps tensorflow_probability.substrates.jax.mcmc.nuts.NoUTurnSampler with TFPKernel. The first argument target_log_prob_fn in TFP kernel construction is replaced by either model or potential_fn.

RandomWalkMetropolis

class RandomWalkMetropolis(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)

Wraps tensorflow_probability.substrates.jax.mcmc.random_walk_metropolis.RandomWalkMetropolis with TFPKernel. The first argument target_log_prob_fn in TFP kernel construction is replaced by either model or potential_fn.

ReplicaExchangeMC

class ReplicaExchangeMC(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)

Wraps tensorflow_probability.substrates.jax.mcmc.replica_exchange_mc.ReplicaExchangeMC with TFPKernel. The first argument target_log_prob_fn in TFP kernel construction is replaced by either model or potential_fn.

SliceSampler

class SliceSampler(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)

Wraps tensorflow_probability.substrates.jax.mcmc.slice_sampler_kernel.SliceSampler with TFPKernel. The first argument target_log_prob_fn in TFP kernel construction is replaced by either model or potential_fn.

UncalibratedHamiltonianMonteCarlo

class UncalibratedHamiltonianMonteCarlo(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)

Wraps tensorflow_probability.substrates.jax.mcmc.hmc.UncalibratedHamiltonianMonteCarlo with TFPKernel. The first argument target_log_prob_fn in TFP kernel construction is replaced by either model or potential_fn.

UncalibratedLangevin

class UncalibratedLangevin(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)

Wraps tensorflow_probability.substrates.jax.mcmc.langevin.UncalibratedLangevin with TFPKernel. The first argument target_log_prob_fn in TFP kernel construction is replaced by either model or potential_fn.

UncalibratedRandomWalk

class UncalibratedRandomWalk(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)

Wraps tensorflow_probability.substrates.jax.mcmc.random_walk_metropolis.UncalibratedRandomWalk with TFPKernel. The first argument target_log_prob_fn in TFP kernel construction is replaced by either model or potential_fn.

MCMC Utilities

initialize_model(rng_key, model, init_strategy=<function init_to_uniform>, dynamic_args=False, model_args=(), model_kwargs=None)[source]

(EXPERIMENTAL INTERFACE) Helper function that calls get_potential_fn() and find_valid_initial_params() under the hood to return a tuple of (init_params_info, potential_fn, postprocess_fn, model_trace).

Parameters:
  • rng_key (jax.random.PRNGKey) – random number generator seed to sample from the prior. The returned init_params will have the batch shape rng_key.shape[:-1].
  • model – Python callable containing Pyro primitives.
  • init_strategy (callable) – a per-site initialization function. See Initialization Strategies section for available functions.
  • dynamic_args (bool) – if True, the potential_fn and constraints_fn are themselves dependent on model arguments. When provided a *model_args, **model_kwargs, they return potential_fn and constraints_fn callables, respectively.
  • model_args (tuple) – args provided to the model.
  • model_kwargs (dict) – kwargs provided to the model.
Returns:

a namedtupe ModelInfo which contains the fields (param_info, potential_fn, postprocess_fn, model_trace), where param_info is a namedtuple ParamInfo containing values from the prior used to initiate MCMC, their corresponding potential energy, and their gradients; postprocess_fn is a callable that uses inverse transforms to convert unconstrained HMC samples to constrained values that lie within the site’s support, in addition to returning values at deterministic sites in the model.

fori_collect(lower, upper, body_fun, init_val, transform=<function identity>, progbar=True, return_last_val=False, collection_size=None, **progbar_opts)[source]

This looping construct works like fori_loop() but with the additional effect of collecting values from the loop body. In addition, this allows for post-processing of these samples via transform, and progress bar updates. Note that, progbar=False will be faster, especially when collecting a lot of samples. Refer to example usage in hmc().

Parameters:
  • lower (int) – the index to start the collective work. In other words, we will skip collecting the first lower values.
  • upper (int) – number of times to run the loop body.
  • body_fun – a callable that takes a collection of np.ndarray and returns a collection with the same shape and dtype.
  • init_val – initial value to pass as argument to body_fun. Can be any Python collection type containing np.ndarray objects.
  • transform – a callable to post-process the values returned by body_fn.
  • progbar – whether to post progress bar updates.
  • return_last_val (bool) – If True, the last value is also returned. This has the same type as init_val.
  • collection_size (int) – Size of the returned collection. If not specified, the size will be upper - lower. If the size is larger than upper - lower, only the top upper - lower entries will be non-zero.
  • **progbar_opts – optional additional progress bar arguments. A diagnostics_fn can be supplied which when passed the current value from body_fun returns a string that is used to update the progress bar postfix. Also a progbar_desc keyword argument can be supplied which is used to label the progress bar.
Returns:

collection with the same type as init_val with values collected along the leading axis of np.ndarray objects.

consensus(subposteriors, num_draws=None, diagonal=False, rng_key=None)[source]

Merges subposteriors following consensus Monte Carlo algorithm.

References:

  1. Bayes and big data: The consensus Monte Carlo algorithm, Steven L. Scott, Alexander W. Blocker, Fernando V. Bonassi, Hugh A. Chipman, Edward I. George, Robert E. McCulloch
Parameters:
  • subposteriors (list) – a list in which each element is a collection of samples.
  • num_draws (int) – number of draws from the merged posterior.
  • diagonal (bool) – whether to compute weights using variance or covariance, defaults to False (using covariance).
  • rng_key (jax.random.PRNGKey) – source of the randomness, defaults to jax.random.PRNGKey(0).
Returns:

if num_draws is None, merges subposteriors without resampling; otherwise, returns a collection of num_draws samples with the same data structure as each subposterior.

parametric(subposteriors, diagonal=False)[source]

Merges subposteriors following (embarrassingly parallel) parametric Monte Carlo algorithm.

References:

  1. Asymptotically Exact, Embarrassingly Parallel MCMC, Willie Neiswanger, Chong Wang, Eric Xing
Parameters:
  • subposteriors (list) – a list in which each element is a collection of samples.
  • diagonal (bool) – whether to compute weights using variance or covariance, defaults to False (using covariance).
Returns:

the estimated mean and variance/covariance parameters of the joined posterior

parametric_draws(subposteriors, num_draws, diagonal=False, rng_key=None)[source]

Merges subposteriors following (embarrassingly parallel) parametric Monte Carlo algorithm.

References:

  1. Asymptotically Exact, Embarrassingly Parallel MCMC, Willie Neiswanger, Chong Wang, Eric Xing
Parameters:
  • subposteriors (list) – a list in which each element is a collection of samples.
  • num_draws (int) – number of draws from the merged posterior.
  • diagonal (bool) – whether to compute weights using variance or covariance, defaults to False (using covariance).
  • rng_key (jax.random.PRNGKey) – source of the randomness, defaults to jax.random.PRNGKey(0).
Returns:

a collection of num_draws samples with the same data structure as each subposterior.