import math
import os
from collections import namedtuple
import tqdm
import jax.numpy as np
from jax import jit, partial, random
from jax.flatten_util import ravel_pytree
from jax.random import PRNGKey
from jax.tree_util import register_pytree_node
from numpyro.diagnostics import summary
from numpyro.hmc_util import IntegratorState, build_tree, find_reasonable_step_size, velocity_verlet, warmup_adapter
from numpyro.util import cond, fori_collect, fori_loop, identity
HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'num_steps', 'accept_prob',
'mean_accept_prob', 'step_size', 'inverse_mass_matrix', 'mass_matrix_sqrt',
'rng'])
"""
A :func:`~collections.namedtuple` consisting of the following fields:
- **i** - iteration. This is reset to 0 after warmup.
- **z** - Python collection representing values (unconstrained samples from
the posterior) at latent sites.
- **z_grad** - Gradient of potential energy w.r.t. latent sample sites.
- **potential_energy** - Potential energy computed at the given value of ``z``.
- **num_steps** - Number of steps in the Hamiltonian trajectory (for diagnostics).
- **accept_prob** - Acceptance probability of the proposal. Note that ``z``
does not correspond to the proposal if it is rejected.
- **mean_accept_prob** - Mean acceptance probability until current iteration
during warmup adaptation or sampling (for diagnostics).
- **step_size** - Step size to be used by the integrator in the next iteration.
This is adapted during warmup.
- **inverse_mass_matrix** - The inverse mass matrix to be be used for the next
iteration. This is adapted during warmup.
- **rng** - random number generator seed used for the iteration.
"""
register_pytree_node(
HMCState,
lambda xs: (tuple(xs), None),
lambda _, xs: HMCState(*xs)
)
HMCState.update = HMCState._replace
def _get_num_steps(step_size, trajectory_length):
num_steps = np.clip(trajectory_length / step_size, a_min=1)
# NB: casting to np.int64 does not take effect (returns np.int32 instead)
# if jax_enable_x64 is False
return num_steps.astype(np.int64)
def _sample_momentum(unpack_fn, mass_matrix_sqrt, rng):
eps = random.normal(rng, np.shape(mass_matrix_sqrt)[:1])
if mass_matrix_sqrt.ndim == 1:
r = np.multiply(mass_matrix_sqrt, eps)
return unpack_fn(r)
elif mass_matrix_sqrt.ndim == 2:
r = np.dot(mass_matrix_sqrt, eps)
return unpack_fn(r)
else:
raise ValueError("Mass matrix has incorrect number of dims.")
def _euclidean_ke(inverse_mass_matrix, r):
r, _ = ravel_pytree(r)
if inverse_mass_matrix.ndim == 2:
v = np.matmul(inverse_mass_matrix, r)
elif inverse_mass_matrix.ndim == 1:
v = np.multiply(inverse_mass_matrix, r)
return 0.5 * np.dot(v, r)
def get_diagnostics_str(hmc_state):
return '{} steps of size {:.2e}. acc. prob={:.2f}'.format(hmc_state.num_steps,
hmc_state.step_size,
hmc_state.mean_accept_prob)
[docs]def hmc(potential_fn, kinetic_fn=None, algo='NUTS'):
r"""
Hamiltonian Monte Carlo inference, using either fixed number of
steps or the No U-Turn Sampler (NUTS) with adaptive path length.
**References:**
1. *MCMC Using Hamiltonian Dynamics*,
Radford M. Neal
2. *The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo*,
Matthew D. Hoffman, and Andrew Gelman.
3. *A Conceptual Introduction to Hamiltonian Monte Carlo`*,
Michael Betancourt
:param potential_fn: Python callable that computes the potential energy
given input parameters. The input parameters to `potential_fn` can be
any python collection type, provided that `init_params` argument to
`init_kernel` has the same type.
:param kinetic_fn: Python callable that returns the kinetic energy given
inverse mass matrix and momentum. If not provided, the default is
euclidean kinetic energy.
:param str algo: Whether to run ``HMC`` with fixed number of steps or ``NUTS``
with adaptive path length. Default is ``NUTS``.
:return: a tuple of callables (`init_kernel`, `sample_kernel`), the first
one to initialize the sampler, and the second one to generate samples
given an existing one.
**Example**
.. testsetup::
import jax
from jax import random
import jax.numpy as np
import numpyro.distributions as dist
from numpyro.handlers import sample
from numpyro.hmc_util import initialize_model
from numpyro.mcmc import hmc
from numpyro.util import fori_collect
.. doctest::
>>> true_coefs = np.array([1., 2., 3.])
>>> data = random.normal(random.PRNGKey(2), (2000, 3))
>>> dim = 3
>>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(3))
>>>
>>> def model(data, labels):
... coefs_mean = np.zeros(dim)
... coefs = sample('beta', dist.Normal(coefs_mean, np.ones(3)))
... intercept = sample('intercept', dist.Normal(0., 10.))
... return sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)), obs=labels)
>>>
>>> init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(0),
... model, data, labels)
>>> init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')
>>> hmc_state = init_kernel(init_params,
... trajectory_length=10,
... num_warmup=300)
>>> samples = fori_collect(500, sample_kernel, hmc_state,
... transform=lambda state: constrain_fn(state.z))
>>> print(np.mean(samples['beta'], axis=0)) # doctest: +SKIP
[0.9153987 2.0754058 2.9621222]
"""
if kinetic_fn is None:
kinetic_fn = _euclidean_ke
vv_init, vv_update = velocity_verlet(potential_fn, kinetic_fn)
trajectory_len = None
max_treedepth = None
momentum_generator = None
wa_update = None
if algo not in {'HMC', 'NUTS'}:
raise ValueError('`algo` must be one of `HMC` or `NUTS`.')
def init_kernel(init_params,
num_warmup,
step_size=1.0,
adapt_step_size=True,
adapt_mass_matrix=True,
dense_mass=False,
target_accept_prob=0.8,
trajectory_length=2*math.pi,
max_tree_depth=10,
run_warmup=True,
progbar=True,
rng=PRNGKey(0)):
"""
Initializes the HMC sampler.
:param init_params: Initial parameters to begin sampling. The type must
be consistent with the input type to `potential_fn`.
:param int num_warmup_steps: Number of warmup steps; samples generated
during warmup are discarded.
:param float step_size: Determines the size of a single step taken by the
verlet integrator while computing the trajectory using Hamiltonian
dynamics. If not specified, it will be set to 1.
:param bool adapt_step_size: A flag to decide if we want to adapt step_size
during warm-up phase using Dual Averaging scheme.
:param bool adapt_mass_matrix: A flag to decide if we want to adapt mass
matrix during warm-up phase using Welford scheme.
:param bool dense_mass: A flag to decide if mass matrix is dense or
diagonal (default when ``dense_mass=False``)
:param float target_accept_prob: Target acceptance probability for step size
adaptation using Dual Averaging. Increasing this value will lead to a smaller
step size, hence the sampling will be slower but more robust. Default to 0.8.
:param float trajectory_length: Length of a MCMC trajectory for HMC. Default
value is :math:`2\\pi`.
:param int max_tree_depth: Max depth of the binary tree created during the doubling
scheme of NUTS sampler. Defaults to 10.
:param bool run_warmup: Flag to decide whether warmup is run. If ``True``,
`init_kernel` returns an initial :data:`HMCState` that can be used to
generate samples using MCMC. Else, returns the arguments and callable
that does the initial adaptation.
:param bool progbar: Whether to enable progress bar updates. Defaults to
``True``.
:param bool heuristic_step_size: If ``True``, a coarse grained adjustment of
step size is done at the beginning of each adaptation window to achieve
`target_acceptance_prob`.
:param jax.random.PRNGKey rng: random key to be used as the source of
randomness.
"""
step_size = float(step_size)
nonlocal momentum_generator, wa_update, trajectory_len, max_treedepth
trajectory_len = float(trajectory_length)
max_treedepth = max_tree_depth
z = init_params
z_flat, unravel_fn = ravel_pytree(z)
momentum_generator = partial(_sample_momentum, unravel_fn)
find_reasonable_ss = partial(find_reasonable_step_size,
potential_fn, kinetic_fn,
momentum_generator)
wa_init, wa_update = warmup_adapter(num_warmup,
adapt_step_size=adapt_step_size,
adapt_mass_matrix=adapt_mass_matrix,
dense_mass=dense_mass,
target_accept_prob=target_accept_prob,
find_reasonable_step_size=find_reasonable_ss)
rng_hmc, rng_wa = random.split(rng)
wa_state = wa_init(z, rng_wa, step_size, mass_matrix_size=np.size(z_flat))
r = momentum_generator(wa_state.mass_matrix_sqrt, rng)
vv_state = vv_init(z, r)
hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, 0, 0., 0.,
wa_state.step_size, wa_state.inverse_mass_matrix, wa_state.mass_matrix_sqrt,
rng_hmc)
wa_update = jit(wa_update)
if run_warmup:
# JIT if progress bar updates not required
if not progbar:
hmc_state, _ = jit(fori_loop, static_argnums=(2,))(0, num_warmup,
warmup_update,
(hmc_state, wa_state))
else:
with tqdm.trange(num_warmup, desc='warmup') as t:
for i in t:
hmc_state, wa_state = warmup_update(i, (hmc_state, wa_state))
# TODO: set refresh=True when its performance issue is resolved
t.set_postfix_str(get_diagnostics_str(hmc_state), refresh=False)
# Reset `i` and `mean_accept_prob` for fresh diagnostics.
hmc_state.update(i=0, mean_accept_prob=0)
return hmc_state
else:
return hmc_state, wa_state, warmup_update
def warmup_update(t, states):
hmc_state, wa_state = states
hmc_state = sample_kernel(hmc_state)
wa_state = wa_update(t, hmc_state.accept_prob, hmc_state.z, wa_state)
hmc_state = hmc_state.update(step_size=wa_state.step_size,
inverse_mass_matrix=wa_state.inverse_mass_matrix,
mass_matrix_sqrt=wa_state.mass_matrix_sqrt)
return hmc_state, wa_state
def _hmc_next(step_size, inverse_mass_matrix, vv_state, rng):
num_steps = _get_num_steps(step_size, trajectory_len)
vv_state_new = fori_loop(0, num_steps,
lambda i, val: vv_update(step_size, inverse_mass_matrix, val),
vv_state)
energy_old = vv_state.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state.r)
energy_new = vv_state_new.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state_new.r)
delta_energy = energy_new - energy_old
delta_energy = np.where(np.isnan(delta_energy), np.inf, delta_energy)
accept_prob = np.clip(np.exp(-delta_energy), a_max=1.0)
transition = random.bernoulli(rng, accept_prob)
vv_state = cond(transition,
vv_state_new, lambda state: state,
vv_state, lambda state: state)
return vv_state, num_steps, accept_prob
def _nuts_next(step_size, inverse_mass_matrix, vv_state, rng):
binary_tree = build_tree(vv_update, kinetic_fn, vv_state,
inverse_mass_matrix, step_size, rng,
max_tree_depth=max_treedepth)
accept_prob = binary_tree.sum_accept_probs / binary_tree.num_proposals
num_steps = binary_tree.num_proposals
vv_state = vv_state.update(z=binary_tree.z_proposal,
potential_energy=binary_tree.z_proposal_pe,
z_grad=binary_tree.z_proposal_grad)
return vv_state, num_steps, accept_prob
_next = _nuts_next if algo == 'NUTS' else _hmc_next
@jit
def sample_kernel(hmc_state):
"""
Given an existing :data:`HMCState`, run HMC with fixed (possibly adapted)
step size and return a new :data:`HMCState`.
:param hmc_state: Current sample (and associated state).
:return: new proposed :data:`HMCState` from simulating
Hamiltonian dynamics given existing state.
"""
rng, rng_momentum, rng_transition = random.split(hmc_state.rng, 3)
r = momentum_generator(hmc_state.mass_matrix_sqrt, rng_momentum)
vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy, hmc_state.z_grad)
vv_state, num_steps, accept_prob = _next(hmc_state.step_size,
hmc_state.inverse_mass_matrix,
vv_state, rng_transition)
itr = hmc_state.i + 1
mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob - hmc_state.mean_accept_prob) / itr
return HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, num_steps,
accept_prob, mean_accept_prob, hmc_state.step_size, hmc_state.inverse_mass_matrix,
hmc_state.mass_matrix_sqrt, rng)
# Make `init_kernel` and `sample_kernel` visible from the global scope once
# `hmc` is called for sphinx doc generation.
if 'SPHINX_BUILD' in os.environ:
hmc.init_kernel = init_kernel
hmc.sample_kernel = sample_kernel
return init_kernel, sample_kernel
[docs]def mcmc(num_warmup, num_samples, init_params, sampler='hmc',
constrain_fn=None, print_summary=True, **sampler_kwargs):
"""
Convenience wrapper for MCMC samplers -- runs warmup, prints
diagnostic summary and returns a collections of samples
from the posterior.
:param num_warmup: Number of warmup steps.
:param num_samples: Number of samples to generate from the Markov chain.
:param init_params: Initial parameters to begin sampling. The type can
must be consistent with the input type to `potential_fn`.
:param sampler: currently, only `hmc` is implemented (default).
:param constrain_fn: Callable that converts a collection of unconstrained
sample values returned from the sampler to constrained values that
lie within the support of the sample sites.
:param print_summary: Whether to print diagnostics summary for
each sample site. Default is ``True``.
:param `**sampler_kwargs`: Sampler specific keyword arguments.
- *HMC*: Refer to :func:`~numpyro.mcmc.hmc` and
:func:`~numpyro.mcmc.hmc.init_kernel` for accepted arguments. Note
that all arguments must be provided as keywords.
:return: collection of samples from the posterior.
.. testsetup::
import jax
from jax import random
import jax.numpy as np
import numpyro.distributions as dist
from numpyro.handlers import sample
from numpyro.hmc_util import initialize_model
from numpyro.mcmc import hmc
from numpyro.util import fori_collect
.. doctest::
>>> true_coefs = np.array([1., 2., 3.])
>>> data = random.normal(random.PRNGKey(2), (2000, 3))
>>> dim = 3
>>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(3))
>>>
>>> def model(data, labels):
... coefs_mean = np.zeros(dim)
... coefs = sample('beta', dist.Normal(coefs_mean, np.ones(3)))
... intercept = sample('intercept', dist.Normal(0., 10.))
... return sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)), obs=labels)
>>>
>>> init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(0), model,
... data, labels)
>>> num_warmup, num_samples = 1000, 1000
>>> samples = mcmc(num_warmup, num_samples, init_params,
... potential_fn=potential_fn,
... constrain_fn=constrain_fn) # doctest: +SKIP
warmup: 100%|██████████| 1000/1000 [00:09<00:00, 109.40it/s, 1 steps of size 5.83e-01. acc. prob=0.79]
sample: 100%|██████████| 1000/1000 [00:00<00:00, 1252.39it/s, 1 steps of size 5.83e-01. acc. prob=0.85]
mean sd 5.5% 94.5% n_eff Rhat
coefs[0] 0.96 0.07 0.85 1.07 455.35 1.01
coefs[1] 2.05 0.09 1.91 2.20 332.00 1.01
coefs[2] 3.18 0.13 2.96 3.37 320.27 1.00
intercept -0.03 0.02 -0.06 0.00 402.53 1.00
"""
if sampler == 'hmc':
if constrain_fn is None:
constrain_fn = identity
potential_fn = sampler_kwargs.pop('potential_fn')
kinetic_fn = sampler_kwargs.pop('kinetic_fn', None)
algo = sampler_kwargs.pop('algo', 'NUTS')
progbar = sampler_kwargs.pop('progbar', True)
init_kernel, sample_kernel = hmc(potential_fn, kinetic_fn, algo)
hmc_state = init_kernel(init_params, num_warmup, progbar=progbar, **sampler_kwargs)
samples = fori_collect(num_samples, sample_kernel, hmc_state,
transform=lambda x: constrain_fn(x.z),
progbar=progbar,
diagnostics_fn=get_diagnostics_str,
progbar_desc='sample')
if print_summary:
summary(samples)
return samples
else:
raise ValueError('sampler: {} not recognized'.format(sampler))