Markov Chain Monte Carlo (MCMC)
We provide a high-level overview of the MCMC algorithms in NumPyro:
NUTS, which is an adaptive variant of HMC, is probably the most commonly used MCMC algorithm in NumPyro. Note that NUTS and HMC are not directly applicable to models with discrete latent variables, but in cases where the discrete variables have finite support and summing them out (i.e. enumeration) is tractable, NumPyro will automatically sum out discrete latent variables and perform NUTS/HMC on the remaining continuous latent variables. As discussed above, model reparameterization may be important in some cases to get good performance. Note that, generally speaking, we expect inference to be harder as the dimension of the latent space increases. See the bad geometry tutorial for additional tips and tricks.
MixedHMC can be an effective inference strategy for models that contain both continuous and discrete latent variables.
HMCECS can be an effective inference strategy for models with a large number of data points. It is applicable to models with continuous latent variables. See this example for detailed usage.
BarkerMH is a gradient-based MCMC method that may be competitive with HMC and NUTS for some models. It is applicable to models with continuous latent variables.
HMCGibbs combines HMC/NUTS steps with custom Gibbs updates. Gibbs updates must be specified by the user.
DiscreteHMCGibbs combines HMC/NUTS steps with Gibbs updates for discrete latent variables. The corresponding Gibbs updates are computed automatically.
SA is a gradient-free MCMC method. It is only applicable to models with continuous latent variables. It is expected to perform best for models whose latent dimension is low to moderate. It may be a good choice for models with non-differentiable log densities. Note that SA generally requires a very large number of samples, as mixing tends to be slow. On the plus side individual steps can be fast.
AIES is a gradient-free ensemble MCMC method that informs Metropolis-Hastings proposals by sharing information between chains. It is only applicable to models with continuous latent variables. It is expected to perform best for models whose latent dimension is low to moderate. It may be a good choice for models with non-differentiable log densities, and can be robust to likelihood-free models. AIES generally requires the number of chains to be twice as large as the number of latent parameters, (and ideally larger).
ESS is a gradient-free ensemble MCMC method that shares information between chains to find good slice sampling directions. It tends to be more sample efficient than AIES. It is only applicable to models with continuous latent variables. It is expected to perform best for models whose latent dimension is low to moderate and may be a good choice for models with non-differentiable log densities. ESS generally requires the number of chains to be twice as large as the number of latent parameters, (and ideally larger).
Like HMC/NUTS, all remaining MCMC algorithms support enumeration over discrete latent variables if possible (see restrictions). Enumerated sites need to be marked with infer={‘enumerate’: ‘parallel’} like in the annotation example.
- class MCMC(sampler, *, num_warmup, num_samples, num_chains=1, thinning=1, postprocess_fn=None, chain_method='parallel', progress_bar=True, jit_model_args=False)[source]
Bases:
object
Provides access to Markov Chain Monte Carlo inference algorithms in NumPyro.
Note
chain_method is an experimental arg, which might be removed in a future version.
Note
Setting progress_bar=False will improve the speed for many cases. But it might require more memory than the other option.
Note
If setting num_chains greater than 1 in a Jupyter Notebook, then you will need to have installed ipywidgets in the environment from which you launced Jupyter in order for the progress bars to render correctly. If you are using Jupyter Notebook or Jupyter Lab, please also install the corresponding extension package like widgetsnbextension or jupyterlab_widgets.
Note
If your dataset is large and you have access to multiple acceleration devices, you can distribute the computation across multiple devices. Make sure that your jax version is v0.4.4 or newer. For example,
import jax from jax.experimental import mesh_utils from jax.sharding import PositionalSharding import numpy as np import numpyro import numpyro.distributions as dist from numpyro.infer import MCMC, NUTS X = np.random.randn(128, 3) y = np.random.randn(128) def model(X, y): beta = numpyro.sample("beta", dist.Normal(0, 1).expand([3])) numpyro.sample("obs", dist.Normal(X @ beta, 1), obs=y) mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10) # See https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html sharding = PositionalSharding(mesh_utils.create_device_mesh((8,))) X_shard = jax.device_put(X, sharding.reshape(8, 1)) y_shard = jax.device_put(y, sharding.reshape(8)) mcmc.run(jax.random.PRNGKey(0), X_shard, y_shard)
- Parameters:
sampler (MCMCKernel) – an instance of
MCMCKernel
that determines the sampler for running MCMC. Currently, onlyHMC
andNUTS
are available.num_warmup (int) – Number of warmup steps.
num_samples (int) – Number of samples to generate from the Markov chain.
thinning (int) – Positive integer that controls the fraction of post-warmup samples that are retained. For example if thinning is 2 then every other sample is retained. Defaults to 1, i.e. no thinning.
num_chains (int) – Number of MCMC chains to run. By default, chains will be run in parallel using
jax.pmap()
. If there are not enough devices available, chains will be run in sequence.postprocess_fn – Post-processing callable - used to convert a collection of unconstrained sample values returned from the sampler to constrained values that lie within the support of the sample sites. Additionally, this is used to return values at deterministic sites in the model.
chain_method (str) – A callable jax transform like jax.vmap or one of ‘parallel’ (default), ‘sequential’, ‘vectorized’. The method ‘parallel’ is used to execute the drawing process in parallel on XLA devices (CPUs/GPUs/TPUs), If there are not enough devices for ‘parallel’, we fall back to ‘sequential’ method to draw chains sequentially. ‘vectorized’ method is an experimental feature which vectorizes the drawing method, hence allowing us to collect samples in parallel on a single device.
progress_bar (bool) – Whether to enable progress bar updates. Defaults to
True
.jit_model_args (bool) – If set to True, this will compile the potential energy computation as a function of model arguments. As such, calling MCMC.run again on a same sized but different dataset will not result in additional compilation cost. Note that currently, this does not take effect for the case
num_chains > 1
andchain_method == 'parallel'
.
Note
It is possible to mix parallel and vectorized sampling, i.e., run vectorized chains on multiple devices using explicit pmap. Currently, doing so requires disabling the progress bar. For example,
def do_mcmc(rng_key, n_vectorized=8): nuts_kernel = NUTS(model) mcmc = MCMC( nuts_kernel, progress_bar=False, num_chains=n_vectorized, chain_method='vectorized' ) mcmc.run( rng_key, extra_fields=("potential_energy",), ) return {**mcmc.get_samples(), **mcmc.get_extra_fields()} # Number of devices to pmap over n_parallel = jax.local_device_count() rng_keys = jax.random.split(PRNGKey(rng_seed), n_parallel) traces = pmap(do_mcmc)(rng_keys) # concatenate traces along pmap'ed axis trace = {k: np.concatenate(v) for k, v in traces.items()}
- property post_warmup_state
The state before the sampling phase. If this attribute is not None,
run()
will skip the warmup phase and start with the state specified in this attribute.Note
This attribute can be used to sequentially draw MCMC samples. For example,
mcmc = MCMC(NUTS(model), num_warmup=100, num_samples=100) mcmc.run(random.PRNGKey(0)) first_100_samples = mcmc.get_samples() mcmc.post_warmup_state = mcmc.last_state mcmc.run(mcmc.post_warmup_state.rng_key) # or mcmc.run(random.PRNGKey(1)) second_100_samples = mcmc.get_samples()
- property last_state
The final MCMC state at the end of the sampling phase.
- warmup(rng_key, *args, extra_fields=(), collect_warmup=False, init_params=None, **kwargs)[source]
Run the MCMC warmup adaptation phase. After this call, self.post_warmup_state will be set and the
run()
method will skip the warmup adaptation phase. To run warmup again for the new data, it is required to runwarmup()
again.- Parameters:
rng_key (random.PRNGKey) – Random number generator key to be used for the sampling.
args – Arguments to be provided to the
numpyro.infer.mcmc.MCMCKernel.init()
method. These are typically the arguments needed by the model.extra_fields (tuple or list) – Extra fields (aside from
default_fields()
) from the state object (e.g.numpyro.infer.hmc.HMCState
for HMC) to collect during the MCMC run. Exclude sample sites from collection with “~`sampler.sample_field`.`sample_site`”. e.g. “~z.a” will prevent site “a” from being collected if you’re using the NUTS sampler. To collect samples of a site “a” in the unconstrained space, we can specify the variable here, e.g. extra_fields=(“z.a”,).collect_warmup (bool) – Whether to collect samples from the warmup phase. Defaults to False.
init_params – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn provided to the kernel. If the kernel is instantiated by a numpyro model, the initial parameters here correspond to latent values in unconstrained space.
kwargs – Keyword arguments to be provided to the
numpyro.infer.mcmc.MCMCKernel.init()
method. These are typically the keyword arguments needed by the model.
- run(rng_key, *args, extra_fields=(), init_params=None, **kwargs)[source]
Run the MCMC samplers and collect samples.
- Parameters:
rng_key (random.PRNGKey) – Random number generator key to be used for the sampling. For multi-chains, a batch of num_chains keys can be supplied. If rng_key does not have batch_size, it will be split in to a batch of num_chains keys.
args – Arguments to be provided to the
numpyro.infer.mcmc.MCMCKernel.init()
method. These are typically the arguments needed by the model.extra_fields (tuple or list of str) – Extra fields (aside from “z”, “diverging”) from the state object (e.g.
numpyro.infer.hmc.HMCState
for HMC) to be collected during the MCMC run. Note that subfields can be accessed using dots, e.g. “adapt_state.step_size” can be used to collect step sizes at each step. Exclude sample sites from collection with “~`sampler.sample_field`.`sample_site`”. e.g. “~z.a” will prevent site “a” from being collected if you’re using the NUTS sampler. To collect samples of a site “a” in the unconstrained space, we can specify the variable here, e.g. extra_fields=(“z.a”,).init_params – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn provided to the kernel. If the kernel is instantiated by a numpyro model, the initial parameters here correspond to latent values in unconstrained space.
kwargs – Keyword arguments to be provided to the
numpyro.infer.mcmc.MCMCKernel.init()
method. These are typically the keyword arguments needed by the model.
Note
jax allows python code to continue even when the compiled code has not finished yet. This can cause troubles when trying to profile the code for speed. See https://jax.readthedocs.io/en/latest/async_dispatch.html and https://jax.readthedocs.io/en/latest/profiling.html for pointers on profiling jax programs.
- get_samples(group_by_chain=False)[source]
Get samples from the MCMC run.
- Parameters:
group_by_chain (bool) – Whether to preserve the chain dimension. If True, all samples will have num_chains as the size of their leading dimension.
- Returns:
Samples having the same data type as init_params. The data type is a dict keyed on site names if a model containing Pyro primitives is used, but can be any
jaxlib.pytree()
, more generally (e.g. when defining a potential_fn for HMC that takes list args).
Example:
You can then pass those samples to
Predictive
:posterior_samples = mcmc.get_samples() predictive = Predictive(model, posterior_samples=posterior_samples) samples = predictive(rng_key1, *model_args, **model_kwargs)
MCMC Kernels
MCMCKernel
- class MCMCKernel[source]
Bases:
ABC
Defines the interface for the Markov transition kernel that is used for
MCMC
inference.Example:
>>> from collections import namedtuple >>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.infer import MCMC >>> MHState = namedtuple("MHState", ["u", "rng_key"]) >>> class MetropolisHastings(numpyro.infer.mcmc.MCMCKernel): ... sample_field = "u" ... ... def __init__(self, potential_fn, step_size=0.1): ... self.potential_fn = potential_fn ... self.step_size = step_size ... ... def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): ... return MHState(init_params, rng_key) ... ... def sample(self, state, model_args, model_kwargs): ... u, rng_key = state ... rng_key, key_proposal, key_accept = random.split(rng_key, 3) ... u_proposal = dist.Normal(u, self.step_size).sample(key_proposal) ... accept_prob = jnp.exp(self.potential_fn(u) - self.potential_fn(u_proposal)) ... u_new = jnp.where(dist.Uniform().sample(key_accept) < accept_prob, u_proposal, u) ... return MHState(u_new, rng_key) >>> def f(x): ... return ((x - 2) ** 2).sum() >>> kernel = MetropolisHastings(f) >>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000) >>> mcmc.run(random.PRNGKey(0), init_params=jnp.array([1., 2.])) >>> posterior_samples = mcmc.get_samples() >>> mcmc.print_summary()
- postprocess_fn(model_args, model_kwargs)[source]
Get a function that transforms unconstrained values at sample sites to values constrained to the site’s support, in addition to returning deterministic sites in the model.
- Parameters:
model_args – Arguments to the model.
model_kwargs – Keyword arguments to the model.
- abstract init(rng_key, num_warmup, init_params, model_args, model_kwargs)[source]
Initialize the MCMCKernel and return an initial state to begin sampling from.
- Parameters:
rng_key (random.PRNGKey) – Random number generator key to initialize the kernel.
num_warmup (int) – Number of warmup steps. This can be useful when doing adaptation during warmup.
init_params (tuple) – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
model_args – Arguments provided to the model.
model_kwargs – Keyword arguments provided to the model.
- Returns:
The initial state representing the state of the kernel. This can be any class that is registered as a pytree.
- abstract sample(state, model_args, model_kwargs)[source]
Given the current state, return the next state using the given transition kernel.
- property sample_field
The attribute of the state object passed to
sample()
that denotes the MCMC sample. This is used bypostprocess_fn()
and for reporting results inMCMC.print_summary()
.
- property default_fields
The attributes of the state object to be collected by default during the MCMC run (when
MCMC.run()
is called).
- property is_ensemble_kernel
Denotes whether the kernel is an ensemble kernel. If True, diagnostics_str will be displayed during the MCMC run (when
MCMC.run()
is called) if chain_method = “vectorized”.
BarkerMH
- class BarkerMH(model=None, potential_fn=None, step_size=1.0, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.4, init_strategy=<function init_to_uniform>)[source]
Bases:
MCMCKernel
This is a gradient-based MCMC algorithm of Metropolis-Hastings type that uses a skew-symmetric proposal distribution that depends on the gradient of the potential (the Barker proposal; see reference [1]). In particular the proposal distribution is skewed in the direction of the gradient at the current sample.
We expect this algorithm to be particularly effective for low to moderate dimensional models, where it may be competitive with HMC and NUTS.
Note
We recommend to use this kernel with progress_bar=False in
MCMC
to reduce JAX’s dispatch overhead.References:
The Barker proposal: combining robustness and efficiency in gradient-based MCMC. Samuel Livingstone, Giacomo Zanella.
- Parameters:
model – Python callable containing Pyro
primitives
. If model is provided, potential_fn will be inferred using the model.potential_fn – Python callable that computes the potential energy given input parameters. The input parameters to potential_fn can be any python collection type, provided that init_params argument to
init()
has the same type.step_size (float) – (Initial) step size to use in the Barker proposal.
adapt_step_size (bool) – Whether to adapt the step size during warm-up. Defaults to
adapt_step_size==True
.adapt_mass_matrix (bool) – Whether to adapt the mass matrix during warm-up. Defaults to
adapt_mass_matrix==True
.dense_mass (bool) – Whether to use a dense (i.e. full-rank) or diagonal mass matrix. (defaults to
dense_mass=False
).target_accept_prob (float) – The target acceptance probability that is used to guide step size adaptation. Defaults to
target_accept_prob=0.4
.init_strategy (callable) – a per-site initialization function. See Initialization Strategies section for available functions.
Example
>>> import jax >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.infer import MCMC, BarkerMH >>> def model(): ... x = numpyro.sample("x", dist.Normal().expand([10])) ... numpyro.sample("obs", dist.Normal(x, 1.0), obs=jnp.ones(10)) >>> >>> kernel = BarkerMH(model) >>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, progress_bar=True) >>> mcmc.run(jax.random.PRNGKey(0)) >>> mcmc.print_summary()
- property model
- property sample_field
The attribute of the state object passed to
sample()
that denotes the MCMC sample. This is used bypostprocess_fn()
and for reporting results inMCMC.print_summary()
.
- get_diagnostics_str(state)[source]
Given the current state, returns the diagnostics string to be added to progress bar for diagnostics purpose.
- init(rng_key, num_warmup, init_params, model_args, model_kwargs)[source]
Initialize the MCMCKernel and return an initial state to begin sampling from.
- Parameters:
rng_key (random.PRNGKey) – Random number generator key to initialize the kernel.
num_warmup (int) – Number of warmup steps. This can be useful when doing adaptation during warmup.
init_params (tuple) – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
model_args – Arguments provided to the model.
model_kwargs – Keyword arguments provided to the model.
- Returns:
The initial state representing the state of the kernel. This can be any class that is registered as a pytree.
- postprocess_fn(args, kwargs)[source]
Get a function that transforms unconstrained values at sample sites to values constrained to the site’s support, in addition to returning deterministic sites in the model.
- Parameters:
model_args – Arguments to the model.
model_kwargs – Keyword arguments to the model.
HMC
- class HMC(model=None, potential_fn=None, kinetic_fn=None, step_size=1.0, inverse_mass_matrix=None, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, num_steps=None, trajectory_length=6.283185307179586, init_strategy=<function init_to_uniform>, find_heuristic_step_size=False, forward_mode_differentiation=False, regularize_mass_matrix=True)[source]
Bases:
MCMCKernel
Hamiltonian Monte Carlo inference, using fixed trajectory length, with provision for step size and mass matrix adaptation.
Note
Until the kernel is used in an MCMC run, postprocess_fn will return the identity function.
Note
The default init strategy
init_to_uniform
might not be a good strategy for some models. You might want to try other init strategies likeinit_to_median
.References:
MCMC Using Hamiltonian Dynamics, Radford M. Neal
- Parameters:
model – Python callable containing Pyro
primitives
. If model is provided, potential_fn will be inferred using the model.potential_fn – Python callable that computes the potential energy given input parameters. The input parameters to potential_fn can be any python collection type, provided that init_params argument to
init()
has the same type.kinetic_fn – Python callable that returns the kinetic energy given inverse mass matrix and momentum. If not provided, the default is euclidean kinetic energy.
step_size (float) – Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1.
inverse_mass_matrix (numpy.ndarray or dict) – Initial value for inverse mass matrix. This may be adapted during warmup if adapt_mass_matrix = True. If no value is specified, then it is initialized to the identity matrix. For a potential_fn with general JAX pytree parameters, the order of entries of the mass matrix is the order of the flattened version of pytree parameters obtained with jax.tree_flatten, which is a bit ambiguous (see more at https://jax.readthedocs.io/en/latest/pytrees.html). If model is not None, here we can specify a structured block mass matrix as a dictionary, where keys are tuple of site names and values are the corresponding block of the mass matrix. For more information about structured mass matrix, see dense_mass argument.
adapt_step_size (bool) – A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme.
adapt_mass_matrix (bool) – A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme.
This flag controls whether mass matrix is dense (i.e. full-rank) or diagonal (defaults to
dense_mass=False
). To specify a structured mass matrix, users can provide a list of tuples of site names. Each tuple represents a block in the joint mass matrix. For example, assuming that the model has latent variables “x”, “y”, “z” (where each variable can be multi-dimensional), possible specifications and corresponding mass matrix structures are as follows:dense_mass=[(“x”, “y”)]: use a dense mass matrix for the joint (x, y) and a diagonal mass matrix for z
dense_mass=[] (equivalent to dense_mass=False): use a diagonal mass matrix for the joint (x, y, z)
dense_mass=[(“x”, “y”, “z”)] (equivalent to full_mass=True): use a dense mass matrix for the joint (x, y, z)
dense_mass=[(“x”,), (“y”,), (“z”)]: use dense mass matrices for each of x, y, and z (i.e. block-diagonal with 3 blocks)
target_accept_prob (float) – Target acceptance probability for step size adaptation using Dual Averaging. Increasing this value will lead to a smaller step size, hence the sampling will be slower but more robust. Defaults to 0.8.
num_steps (int) – if different than None, fix the number of steps allowed for each iteration.
trajectory_length (float) – Length of a MCMC trajectory for HMC. Default value is \(2\pi\).
init_strategy (callable) – a per-site initialization function. See Initialization Strategies section for available functions.
find_heuristic_step_size (bool) – whether or not to use a heuristic function to adjust the step size at the beginning of each adaptation window. Defaults to False.
forward_mode_differentiation (bool) – whether to use forward-mode differentiation or reverse-mode differentiation. By default, we use reverse mode but the forward mode can be useful in some cases to improve the performance. In addition, some control flow utility on JAX such as jax.lax.while_loop or jax.lax.fori_loop only supports forward-mode differentiation. See JAX’s The Autodiff Cookbook for more information.
regularize_mass_matrix (bool) – whether or not to regularize the estimated mass matrix for numerical stability during warmup phase. Defaults to True. This flag does not take effect if
adapt_mass_matrix == False
.
- property model
- property sample_field
The attribute of the state object passed to
sample()
that denotes the MCMC sample. This is used bypostprocess_fn()
and for reporting results inMCMC.print_summary()
.
- property default_fields
The attributes of the state object to be collected by default during the MCMC run (when
MCMC.run()
is called).
- get_diagnostics_str(state)[source]
Given the current state, returns the diagnostics string to be added to progress bar for diagnostics purpose.
- init(rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={})[source]
Initialize the MCMCKernel and return an initial state to begin sampling from.
- Parameters:
rng_key (random.PRNGKey) – Random number generator key to initialize the kernel.
num_warmup (int) – Number of warmup steps. This can be useful when doing adaptation during warmup.
init_params (tuple) – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
model_args – Arguments provided to the model.
model_kwargs – Keyword arguments provided to the model.
- Returns:
The initial state representing the state of the kernel. This can be any class that is registered as a pytree.
- postprocess_fn(args, kwargs)[source]
Get a function that transforms unconstrained values at sample sites to values constrained to the site’s support, in addition to returning deterministic sites in the model.
- Parameters:
model_args – Arguments to the model.
model_kwargs – Keyword arguments to the model.
- sample(state, model_args, model_kwargs)[source]
Run HMC from the given
HMCState
and return the resultingHMCState
.- Parameters:
state (HMCState) – Represents the current state.
model_args – Arguments provided to the model.
model_kwargs – Keyword arguments provided to the model.
- Returns:
Next state after running HMC.
NUTS
- class NUTS(model=None, potential_fn=None, kinetic_fn=None, step_size=1.0, inverse_mass_matrix=None, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, trajectory_length=None, max_tree_depth=10, init_strategy=<function init_to_uniform>, find_heuristic_step_size=False, forward_mode_differentiation=False, regularize_mass_matrix=True)[source]
Bases:
HMC
Hamiltonian Monte Carlo inference, using the No U-Turn Sampler (NUTS) with adaptive path length and mass matrix adaptation.
Note
Until the kernel is used in an MCMC run, postprocess_fn will return the identity function.
Note
The default init strategy
init_to_uniform
might not be a good strategy for some models. You might want to try other init strategies likeinit_to_median
.References:
MCMC Using Hamiltonian Dynamics, Radford M. Neal
The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo, Matthew D. Hoffman, and Andrew Gelman.
A Conceptual Introduction to Hamiltonian Monte Carlo`, Michael Betancourt
- Parameters:
model – Python callable containing Pyro
primitives
. If model is provided, potential_fn will be inferred using the model.potential_fn – Python callable that computes the potential energy given input parameters. The input parameters to potential_fn can be any python collection type, provided that init_params argument to init_kernel has the same type.
kinetic_fn – Python callable that returns the kinetic energy given inverse mass matrix and momentum. If not provided, the default is euclidean kinetic energy.
step_size (float) – Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1.
inverse_mass_matrix (numpy.ndarray or dict) – Initial value for inverse mass matrix. This may be adapted during warmup if adapt_mass_matrix = True. If no value is specified, then it is initialized to the identity matrix. For a potential_fn with general JAX pytree parameters, the order of entries of the mass matrix is the order of the flattened version of pytree parameters obtained with jax.tree_flatten, which is a bit ambiguous (see more at https://jax.readthedocs.io/en/latest/pytrees.html). If model is not None, here we can specify a structured block mass matrix as a dictionary, where keys are tuple of site names and values are the corresponding block of the mass matrix. For more information about structured mass matrix, see dense_mass argument.
adapt_step_size (bool) – A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme.
adapt_mass_matrix (bool) – A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme.
This flag controls whether mass matrix is dense (i.e. full-rank) or diagonal (defaults to
dense_mass=False
). To specify a structured mass matrix, users can provide a list of tuples of site names. Each tuple represents a block in the joint mass matrix. For example, assuming that the model has latent variables “x”, “y”, “z” (where each variable can be multi-dimensional), possible specifications and corresponding mass matrix structures are as follows:dense_mass=[(“x”, “y”)]: use a dense mass matrix for the joint (x, y) and a diagonal mass matrix for z
dense_mass=[] (equivalent to dense_mass=False): use a diagonal mass matrix for the joint (x, y, z)
dense_mass=[(“x”, “y”, “z”)] (equivalent to full_mass=True): use a dense mass matrix for the joint (x, y, z)
dense_mass=[(“x”,), (“y”,), (“z”)]: use dense mass matrices for each of x, y, and z (i.e. block-diagonal with 3 blocks)
target_accept_prob (float) – Target acceptance probability for step size adaptation using Dual Averaging. Increasing this value will lead to a smaller step size, hence the sampling will be slower but more robust. Defaults to 0.8.
trajectory_length (float) – Length of a MCMC trajectory for HMC. This arg has no effect in NUTS sampler.
max_tree_depth (int) – Max depth of the binary tree created during the doubling scheme of NUTS sampler. Defaults to 10. This argument also accepts a tuple of integers (d1, d2), where d1 is the max tree depth during warmup phase and d2 is the max tree depth during post warmup phase.
init_strategy (callable) – a per-site initialization function. See Initialization Strategies section for available functions.
find_heuristic_step_size (bool) – whether or not to use a heuristic function to adjust the step size at the beginning of each adaptation window. Defaults to False.
forward_mode_differentiation (bool) –
whether to use forward-mode differentiation or reverse-mode differentiation. By default, we use reverse mode but the forward mode can be useful in some cases to improve the performance. In addition, some control flow utility on JAX such as jax.lax.while_loop or jax.lax.fori_loop only supports forward-mode differentiation. See JAX’s The Autodiff Cookbook for more information.
HMCGibbs
- class HMCGibbs(inner_kernel, gibbs_fn, gibbs_sites)[source]
Bases:
MCMCKernel
[EXPERIMENTAL INTERFACE]
HMC-within-Gibbs. This inference algorithm allows the user to combine general purpose gradient-based inference (HMC or NUTS) with custom Gibbs samplers.
Note that it is the user’s responsibility to provide a correct implementation of gibbs_fn that samples from the corresponding posterior conditional.
- Parameters:
gibbs_fn – A Python callable that returns a dictionary of Gibbs samples conditioned on the HMC sites. Must include an argument rng_key that should be used for all sampling. Must also include arguments hmc_sites and gibbs_sites, each of which is a dictionary with keys that are site names and values that are sample values. Note that a given gibbs_fn may not need make use of all these sample values.
gibbs_sites (list) – a list of site names for the latent variables that are covered by the Gibbs sampler.
Example
>>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.infer import MCMC, NUTS, HMCGibbs ... >>> def model(): ... x = numpyro.sample("x", dist.Normal(0.0, 2.0)) ... y = numpyro.sample("y", dist.Normal(0.0, 2.0)) ... numpyro.sample("obs", dist.Normal(x + y, 1.0), obs=jnp.array([1.0])) ... >>> def gibbs_fn(rng_key, gibbs_sites, hmc_sites): ... y = hmc_sites['y'] ... new_x = dist.Normal(0.8 * (1-y), jnp.sqrt(0.8)).sample(rng_key) ... return {'x': new_x} ... >>> hmc_kernel = NUTS(model) >>> kernel = HMCGibbs(hmc_kernel, gibbs_fn=gibbs_fn, gibbs_sites=['x']) >>> mcmc = MCMC(kernel, num_warmup=100, num_samples=100, progress_bar=False) >>> mcmc.run(random.PRNGKey(0)) >>> mcmc.print_summary()
- sample_field = 'z'
- property model
- get_diagnostics_str(state)[source]
Given the current state, returns the diagnostics string to be added to progress bar for diagnostics purpose.
- postprocess_fn(args, kwargs)[source]
Get a function that transforms unconstrained values at sample sites to values constrained to the site’s support, in addition to returning deterministic sites in the model.
- Parameters:
model_args – Arguments to the model.
model_kwargs – Keyword arguments to the model.
- init(rng_key, num_warmup, init_params, model_args, model_kwargs)[source]
Initialize the MCMCKernel and return an initial state to begin sampling from.
- Parameters:
rng_key (random.PRNGKey) – Random number generator key to initialize the kernel.
num_warmup (int) – Number of warmup steps. This can be useful when doing adaptation during warmup.
init_params (tuple) – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
model_args – Arguments provided to the model.
model_kwargs – Keyword arguments provided to the model.
- Returns:
The initial state representing the state of the kernel. This can be any class that is registered as a pytree.
DiscreteHMCGibbs
- class DiscreteHMCGibbs(inner_kernel, *, random_walk=False, modified=False)[source]
Bases:
HMCGibbs
[EXPERIMENTAL INTERFACE]
A subclass of
HMCGibbs
which performs Metropolis updates for discrete latent sites.Note
The site update order is randomly permuted at each step.
Note
This class supports enumeration of discrete latent variables. To marginalize out a discrete latent site, we can specify infer={‘enumerate’: ‘parallel’} keyword in its corresponding
sample()
statement.- Parameters:
random_walk (bool) – If False, Gibbs sampling will be used to draw a sample from the conditional p(gibbs_site | remaining sites). Otherwise, a sample will be drawn uniformly from the domain of gibbs_site. Defaults to False.
modified (bool) – whether to use a modified proposal, as suggested in reference [1], which always proposes a new state for the current Gibbs site. Defaults to False. The modified scheme appears in the literature under the name “modified Gibbs sampler” or “Metropolised Gibbs sampler”.
References:
Peskun’s theorem and a modified discrete-state Gibbs sampler, Liu, J. S. (1996)
Example
>>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.infer import DiscreteHMCGibbs, MCMC, NUTS ... >>> def model(probs, locs): ... c = numpyro.sample("c", dist.Categorical(probs)) ... numpyro.sample("x", dist.Normal(locs[c], 0.5)) ... >>> probs = jnp.array([0.15, 0.3, 0.3, 0.25]) >>> locs = jnp.array([-2, 0, 2, 4]) >>> kernel = DiscreteHMCGibbs(NUTS(model), modified=True) >>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=100000, progress_bar=False) >>> mcmc.run(random.PRNGKey(0), probs, locs) >>> mcmc.print_summary() >>> samples = mcmc.get_samples()["x"] >>> assert abs(jnp.mean(samples) - 1.3) < 0.1 >>> assert abs(jnp.var(samples) - 4.36) < 0.5
- init(rng_key, num_warmup, init_params, model_args, model_kwargs)[source]
Initialize the MCMCKernel and return an initial state to begin sampling from.
- Parameters:
rng_key (random.PRNGKey) – Random number generator key to initialize the kernel.
num_warmup (int) – Number of warmup steps. This can be useful when doing adaptation during warmup.
init_params (tuple) – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
model_args – Arguments provided to the model.
model_kwargs – Keyword arguments provided to the model.
- Returns:
The initial state representing the state of the kernel. This can be any class that is registered as a pytree.
MixedHMC
- class MixedHMC(inner_kernel, *, num_discrete_updates=None, random_walk=False, modified=False)[source]
Bases:
DiscreteHMCGibbs
Implementation of Mixed Hamiltonian Monte Carlo (reference [1]).
Note
The number of discrete sites to update at each MCMC iteration (n_D in reference [1]) is fixed at value 1.
References
Mixed Hamiltonian Monte Carlo for Mixed Discrete and Continuous Variables, Guangyao Zhou (2020)
Peskun’s theorem and a modified discrete-state Gibbs sampler, Liu, J. S. (1996)
- Parameters:
inner_kernel – A
HMC
kernel.num_discrete_updates (int) – Number of times to update discrete variables. Defaults to the number of discrete latent variables.
random_walk (bool) – If False, Gibbs sampling will be used to draw a sample from the conditional p(gibbs_site | remaining sites), where gibbs_site is one of the discrete sample sites in the model. Otherwise, a sample will be drawn uniformly from the domain of gibbs_site. Defaults to False.
modified (bool) – whether to use a modified proposal, as suggested in reference [2], which always proposes a new state for the current Gibbs site (i.e. discrete site). Defaults to False. The modified scheme appears in the literature under the name “modified Gibbs sampler” or “Metropolised Gibbs sampler”.
Example
>>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.infer import HMC, MCMC, MixedHMC ... >>> def model(probs, locs): ... c = numpyro.sample("c", dist.Categorical(probs)) ... numpyro.sample("x", dist.Normal(locs[c], 0.5)) ... >>> probs = jnp.array([0.15, 0.3, 0.3, 0.25]) >>> locs = jnp.array([-2, 0, 2, 4]) >>> kernel = MixedHMC(HMC(model, trajectory_length=1.2), num_discrete_updates=20) >>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=100000, progress_bar=False) >>> mcmc.run(random.PRNGKey(0), probs, locs) >>> mcmc.print_summary() >>> samples = mcmc.get_samples() >>> assert "x" in samples and "c" in samples >>> assert abs(jnp.mean(samples["x"]) - 1.3) < 0.1 >>> assert abs(jnp.var(samples["x"]) - 4.36) < 0.5
- init(rng_key, num_warmup, init_params, model_args, model_kwargs)[source]
Initialize the MCMCKernel and return an initial state to begin sampling from.
- Parameters:
rng_key (random.PRNGKey) – Random number generator key to initialize the kernel.
num_warmup (int) – Number of warmup steps. This can be useful when doing adaptation during warmup.
init_params (tuple) – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
model_args – Arguments provided to the model.
model_kwargs – Keyword arguments provided to the model.
- Returns:
The initial state representing the state of the kernel. This can be any class that is registered as a pytree.
HMCECS
- class HMCECS(inner_kernel, *, num_blocks=1, proxy=None)[source]
Bases:
HMCGibbs
[EXPERIMENTAL INTERFACE]
HMC with Energy Conserving Subsampling.
A subclass of
HMCGibbs
for performing HMC-within-Gibbs for models with subsample statements using theplate
primitive. This implements Algorithm 1 of reference [1] but uses a naive estimation (without control variates) of log likelihood, hence might incur a high variance.The function can divide subsample indices into blocks and update only one block at each MCMC step to improve the acceptance rate of proposed subsamples as detailed in [3].
Note
New subsample indices are proposed randomly with replacement at each MCMC step.
References:
Hamiltonian Monte Carlo with energy conserving subsampling, Dang, K. D., Quiroz, M., Kohn, R., Minh-Ngoc, T., & Villani, M. (2019)
Speeding Up MCMC by Efficient Data Subsampling, Quiroz, M., Kohn, R., Villani, M., & Tran, M. N. (2018)
The Block Pseudo-Margional Sampler, Tran, M.-N., Kohn, R., Quiroz, M. Villani, M. (2017)
The Fundamental Incompatibility of Scalable Hamiltonian Monte Carlo and Naive Data Subsampling Betancourt, M. (2015)
- Parameters:
num_blocks (int) – Number of blocks to partition subsample into.
proxy – Either
taylor_proxy()
for likelihood estimation, or, None for naive (in-between trajectory) subsampling as outlined in [4].
Example
>>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.infer import HMCECS, MCMC, NUTS ... >>> def model(data): ... x = numpyro.sample("x", dist.Normal(0, 1)) ... with numpyro.plate("N", data.shape[0], subsample_size=100): ... batch = numpyro.subsample(data, event_dim=0) ... numpyro.sample("obs", dist.Normal(x, 1), obs=batch) ... >>> data = random.normal(random.PRNGKey(0), (10000,)) + 1 >>> kernel = HMCECS(NUTS(model), num_blocks=10) >>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000) >>> mcmc.run(random.PRNGKey(0), data) >>> samples = mcmc.get_samples()["x"] >>> assert abs(jnp.mean(samples) - 1.) < 0.1
- postprocess_fn(args, kwargs)[source]
Get a function that transforms unconstrained values at sample sites to values constrained to the site’s support, in addition to returning deterministic sites in the model.
- Parameters:
model_args – Arguments to the model.
model_kwargs – Keyword arguments to the model.
- init(rng_key, num_warmup, init_params, model_args, model_kwargs)[source]
Initialize the MCMCKernel and return an initial state to begin sampling from.
- Parameters:
rng_key (random.PRNGKey) – Random number generator key to initialize the kernel.
num_warmup (int) – Number of warmup steps. This can be useful when doing adaptation during warmup.
init_params (tuple) – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
model_args – Arguments provided to the model.
model_kwargs – Keyword arguments provided to the model.
- Returns:
The initial state representing the state of the kernel. This can be any class that is registered as a pytree.
SA
- class SA(model=None, potential_fn=None, adapt_state_size=None, dense_mass=True, init_strategy=<function init_to_uniform>)[source]
Bases:
MCMCKernel
Sample Adaptive MCMC, a gradient-free sampler.
This is a very fast (in term of n_eff / s) sampler but requires many warmup (burn-in) steps. In each MCMC step, we only need to evaluate potential function at one point.
Note that unlike in reference [1], we return a randomly selected (i.e. thinned) subset of approximate posterior samples of size num_chains x num_samples instead of num_chains x num_samples x adapt_state_size.
Note
We recommend to use this kernel with progress_bar=False in
MCMC
to reduce JAX’s dispatch overhead.References:
Sample Adaptive MCMC (https://papers.nips.cc/paper/9107-sample-adaptive-mcmc), Michael Zhu
- Parameters:
model – Python callable containing Pyro
primitives
. If model is provided, potential_fn will be inferred using the model.potential_fn – Python callable that computes the potential energy given input parameters. The input parameters to potential_fn can be any python collection type, provided that init_params argument to
init()
has the same type.adapt_state_size (int) – The number of points to generate proposal distribution. Defaults to 2 times latent size.
dense_mass (bool) – A flag to decide if mass matrix is dense or diagonal (default to
dense_mass=True
)init_strategy (callable) – a per-site initialization function. See Initialization Strategies section for available functions.
- init(rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={})[source]
Initialize the MCMCKernel and return an initial state to begin sampling from.
- Parameters:
rng_key (random.PRNGKey) – Random number generator key to initialize the kernel.
num_warmup (int) – Number of warmup steps. This can be useful when doing adaptation during warmup.
init_params (tuple) – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
model_args – Arguments provided to the model.
model_kwargs – Keyword arguments provided to the model.
- Returns:
The initial state representing the state of the kernel. This can be any class that is registered as a pytree.
- property model
- property sample_field
The attribute of the state object passed to
sample()
that denotes the MCMC sample. This is used bypostprocess_fn()
and for reporting results inMCMC.print_summary()
.
- property default_fields
The attributes of the state object to be collected by default during the MCMC run (when
MCMC.run()
is called).
- get_diagnostics_str(state)[source]
Given the current state, returns the diagnostics string to be added to progress bar for diagnostics purpose.
- postprocess_fn(args, kwargs)[source]
Get a function that transforms unconstrained values at sample sites to values constrained to the site’s support, in addition to returning deterministic sites in the model.
- Parameters:
model_args – Arguments to the model.
model_kwargs – Keyword arguments to the model.
- sample(state, model_args, model_kwargs)[source]
Run SA from the given
SAState
and return the resultingSAState
.- Parameters:
state (SAState) – Represents the current state.
model_args – Arguments provided to the model.
model_kwargs – Keyword arguments provided to the model.
- Returns:
Next state after running SA.
EnsembleSampler
- class EnsembleSampler(model=None, potential_fn=None, *, randomize_split, init_strategy)[source]
Bases:
MCMCKernel
,ABC
Abstract class for ensemble samplers. Each MCMC sample is divided into two sub-iterations in which half of the ensemble is updated.
- Parameters:
model – Python callable containing Pyro
primitives
. If model is provided, potential_fn will be inferred using the model.potential_fn – Python callable that computes the potential energy given input parameters. The input parameters to potential_fn can be any python collection type, provided that init_params argument to
init()
has the same type.randomize_split (bool) – whether or not to permute the chain order at each iteration.
init_strategy (callable) – a per-site initialization function. See Initialization Strategies section for available functions.
- property model
- property sample_field
The attribute of the state object passed to
sample()
that denotes the MCMC sample. This is used bypostprocess_fn()
and for reporting results inMCMC.print_summary()
.
- property is_ensemble_kernel
Denotes whether the kernel is an ensemble kernel. If True, diagnostics_str will be displayed during the MCMC run (when
MCMC.run()
is called) if chain_method = “vectorized”.
- abstract update_active_chains(active, inactive, inner_state)[source]
return (updated active set of chains, updated inner state)
- init(rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={})[source]
Initialize the MCMCKernel and return an initial state to begin sampling from.
- Parameters:
rng_key (random.PRNGKey) – Random number generator key to initialize the kernel.
num_warmup (int) – Number of warmup steps. This can be useful when doing adaptation during warmup.
init_params (tuple) – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
model_args – Arguments provided to the model.
model_kwargs – Keyword arguments provided to the model.
- Returns:
The initial state representing the state of the kernel. This can be any class that is registered as a pytree.
- postprocess_fn(args, kwargs)[source]
Get a function that transforms unconstrained values at sample sites to values constrained to the site’s support, in addition to returning deterministic sites in the model.
- Parameters:
model_args – Arguments to the model.
model_kwargs – Keyword arguments to the model.
AIES
- class AIES(model=None, potential_fn=None, randomize_split=False, moves=None, init_strategy=<function init_to_uniform>)[source]
Bases:
EnsembleSampler
Affine-Invariant Ensemble Sampling: a gradient free method that informs Metropolis-Hastings proposals by sharing information between chains. Suitable for low to moderate dimensional models. Generally, num_chains should be at least twice the dimensionality of the model.
Note
This kernel must be used with num_chains > 1 and chain_method=”vectorized in
MCMC
. The number of chains must be divisible by 2.References:
- emcee: The MCMC Hammer (https://iopscience.iop.org/article/10.1086/670067),
Daniel Foreman-Mackey, David W. Hogg, Dustin Lang, and Jonathan Goodman.
- Parameters:
model – Python callable containing Pyro
primitives
. If model is provided, potential_fn will be inferred using the model.potential_fn – Python callable that computes the potential energy given input parameters. The input parameters to potential_fn can be any python collection type, provided that init_params argument to
init()
has the same type.randomize_split (bool) – whether or not to permute the chain order at each iteration. Defaults to False.
moves – a dictionary mapping moves to their respective probabilities of being selected. Valid keys are AIES.DEMove() and AIES.StretchMove(). Both tend to work well in practice. If the sum of probabilities exceeds 1, the probabilities will be normalized. Defaults to {AIES.DEMove(): 1.0}.
init_strategy (callable) – a per-site initialization function. See Initialization Strategies section for available functions.
Example
>>> import jax >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.infer import MCMC, AIES >>> def model(): ... x = numpyro.sample("x", dist.Normal().expand([10])) ... numpyro.sample("obs", dist.Normal(x, 1.0), obs=jnp.ones(10)) >>> >>> kernel = AIES(model, moves={AIES.DEMove() : 0.5, ... AIES.StretchMove() : 0.5}) >>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=20, chain_method='vectorized') >>> mcmc.run(jax.random.PRNGKey(0))
- get_diagnostics_str(state)[source]
Given the current state, returns the diagnostics string to be added to progress bar for diagnostics purpose.
- update_active_chains(active, inactive, inner_state)[source]
return (updated active set of chains, updated inner state)
- static DEMove(sigma=1e-05, g0=None)[source]
A proposal using differential evolution.
This Differential evolution proposal is implemented following Nelson et al. (2013).
- Parameters:
sigma – (optional) The standard deviation of the Gaussian used to stretch the proposal vector. Defaults to 1.0.e-5.
(optional) (g0) – The mean stretch factor for the proposal vector. By default, it is 2.38 / sqrt(2*ndim) as recommended by the two references.
- static StretchMove(a=2.0)[source]
A Goodman & Weare (2010) “stretch move” with parallelization as described in Foreman-Mackey et al. (2013).
- Parameters:
a – (optional) The stretch scale parameter. (default:
2.0
)
ESS
- class ESS(model=None, potential_fn=None, randomize_split=True, moves=None, max_steps=10000, max_iter=10000, init_mu=1.0, tune_mu=True, init_strategy=<function init_to_uniform>)[source]
Bases:
EnsembleSampler
Ensemble Slice Sampling: a gradient free method that finds better slice sampling directions by sharing information between chains. Suitable for low to moderate dimensional models. Generally, num_chains should be at least twice the dimensionality of the model.
Note
This kernel must be used with num_chains > 1 and chain_method=”vectorized in
MCMC
. The number of chains must be divisible by 2.References:
- zeus: a PYTHON implementation of ensemble slice sampling for efficient Bayesian parameter inference (https://academic.oup.com/mnras/article/508/3/3589/6381726),
Minas Karamanis, Florian Beutler, and John A. Peacock.
- Ensemble slice sampling (https://link.springer.com/article/10.1007/s11222-021-10038-2),
Minas Karamanis, Florian Beutler.
- Parameters:
model – Python callable containing Pyro
primitives
. If model is provided, potential_fn will be inferred using the model.potential_fn – Python callable that computes the potential energy given input parameters. The input parameters to potential_fn can be any python collection type, provided that init_params argument to
init()
has the same type.randomize_split (bool) – whether or not to permute the chain order at each iteration. Defaults to True.
moves – a dictionary mapping moves to their respective probabilities of being selected. If the sum of probabilities exceeds 1, the probabilities will be normalized. Valid keys include: ESS.DifferentialMove() -> default proposal, works well along a wide range of target distributions, ESS.GaussianMove() -> for approximately normally distributed targets, ESS.KDEMove() -> for multimodal posteriors - requires large num_chains, and they must be well initialized ESS.RandomMove() -> no chain interaction, useful for debugging. Defaults to {ESS.DifferentialMove(): 1.0}.
max_steps (int) – number of maximum stepping-out steps per sample. Defaults to 10,000.
max_iter (int) – number of maximum expansions/contractions per sample. Defaults to 10,000.
init_mu (float) – initial scale factor. Defaults to 1.0.
tune_mu (bool) – whether or not to tune the initial scale factor. Defaults to True.
init_strategy (callable) – a per-site initialization function. See Initialization Strategies section for available functions.
Example
>>> import jax >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.infer import MCMC, ESS >>> def model(): ... x = numpyro.sample("x", dist.Normal().expand([10])) ... numpyro.sample("obs", dist.Normal(x, 1.0), obs=jnp.ones(10)) >>> >>> kernel = ESS(model, moves={ESS.DifferentialMove() : 0.8, ... ESS.RandomMove() : 0.2}) >>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=20, chain_method='vectorized') >>> mcmc.run(jax.random.PRNGKey(0))
- update_active_chains(active, inactive, inner_state)[source]
return (updated active set of chains, updated inner state)
- static RandomMove()[source]
The Karamanis & Beutler (2020) “Random Move” with parallelization. When this move is used the walkers move along random directions. There is no communication between the walkers and this Move corresponds to the vanilla Slice Sampling method. This Move should be used for debugging purposes only.
- static KDEMove(bw_method=None)[source]
The Karamanis & Beutler (2020) “KDE Move” with parallelization. When this Move is used the distribution of the walkers of the complementary ensemble is traced using a Gaussian Kernel Density Estimation methods. The walkers then move along random direction vectos sampled from this distribution.
- static GaussianMove()[source]
The Karamanis & Beutler (2020) “Gaussian Move” with parallelization. When this Move is used the walkers move along directions defined by random vectors sampled from the Gaussian approximation of the walkers of the complementary ensemble.
- static DifferentialMove()[source]
The Karamanis & Beutler (2020) “Differential Move” with parallelization. When this Move is used the walkers move along directions defined by random pairs of walkers sampled (with no replacement) from the complementary ensemble. This is the default choice and performs well along a wide range of target distributions.
- hmc(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, algo='NUTS')[source]
Hamiltonian Monte Carlo inference, using either fixed number of steps or the No U-Turn Sampler (NUTS) with adaptive path length.
References:
MCMC Using Hamiltonian Dynamics, Radford M. Neal
The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo, Matthew D. Hoffman, and Andrew Gelman.
A Conceptual Introduction to Hamiltonian Monte Carlo`, Michael Betancourt
- Parameters:
potential_fn – Python callable that computes the potential energy given input parameters. The input parameters to potential_fn can be any python collection type, provided that init_params argument to init_kernel has the same type.
potential_fn_gen – Python callable that when provided with model arguments / keyword arguments returns potential_fn. This may be provided to do inference on the same model with changing data. If the data shape remains the same, we can compile sample_kernel once, and use the same for multiple inference runs.
kinetic_fn – Python callable that returns the kinetic energy given inverse mass matrix and momentum. If not provided, the default is euclidean kinetic energy.
algo (str) – Whether to run
HMC
with fixed number of steps orNUTS
with adaptive path length. Default isNUTS
.
- Returns:
a tuple of callables (init_kernel, sample_kernel), the first one to initialize the sampler, and the second one to generate samples given an existing one.
Warning
Instead of using this interface directly, we would highly recommend you to use the higher level
MCMC
API instead.Example
>>> import jax >>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.infer.hmc import hmc >>> from numpyro.infer.util import initialize_model >>> from numpyro.util import fori_collect >>> true_coefs = jnp.array([1., 2., 3.]) >>> data = random.normal(random.PRNGKey(2), (2000, 3)) >>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(3)) >>> >>> def model(data, labels): ... coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(3), jnp.ones(3))) ... intercept = numpyro.sample('intercept', dist.Normal(0., 10.)) ... return numpyro.sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)), obs=labels) >>> >>> model_info = initialize_model(random.PRNGKey(0), model, model_args=(data, labels,)) >>> init_kernel, sample_kernel = hmc(model_info.potential_fn, algo='NUTS') >>> hmc_state = init_kernel(model_info.param_info, ... trajectory_length=10, ... num_warmup=300) >>> samples = fori_collect(0, 500, sample_kernel, hmc_state, ... transform=lambda state: model_info.postprocess_fn(state.z)) >>> print(jnp.mean(samples['coefs'], axis=0)) [0.9153987 2.0754058 2.9621222]
- init_kernel(init_params, num_warmup, step_size=1.0, inverse_mass_matrix=None, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, *, num_steps=None, trajectory_length=6.283185307179586, max_tree_depth=10, find_heuristic_step_size=False, forward_mode_differentiation=False, regularize_mass_matrix=True, model_args=(), model_kwargs=None, rng_key=None)
Initializes the HMC sampler.
- Parameters:
init_params – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
num_warmup (int) – Number of warmup steps; samples generated during warmup are discarded.
step_size (float) – Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1.
inverse_mass_matrix (numpy.ndarray or dict) – Initial value for inverse mass matrix. This may be adapted during warmup if adapt_mass_matrix = True. If no value is specified, then it is initialized to the identity matrix. For a potential_fn with general JAX pytree parameters, the order of entries of the mass matrix is the order of the flattened version of pytree parameters obtained with jax.tree_flatten, which is a bit ambiguous (see more at https://jax.readthedocs.io/en/latest/pytrees.html). If model is not None, here we can specify a structured block mass matrix as a dictionary, where keys are tuple of site names and values are the corresponding block of the mass matrix. For more information about structured mass matrix, see dense_mass argument.
adapt_step_size (bool) – A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme.
adapt_mass_matrix (bool) – A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme.
This flag controls whether mass matrix is dense (i.e. full-rank) or diagonal (defaults to
dense_mass=False
). To specify a structured mass matrix, users can provide a list of tuples of site names. Each tuple represents a block in the joint mass matrix. For example, assuming that the model has latent variables “x”, “y”, “z” (where each variable can be multi-dimensional), possible specifications and corresponding mass matrix structures are as follows:dense_mass=[(“x”, “y”)]: use a dense mass matrix for the joint (x, y) and a diagonal mass matrix for z
dense_mass=[] (equivalent to dense_mass=False): use a diagonal mass matrix for the joint (x, y, z)
dense_mass=[(“x”, “y”, “z”)] (equivalent to full_mass=True): use a dense mass matrix for the joint (x, y, z)
dense_mass=[(“x”,), (“y”,), (“z”)]: use dense mass matrices for each of x, y, and z (i.e. block-diagonal with 3 blocks)
target_accept_prob (float) – Target acceptance probability for step size adaptation using Dual Averaging. Increasing this value will lead to a smaller step size, hence the sampling will be slower but more robust. Defaults to 0.8.
num_steps (int) – if different than None, fix the number of steps allowed for each iteration.
trajectory_length (float) – Length of a MCMC trajectory for HMC. Default value is \(2\pi\).
max_tree_depth (int) – Max depth of the binary tree created during the doubling scheme of NUTS sampler. Defaults to 10. This argument also accepts a tuple of integers (d1, d2), where d1 is the max tree depth during warmup phase and d2 is the max tree depth during post warmup phase.
find_heuristic_step_size (bool) – whether to a heuristic function to adjust the step size at the beginning of each adaptation window. Defaults to False.
forward_mode_differentiation (bool) –
whether to use forward-mode differentiation or reverse-mode differentiation. By default, we use reverse mode but the forward mode can be useful in some cases to improve the performance. In addition, some control flow utility on JAX such as jax.lax.while_loop or jax.lax.fori_loop only supports forward-mode differentiation. See JAX’s The Autodiff Cookbook for more information.
regularize_mass_matrix (bool) – whether or not to regularize the estimated mass matrix for numerical stability during warmup phase. Defaults to True. This flag does not take effect if
adapt_mass_matrix == False
.model_args (tuple) – Model arguments if potential_fn_gen is specified.
model_kwargs (dict) – Model keyword arguments if potential_fn_gen is specified.
rng_key (jax.random.PRNGKey) – random key to be used as the source of randomness.
- sample_kernel(hmc_state, model_args=(), model_kwargs=None)
Given an existing
HMCState
, run HMC with fixed (possibly adapted) step size and return a newHMCState
.- Parameters:
- Returns:
new proposed
HMCState
from simulating Hamiltonian dynamics given existing state.
- taylor_proxy(reference_params, degree)[source]
Control variate for unbiased log likelihood estimation using a Taylor expansion around a reference parameter. Suggested for subsampling in [1].
- Parameters:
reference_params (dict) – Model parameterization at MLE or MAP-estimate.
degree – number of terms in the Taylor expansion, either one or two.
References:
- [1] On Markov chain Monte Carlo Methods For Tall Data
Bardenet., R., Doucet, A., Holmes, C. (2017)
- BarkerMHState = <class 'numpyro.infer.barker.BarkerMHState'>
A
namedtuple()
consisting of the following fields:i - iteration. This is reset to 0 after warmup.
z - Python collection representing values (unconstrained samples from the posterior) at latent sites.
potential_energy - Potential energy computed at the given value of
z
.z_grad - Gradient of potential energy w.r.t. latent sample sites.
accept_prob - Acceptance probability of the proposal. Note that
z
does not correspond to the proposal if it is rejected.mean_accept_prob - Mean acceptance probability until current iteration during warmup adaptation or sampling (for diagnostics).
adapt_state - A
HMCAdaptState
namedtuple which contains adaptation information during warmup:step_size - Step size to be used by the integrator in the next iteration.
inverse_mass_matrix - The inverse mass matrix to be used for the next iteration.
mass_matrix_sqrt - The square root of mass matrix to be used for the next iteration. In case of dense mass, this is the Cholesky factorization of the mass matrix.
rng_key - random number generator seed used for generating proposals, etc.
- HMCState = <class 'numpyro.infer.hmc.HMCState'>
A
namedtuple()
consisting of the following fields:i - iteration. This is reset to 0 after warmup.
z - Python collection representing values (unconstrained samples from the posterior) at latent sites.
z_grad - Gradient of potential energy w.r.t. latent sample sites.
potential_energy - Potential energy computed at the given value of
z
.energy - Sum of potential energy and kinetic energy of the current state.
r - The current momentum variable. If this is None, a new momentum variable will be drawn at the beginning of each sampling step.
trajectory_length - The amount of time to run HMC dynamics in each sampling step. This field is not used in NUTS.
num_steps - Number of steps in the Hamiltonian trajectory (for diagnostics). In HMC sampler, trajectory_length should be None for step_size to be adapted. In NUTS sampler, the tree depth of a trajectory can be computed from this field with tree_depth = np.log2(num_steps).astype(int) + 1.
accept_prob - Acceptance probability of the proposal. Note that
z
does not correspond to the proposal if it is rejected.mean_accept_prob - Mean acceptance probability until current iteration during warmup adaptation or sampling (for diagnostics).
diverging - A boolean value to indicate whether the current trajectory is diverging.
adapt_state - A
HMCAdaptState
namedtuple which contains adaptation information during warmup:step_size - Step size to be used by the integrator in the next iteration.
inverse_mass_matrix - The inverse mass matrix to be used for the next iteration.
mass_matrix_sqrt - The square root of mass matrix to be used for the next iteration. In case of dense mass, this is the Cholesky factorization of the mass matrix.
rng_key - random number generator seed used for the iteration.
- HMCGibbsState = <class 'numpyro.infer.hmc_gibbs.HMCGibbsState'>
z - a dict of the current latent values (both HMC and Gibbs sites)
hmc_state - current
HMCState
rng_key - random key for the current step
- SAState = <class 'numpyro.infer.sa.SAState'>
A
namedtuple()
used in Sample Adaptive MCMC. This consists of the following fields:i - iteration. This is reset to 0 after warmup.
z - Python collection representing values (unconstrained samples from the posterior) at latent sites.
potential_energy - Potential energy computed at the given value of
z
.accept_prob - Acceptance probability of the proposal. Note that
z
does not correspond to the proposal if it is rejected.mean_accept_prob - Mean acceptance probability until current iteration during warmup or sampling (for diagnostics).
diverging - A boolean value to indicate whether the new sample potential energy is diverging from the current one.
adapt_state - A
SAAdaptState
namedtuple which contains adaptation information:zs - Step size to be used by the integrator in the next iteration.
pes - Potential energies of zs.
loc - Mean of those zs.
inv_mass_matrix_sqrt - If using dense mass matrix, this is Cholesky of the covariance of zs. Otherwise, this is standard deviation of those zs.
rng_key - random number generator seed used for the iteration.
- EnsembleSamplerState = <class 'numpyro.infer.ensemble.EnsembleSamplerState'>
A
namedtuple()
consisting of the following fields:z - Python collection representing values (unconstrained samples from the posterior) at latent sites.
inner_state - A namedtuple containing information needed to update half the ensemble.
rng_key - random number generator seed used for generating proposals, etc.
- AIESState = <class 'numpyro.infer.ensemble.AIESState'>
A
namedtuple()
consisting of the following fields.i - iteration.
accept_prob - Acceptance probability of the proposal. Note that
z
does not correspond to the proposal if it is rejected.mean_accept_prob - Mean acceptance probability until current iteration during warmup adaptation or sampling (for diagnostics).
rng_key - random number generator seed used for generating proposals, etc.
- ESSState = <class 'numpyro.infer.ensemble.ESSState'>
A
namedtuple()
used as an inner state for Ensemble Sampler. This consists of the following fields:i - iteration.
n_expansions - number of expansions in the current batch. Used for tuning mu.
n_contractions - number of contractions in the current batch. Used for tuning mu.
mu - Scale factor. This is tuned if tune_mu=True.
rng_key - random number generator seed used for generating proposals, etc.
TensorFlow Kernels
Thin wrappers around TensorFlow Probability (TFP) MCMC kernels. For details on the TFP MCMC kernel interface, see its TransitionKernel docs.
TFPKernel
- class TFPKernel(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)[source]
A thin wrapper for TensorFlow Probability (TFP) MCMC transition kernels. The argument target_log_prob_fn in TFP is replaced by either model or potential_fn (which is the negative of target_log_prob_fn).
This class can be used to convert a TFP kernel to a NumPyro-compatible one as follows:
from numpyro.contrib.tfp.mcmc import TFPKernel kernel = TFPKernel[tfp.mcmc.NoUTurnSampler](model, step_size=1.)
Note
By default, uncalibrated kernels will be inner kernels of the
MetropolisHastings
kernel.Note
For
ReplicaExchangeMC
, TFP requires that the shape of step_size of the inner kernel must be [len(inverse_temperatures), 1] or [len(inverse_temperatures), latent_size].- Parameters:
model – Python callable containing Pyro
primitives
. If model is provided, potential_fn will be inferred using the model.potential_fn – Python callable that computes the target potential energy given input parameters. The input parameters to potential_fn can be any python collection type, provided that init_params argument to
init()
has the same type.init_strategy (callable) – a per-site initialization function. See Initialization Strategies section for available functions.
kernel_kwargs – other arguments to be passed to TFP kernel constructor.
HamiltonianMonteCarlo
- class HamiltonianMonteCarlo(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)
Wraps tensorflow_probability.substrates.jax.mcmc.hmc.HamiltonianMonteCarlo with
TFPKernel
. The first argument target_log_prob_fn in TFP kernel construction is replaced by either model or potential_fn.
MetropolisAdjustedLangevinAlgorithm
- class MetropolisAdjustedLangevinAlgorithm(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)
Wraps tensorflow_probability.substrates.jax.mcmc.langevin.MetropolisAdjustedLangevinAlgorithm with
TFPKernel
. The first argument target_log_prob_fn in TFP kernel construction is replaced by either model or potential_fn.
NoUTurnSampler
- class NoUTurnSampler(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)
Wraps tensorflow_probability.substrates.jax.mcmc.nuts.NoUTurnSampler with
TFPKernel
. The first argument target_log_prob_fn in TFP kernel construction is replaced by either model or potential_fn.
RandomWalkMetropolis
- class RandomWalkMetropolis(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)
Wraps tensorflow_probability.substrates.jax.mcmc.random_walk_metropolis.RandomWalkMetropolis with
TFPKernel
. The first argument target_log_prob_fn in TFP kernel construction is replaced by either model or potential_fn.
ReplicaExchangeMC
- class ReplicaExchangeMC(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)
Wraps tensorflow_probability.substrates.jax.mcmc.replica_exchange_mc.ReplicaExchangeMC with
TFPKernel
. The first argument target_log_prob_fn in TFP kernel construction is replaced by either model or potential_fn.
SliceSampler
- class SliceSampler(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)
Wraps tensorflow_probability.substrates.jax.mcmc.slice_sampler_kernel.SliceSampler with
TFPKernel
. The first argument target_log_prob_fn in TFP kernel construction is replaced by either model or potential_fn.
UncalibratedHamiltonianMonteCarlo
- class UncalibratedHamiltonianMonteCarlo(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)
Wraps tensorflow_probability.substrates.jax.mcmc.hmc.UncalibratedHamiltonianMonteCarlo with
TFPKernel
. The first argument target_log_prob_fn in TFP kernel construction is replaced by either model or potential_fn.
UncalibratedLangevin
- class UncalibratedLangevin(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)
Wraps tensorflow_probability.substrates.jax.mcmc.langevin.UncalibratedLangevin with
TFPKernel
. The first argument target_log_prob_fn in TFP kernel construction is replaced by either model or potential_fn.
UncalibratedRandomWalk
- class UncalibratedRandomWalk(model=None, potential_fn=None, init_strategy=<function init_to_uniform>, **kernel_kwargs)
Wraps tensorflow_probability.substrates.jax.mcmc.random_walk_metropolis.UncalibratedRandomWalk with
TFPKernel
. The first argument target_log_prob_fn in TFP kernel construction is replaced by either model or potential_fn.
MCMC Utilities
- initialize_model(rng_key, model, *, init_strategy=<function init_to_uniform>, dynamic_args=False, model_args=(), model_kwargs=None, forward_mode_differentiation=False, validate_grad=True)[source]
(EXPERIMENTAL INTERFACE) Helper function that calls
get_potential_fn()
andfind_valid_initial_params()
under the hood to return a tuple of (init_params_info, potential_fn, postprocess_fn, model_trace).- Parameters:
rng_key (jax.random.PRNGKey) – random number generator seed to sample from the prior. The returned init_params will have the batch shape
rng_key.shape[:-1]
.model – Python callable containing Pyro primitives.
init_strategy (callable) – a per-site initialization function. See Initialization Strategies section for available functions.
dynamic_args (bool) – if True, the potential_fn and constraints_fn are themselves dependent on model arguments. When provided a *model_args, **model_kwargs, they return potential_fn and constraints_fn callables, respectively.
model_args (tuple) – args provided to the model.
model_kwargs (dict) – kwargs provided to the model.
forward_mode_differentiation (bool) –
whether to use forward-mode differentiation or reverse-mode differentiation. By default, we use reverse mode but the forward mode can be useful in some cases to improve the performance. In addition, some control flow utility on JAX such as jax.lax.while_loop or jax.lax.fori_loop only supports forward-mode differentiation. See JAX’s The Autodiff Cookbook for more information.
validate_grad (bool) – whether to validate gradient of the initial params. Defaults to True.
- Returns:
a namedtupe ModelInfo which contains the fields (param_info, potential_fn, postprocess_fn, model_trace), where param_info is a namedtuple ParamInfo containing values from the prior used to initiate MCMC, their corresponding potential energy, and their gradients; postprocess_fn is a callable that uses inverse transforms to convert unconstrained HMC samples to constrained values that lie within the site’s support, in addition to returning values at deterministic sites in the model.
- fori_collect(lower: int, upper: int, body_fun: ~typing.Callable, init_val: ~typing.Any, transform: ~typing.Callable = <function identity>, progbar: bool = True, return_last_val: bool = False, collection_size=None, thinning=1, **progbar_opts)[source]
This looping construct works like
fori_loop()
but with the additional effect of collecting values from the loop body. In addition, this allows for post-processing of these samples via transform, and progress bar updates. Note that, progbar=False will be faster, especially when collecting a lot of samples. Refer to example usage inhmc()
.- Parameters:
lower (int) – the index to start the collective work. In other words, we will skip collecting the first lower values.
upper (int) – number of times to run the loop body.
body_fun – a callable that takes a collection of np.ndarray and returns a collection with the same shape and dtype.
init_val – initial value to pass as argument to body_fun. Can be any Python collection type containing np.ndarray objects.
transform – a callable to post-process the values returned by body_fn.
progbar – whether to post progress bar updates.
return_last_val (bool) – If True, the last value is also returned. This has the same type as init_val.
thinning – Positive integer that controls the thinning ratio for retained values. Defaults to 1, i.e. no thinning.
collection_size (int) – Size of the returned collection. If not specified, the size will be
(upper - lower) // thinning
. If the size is larger than(upper - lower) // thinning
, only the top(upper - lower) // thinning
entries will be non-zero.**progbar_opts – optional additional progress bar arguments. A diagnostics_fn can be supplied which when passed the current value from body_fun returns a string that is used to update the progress bar postfix. Also a progbar_desc keyword argument can be supplied which is used to label the progress bar.
- Returns:
collection with the same type as init_val with values collected along the leading axis of np.ndarray objects.
- consensus(subposteriors, num_draws=None, diagonal=False, rng_key=None)[source]
Merges subposteriors following consensus Monte Carlo algorithm.
References:
Bayes and big data: The consensus Monte Carlo algorithm, Steven L. Scott, Alexander W. Blocker, Fernando V. Bonassi, Hugh A. Chipman, Edward I. George, Robert E. McCulloch
- Parameters:
subposteriors (list) – a list in which each element is a collection of samples.
num_draws (int) – number of draws from the merged posterior.
diagonal (bool) – whether to compute weights using variance or covariance, defaults to False (using covariance).
rng_key (jax.random.PRNGKey) – source of the randomness, defaults to jax.random.PRNGKey(0).
- Returns:
if num_draws is None, merges subposteriors without resampling; otherwise, returns a collection of num_draws samples with the same data structure as each subposterior.
- parametric(subposteriors, diagonal=False)[source]
Merges subposteriors following (embarrassingly parallel) parametric Monte Carlo algorithm.
References:
Asymptotically Exact, Embarrassingly Parallel MCMC, Willie Neiswanger, Chong Wang, Eric Xing
- parametric_draws(subposteriors, num_draws, diagonal=False, rng_key=None)[source]
Merges subposteriors following (embarrassingly parallel) parametric Monte Carlo algorithm.
References:
Asymptotically Exact, Embarrassingly Parallel MCMC, Willie Neiswanger, Chong Wang, Eric Xing
- Parameters:
subposteriors (list) – a list in which each element is a collection of samples.
num_draws (int) – number of draws from the merged posterior.
diagonal (bool) – whether to compute weights using variance or covariance, defaults to False (using covariance).
rng_key (jax.random.PRNGKey) – source of the randomness, defaults to jax.random.PRNGKey(0).
- Returns:
a collection of num_draws samples with the same data structure as each subposterior.