Markov Chain Monte Carlo (MCMC)¶

class
MCMC
(sampler, num_warmup, num_samples, num_chains=1, postprocess_fn=None, chain_method='parallel', progress_bar=True, jit_model_args=False)[source]¶ Bases:
object
Provides access to Markov Chain Monte Carlo inference algorithms in NumPyro.
Note
chain_method is an experimental arg, which might be removed in a future version.
Note
Setting progress_bar=False will improve the speed for many cases.
Parameters:  sampler (MCMCKernel) – an instance of
MCMCKernel
that determines the sampler for running MCMC. Currently, onlyHMC
andNUTS
are available.  num_warmup (int) – Number of warmup steps.
 num_samples (int) – Number of samples to generate from the Markov chain.
 num_chains (int) – Number of Number of MCMC chains to run. By default,
chains will be run in parallel using
jax.pmap()
, failing which, chains will be run in sequence.  postprocess_fn – Postprocessing callable  used to convert a collection of unconstrained sample values returned from the sampler to constrained values that lie within the support of the sample sites. Additionally, this is used to return values at deterministic sites in the model.
 chain_method (str) – One of ‘parallel’ (default), ‘sequential’, ‘vectorized’. The method ‘parallel’ is used to execute the drawing process in parallel on XLA devices (CPUs/GPUs/TPUs), If there are not enough devices for ‘parallel’, we fall back to ‘sequential’ method to draw chains sequentially. ‘vectorized’ method is an experimental feature which vectorizes the drawing method, hence allowing us to collect samples in parallel on a single device.
 progress_bar (bool) – Whether to enable progress bar updates. Defaults to
True
.  jit_model_args (bool) – If set to True, this will compile the potential energy computation as a function of model arguments. As such, calling MCMC.run again on a same sized but different dataset will not result in additional compilation cost.

warmup
(rng_key, *args, extra_fields=(), collect_warmup=False, init_params=None, **kwargs)[source]¶ Run the MCMC warmup adaptation phase. After this call, the
run()
method will skip the warmup adaptation phase. To run warmup again for the new data, it is required to runwarmup()
again.Parameters:  rng_key (random.PRNGKey) – Random number generator key to be used for the sampling.
 args – Arguments to be provided to the
numpyro.infer.mcmc.MCMCKernel.init()
method. These are typically the arguments needed by the model.  extra_fields (tuple or list) – Extra fields (aside from
default_fields()
) from the state object (e.g.numpyro.infer.mcmc.HMCState
for HMC) to collect during the MCMC run.  collect_warmup (bool) – Whether to collect samples from the warmup phase. Defaults to False.
 init_params – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
 kwargs – Keyword arguments to be provided to the
numpyro.infer.mcmc.MCMCKernel.init()
method. These are typically the keyword arguments needed by the model.

run
(rng_key, *args, extra_fields=(), init_params=None, **kwargs)[source]¶ Run the MCMC samplers and collect samples.
Parameters:  rng_key (random.PRNGKey) – Random number generator key to be used for the sampling. For multichains, a batch of num_chains keys can be supplied. If rng_key does not have batch_size, it will be split in to a batch of num_chains keys.
 args – Arguments to be provided to the
numpyro.infer.mcmc.MCMCKernel.init()
method. These are typically the arguments needed by the model.  extra_fields (tuple or list) – Extra fields (aside from z, diverging) from
numpyro.infer.mcmc.HMCState
to collect during the MCMC run.  init_params – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
 kwargs – Keyword arguments to be provided to the
numpyro.infer.mcmc.MCMCKernel.init()
method. These are typically the keyword arguments needed by the model.
Note
jax allows python code to continue even when the compiled code has not finished yet. This can cause troubles when trying to profile the code for speed. See https://jax.readthedocs.io/en/latest/async_dispatch.html and https://jax.readthedocs.io/en/latest/profiling.html for pointers on profiling jax programs.

get_samples
(group_by_chain=False)[source]¶ Get samples from the MCMC run.
Parameters: group_by_chain (bool) – Whether to preserve the chain dimension. If True, all samples will have num_chains as the size of their leading dimension. Returns: Samples having the same data type as init_params. The data type is a dict keyed on site names if a model containing Pyro primitives is used, but can be any jaxlib.pytree()
, more generally (e.g. when defining a potential_fn for HMC that takes list args).

get_extra_fields
(group_by_chain=False)[source]¶ Get extra fields from the MCMC run.
Parameters: group_by_chain (bool) – Whether to preserve the chain dimension. If True, all samples will have num_chains as the size of their leading dimension. Returns: Extra fields keyed by field names which are specified in the extra_fields keyword of run()
.
 sampler (MCMCKernel) – an instance of
MCMC Kernels¶

class
MCMCKernel
[source]¶ Bases:
abc.ABC
Defines the interface for the Markov transition kernel that is used for
MCMC
inference.Example:
>>> from collections import namedtuple >>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.infer import MCMC >>> MHState = namedtuple("MHState", ["z", "rng_key"]) >>> class MetropolisHastings(numpyro.infer.mcmc.MCMCKernel): ... sample_field = "z" ... ... def __init__(self, potential_fn, step_size=0.1): ... self.potential_fn = potential_fn ... self.step_size = step_size ... ... def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): ... return MHState(init_params, rng_key) ... ... def sample(self, state, model_args, model_kwargs): ... z, rng_key = state ... rng_key, key_proposal, key_accept = random.split(rng_key, 3) ... z_proposal = dist.Normal(z, self.step_size).sample(key_proposal) ... accept_prob = jnp.exp(self.potential_fn(z)  self.potential_fn(z_proposal)) ... z_new = jnp.where(dist.Uniform().sample(key_accept) < accept_prob, z_proposal, z) ... return MHState(z_new, rng_key) >>> def f(x): ... return ((x  2) ** 2).sum() >>> kernel = MetropolisHastings(f) >>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000) >>> mcmc.run(random.PRNGKey(0), init_params=jnp.array([1., 2.])) >>> samples = mcmc.get_samples()

postprocess_fn
(model_args, model_kwargs)[source]¶ Get a function that transforms unconstrained values at sample sites to values constrained to the site’s support, in addition to returning deterministic sites in the model.
Parameters:  model_args – Arguments to the model.
 model_kwargs – Keyword arguments to the model.

init
(rng_key, num_warmup, init_params, model_args, model_kwargs)[source]¶ Initialize the MCMCKernel and return an initial state to begin sampling from.
Parameters:  rng_key (random.PRNGKey) – Random number generator key to initialize the kernel.
 num_warmup (int) – Number of warmup steps. This can be useful when doing adaptation during warmup.
 init_params (tuple) – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
 model_args – Arguments provided to the model.
 model_kwargs – Keyword arguments provided to the model.
Returns: The initial state representing the state of the kernel. This can be any class that is registered as a pytree.

sample
(state, model_args, model_kwargs)[source]¶ Given the current state, return the next state using the given transition kernel.
Parameters: Returns: Next state.

sample_field
¶ The attribute of the state object passed to
sample()
that denotes the MCMC sample. This is used bypostprocess_fn()
and for reporting results inMCMC.print_summary()
.

default_fields
¶ The attributes of the state object to be collected by default during the MCMC run (when
MCMC.run()
is called).


class
HMC
(model=None, potential_fn=None, kinetic_fn=None, step_size=1.0, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, trajectory_length=6.283185307179586, init_strategy=<function init_to_uniform>, find_heuristic_step_size=False)[source]¶ Bases:
numpyro.infer.mcmc.MCMCKernel
Hamiltonian Monte Carlo inference, using fixed trajectory length, with provision for step size and mass matrix adaptation.
References:
 MCMC Using Hamiltonian Dynamics, Radford M. Neal
Parameters:  model – Python callable containing Pyro
primitives
. If model is provided, potential_fn will be inferred using the model.  potential_fn – Python callable that computes the potential energy given input parameters. The input parameters to potential_fn can be any python collection type, provided that init_params argument to init_kernel has the same type.
 kinetic_fn – Python callable that returns the kinetic energy given inverse mass matrix and momentum. If not provided, the default is euclidean kinetic energy.
 step_size (float) – Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1.
 adapt_step_size (bool) – A flag to decide if we want to adapt step_size during warmup phase using Dual Averaging scheme.
 adapt_mass_matrix (bool) – A flag to decide if we want to adapt mass matrix during warmup phase using Welford scheme.
 dense_mass (bool) – A flag to decide if mass matrix is dense or
diagonal (default when
dense_mass=False
)  target_accept_prob (float) – Target acceptance probability for step size adaptation using Dual Averaging. Increasing this value will lead to a smaller step size, hence the sampling will be slower but more robust. Default to 0.8.
 trajectory_length (float) – Length of a MCMC trajectory for HMC. Default value is \(2\pi\).
 init_strategy (callable) – a persite initialization function. See Initialization Strategies section for available functions.
 find_heuristic_step_size (bool) – whether to a heuristic function to adjust the step size at the beginning of each adaptation window. Defaults to False.

model
¶

sample_field
¶ The attribute of the state object passed to
sample()
that denotes the MCMC sample. This is used bypostprocess_fn()
and for reporting results inMCMC.print_summary()
.

default_fields
¶ The attributes of the state object to be collected by default during the MCMC run (when
MCMC.run()
is called).

get_diagnostics_str
(state)[source]¶ Given the current state, returns the diagnostics string to be added to progress bar for diagnostics purpose.

init
(rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={})[source]¶ Initialize the MCMCKernel and return an initial state to begin sampling from.
Parameters:  rng_key (random.PRNGKey) – Random number generator key to initialize the kernel.
 num_warmup (int) – Number of warmup steps. This can be useful when doing adaptation during warmup.
 init_params (tuple) – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
 model_args – Arguments provided to the model.
 model_kwargs – Keyword arguments provided to the model.
Returns: The initial state representing the state of the kernel. This can be any class that is registered as a pytree.

postprocess_fn
(args, kwargs)[source]¶ Get a function that transforms unconstrained values at sample sites to values constrained to the site’s support, in addition to returning deterministic sites in the model.
Parameters:  model_args – Arguments to the model.
 model_kwargs – Keyword arguments to the model.

sample
(state, model_args, model_kwargs)[source]¶ Run HMC from the given
HMCState
and return the resultingHMCState
.Parameters:  state (HMCState) – Represents the current state.
 model_args – Arguments provided to the model.
 model_kwargs – Keyword arguments provided to the model.
Returns: Next state after running HMC.

class
NUTS
(model=None, potential_fn=None, kinetic_fn=None, step_size=1.0, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, trajectory_length=None, max_tree_depth=10, init_strategy=<function init_to_uniform>, find_heuristic_step_size=False)[source]¶ Bases:
numpyro.infer.hmc.HMC
Hamiltonian Monte Carlo inference, using the No UTurn Sampler (NUTS) with adaptive path length and mass matrix adaptation.
References:
 MCMC Using Hamiltonian Dynamics, Radford M. Neal
 The NoUturn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo, Matthew D. Hoffman, and Andrew Gelman.
 A Conceptual Introduction to Hamiltonian Monte Carlo`, Michael Betancourt
Parameters:  model – Python callable containing Pyro
primitives
. If model is provided, potential_fn will be inferred using the model.  potential_fn – Python callable that computes the potential energy given input parameters. The input parameters to potential_fn can be any python collection type, provided that init_params argument to init_kernel has the same type.
 kinetic_fn – Python callable that returns the kinetic energy given inverse mass matrix and momentum. If not provided, the default is euclidean kinetic energy.
 step_size (float) – Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1.
 adapt_step_size (bool) – A flag to decide if we want to adapt step_size during warmup phase using Dual Averaging scheme.
 adapt_mass_matrix (bool) – A flag to decide if we want to adapt mass matrix during warmup phase using Welford scheme.
 dense_mass (bool) – A flag to decide if mass matrix is dense or
diagonal (default when
dense_mass=False
)  target_accept_prob (float) – Target acceptance probability for step size adaptation using Dual Averaging. Increasing this value will lead to a smaller step size, hence the sampling will be slower but more robust. Default to 0.8.
 trajectory_length (float) – Length of a MCMC trajectory for HMC. This arg has no effect in NUTS sampler.
 max_tree_depth (int) – Max depth of the binary tree created during the doubling scheme of NUTS sampler. Defaults to 10.
 init_strategy (callable) – a persite initialization function. See Initialization Strategies section for available functions.
 find_heuristic_step_size (bool) – whether to a heuristic function to adjust the step size at the beginning of each adaptation window. Defaults to False.

class
SA
(model=None, potential_fn=None, adapt_state_size=None, dense_mass=True, init_strategy=<function init_to_uniform>)[source]¶ Bases:
numpyro.infer.mcmc.MCMCKernel
Sample Adaptive MCMC, a gradientfree sampler.
This is a very fast (in term of n_eff / s) sampler but requires many warmup (burnin) steps. In each MCMC step, we only need to evaluate potential function at one point.
Note that unlike in reference [1], we return a randomly selected (i.e. thinned) subset of approximate posterior samples of size num_chains x num_samples instead of num_chains x num_samples x adapt_state_size.
Note
We recommend to use this kernel with progress_bar=False in
MCMC
to reduce JAX’s dispatch overhead.References:
 Sample Adaptive MCMC (https://papers.nips.cc/paper/9107sampleadaptivemcmc), Michael Zhu
Parameters:  model – Python callable containing Pyro
primitives
. If model is provided, potential_fn will be inferred using the model.  potential_fn – Python callable that computes the potential energy given input parameters. The input parameters to potential_fn can be any python collection type, provided that init_params argument to init_kernel has the same type.
 adapt_state_size (int) – The number of points to generate proposal distribution. Defaults to 2 times latent size.
 dense_mass (bool) – A flag to decide if mass matrix is dense or
diagonal (default to
dense_mass=True
)  init_strategy (callable) – a persite initialization function. See Initialization Strategies section for available functions.

init
(rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={})[source]¶ Initialize the MCMCKernel and return an initial state to begin sampling from.
Parameters:  rng_key (random.PRNGKey) – Random number generator key to initialize the kernel.
 num_warmup (int) – Number of warmup steps. This can be useful when doing adaptation during warmup.
 init_params (tuple) – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
 model_args – Arguments provided to the model.
 model_kwargs – Keyword arguments provided to the model.
Returns: The initial state representing the state of the kernel. This can be any class that is registered as a pytree.

sample_field
¶ The attribute of the state object passed to
sample()
that denotes the MCMC sample. This is used bypostprocess_fn()
and for reporting results inMCMC.print_summary()
.

default_fields
¶ The attributes of the state object to be collected by default during the MCMC run (when
MCMC.run()
is called).

get_diagnostics_str
(state)[source]¶ Given the current state, returns the diagnostics string to be added to progress bar for diagnostics purpose.

postprocess_fn
(args, kwargs)[source]¶ Get a function that transforms unconstrained values at sample sites to values constrained to the site’s support, in addition to returning deterministic sites in the model.
Parameters:  model_args – Arguments to the model.
 model_kwargs – Keyword arguments to the model.

sample
(state, model_args, model_kwargs)[source]¶ Run SA from the given
SAState
and return the resultingSAState
.Parameters:  state (SAState) – Represents the current state.
 model_args – Arguments provided to the model.
 model_kwargs – Keyword arguments provided to the model.
Returns: Next state after running SA.

hmc
(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, algo='NUTS')[source]¶ Hamiltonian Monte Carlo inference, using either fixed number of steps or the No UTurn Sampler (NUTS) with adaptive path length.
References:
 MCMC Using Hamiltonian Dynamics, Radford M. Neal
 The NoUturn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo, Matthew D. Hoffman, and Andrew Gelman.
 A Conceptual Introduction to Hamiltonian Monte Carlo`, Michael Betancourt
Parameters:  potential_fn – Python callable that computes the potential energy given input parameters. The input parameters to potential_fn can be any python collection type, provided that init_params argument to init_kernel has the same type.
 potential_fn_gen – Python callable that when provided with model arguments / keyword arguments returns potential_fn. This may be provided to do inference on the same model with changing data. If the data shape remains the same, we can compile sample_kernel once, and use the same for multiple inference runs.
 kinetic_fn – Python callable that returns the kinetic energy given inverse mass matrix and momentum. If not provided, the default is euclidean kinetic energy.
 algo (str) – Whether to run
HMC
with fixed number of steps orNUTS
with adaptive path length. Default isNUTS
.
Returns: a tuple of callables (init_kernel, sample_kernel), the first one to initialize the sampler, and the second one to generate samples given an existing one.
Warning
Instead of using this interface directly, we would highly recommend you to use the higher level
numpyro.infer.MCMC
API instead.Example
>>> import jax >>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.infer.hmc import hmc >>> from numpyro.infer.util import initialize_model >>> from numpyro.util import fori_collect >>> true_coefs = jnp.array([1., 2., 3.]) >>> data = random.normal(random.PRNGKey(2), (2000, 3)) >>> dim = 3 >>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(1)).sample(random.PRNGKey(3)) >>> >>> def model(data, labels): ... coefs_mean = jnp.zeros(dim) ... coefs = numpyro.sample('beta', dist.Normal(coefs_mean, jnp.ones(3))) ... intercept = numpyro.sample('intercept', dist.Normal(0., 10.)) ... return numpyro.sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(1)), obs=labels) >>> >>> model_info = initialize_model(random.PRNGKey(0), model, model_args=(data, labels,)) >>> init_kernel, sample_kernel = hmc(model_info.potential_fn, algo='NUTS') >>> hmc_state = init_kernel(model_info.param_info, ... trajectory_length=10, ... num_warmup=300) >>> samples = fori_collect(0, 500, sample_kernel, hmc_state, ... transform=lambda state: model_info.postprocess_fn(state.z)) >>> print(jnp.mean(samples['beta'], axis=0)) [0.9153987 2.0754058 2.9621222]

init_kernel
(init_params, num_warmup, step_size=1.0, inverse_mass_matrix=None, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, trajectory_length=6.283185307179586, max_tree_depth=10, find_heuristic_step_size=False, model_args=(), model_kwargs=None, rng_key=DeviceArray([0, 0], dtype=uint32))¶ Initializes the HMC sampler.
Parameters:  init_params – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
 num_warmup (int) – Number of warmup steps; samples generated during warmup are discarded.
 step_size (float) – Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1.
 inverse_mass_matrix (numpy.ndarray) – Initial value for inverse mass matrix. This may be adapted during warmup if adapt_mass_matrix = True. If no value is specified, then it is initialized to the identity matrix.
 adapt_step_size (bool) – A flag to decide if we want to adapt step_size during warmup phase using Dual Averaging scheme.
 adapt_mass_matrix (bool) – A flag to decide if we want to adapt mass matrix during warmup phase using Welford scheme.
 dense_mass (bool) – A flag to decide if mass matrix is dense or
diagonal (default when
dense_mass=False
)  target_accept_prob (float) – Target acceptance probability for step size adaptation using Dual Averaging. Increasing this value will lead to a smaller step size, hence the sampling will be slower but more robust. Default to 0.8.
 trajectory_length (float) – Length of a MCMC trajectory for HMC. Default value is \(2\pi\).
 max_tree_depth (int) – Max depth of the binary tree created during the doubling scheme of NUTS sampler. Defaults to 10.
 find_heuristic_step_size (bool) – whether to a heuristic function to adjust the step size at the beginning of each adaptation window. Defaults to False.
 model_args (tuple) – Model arguments if potential_fn_gen is specified.
 model_kwargs (dict) – Model keyword arguments if potential_fn_gen is specified.
 rng_key (jax.random.PRNGKey) – random key to be used as the source of randomness.

sample_kernel
(hmc_state, model_args=(), model_kwargs=None)¶ Given an existing
HMCState
, run HMC with fixed (possibly adapted) step size and return a newHMCState
.Parameters: Returns: new proposed
HMCState
from simulating Hamiltonian dynamics given existing state.

HMCState
= <class 'numpyro.infer.hmc.HMCState'>¶ A
namedtuple()
consisting of the following fields: i  iteration. This is reset to 0 after warmup.
 z  Python collection representing values (unconstrained samples from the posterior) at latent sites.
 z_grad  Gradient of potential energy w.r.t. latent sample sites.
 potential_energy  Potential energy computed at the given value of
z
.  energy  Sum of potential energy and kinetic energy of the current state.
 num_steps  Number of steps in the Hamiltonian trajectory (for diagnostics).
 accept_prob  Acceptance probability of the proposal. Note that
z
does not correspond to the proposal if it is rejected.  mean_accept_prob  Mean acceptance probability until current iteration during warmup adaptation or sampling (for diagnostics).
 diverging  A boolean value to indicate whether the current trajectory is diverging.
 adapt_state  A
HMCAdaptState
namedtuple which contains adaptation information during warmup: step_size  Step size to be used by the integrator in the next iteration.
 inverse_mass_matrix  The inverse mass matrix to be used for the next iteration.
 mass_matrix_sqrt  The square root of mass matrix to be used for the next iteration. In case of dense mass, this is the Cholesky factorization of the mass matrix.
 rng_key  random number generator seed used for the iteration.

SAState
= <class 'numpyro.infer.sa.SAState'>¶ A
namedtuple()
used in Sample Adaptive MCMC. This consists of the following fields: i  iteration. This is reset to 0 after warmup.
 z  Python collection representing values (unconstrained samples from the posterior) at latent sites.
 potential_energy  Potential energy computed at the given value of
z
.  accept_prob  Acceptance probability of the proposal. Note that
z
does not correspond to the proposal if it is rejected.  mean_accept_prob  Mean acceptance probability until current iteration during warmup or sampling (for diagnostics).
 diverging  A boolean value to indicate whether the new sample potential energy is diverging from the current one.
 adapt_state  A
SAAdaptState
namedtuple which contains adaptation information: zs  Step size to be used by the integrator in the next iteration.
 pes  Potential energies of zs.
 loc  Mean of those zs.
 inv_mass_matrix_sqrt  If using dense mass matrix, this is Cholesky of the covariance of zs. Otherwise, this is standard deviation of those zs.
 rng_key  random number generator seed used for the iteration.
MCMC Utilities¶

initialize_model
(rng_key, model, init_strategy=<function init_to_uniform>, dynamic_args=False, model_args=(), model_kwargs=None)[source]¶ (EXPERIMENTAL INTERFACE) Helper function that calls
get_potential_fn()
andfind_valid_initial_params()
under the hood to return a tuple of (init_params_info, potential_fn, postprocess_fn, model_trace).Parameters:  rng_key (jax.random.PRNGKey) – random number generator seed to
sample from the prior. The returned init_params will have the
batch shape
rng_key.shape[:1]
.  model – Python callable containing Pyro primitives.
 init_strategy (callable) – a persite initialization function. See Initialization Strategies section for available functions.
 dynamic_args (bool) – if True, the potential_fn and constraints_fn are themselves dependent on model arguments. When provided a *model_args, **model_kwargs, they return potential_fn and constraints_fn callables, respectively.
 model_args (tuple) – args provided to the model.
 model_kwargs (dict) – kwargs provided to the model.
Returns: a namedtupe ModelInfo which contains the fields (param_info, potential_fn, postprocess_fn, model_trace), where param_info is a namedtuple ParamInfo containing values from the prior used to initiate MCMC, their corresponding potential energy, and their gradients; postprocess_fn is a callable that uses inverse transforms to convert unconstrained HMC samples to constrained values that lie within the site’s support, in addition to returning values at deterministic sites in the model.
 rng_key (jax.random.PRNGKey) – random number generator seed to
sample from the prior. The returned init_params will have the
batch shape

fori_collect
(lower, upper, body_fun, init_val, transform=<function identity>, progbar=True, return_last_val=False, collection_size=None, **progbar_opts)[source]¶ This looping construct works like
fori_loop()
but with the additional effect of collecting values from the loop body. In addition, this allows for postprocessing of these samples via transform, and progress bar updates. Note that, progbar=False will be faster, especially when collecting a lot of samples. Refer to example usage inhmc()
.Parameters:  lower (int) – the index to start the collective work. In other words, we will skip collecting the first lower values.
 upper (int) – number of times to run the loop body.
 body_fun – a callable that takes a collection of np.ndarray and returns a collection with the same shape and dtype.
 init_val – initial value to pass as argument to body_fun. Can be any Python collection type containing np.ndarray objects.
 transform – a callable to postprocess the values returned by body_fn.
 progbar – whether to post progress bar updates.
 return_last_val (bool) – If True, the last value is also returned. This has the same type as init_val.
 collection_size (int) – Size of the returned collection. If not specified,
the size will be
upper  lower
. If the size is larger thanupper  lower
, only the topupper  lower
entries will be nonzero.  **progbar_opts – optional additional progress bar arguments. A diagnostics_fn can be supplied which when passed the current value from body_fun returns a string that is used to update the progress bar postfix. Also a progbar_desc keyword argument can be supplied which is used to label the progress bar.
Returns: collection with the same type as init_val with values collected along the leading axis of np.ndarray objects.

consensus
(subposteriors, num_draws=None, diagonal=False, rng_key=None)[source]¶ Merges subposteriors following consensus Monte Carlo algorithm.
References:
 Bayes and big data: The consensus Monte Carlo algorithm, Steven L. Scott, Alexander W. Blocker, Fernando V. Bonassi, Hugh A. Chipman, Edward I. George, Robert E. McCulloch
Parameters:  subposteriors (list) – a list in which each element is a collection of samples.
 num_draws (int) – number of draws from the merged posterior.
 diagonal (bool) – whether to compute weights using variance or covariance, defaults to False (using covariance).
 rng_key (jax.random.PRNGKey) – source of the randomness, defaults to jax.random.PRNGKey(0).
Returns: if num_draws is None, merges subposteriors without resampling; otherwise, returns a collection of num_draws samples with the same data structure as each subposterior.

parametric
(subposteriors, diagonal=False)[source]¶ Merges subposteriors following (embarrassingly parallel) parametric Monte Carlo algorithm.
References:
 Asymptotically Exact, Embarrassingly Parallel MCMC, Willie Neiswanger, Chong Wang, Eric Xing
Parameters: Returns: the estimated mean and variance/covariance parameters of the joined posterior

parametric_draws
(subposteriors, num_draws, diagonal=False, rng_key=None)[source]¶ Merges subposteriors following (embarrassingly parallel) parametric Monte Carlo algorithm.
References:
 Asymptotically Exact, Embarrassingly Parallel MCMC, Willie Neiswanger, Chong Wang, Eric Xing
Parameters:  subposteriors (list) – a list in which each element is a collection of samples.
 num_draws (int) – number of draws from the merged posterior.
 diagonal (bool) – whether to compute weights using variance or covariance, defaults to False (using covariance).
 rng_key (jax.random.PRNGKey) – source of the randomness, defaults to jax.random.PRNGKey(0).
Returns: a collection of num_draws samples with the same data structure as each subposterior.