Markov Chain Monte Carlo (MCMC)

Hamiltonian Monte Carlo

mcmc(num_warmup, num_samples, init_params, sampler='hmc', constrain_fn=None, print_summary=True, **sampler_kwargs)[source]

Convenience wrapper for MCMC samplers – runs warmup, prints diagnostic summary and returns a collections of samples from the posterior.

Parameters:
  • num_warmup – Number of warmup steps.
  • num_samples – Number of samples to generate from the Markov chain.
  • init_params – Initial parameters to begin sampling. The type can must be consistent with the input type to potential_fn.
  • sampler – currently, only hmc is implemented (default).
  • constrain_fn – Callable that converts a collection of unconstrained sample values returned from the sampler to constrained values that lie within the support of the sample sites.
  • print_summary – Whether to print diagnostics summary for each sample site. Default is True.
  • **sampler_kwargs

    Sampler specific keyword arguments.

    • HMC: Refer to hmc() and init_kernel() for accepted arguments. Note that all arguments must be provided as keywords.
Returns:

collection of samples from the posterior.

>>> true_coefs = np.array([1., 2., 3.])
>>> data = random.normal(random.PRNGKey(2), (2000, 3))
>>> dim = 3
>>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(3))
>>>
>>> def model(data, labels):
...     coefs_mean = np.zeros(dim)
...     coefs = sample('beta', dist.Normal(coefs_mean, np.ones(3)))
...     intercept = sample('intercept', dist.Normal(0., 10.))
...     return sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)), obs=labels)
>>>
>>> init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(0), model,
...                                                            data, labels)
>>> num_warmup, num_samples = 1000, 1000
>>> samples = mcmc(num_warmup, num_samples, init_params,
...                potential_fn=potential_fn,
...                constrain_fn=constrain_fn)  
warmup: 100%|██████████| 1000/1000 [00:09<00:00, 109.40it/s, 1 steps of size 5.83e-01. acc. prob=0.79]
sample: 100%|██████████| 1000/1000 [00:00<00:00, 1252.39it/s, 1 steps of size 5.83e-01. acc. prob=0.85]


                   mean         sd       5.5%      94.5%      n_eff       Rhat
    coefs[0]       0.96       0.07       0.85       1.07     455.35       1.01
    coefs[1]       2.05       0.09       1.91       2.20     332.00       1.01
    coefs[2]       3.18       0.13       2.96       3.37     320.27       1.00
   intercept      -0.03       0.02      -0.06       0.00     402.53       1.00
hmc(potential_fn, kinetic_fn=None, algo='NUTS')[source]

Hamiltonian Monte Carlo inference, using either fixed number of steps or the No U-Turn Sampler (NUTS) with adaptive path length.

References:

  1. MCMC Using Hamiltonian Dynamics, Radford M. Neal
  2. The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo, Matthew D. Hoffman, and Andrew Gelman.
  3. A Conceptual Introduction to Hamiltonian Monte Carlo`, Michael Betancourt
Parameters:
  • potential_fn – Python callable that computes the potential energy given input parameters. The input parameters to potential_fn can be any python collection type, provided that init_params argument to init_kernel has the same type.
  • kinetic_fn – Python callable that returns the kinetic energy given inverse mass matrix and momentum. If not provided, the default is euclidean kinetic energy.
  • algo (str) – Whether to run HMC with fixed number of steps or NUTS with adaptive path length. Default is NUTS.
Returns:

a tuple of callables (init_kernel, sample_kernel), the first one to initialize the sampler, and the second one to generate samples given an existing one.

Example

>>> true_coefs = np.array([1., 2., 3.])
>>> data = random.normal(random.PRNGKey(2), (2000, 3))
>>> dim = 3
>>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(3))
>>>
>>> def model(data, labels):
...     coefs_mean = np.zeros(dim)
...     coefs = sample('beta', dist.Normal(coefs_mean, np.ones(3)))
...     intercept = sample('intercept', dist.Normal(0., 10.))
...     return sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)), obs=labels)
>>>
>>> init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(0),
...                                                            model, data, labels)
>>> init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')
>>> hmc_state = init_kernel(init_params,
...                         trajectory_length=10,
...                         num_warmup=300)
>>> samples = fori_collect(500, sample_kernel, hmc_state,
...                        transform=lambda state: constrain_fn(state.z))
>>> print(np.mean(samples['beta'], axis=0))  
[0.9153987 2.0754058 2.9621222]
init_kernel(init_params, num_warmup, step_size=1.0, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, trajectory_length=6.283185307179586, max_tree_depth=10, run_warmup=True, progbar=True, rng=DeviceArray([0, 0], dtype=uint32))

Initializes the HMC sampler.

Parameters:
  • init_params – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
  • num_warmup_steps (int) – Number of warmup steps; samples generated during warmup are discarded.
  • step_size (float) – Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1.
  • adapt_step_size (bool) – A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme.
  • adapt_mass_matrix (bool) – A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme.
  • dense_mass (bool) – A flag to decide if mass matrix is dense or diagonal (default when dense_mass=False)
  • target_accept_prob (float) – Target acceptance probability for step size adaptation using Dual Averaging. Increasing this value will lead to a smaller step size, hence the sampling will be slower but more robust. Default to 0.8.
  • trajectory_length (float) – Length of a MCMC trajectory for HMC. Default value is \(2\pi\).
  • max_tree_depth (int) – Max depth of the binary tree created during the doubling scheme of NUTS sampler. Defaults to 10.
  • run_warmup (bool) – Flag to decide whether warmup is run. If True, init_kernel returns an initial HMCState that can be used to generate samples using MCMC. Else, returns the arguments and callable that does the initial adaptation.
  • progbar (bool) – Whether to enable progress bar updates. Defaults to True.
  • heuristic_step_size (bool) – If True, a coarse grained adjustment of step size is done at the beginning of each adaptation window to achieve target_acceptance_prob.
  • rng (jax.random.PRNGKey) – random key to be used as the source of randomness.
sample_kernel(hmc_state)

Given an existing HMCState, run HMC with fixed (possibly adapted) step size and return a new HMCState.

Parameters:hmc_state – Current sample (and associated state).
Returns:new proposed HMCState from simulating Hamiltonian dynamics given existing state.
HMCState = <class 'numpyro.mcmc.HMCState'>

A namedtuple() consisting of the following fields:

  • i - iteration. This is reset to 0 after warmup.
  • z - Python collection representing values (unconstrained samples from the posterior) at latent sites.
  • z_grad - Gradient of potential energy w.r.t. latent sample sites.
  • potential_energy - Potential energy computed at the given value of z.
  • num_steps - Number of steps in the Hamiltonian trajectory (for diagnostics).
  • accept_prob - Acceptance probability of the proposal. Note that z does not correspond to the proposal if it is rejected.
  • mean_accept_prob - Mean acceptance probability until current iteration during warmup adaptation or sampling (for diagnostics).
  • step_size - Step size to be used by the integrator in the next iteration. This is adapted during warmup.
  • inverse_mass_matrix - The inverse mass matrix to be be used for the next iteration. This is adapted during warmup.
  • rng - random number generator seed used for the iteration.

MCMC Utilities

initialize_model(rng, model, *model_args, init_strategy='uniform', **model_kwargs)[source]

Given a model with Pyro primitives, returns a function which, given unconstrained parameters, evaluates the potential energy (negative joint density). In addition, this also returns initial parameters sampled from the prior to initiate MCMC sampling and functions to transform unconstrained values at sample sites to constrained values within their respective support.

Parameters:
  • rng (jax.random.PRNGKey) – random number generator seed to sample from the prior.
  • model – Python callable containing Pyro primitives.
  • *model_args – args provided to the model.
  • init_strategy (str) – initialization strategy - uniform initializes the unconstrained parameters by drawing from a Uniform(-2, 2) distribution (as used by Stan), whereas prior initializes the parameters by sampling from the prior for each of the sample sites.
  • **model_kwargs – kwargs provided to the model.
Returns:

tuple of (init_params, potential_fn, constrain_fn), init_params are values from the prior used to initiate MCMC, constrain_fn is a callable that uses inverse transforms to convert unconstrained HMC samples to constrained values that lie within the site’s support.

fori_collect(n, body_fun, init_val, transform=<function identity>, progbar=True, **progbar_opts)[source]

This looping construct works like fori_loop() but with the additional effect of collecting values from the loop body. In addition, this allows for post-processing of these samples via transform, and progress bar updates. Note that, in some cases, progbar=False can be faster, when collecting a lot of samples. Refer to example usage in hmc().

Parameters:
  • n (int) – number of times to run the loop body.
  • body_fun – a callable that takes a collection of np.ndarray and returns a collection with the same shape and dtype.
  • init_val – initial value to pass as argument to body_fun. Can be any Python collection type containing np.ndarray objects.
  • transform – A callable
  • progbar – whether to post progress bar updates.
  • **progbar_opts – optional additional progress bar arguments. A diagnostics_fn can be supplied which when passed the current value from body_fun returns a string that is used to update the progress bar postfix. Also a progbar_desc keyword argument can be supplied which is used to label the progress bar.
Returns:

collection with the same type as init_val with values collected along the leading axis of np.ndarray objects.

summary(samples, prob=0.89)[source]

Prints a summary table displaying diagnostics of samples from the posterior. The diagnostics displayed are mean, standard deviation, the 89% Credibility Interval, effective_sample_size() split_gelman_rubin().

Parameters:
  • samples – a collection of input samples.
  • prob (float) – the probability mass of samples within the HPDI interval.