"""
This provides a small set of effect handlers in NumPyro that are modeled
after Pyro's `poutine <http://docs.pyro.ai/en/stable/poutine.html>`_ module.
For a tutorial on effect handlers more generally, readers are encouraged to
read `Poutine: A Guide to Programming with Effect Handlers in Pyro
<http://pyro.ai/examples/effect_handlers.html>`_. These simple effect handlers
can be composed together or new ones added to enable implementation of custom
inference utilities and algorithms.
**Example**
As an example, we are using :class:`~numpyro.handlers.seed`, :class:`~numpyro.handlers.trace`
and :class:`~numpyro.handlers.substitute` handlers to define the `log_likelihood` function below.
We first create a logistic regression model and sample from the posterior distribution over
the regression parameters using :func:`~numpyro.mcmc.mcmc`. The `log_likelihood` function
uses effect handlers to run the model by substituting sample sites with values from the posterior
distribution and computes the log density for a single data point. The `expected_log_likelihood`
function computes the log likelihood for each draw from the joint posterior and aggregates the
results, but does so by using JAX's auto-vectorize transform called `vmap` so that we do not
need to loop over all the data points.
.. testsetup::
import jax.numpy as np
from jax import random, vmap
from jax.scipy.special import logsumexp
import numpyro.distributions as dist
from numpyro.handlers import sample, seed, substitute, trace
from numpyro.hmc_util import initialize_model
from numpyro.mcmc import mcmc
.. doctest::
>>> N, D = 3000, 3
>>> def logistic_regression(data, labels):
... coefs = sample('coefs', dist.Normal(np.zeros(D), np.ones(D)))
... intercept = sample('intercept', dist.Normal(0., 10.))
... logits = np.sum(coefs * data + intercept, axis=-1)
... return sample('obs', dist.Bernoulli(logits=logits), obs=labels)
>>> data = random.normal(random.PRNGKey(0), (N, D))
>>> true_coefs = np.arange(1., D + 1.)
>>> logits = np.sum(true_coefs * data, axis=-1)
>>> labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))
>>> init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2), logistic_regression, data, labels)
>>> num_warmup, num_samples = 1000, 1000
>>> samples = mcmc(num_warmup, num_samples, init_params,
... potential_fn=potential_fn,
... constrain_fn=constrain_fn) # doctest: +SKIP
warmup: 100%|██████████| 1000/1000 [00:09<00:00, 109.40it/s, 1 steps of size 5.83e-01. acc. prob=0.79]
sample: 100%|██████████| 1000/1000 [00:00<00:00, 1252.39it/s, 1 steps of size 5.83e-01. acc. prob=0.85]
mean sd 5.5% 94.5% n_eff Rhat
coefs[0] 0.96 0.07 0.85 1.07 455.35 1.01
coefs[1] 2.05 0.09 1.91 2.20 332.00 1.01
coefs[2] 3.18 0.13 2.96 3.37 320.27 1.00
intercept -0.03 0.02 -0.06 0.00 402.53 1.00
>>> def log_likelihood(rng, params, model, *args, **kwargs):
... model = substitute(seed(model, rng), params)
... model_trace = trace(model).get_trace(*args, **kwargs)
... obs_node = model_trace['obs']
... return np.sum(obs_node['fn'].log_prob(obs_node['value']))
>>> def expected_log_likelihood(rng, params, model, *args, **kwargs):
... n = list(params.values())[0].shape[0]
... log_lk_fn = vmap(lambda rng, params: log_likelihood(rng, params, model, *args, **kwargs))
... log_lk_vals = log_lk_fn(random.split(rng, n), params)
... return logsumexp(log_lk_vals) - np.log(n)
>>> print(expected_log_likelihood(random.PRNGKey(2), samples, logistic_regression, data, labels)) # doctest: +SKIP
-876.172
"""
from __future__ import absolute_import, division, print_function
from collections import OrderedDict
from jax import random
_PYRO_STACK = []
class Messenger(object):
def __init__(self, fn=None):
self.fn = fn
def __enter__(self):
_PYRO_STACK.append(self)
def __exit__(self, *args, **kwargs):
assert _PYRO_STACK[-1] is self
_PYRO_STACK.pop()
def process_message(self, msg):
pass
def postprocess_message(self, msg):
pass
def __call__(self, *args, **kwargs):
with self:
return self.fn(*args, **kwargs)
[docs]class trace(Messenger):
"""
Returns a handler that records the inputs and outputs at primitive calls
inside `fn`.
**Example**
.. testsetup::
from jax import random
import numpyro.distributions as dist
from numpyro.handlers import sample, seed, trace
import pprint as pp
.. doctest::
>>> def model():
... sample('a', dist.Normal(0., 1.))
>>> exec_trace = trace(seed(model, random.PRNGKey(0))).get_trace()
>>> pp.pprint(exec_trace) # doctest: +SKIP
OrderedDict([('a',
{'args': (),
'fn': <numpyro.distributions.continuous.Normal object at 0x7f9e689b1eb8>,
'is_observed': False,
'kwargs': {'random_state': DeviceArray([0, 0], dtype=uint32)},
'name': 'a',
'type': 'sample',
'value': DeviceArray(-0.20584235, dtype=float32)})])
"""
def __enter__(self):
super(trace, self).__enter__()
self.trace = OrderedDict()
return self.trace
[docs] def postprocess_message(self, msg):
assert msg['name'] not in self.trace, 'all sites must have unique names'
self.trace[msg['name']] = msg.copy()
[docs] def get_trace(self, *args, **kwargs):
"""
Run the wrapped callable and return the recorded trace.
:param `*args`: arguments to the callable.
:param `**kwargs`: keyword arguments to the callable.
:return: `OrderedDict` containing the execution trace.
"""
self(*args, **kwargs)
return self.trace
[docs]class replay(Messenger):
"""
Given a callable `fn` and an execution trace `guide_trace`,
return a callable which substitutes `sample` calls in `fn` with
values from the corresponding site names in `guide_trace`.
:param fn: Python callable with NumPyro primitives.
:param guide_trace: an OrderedDict containing execution metadata.
**Example**
.. testsetup::
from jax import random
import numpyro.distributions as dist
from numpyro.handlers import replay, sample, seed, trace
.. doctest::
>>> def model():
... sample('a', dist.Normal(0., 1.))
>>> exec_trace = trace(seed(model, random.PRNGKey(0))).get_trace()
>>> print(exec_trace['a']['value']) # doctest: +SKIP
-0.20584235
>>> replayed_trace = trace(replay(model, exec_trace)).get_trace()
>>> print(exec_trace['a']['value']) # doctest: +SKIP
-0.20584235
>>> assert replayed_trace['a']['value'] == exec_trace['a']['value']
"""
def __init__(self, fn, guide_trace):
self.guide_trace = guide_trace
super(replay, self).__init__(fn)
[docs] def process_message(self, msg):
if msg['name'] in self.guide_trace:
msg['value'] = self.guide_trace[msg['name']]['value']
[docs]class block(Messenger):
"""
Given a callable `fn`, return another callable that selectively hides
primitive sites where `hide_fn` returns True from other effect handlers
on the stack.
:param fn: Python callable with NumPyro primitives.
:param hide_fn: function which when given a dictionary containing
site-level metadata returns whether it should be blocked.
**Example:**
.. testsetup::
from jax import random
from numpyro.handlers import block, sample, seed, trace
import numpyro.distributions as dist
.. doctest::
>>> def model():
... a = sample('a', dist.Normal(0., 1.))
... return sample('b', dist.Normal(a, 1.))
>>> model = seed(model, random.PRNGKey(0))
>>> block_all = block(model)
>>> block_a = block(model, lambda site: site['name'] == 'a')
>>> trace_block_all = trace(block_all).get_trace()
>>> assert not {'a', 'b'}.intersection(trace_block_all.keys())
>>> trace_block_a = trace(block_a).get_trace()
>>> assert 'a' not in trace_block_a
>>> assert 'b' in trace_block_a
"""
def __init__(self, fn=None, hide_fn=lambda msg: True):
self.hide_fn = hide_fn
super(block, self).__init__(fn)
[docs] def process_message(self, msg):
if self.hide_fn(msg):
msg['stop'] = True
[docs]class seed(Messenger):
"""
JAX uses a functional pseudo random number generator that requires passing
in a seed :func:`~jax.random.PRNGKey` to every stochastic function. The
`seed` handler allows us to initially seed a stochastic function with a
:func:`~jax.random.PRNGKey`. Every call to the :func:`~numpyro.handlers.sample`
primitive inside the function results in a splitting of this initial seed
so that we use a fresh seed for each subsequent call without having to
explicitly pass in a `PRNGKey` to each `sample` call.
"""
def __init__(self, fn, rng):
self.rng = rng
super(seed, self).__init__(fn)
[docs] def process_message(self, msg):
if msg['type'] == 'sample':
msg['kwargs']['random_state'] = self.rng
self.rng, = random.split(self.rng, 1)
[docs]class substitute(Messenger):
"""
Given a callable `fn` and a dict `param_map` keyed by site names,
return a callable which substitutes all primitive calls in `fn` with
values from `param_map` whose key matches the site name. If the
site name is not present in `param_map`, there is no side effect.
:param fn: Python callable with NumPyro primitives.
:param dict param_map: dictionary of `numpy.ndarray` values keyed by
site names.
**Example:**
.. testsetup::
from jax import random
from numpyro.handlers import sample, seed, substitute, trace
import numpyro.distributions as dist
.. doctest::
>>> def model():
... sample('a', dist.Normal(0., 1.))
>>> model = seed(model, random.PRNGKey(0))
>>> exec_trace = trace(substitute(model, {'a': -1})).get_trace()
>>> assert exec_trace['a']['value'] == -1
"""
def __init__(self, fn=None, param_map=None):
self.param_map = param_map
super(substitute, self).__init__(fn)
[docs] def process_message(self, msg):
if msg['name'] in self.param_map:
msg['value'] = self.param_map[msg['name']]
def apply_stack(msg):
pointer = 0
for pointer, handler in enumerate(reversed(_PYRO_STACK)):
handler.process_message(msg)
# When a Messenger sets the "stop" field of a message,
# it prevents any Messengers above it on the stack from being applied.
if msg.get("stop"):
break
if msg['value'] is None:
msg['value'] = msg['fn'](*msg['args'], **msg['kwargs'])
# A Messenger that sets msg["stop"] == True also prevents application
# of postprocess_message by Messengers above it on the stack
# via the pointer variable from the process_message loop
for handler in _PYRO_STACK[-pointer-1:]:
handler.postprocess_message(msg)
return msg
[docs]def sample(name, fn, obs=None):
"""
Returns a random sample from the stochastic function `fn`. This can have
additional side effects when wrapped inside effect handlers like
:class:`~numpyro.handlers.substitute`.
:param str name: name of the sample site
:param fn: Python callable
:param numpy.ndarray obs: observed value
:return: sample from the stochastic `fn`.
"""
# if there are no active Messengers, we just draw a sample and return it as expected:
if not _PYRO_STACK:
return fn()
# Otherwise, we initialize a message...
initial_msg = {
'type': 'sample',
'name': name,
'fn': fn,
'args': (),
'kwargs': {},
'value': obs,
'is_observed': obs is not None,
}
# ...and use apply_stack to send it to the Messengers
msg = apply_stack(initial_msg)
return msg['value']
def identity(x):
return x
[docs]def param(name, init_value):
"""
Annotate the given site as an optimizable parameter for use with
:mod:`jax.experimental.optimizers`. For an example of how `param` statements
can be used in inference algorithms, refer to :func:`~numpyro.svi.svi`.
:param str name: name of site.
:param numpy.ndarray init_value: initial value specified by the user. Note that
the onus of using this to initialize the optimizer is on the user /
inference algorithm, since there is no global parameter store in
NumPyro.
:return: value for the parameter. Unless wrapped inside a
handler like :class:`~numpyro.handlers.substitute`, this will simply
return the initial value.
"""
# if there are no active Messengers, we just draw a sample and return it as expected:
if not _PYRO_STACK:
return init_value
# Otherwise, we initialize a message...
initial_msg = {
'type': 'param',
'name': name,
'fn': identity,
'args': (init_value,),
'kwargs': {},
'value': None,
}
# ...and use apply_stack to send it to the Messengers
msg = apply_stack(initial_msg)
return msg['value']