Markov Chain Monte Carlo (MCMC)

Hamiltonian Monte Carlo

class MCMC(sampler, num_warmup, num_samples, num_chains=1, constrain_fn=None, chain_method='parallel', progress_bar=True)[source]

Bases: object

Provides access to Markov Chain Monte Carlo inference algorithms in NumPyro.

Note

chain_method is an experimental arg, which might be removed in a future version.

Parameters:
  • sampler (MCMCKernel) – an instance of MCMCKernel that determines the sampler for running MCMC. Currently, only HMC and NUTS are available.
  • num_warmup (int) – Number of warmup steps.
  • num_samples (int) – Number of samples to generate from the Markov chain.
  • num_chains (int) – Number of Number of MCMC chains to run. By default, chains will be run in parallel using jax.pmap(), failing which, chains will be run in sequence.
  • constrain_fn – Callable that converts a collection of unconstrained sample values returned from the sampler to constrained values that lie within the support of the sample sites.
  • chain_method (str) – One of ‘parallel’ (default), ‘sequential’, ‘vectorized’. The method ‘parallel’ is used to execute the drawing process in parallel on XLA devices (CPUs/GPUs/TPUs), If there are not enough devices for ‘parallel’, we fall back to ‘sequential’ method to draw chains sequentially. ‘vectorized’ method is an experimental feature which vectorizes the drawing method, hence allowing us to collect samples in parallel on a single device.
  • progress_bar (bool) – Whether to enable progress bar updates. Defaults to True.
run(rng, *args, collect_fields=('z', ), collect_warmup=False, init_params=None, **kwargs)[source]

Run the MCMC samplers and collect samples.

Parameters:
  • rng (random.PRNGKey) – Random number generator key to be used for the sampling.
  • args – Arguments to be provided to the numpyro.mcmc.MCMCKernel.init() method. These are typically the arguments needed by the model.
  • collect_fields (tuple or list) – Fields from numpyro.mcmc.HMCState to collect during the MCMC run. By default, only the latent sample sites z is collected.
  • collect_warmup (bool) – Whether to collect samples from the warmup phase. Defaults to False.
  • init_params – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
  • kwargs – Keyword arguments to be provided to the numpyro.mcmc.MCMCKernel.init() method. These are typically the keyword arguments needed by the model.
get_samples(group_by_chain=False)[source]

Get samples from the MCMC run.

Parameters:group_by_chain (bool) – Whether to preserve the chain dimension. If True, all samples will have num_chains as the size of their leading dimension.
Returns:Samples having the same data type as init_params. If multiple fields are collected via the collect_fields arg to run(), then a tuple with the same data type is returned, one for each of the fields. The data type for a particular field is a dict keyed on site names if a model containing Pyro primitives is used, but can be any jaxlib.pytree(), more generally (e.g. when defining a potential_fn for HMC that takes list args).
print_summary()[source]
class HMC(model=None, potential_fn=None, kinetic_fn=None, step_size=1.0, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, trajectory_length=6.283185307179586)[source]

Bases: numpyro.mcmc.MCMCKernel

Hamiltonian Monte Carlo inference, using fixed trajectory length, with provision for step size and mass matrix adaptation.

References:

  1. MCMC Using Hamiltonian Dynamics, Radford M. Neal
Parameters:
  • model – Python callable containing Pyro primitives. If model is provided, potential_fn will be inferred using the model.
  • potential_fn – Python callable that computes the potential energy given input parameters. The input parameters to potential_fn can be any python collection type, provided that init_params argument to init_kernel has the same type.
  • kinetic_fn – Python callable that returns the kinetic energy given inverse mass matrix and momentum. If not provided, the default is euclidean kinetic energy.
  • step_size (float) – Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1.
  • adapt_step_size (bool) – A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme.
  • adapt_mass_matrix (bool) – A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme.
  • dense_mass (bool) – A flag to decide if mass matrix is dense or diagonal (default when dense_mass=False)
  • target_accept_prob (float) – Target acceptance probability for step size adaptation using Dual Averaging. Increasing this value will lead to a smaller step size, hence the sampling will be slower but more robust. Default to 0.8.
  • trajectory_length (float) – Length of a MCMC trajectory for HMC. Default value is \(2\pi\).
init(rng, num_warmup, init_params=None, model_args=(), model_kwargs={})[source]
sample(state)[source]

Run HMC from the given HMCState and return the resulting HMCState.

Parameters:state (HMCState) – Represents the current state.
Returns:Next state after running HMC.
class NUTS(model=None, potential_fn=None, kinetic_fn=None, step_size=1.0, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, trajectory_length=6.283185307179586, max_tree_depth=10)[source]

Bases: numpyro.mcmc.HMC

Hamiltonian Monte Carlo inference, using the No U-Turn Sampler (NUTS) with adaptive path length and mass matrix adaptation.

References:

  1. MCMC Using Hamiltonian Dynamics, Radford M. Neal
  2. The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo, Matthew D. Hoffman, and Andrew Gelman.
  3. A Conceptual Introduction to Hamiltonian Monte Carlo`, Michael Betancourt
Parameters:
  • model – Python callable containing Pyro primitives. If model is provided, potential_fn will be inferred using the model.
  • potential_fn – Python callable that computes the potential energy given input parameters. The input parameters to potential_fn can be any python collection type, provided that init_params argument to init_kernel has the same type.
  • kinetic_fn – Python callable that returns the kinetic energy given inverse mass matrix and momentum. If not provided, the default is euclidean kinetic energy.
  • step_size (float) – Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1.
  • adapt_step_size (bool) – A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme.
  • adapt_mass_matrix (bool) – A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme.
  • dense_mass (bool) – A flag to decide if mass matrix is dense or diagonal (default when dense_mass=False)
  • target_accept_prob (float) – Target acceptance probability for step size adaptation using Dual Averaging. Increasing this value will lead to a smaller step size, hence the sampling will be slower but more robust. Default to 0.8.
  • trajectory_length (float) – Length of a MCMC trajectory for HMC. Default value is \(2\pi\).
  • max_tree_depth (int) – Max depth of the binary tree created during the doubling scheme of NUTS sampler. Defaults to 10.
hmc(potential_fn, kinetic_fn=None, algo='NUTS')[source]

Hamiltonian Monte Carlo inference, using either fixed number of steps or the No U-Turn Sampler (NUTS) with adaptive path length.

References:

  1. MCMC Using Hamiltonian Dynamics, Radford M. Neal
  2. The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo, Matthew D. Hoffman, and Andrew Gelman.
  3. A Conceptual Introduction to Hamiltonian Monte Carlo`, Michael Betancourt
Parameters:
  • potential_fn – Python callable that computes the potential energy given input parameters. The input parameters to potential_fn can be any python collection type, provided that init_params argument to init_kernel has the same type.
  • kinetic_fn – Python callable that returns the kinetic energy given inverse mass matrix and momentum. If not provided, the default is euclidean kinetic energy.
  • algo (str) – Whether to run HMC with fixed number of steps or NUTS with adaptive path length. Default is NUTS.
Returns:

a tuple of callables (init_kernel, sample_kernel), the first one to initialize the sampler, and the second one to generate samples given an existing one.

Warning

Instead of using this interface directly, we would highly recommend you to use the higher level numpyro.mcmc.MCMC API instead.

Example

>>> true_coefs = np.array([1., 2., 3.])
>>> data = random.normal(random.PRNGKey(2), (2000, 3))
>>> dim = 3
>>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(3))
>>>
>>> def model(data, labels):
...     coefs_mean = np.zeros(dim)
...     coefs = numpyro.sample('beta', dist.Normal(coefs_mean, np.ones(3)))
...     intercept = numpyro.sample('intercept', dist.Normal(0., 10.))
...     return numpyro.sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)), obs=labels)
>>>
>>> init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(0),
...                                                            model, data, labels)
>>> init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')
>>> hmc_state = init_kernel(init_params,
...                         trajectory_length=10,
...                         num_warmup=300)
>>> samples = fori_collect(0, 500, sample_kernel, hmc_state,
...                        transform=lambda state: constrain_fn(state.z))
>>> print(np.mean(samples['beta'], axis=0))  
[0.9153987 2.0754058 2.9621222]
init_kernel(init_params, num_warmup, step_size=1.0, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, trajectory_length=6.283185307179586, max_tree_depth=10, run_warmup=True, progbar=True, rng=DeviceArray([0, 0], dtype=uint32))

Initializes the HMC sampler.

Parameters:
  • init_params – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
  • num_warmup (int) – Number of warmup steps; samples generated during warmup are discarded.
  • step_size (float) – Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1.
  • adapt_step_size (bool) – A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme.
  • adapt_mass_matrix (bool) – A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme.
  • dense_mass (bool) – A flag to decide if mass matrix is dense or diagonal (default when dense_mass=False)
  • target_accept_prob (float) – Target acceptance probability for step size adaptation using Dual Averaging. Increasing this value will lead to a smaller step size, hence the sampling will be slower but more robust. Default to 0.8.
  • trajectory_length (float) – Length of a MCMC trajectory for HMC. Default value is \(2\pi\).
  • max_tree_depth (int) – Max depth of the binary tree created during the doubling scheme of NUTS sampler. Defaults to 10.
  • run_warmup (bool) – Flag to decide whether warmup is run. If True, init_kernel returns an initial HMCState that can be used to generate samples using MCMC. Else, returns the arguments and callable that does the initial adaptation.
  • progbar (bool) – Whether to enable progress bar updates. Defaults to True.
  • rng (jax.random.PRNGKey) – random key to be used as the source of randomness.
sample_kernel(hmc_state)

Given an existing HMCState, run HMC with fixed (possibly adapted) step size and return a new HMCState.

Parameters:hmc_state – Current sample (and associated state).
Returns:new proposed HMCState from simulating Hamiltonian dynamics given existing state.
HMCState = <class 'numpyro.mcmc.HMCState'>

A namedtuple() consisting of the following fields:

  • i - iteration. This is reset to 0 after warmup.
  • z - Python collection representing values (unconstrained samples from the posterior) at latent sites.
  • z_grad - Gradient of potential energy w.r.t. latent sample sites.
  • potential_energy - Potential energy computed at the given value of z.
  • num_steps - Number of steps in the Hamiltonian trajectory (for diagnostics).
  • accept_prob - Acceptance probability of the proposal. Note that z does not correspond to the proposal if it is rejected.
  • mean_accept_prob - Mean acceptance probability until current iteration during warmup adaptation or sampling (for diagnostics).
  • diverging - A boolean value to indicate whether the current trajectory is diverging.
  • adapt_state - A AdaptState namedtuple which contains adaptation information during warmup:
    • step_size - Step size to be used by the integrator in the next iteration.
    • inverse_mass_matrix - The inverse mass matrix to be used for the next iteration.
    • mass_matrix_sqrt - The square root of mass matrix to be used for the next iteration. In case of dense mass, this is the Cholesky factorization of the mass matrix.
  • rng - random number generator seed used for the iteration.

MCMC Utilities

initialize_model(rng, model, *model_args, init_strategy=<function init_to_uniform>, **model_kwargs)[source]

Given a model with Pyro primitives, returns a function which, given unconstrained parameters, evaluates the potential energy (negative joint density). In addition, this also returns initial parameters sampled from the prior to initiate MCMC sampling and functions to transform unconstrained values at sample sites to constrained values within their respective support.

Parameters:
  • rng (jax.random.PRNGKey) – random number generator seed to sample from the prior. The returned init_params will have the batch shape rng.shape[:-1].
  • model – Python callable containing Pyro primitives.
  • *model_args – args provided to the model.
  • init_strategy (callable) – a per-site initialization function.
  • **model_kwargs – kwargs provided to the model.
Returns:

tuple of (init_params, potential_fn, constrain_fn), init_params are values from the prior used to initiate MCMC, constrain_fn is a callable that uses inverse transforms to convert unconstrained HMC samples to constrained values that lie within the site’s support.

fori_collect(lower, upper, body_fun, init_val, transform=<function identity>, progbar=True, **progbar_opts)[source]

This looping construct works like fori_loop() but with the additional effect of collecting values from the loop body. In addition, this allows for post-processing of these samples via transform, and progress bar updates. Note that, progbar=False will be faster, especially when collecting a lot of samples. Refer to example usage in hmc().

Parameters:
  • lower (int) – the index to start the collective work. In other words, we will skip collecting the first lower values.
  • upper (int) – number of times to run the loop body.
  • body_fun – a callable that takes a collection of np.ndarray and returns a collection with the same shape and dtype.
  • init_val – initial value to pass as argument to body_fun. Can be any Python collection type containing np.ndarray objects.
  • transform – a callable to post-process the values returned by body_fn.
  • progbar – whether to post progress bar updates.
  • **progbar_opts – optional additional progress bar arguments. A diagnostics_fn can be supplied which when passed the current value from body_fun returns a string that is used to update the progress bar postfix. Also a progbar_desc keyword argument can be supplied which is used to label the progress bar.
Returns:

collection with the same type as init_val with values collected along the leading axis of np.ndarray objects.

consensus(subposteriors, num_draws=None, diagonal=False, rng=None)[source]

Merges subposteriors following consensus Monte Carlo algorithm.

References:

  1. Bayes and big data: The consensus Monte Carlo algorithm, Steven L. Scott, Alexander W. Blocker, Fernando V. Bonassi, Hugh A. Chipman, Edward I. George, Robert E. McCulloch
Parameters:
  • subposteriors (list) – a list in which each element is a collection of samples.
  • num_draws (int) – number of draws from the merged posterior.
  • diagonal (bool) – whether to compute weights using variance or covariance, defaults to False (using covariance).
  • rng (jax.random.PRNGKey) – source of the randomness, defaults to jax.random.PRNGKey(0).
Returns:

if num_draws is None, merges subposteriors without resampling; otherwise, returns a collection of num_draws samples with the same data structure as each subposterior.

parametric(subposteriors, diagonal=False)[source]

Merges subposteriors following (embarrassingly parallel) parametric Monte Carlo algorithm.

References:

  1. Asymptotically Exact, Embarrassingly Parallel MCMC, Willie Neiswanger, Chong Wang, Eric Xing
Parameters:
  • subposteriors (list) – a list in which each element is a collection of samples.
  • diagonal (bool) – whether to compute weights using variance or covariance, defaults to False (using covariance).
Returns:

the estimated mean and variance/covariance parameters of the joined posterior

parametric_draws(subposteriors, num_draws, diagonal=False, rng=None)[source]

Merges subposteriors following (embarrassingly parallel) parametric Monte Carlo algorithm.

References:

  1. Asymptotically Exact, Embarrassingly Parallel MCMC, Willie Neiswanger, Chong Wang, Eric Xing
Parameters:
  • subposteriors (list) – a list in which each element is a collection of samples.
  • num_draws (int) – number of draws from the merged posterior.
  • diagonal (bool) – whether to compute weights using variance or covariance, defaults to False (using covariance).
  • rng (jax.random.PRNGKey) – source of the randomness, defaults to jax.random.PRNGKey(0).
Returns:

a collection of num_draws samples with the same data structure as each subposterior.