from collections import namedtuple
import math
import os
from jax import device_put, lax, partial, random, vmap
from jax.dtypes import canonicalize_dtype
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
from numpyro.infer.hmc_util import (
IntegratorState,
build_tree,
euclidean_kinetic_energy,
find_reasonable_step_size,
velocity_verlet,
warmup_adapter
)
from numpyro.infer.mcmc import MCMCKernel
from numpyro.infer.util import ParamInfo, init_to_uniform, initialize_model
from numpyro.util import cond, fori_loop, identity
HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'energy', 'num_steps', 'accept_prob',
'mean_accept_prob', 'diverging', 'adapt_state', 'rng_key'])
"""
A :func:`~collections.namedtuple` consisting of the following fields:
- **i** - iteration. This is reset to 0 after warmup.
- **z** - Python collection representing values (unconstrained samples from
the posterior) at latent sites.
- **z_grad** - Gradient of potential energy w.r.t. latent sample sites.
- **potential_energy** - Potential energy computed at the given value of ``z``.
- **energy** - Sum of potential energy and kinetic energy of the current state.
- **num_steps** - Number of steps in the Hamiltonian trajectory (for diagnostics).
- **accept_prob** - Acceptance probability of the proposal. Note that ``z``
does not correspond to the proposal if it is rejected.
- **mean_accept_prob** - Mean acceptance probability until current iteration
during warmup adaptation or sampling (for diagnostics).
- **diverging** - A boolean value to indicate whether the current trajectory is diverging.
- **adapt_state** - A ``HMCAdaptState`` namedtuple which contains adaptation information
during warmup:
+ **step_size** - Step size to be used by the integrator in the next iteration.
+ **inverse_mass_matrix** - The inverse mass matrix to be used for the next
iteration.
+ **mass_matrix_sqrt** - The square root of mass matrix to be used for the next
iteration. In case of dense mass, this is the Cholesky factorization of the
mass matrix.
- **rng_key** - random number generator seed used for the iteration.
"""
def _get_num_steps(step_size, trajectory_length):
num_steps = jnp.clip(trajectory_length / step_size, a_min=1)
# NB: casting to jnp.int64 does not take effect (returns jnp.int32 instead)
# if jax_enable_x64 is False
return num_steps.astype(canonicalize_dtype(jnp.int64))
def momentum_generator(prototype_r, mass_matrix_sqrt, rng_key):
_, unpack_fn = ravel_pytree(prototype_r)
eps = random.normal(rng_key, jnp.shape(mass_matrix_sqrt)[:1])
if mass_matrix_sqrt.ndim == 1:
r = jnp.multiply(mass_matrix_sqrt, eps)
return unpack_fn(r)
elif mass_matrix_sqrt.ndim == 2:
r = jnp.dot(mass_matrix_sqrt, eps)
return unpack_fn(r)
else:
raise ValueError("Mass matrix has incorrect number of dims.")
[docs]def hmc(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, algo='NUTS'):
r"""
Hamiltonian Monte Carlo inference, using either fixed number of
steps or the No U-Turn Sampler (NUTS) with adaptive path length.
**References:**
1. *MCMC Using Hamiltonian Dynamics*,
Radford M. Neal
2. *The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo*,
Matthew D. Hoffman, and Andrew Gelman.
3. *A Conceptual Introduction to Hamiltonian Monte Carlo`*,
Michael Betancourt
:param potential_fn: Python callable that computes the potential energy
given input parameters. The input parameters to `potential_fn` can be
any python collection type, provided that `init_params` argument to
`init_kernel` has the same type.
:param potential_fn_gen: Python callable that when provided with model
arguments / keyword arguments returns `potential_fn`. This
may be provided to do inference on the same model with changing data.
If the data shape remains the same, we can compile `sample_kernel`
once, and use the same for multiple inference runs.
:param kinetic_fn: Python callable that returns the kinetic energy given
inverse mass matrix and momentum. If not provided, the default is
euclidean kinetic energy.
:param str algo: Whether to run ``HMC`` with fixed number of steps or ``NUTS``
with adaptive path length. Default is ``NUTS``.
:return: a tuple of callables (`init_kernel`, `sample_kernel`), the first
one to initialize the sampler, and the second one to generate samples
given an existing one.
.. warning::
Instead of using this interface directly, we would highly recommend you
to use the higher level :class:`numpyro.infer.MCMC` API instead.
**Example**
.. doctest::
>>> import jax
>>> from jax import random
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.infer.hmc import hmc
>>> from numpyro.infer.util import initialize_model
>>> from numpyro.util import fori_collect
>>> true_coefs = jnp.array([1., 2., 3.])
>>> data = random.normal(random.PRNGKey(2), (2000, 3))
>>> dim = 3
>>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(3))
>>>
>>> def model(data, labels):
... coefs_mean = jnp.zeros(dim)
... coefs = numpyro.sample('beta', dist.Normal(coefs_mean, jnp.ones(3)))
... intercept = numpyro.sample('intercept', dist.Normal(0., 10.))
... return numpyro.sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)), obs=labels)
>>>
>>> model_info = initialize_model(random.PRNGKey(0), model, model_args=(data, labels,))
>>> init_kernel, sample_kernel = hmc(model_info.potential_fn, algo='NUTS')
>>> hmc_state = init_kernel(model_info.param_info,
... trajectory_length=10,
... num_warmup=300)
>>> samples = fori_collect(0, 500, sample_kernel, hmc_state,
... transform=lambda state: model_info.postprocess_fn(state.z))
>>> print(jnp.mean(samples['beta'], axis=0)) # doctest: +SKIP
[0.9153987 2.0754058 2.9621222]
"""
if kinetic_fn is None:
kinetic_fn = euclidean_kinetic_energy
vv_update = None
trajectory_len = None
max_treedepth = None
wa_update = None
wa_steps = None
max_delta_energy = 1000.
if algo not in {'HMC', 'NUTS'}:
raise ValueError('`algo` must be one of `HMC` or `NUTS`.')
def init_kernel(init_params,
num_warmup,
step_size=1.0,
inverse_mass_matrix=None,
adapt_step_size=True,
adapt_mass_matrix=True,
dense_mass=False,
target_accept_prob=0.8,
trajectory_length=2*math.pi,
max_tree_depth=10,
find_heuristic_step_size=False,
model_args=(),
model_kwargs=None,
rng_key=random.PRNGKey(0)):
"""
Initializes the HMC sampler.
:param init_params: Initial parameters to begin sampling. The type must
be consistent with the input type to `potential_fn`.
:param int num_warmup: Number of warmup steps; samples generated
during warmup are discarded.
:param float step_size: Determines the size of a single step taken by the
verlet integrator while computing the trajectory using Hamiltonian
dynamics. If not specified, it will be set to 1.
:param numpy.ndarray inverse_mass_matrix: Initial value for inverse mass matrix.
This may be adapted during warmup if adapt_mass_matrix = True.
If no value is specified, then it is initialized to the identity matrix.
:param bool adapt_step_size: A flag to decide if we want to adapt step_size
during warm-up phase using Dual Averaging scheme.
:param bool adapt_mass_matrix: A flag to decide if we want to adapt mass
matrix during warm-up phase using Welford scheme.
:param bool dense_mass: A flag to decide if mass matrix is dense or
diagonal (default when ``dense_mass=False``)
:param float target_accept_prob: Target acceptance probability for step size
adaptation using Dual Averaging. Increasing this value will lead to a smaller
step size, hence the sampling will be slower but more robust. Default to 0.8.
:param float trajectory_length: Length of a MCMC trajectory for HMC. Default
value is :math:`2\\pi`.
:param int max_tree_depth: Max depth of the binary tree created during the doubling
scheme of NUTS sampler. Defaults to 10.
:param bool find_heuristic_step_size: whether to a heuristic function to adjust the
step size at the beginning of each adaptation window. Defaults to False.
:param tuple model_args: Model arguments if `potential_fn_gen` is specified.
:param dict model_kwargs: Model keyword arguments if `potential_fn_gen` is specified.
:param jax.random.PRNGKey rng_key: random key to be used as the source of
randomness.
"""
step_size = lax.convert_element_type(step_size, canonicalize_dtype(jnp.float64))
nonlocal wa_update, trajectory_len, max_treedepth, vv_update, wa_steps
wa_steps = num_warmup
trajectory_len = trajectory_length
max_treedepth = max_tree_depth
if isinstance(init_params, ParamInfo):
z, pe, z_grad = init_params
else:
z, pe, z_grad = init_params, None, None
pe_fn = potential_fn
if potential_fn_gen:
if pe_fn is not None:
raise ValueError('Only one of `potential_fn` or `potential_fn_gen` must be provided.')
else:
kwargs = {} if model_kwargs is None else model_kwargs
pe_fn = potential_fn_gen(*model_args, **kwargs)
find_reasonable_ss = None
if find_heuristic_step_size:
find_reasonable_ss = partial(find_reasonable_step_size,
pe_fn,
kinetic_fn,
momentum_generator)
wa_init, wa_update = warmup_adapter(num_warmup,
adapt_step_size=adapt_step_size,
adapt_mass_matrix=adapt_mass_matrix,
dense_mass=dense_mass,
target_accept_prob=target_accept_prob,
find_reasonable_step_size=find_reasonable_ss)
rng_key_hmc, rng_key_wa, rng_key_momentum = random.split(rng_key, 3)
z_info = IntegratorState(z=z, potential_energy=pe, z_grad=z_grad)
wa_state = wa_init(z_info, rng_key_wa, step_size,
inverse_mass_matrix=inverse_mass_matrix,
mass_matrix_size=jnp.size(ravel_pytree(z)[0]))
r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum)
vv_init, vv_update = velocity_verlet(pe_fn, kinetic_fn)
vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad)
energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r)
hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy,
0, 0., 0., False, wa_state, rng_key_hmc)
return device_put(hmc_state)
def _hmc_next(step_size, inverse_mass_matrix, vv_state,
model_args, model_kwargs, rng_key):
if potential_fn_gen:
nonlocal vv_update
pe_fn = potential_fn_gen(*model_args, **model_kwargs)
_, vv_update = velocity_verlet(pe_fn, kinetic_fn)
num_steps = _get_num_steps(step_size, trajectory_len)
vv_state_new = fori_loop(0, num_steps,
lambda i, val: vv_update(step_size, inverse_mass_matrix, val),
vv_state)
energy_old = vv_state.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state.r)
energy_new = vv_state_new.potential_energy + kinetic_fn(inverse_mass_matrix, vv_state_new.r)
delta_energy = energy_new - energy_old
delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy)
accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0)
diverging = delta_energy > max_delta_energy
transition = random.bernoulli(rng_key, accept_prob)
vv_state, energy = cond(transition,
(vv_state_new, energy_new), identity,
(vv_state, energy_old), identity)
return vv_state, energy, num_steps, accept_prob, diverging
def _nuts_next(step_size, inverse_mass_matrix, vv_state,
model_args, model_kwargs, rng_key):
if potential_fn_gen:
nonlocal vv_update
pe_fn = potential_fn_gen(*model_args, **model_kwargs)
_, vv_update = velocity_verlet(pe_fn, kinetic_fn)
binary_tree = build_tree(vv_update, kinetic_fn, vv_state,
inverse_mass_matrix, step_size, rng_key,
max_delta_energy=max_delta_energy,
max_tree_depth=max_treedepth)
accept_prob = binary_tree.sum_accept_probs / binary_tree.num_proposals
num_steps = binary_tree.num_proposals
vv_state = IntegratorState(z=binary_tree.z_proposal,
r=vv_state.r,
potential_energy=binary_tree.z_proposal_pe,
z_grad=binary_tree.z_proposal_grad)
return vv_state, binary_tree.z_proposal_energy, num_steps, accept_prob, binary_tree.diverging
_next = _nuts_next if algo == 'NUTS' else _hmc_next
def sample_kernel(hmc_state, model_args=(), model_kwargs=None):
"""
Given an existing :data:`~numpyro.infer.mcmc.HMCState`, run HMC with fixed (possibly adapted)
step size and return a new :data:`~numpyro.infer.mcmc.HMCState`.
:param hmc_state: Current sample (and associated state).
:param tuple model_args: Model arguments if `potential_fn_gen` is specified.
:param dict model_kwargs: Model keyword arguments if `potential_fn_gen` is specified.
:return: new proposed :data:`~numpyro.infer.mcmc.HMCState` from simulating
Hamiltonian dynamics given existing state.
"""
model_kwargs = {} if model_kwargs is None else model_kwargs
rng_key, rng_key_momentum, rng_key_transition = random.split(hmc_state.rng_key, 3)
r = momentum_generator(hmc_state.z, hmc_state.adapt_state.mass_matrix_sqrt, rng_key_momentum)
vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy, hmc_state.z_grad)
vv_state, energy, num_steps, accept_prob, diverging = _next(hmc_state.adapt_state.step_size,
hmc_state.adapt_state.inverse_mass_matrix,
vv_state,
model_args,
model_kwargs,
rng_key_transition)
# not update adapt_state after warmup phase
adapt_state = cond(hmc_state.i < wa_steps,
(hmc_state.i, accept_prob, vv_state, hmc_state.adapt_state),
lambda args: wa_update(*args),
hmc_state.adapt_state,
identity)
itr = hmc_state.i + 1
n = jnp.where(hmc_state.i < wa_steps, itr, itr - wa_steps)
mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob - hmc_state.mean_accept_prob) / n
return HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, num_steps,
accept_prob, mean_accept_prob, diverging, adapt_state, rng_key)
# Make `init_kernel` and `sample_kernel` visible from the global scope once
# `hmc` is called for sphinx doc generation.
if 'SPHINX_BUILD' in os.environ:
hmc.init_kernel = init_kernel
hmc.sample_kernel = sample_kernel
return init_kernel, sample_kernel
[docs]class HMC(MCMCKernel):
"""
Hamiltonian Monte Carlo inference, using fixed trajectory length, with
provision for step size and mass matrix adaptation.
**References:**
1. *MCMC Using Hamiltonian Dynamics*,
Radford M. Neal
:param model: Python callable containing Pyro :mod:`~numpyro.primitives`.
If model is provided, `potential_fn` will be inferred using the model.
:param potential_fn: Python callable that computes the potential energy
given input parameters. The input parameters to `potential_fn` can be
any python collection type, provided that `init_params` argument to
:meth:`init` has the same type.
:param kinetic_fn: Python callable that returns the kinetic energy given
inverse mass matrix and momentum. If not provided, the default is
euclidean kinetic energy.
:param float step_size: Determines the size of a single step taken by the
verlet integrator while computing the trajectory using Hamiltonian
dynamics. If not specified, it will be set to 1.
:param bool adapt_step_size: A flag to decide if we want to adapt step_size
during warm-up phase using Dual Averaging scheme.
:param bool adapt_mass_matrix: A flag to decide if we want to adapt mass
matrix during warm-up phase using Welford scheme.
:param bool dense_mass: A flag to decide if mass matrix is dense or
diagonal (default when ``dense_mass=False``)
:param float target_accept_prob: Target acceptance probability for step size
adaptation using Dual Averaging. Increasing this value will lead to a smaller
step size, hence the sampling will be slower but more robust. Default to 0.8.
:param float trajectory_length: Length of a MCMC trajectory for HMC. Default
value is :math:`2\\pi`.
:param callable init_strategy: a per-site initialization function.
See :ref:`init_strategy` section for available functions.
:param bool find_heuristic_step_size: whether to a heuristic function to adjust the
step size at the beginning of each adaptation window. Defaults to False.
"""
def __init__(self,
model=None,
potential_fn=None,
kinetic_fn=None,
step_size=1.0,
adapt_step_size=True,
adapt_mass_matrix=True,
dense_mass=False,
target_accept_prob=0.8,
trajectory_length=2 * math.pi,
init_strategy=init_to_uniform,
find_heuristic_step_size=False):
if not (model is None) ^ (potential_fn is None):
raise ValueError('Only one of `model` or `potential_fn` must be specified.')
self._model = model
self._potential_fn = potential_fn
self._kinetic_fn = kinetic_fn if kinetic_fn is not None else euclidean_kinetic_energy
self._step_size = step_size
self._adapt_step_size = adapt_step_size
self._adapt_mass_matrix = adapt_mass_matrix
self._dense_mass = dense_mass
self._target_accept_prob = target_accept_prob
self._trajectory_length = trajectory_length
self._algo = 'HMC'
self._max_tree_depth = 10
self._init_strategy = init_strategy
self._find_heuristic_step_size = find_heuristic_step_size
# Set on first call to init
self._init_fn = None
self._postprocess_fn = None
self._sample_fn = None
def _init_state(self, rng_key, model_args, model_kwargs, init_params):
if self._model is not None:
init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
rng_key,
self._model,
dynamic_args=True,
init_strategy=self._init_strategy,
model_args=model_args,
model_kwargs=model_kwargs)
if self._init_fn is None:
self._init_fn, self._sample_fn = hmc(potential_fn_gen=potential_fn,
kinetic_fn=self._kinetic_fn,
algo=self._algo)
self._postprocess_fn = postprocess_fn
elif self._init_fn is None:
self._init_fn, self._sample_fn = hmc(potential_fn=self._potential_fn,
kinetic_fn=self._kinetic_fn,
algo=self._algo)
return init_params
@property
def model(self):
return self._model
@property
def sample_field(self):
return 'z'
@property
def default_fields(self):
return ('z', 'diverging')
[docs] def get_diagnostics_str(self, state):
return '{} steps of size {:.2e}. acc. prob={:.2f}'.format(state.num_steps,
state.adapt_state.step_size,
state.mean_accept_prob)
[docs] def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={}):
# non-vectorized
if rng_key.ndim == 1:
rng_key, rng_key_init_model = random.split(rng_key)
# vectorized
else:
rng_key, rng_key_init_model = jnp.swapaxes(vmap(random.split)(rng_key), 0, 1)
init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, init_params)
if self._potential_fn and init_params is None:
raise ValueError('Valid value of `init_params` must be provided with'
' `potential_fn`.')
hmc_init_fn = lambda init_params, rng_key: self._init_fn( # noqa: E731
init_params,
num_warmup=num_warmup,
step_size=self._step_size,
adapt_step_size=self._adapt_step_size,
adapt_mass_matrix=self._adapt_mass_matrix,
dense_mass=self._dense_mass,
target_accept_prob=self._target_accept_prob,
trajectory_length=self._trajectory_length,
max_tree_depth=self._max_tree_depth,
find_heuristic_step_size=self._find_heuristic_step_size,
model_args=model_args,
model_kwargs=model_kwargs,
rng_key=rng_key,
)
if rng_key.ndim == 1:
init_state = hmc_init_fn(init_params, rng_key)
else:
# XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
# nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth,
# wa_steps because those variables do not depend on traced args: init_params, rng_key.
init_state = vmap(hmc_init_fn)(init_params, rng_key)
sample_fn = vmap(self._sample_fn, in_axes=(0, None, None))
self._sample_fn = sample_fn
return init_state
[docs] def postprocess_fn(self, args, kwargs):
if self._postprocess_fn is None:
return identity
return self._postprocess_fn(*args, **kwargs)
[docs] def sample(self, state, model_args, model_kwargs):
"""
Run HMC from the given :data:`~numpyro.infer.hmc.HMCState` and return the resulting
:data:`~numpyro.infer.hmc.HMCState`.
:param HMCState state: Represents the current state.
:param model_args: Arguments provided to the model.
:param model_kwargs: Keyword arguments provided to the model.
:return: Next `state` after running HMC.
"""
return self._sample_fn(state, model_args, model_kwargs)
[docs]class NUTS(HMC):
"""
Hamiltonian Monte Carlo inference, using the No U-Turn Sampler (NUTS)
with adaptive path length and mass matrix adaptation.
**References:**
1. *MCMC Using Hamiltonian Dynamics*,
Radford M. Neal
2. *The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo*,
Matthew D. Hoffman, and Andrew Gelman.
3. *A Conceptual Introduction to Hamiltonian Monte Carlo`*,
Michael Betancourt
:param model: Python callable containing Pyro :mod:`~numpyro.primitives`.
If model is provided, `potential_fn` will be inferred using the model.
:param potential_fn: Python callable that computes the potential energy
given input parameters. The input parameters to `potential_fn` can be
any python collection type, provided that `init_params` argument to
`init_kernel` has the same type.
:param kinetic_fn: Python callable that returns the kinetic energy given
inverse mass matrix and momentum. If not provided, the default is
euclidean kinetic energy.
:param float step_size: Determines the size of a single step taken by the
verlet integrator while computing the trajectory using Hamiltonian
dynamics. If not specified, it will be set to 1.
:param bool adapt_step_size: A flag to decide if we want to adapt step_size
during warm-up phase using Dual Averaging scheme.
:param bool adapt_mass_matrix: A flag to decide if we want to adapt mass
matrix during warm-up phase using Welford scheme.
:param bool dense_mass: A flag to decide if mass matrix is dense or
diagonal (default when ``dense_mass=False``)
:param float target_accept_prob: Target acceptance probability for step size
adaptation using Dual Averaging. Increasing this value will lead to a smaller
step size, hence the sampling will be slower but more robust. Default to 0.8.
:param float trajectory_length: Length of a MCMC trajectory for HMC. This arg has
no effect in NUTS sampler.
:param int max_tree_depth: Max depth of the binary tree created during the doubling
scheme of NUTS sampler. Defaults to 10.
:param callable init_strategy: a per-site initialization function.
See :ref:`init_strategy` section for available functions.
:param bool find_heuristic_step_size: whether to a heuristic function to adjust the
step size at the beginning of each adaptation window. Defaults to False.
"""
def __init__(self,
model=None,
potential_fn=None,
kinetic_fn=None,
step_size=1.0,
adapt_step_size=True,
adapt_mass_matrix=True,
dense_mass=False,
target_accept_prob=0.8,
trajectory_length=None,
max_tree_depth=10,
init_strategy=init_to_uniform,
find_heuristic_step_size=False):
super(NUTS, self).__init__(potential_fn=potential_fn, model=model, kinetic_fn=kinetic_fn,
step_size=step_size, adapt_step_size=adapt_step_size,
adapt_mass_matrix=adapt_mass_matrix, dense_mass=dense_mass,
target_accept_prob=target_accept_prob,
trajectory_length=trajectory_length,
init_strategy=init_strategy,
find_heuristic_step_size=find_heuristic_step_size)
self._max_tree_depth = max_tree_depth
self._algo = 'NUTS'