# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from collections import namedtuple
from contextlib import ExitStack, contextmanager
import functools
import warnings
import jax
from jax import lax, random
import jax.numpy as jnp
import numpyro
from numpyro.util import find_stack_level, identity
_PYRO_STACK = []
CondIndepStackFrame = namedtuple("CondIndepStackFrame", ["name", "dim", "size"])
def default_process_message(msg):
if msg["value"] is None:
if msg["type"] == "sample":
msg["value"], msg["intermediates"] = msg["fn"](
*msg["args"], sample_intermediates=True, **msg["kwargs"]
)
else:
msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])
def apply_stack(msg):
"""
Execute the effect stack at a single site according to the following scheme:
1. For each ``Messenger`` in the stack from bottom to top,
execute ``Messenger.process_message`` with the message;
if the message field "stop" is True, stop;
otherwise, continue
2. Apply default behavior (``default_process_message``) to finish remaining
site execution
3. For each ``Messenger`` in the stack from top to bottom,
execute ``Messenger.postprocess_message`` to update the message
and internal messenger state with the site results
"""
pointer = 0
for pointer, handler in enumerate(reversed(_PYRO_STACK)):
handler.process_message(msg)
# When a Messenger sets the "stop" field of a message,
# it prevents any Messengers above it on the stack from being applied.
if msg.get("stop"):
break
default_process_message(msg)
# A Messenger that sets msg["stop"] == True also prevents application
# of postprocess_message by Messengers above it on the stack
# via the pointer variable from the process_message loop
for handler in _PYRO_STACK[-pointer - 1 :]:
handler.postprocess_message(msg)
return msg
class Messenger(object):
def __init__(self, fn=None):
if fn is not None and not callable(fn):
raise ValueError(
"Expected `fn` to be a Python callable object; "
"instead found type(fn) = {}.".format(type(fn))
)
self.fn = fn
functools.update_wrapper(self, fn, updated=[])
def __enter__(self):
_PYRO_STACK.append(self)
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is None:
assert _PYRO_STACK[-1] is self
_PYRO_STACK.pop()
else:
# NB: this mimics Pyro exception handling
# the wrapped function or block raised an exception
# handler exception handling:
# when the callee or enclosed block raises an exception,
# find this handler's position in the stack,
# then remove it and everything below it in the stack.
if self in _PYRO_STACK:
loc = _PYRO_STACK.index(self)
for i in range(loc, len(_PYRO_STACK)):
_PYRO_STACK.pop()
def process_message(self, msg):
pass
def postprocess_message(self, msg):
pass
def __call__(self, *args, **kwargs):
if self.fn is None:
# Assume self is being used as a decorator.
assert len(args) == 1 and not kwargs
self.fn = args[0]
return self
with self:
return self.fn(*args, **kwargs)
def _masked_observe(name, fn, obs, obs_mask, **kwargs):
# Split into two auxiliary sample sites.
with numpyro.handlers.mask(mask=obs_mask):
observed = sample(f"{name}_observed", fn, **kwargs, obs=obs)
with numpyro.handlers.mask(mask=(obs_mask ^ True)):
unobserved = sample(f"{name}_unobserved", fn, **kwargs)
# Interleave observed and unobserved events.
shape = jnp.shape(obs_mask) + (1,) * fn.event_dim
batch_mask = jnp.reshape(obs_mask, shape)
value = jnp.where(batch_mask, observed, unobserved)
return deterministic(name, value)
[docs]
def sample(
name, fn, obs=None, rng_key=None, sample_shape=(), infer=None, obs_mask=None
):
"""
Returns a random sample from the stochastic function `fn`. This can have
additional side effects when wrapped inside effect handlers like
:class:`~numpyro.handlers.substitute`.
.. note::
By design, `sample` primitive is meant to be used inside a NumPyro model.
Then :class:`~numpyro.handlers.seed` handler is used to inject a random
state to `fn`. In those situations, `rng_key` keyword will take no
effect.
:param str name: name of the sample site.
:param fn: a stochastic function that returns a sample.
:param jnp.ndarray obs: observed value
:param jax.random.PRNGKey rng_key: an optional random key for `fn`.
:param sample_shape: Shape of samples to be drawn.
:param dict infer: an optional dictionary containing additional information
for inference algorithms. For example, if `fn` is a discrete distribution,
setting `infer={'enumerate': 'parallel'}` to tell MCMC marginalize
this discrete latent site.
:param jnp.ndarray obs_mask: Optional boolean array mask of shape
broadcastable with ``fn.batch_shape``. If provided, events with
mask=True will be conditioned on ``obs`` and remaining events will be
imputed by sampling. This introduces a latent sample site named ``name
+ "_unobserved"`` which should be used by guides in SVI. Note that this
argument is not intended to be used with MCMC.
:return: sample from the stochastic `fn`.
"""
assert isinstance(
sample_shape, tuple
), "sample_shape needs to be a tuple of integers"
if not isinstance(fn, numpyro.distributions.Distribution):
type_error = TypeError(
"It looks like you tried to use a fn that isn't an instance of "
"numpyro.distributions.Distribution, funsor.Funsor or "
"tensorflow_probability.distributions.Distribution. If you're using "
"funsor or tensorflow_probability, make sure they are correctly installed."
)
# fn can be a funsor.Funsor, but this won't be installed for all users
try:
from funsor import Funsor
except ImportError:
Funsor = None
# if Funsor import failed, or fn is not a Funsor it's also possible fn could be
# a tensorflow_probability distribution
if Funsor is None or not isinstance(fn, Funsor):
try:
from tensorflow_probability.substrates.jax import distributions as tfd
from numpyro.contrib.tfp.distributions import TFPDistribution
except ImportError:
# if tensorflow_probability fails to import here, then fn is not a
# numpyro Distribution or a Funsor, and it can't have been a tfp
# distribution either, so raising TypeError is ok
raise type_error
if isinstance(fn, tfd.Distribution):
with warnings.catch_warnings():
# ignore FutureWarnings when instantiating TFPDistribution
warnings.simplefilter("ignore", category=FutureWarning)
# if fn is a tfp distribution we need to wrap it
fn = TFPDistribution[fn.__class__](**fn.parameters)
else:
# if tensorflow_probability imported, but fn is not tfd.Distribution we
# still need to raise a type error
raise type_error
# if no active Messengers, draw a sample or return obs as expected:
if not _PYRO_STACK:
if obs is None:
return fn(rng_key=rng_key, sample_shape=sample_shape)
else:
return obs
if obs_mask is not None:
return _masked_observe(
name, fn, obs, obs_mask, rng_key=rng_key, sample_shape=(), infer=infer
)
# Otherwise, we initialize a message...
initial_msg = {
"type": "sample",
"name": name,
"fn": fn,
"args": (),
"kwargs": {"rng_key": rng_key, "sample_shape": sample_shape},
"value": obs,
"scale": None,
"is_observed": obs is not None,
"intermediates": [],
"cond_indep_stack": [],
"infer": {} if infer is None else infer,
}
# ...and use apply_stack to send it to the Messengers
msg = apply_stack(initial_msg)
return msg["value"]
[docs]
def param(name, init_value=None, **kwargs):
"""
Annotate the given site as an optimizable parameter for use with
:mod:`jax.example_libraries.optimizers`. For an example of how `param` statements
can be used in inference algorithms, refer to :class:`~numpyro.infer.SVI`.
:param str name: name of site.
:param init_value: initial value specified by the user or a lazy callable
that accepts a JAX random PRNGKey and returns an array.
Note that the onus of using this to initialize the optimizer is
on the user inference algorithm, since there is no global parameter
store in NumPyro.
:type init_value: jnp.ndarray or callable
:param constraint: NumPyro constraint, defaults to ``constraints.real``.
:type constraint: numpyro.distributions.constraints.Constraint
:param int event_dim: (optional) number of rightmost dimensions unrelated
to batching. Dimension to the left of this will be considered batch
dimensions; if the param statement is inside a subsampled plate, then
corresponding batch dimensions of the parameter will be correspondingly
subsampled. If unspecified, all dimensions will be considered event
dims and no subsampling will be performed.
:return: value for the parameter. Unless wrapped inside a
handler like :class:`~numpyro.handlers.substitute`, this will simply
return the initial value.
"""
# if there are no active Messengers, we just draw a sample and return it as expected:
if not _PYRO_STACK:
assert not callable(
init_value
), "A callable init_value needs to be put inside a numpyro.handlers.seed handler."
return init_value
if callable(init_value):
def fn(init_fn, *args, **kwargs):
return init_fn(prng_key())
else:
fn = identity
# Otherwise, we initialize a message...
initial_msg = {
"type": "param",
"name": name,
"fn": fn,
"args": (init_value,),
"kwargs": kwargs,
"value": None,
"scale": None,
"cond_indep_stack": [],
}
# ...and use apply_stack to send it to the Messengers
msg = apply_stack(initial_msg)
return msg["value"]
[docs]
def deterministic(name, value):
"""
Used to designate deterministic sites in the model. Note that most effect
handlers will not operate on deterministic sites (except
:func:`~numpyro.handlers.trace`), so deterministic sites should be
side-effect free. The use case for deterministic nodes is to record any
values in the model execution trace.
:param str name: name of the deterministic site.
:param jnp.ndarray value: deterministic value to record in the trace.
"""
if not _PYRO_STACK:
return value
initial_msg = {
"type": "deterministic",
"name": name,
"value": value,
"cond_indep_stack": [],
}
# ...and use apply_stack to send it to the Messengers
msg = apply_stack(initial_msg)
return msg["value"]
def mutable(name, init_value=None):
"""
This primitive is used to store a mutable value that can be changed
during model execution::
a = numpyro.mutable("a", {"value": 1.})
a["value"] = 2.
assert numpyro.mutable("a")["value"] == 2.
For example, this can be used to store and update information like
running mean/variance in a neural network batch normalization layer.
:param str name: name of the mutable site.
:param init_value: mutable value to record in the trace.
"""
if not _PYRO_STACK:
return init_value
initial_msg = {
"type": "mutable",
"name": name,
"fn": identity,
"args": (init_value,),
"kwargs": {},
"value": init_value,
}
# ...and use apply_stack to send it to the Messengers
msg = apply_stack(initial_msg)
return msg["value"]
def _inspect():
"""
EXPERIMENTAL Inspect the Pyro stack.
.. warning:: The format of the returned message may change at any time and
does not guarantee backwards compatibility.
:returns: A message with mask effects applied.
:rtype: dict
"""
# NB: this is different from Pyro that in Pyro, all effects applied.
# Here, we only apply mask effect handler.
msg = {
"type": "inspect",
"fn": lambda: True,
"args": (),
"kwargs": {},
"value": None,
"mask": None,
}
apply_stack(msg)
return msg
[docs]
def get_mask():
"""
Records the effects of enclosing ``handlers.mask`` handlers.
This is useful for avoiding expensive ``numpyro.factor()`` computations during
prediction, when the log density need not be computed, e.g.::
def model():
# ...
if numpyro.get_mask() is not False:
log_density = my_expensive_computation()
numpyro.factor("foo", log_density)
# ...
:returns: The mask.
:rtype: None, bool, or jnp.ndarray
"""
return _inspect()["mask"]
[docs]
def module(name, nn, input_shape=None):
"""
Declare a :mod:`~jax.example_libraries.stax` style neural network inside a
model so that its parameters are registered for optimization via
:func:`~numpyro.primitives.param` statements.
:param str name: name of the module to be registered.
:param tuple nn: a tuple of `(init_fn, apply_fn)` obtained by a :mod:`~jax.example_libraries.stax`
constructor function.
:param tuple input_shape: shape of the input taken by the
neural network.
:return: a `apply_fn` with bound parameters that takes an array
as an input and returns the neural network transformed output
array.
"""
module_key = name + "$params"
nn_init, nn_apply = nn
nn_params = param(module_key)
if nn_params is None:
if input_shape is None:
raise ValueError("Valid value for `input_shape` needed to initialize.")
rng_key = prng_key()
_, nn_params = nn_init(rng_key, input_shape)
param(module_key, nn_params)
return functools.partial(nn_apply, nn_params)
def _subsample_fn(size, subsample_size, rng_key=None):
if rng_key is None:
raise ValueError(
"Missing random key to generate subsample indices."
" Algorithms like HMC/NUTS do not support subsampling."
" You might want to use SVI or HMCECS instead."
)
if jax.default_backend() == "cpu":
# ref: https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm
rng_keys = random.split(rng_key, subsample_size)
def body_fn(val, idx):
i_p1 = size - idx
i = i_p1 - 1
j = random.randint(rng_keys[idx], (), 0, i_p1)
val = val.at[jnp.array([i, j])].set(val[jnp.array([j, i])])
return val, None
val, _ = lax.scan(body_fn, jnp.arange(size), jnp.arange(subsample_size))
return val[-subsample_size:]
else:
return random.choice(rng_key, size, (subsample_size,), replace=False)
[docs]
class plate(Messenger):
"""
Construct for annotating conditionally independent variables. Within a
`plate` context manager, `sample` sites will be automatically broadcasted to
the size of the plate. Additionally, a scale factor might be applied by
certain inference algorithms if `subsample_size` is specified.
.. note:: This can be used to subsample minibatches of data:
.. code-block:: python
with plate("data", len(data), subsample_size=100) as ind:
batch = data[ind]
assert len(batch) == 100
:param str name: Name of the plate.
:param int size: Size of the plate.
:param int subsample_size: Optional argument denoting the size of the mini-batch.
This can be used to apply a scaling factor by inference algorithms. e.g.
when computing ELBO using a mini-batch.
:param int dim: Optional argument to specify which dimension in the tensor
is used as the plate dim. If `None` (default), the rightmost available dim
is allocated.
"""
def __init__(self, name, size, subsample_size=None, dim=None):
self.name = name
assert size > 0, "size of plate should be positive"
self.size = size
if dim is not None and dim >= 0:
raise ValueError("dim arg must be negative.")
self.dim, self._indices = self._subsample(
self.name, self.size, subsample_size, dim
)
self.subsample_size = self._indices.shape[0]
super(plate, self).__init__()
# XXX: different from Pyro, this method returns dim and indices
@staticmethod
def _subsample(name, size, subsample_size, dim):
msg = {
"type": "plate",
"fn": _subsample_fn,
"name": name,
"args": (size, subsample_size),
"kwargs": {"rng_key": None},
"value": (
None
if (subsample_size is not None and size != subsample_size)
else jnp.arange(size)
),
"scale": 1.0,
"cond_indep_stack": [],
}
apply_stack(msg)
subsample = msg["value"]
subsample_size = msg["args"][1]
if subsample_size is not None and subsample_size != subsample.shape[0]:
warnings.warn(
"subsample_size does not match len(subsample), {} vs {}.".format(
subsample_size, len(subsample)
)
+ " Did you accidentally use different subsample_size in the model and guide?",
stacklevel=find_stack_level(),
)
cond_indep_stack = msg["cond_indep_stack"]
occupied_dims = {f.dim for f in cond_indep_stack}
if dim is None:
new_dim = -1
while new_dim in occupied_dims:
new_dim -= 1
dim = new_dim
else:
assert dim not in occupied_dims
return dim, subsample
def __enter__(self):
super().__enter__()
return self._indices
@staticmethod
def _get_batch_shape(cond_indep_stack):
n_dims = max(-f.dim for f in cond_indep_stack)
batch_shape = [1] * n_dims
for f in cond_indep_stack:
batch_shape[f.dim] = f.size
return tuple(batch_shape)
def process_message(self, msg):
if msg["type"] not in ("param", "sample", "plate", "deterministic"):
if msg["type"] == "control_flow":
raise NotImplementedError(
"Cannot use control flow primitive under a `plate` primitive."
" Please move those `plate` statements into the control flow"
" body function. See `scan` documentation for more information."
)
return
if (
"block_plates" in msg.get("infer", {})
and self.name in msg["infer"]["block_plates"]
):
return
cond_indep_stack = msg["cond_indep_stack"]
frame = CondIndepStackFrame(self.name, self.dim, self.subsample_size)
cond_indep_stack.append(frame)
if msg["type"] == "deterministic":
return
if msg["type"] == "sample":
expected_shape = self._get_batch_shape(cond_indep_stack)
dist_batch_shape = msg["fn"].batch_shape
if "sample_shape" in msg["kwargs"]:
dist_batch_shape = msg["kwargs"]["sample_shape"] + dist_batch_shape
msg["kwargs"]["sample_shape"] = ()
overlap_idx = max(len(expected_shape) - len(dist_batch_shape), 0)
trailing_shape = expected_shape[overlap_idx:]
broadcast_shape = lax.broadcast_shapes(
trailing_shape, tuple(dist_batch_shape)
)
batch_shape = expected_shape[:overlap_idx] + broadcast_shape
msg["fn"] = msg["fn"].expand(batch_shape)
if self.size != self.subsample_size:
scale = 1.0 if msg["scale"] is None else msg["scale"]
msg["scale"] = scale * (
self.size / self.subsample_size if self.subsample_size else 1
)
def postprocess_message(self, msg):
if msg["type"] in ("subsample", "param") and self.dim is not None:
event_dim = msg["kwargs"].get("event_dim")
if event_dim is not None:
assert event_dim >= 0
dim = self.dim - event_dim
shape = jnp.shape(msg["value"])
if len(shape) >= -dim and shape[dim] != 1:
if shape[dim] != self.size:
if msg["type"] == "param":
statement = "numpyro.param({}, ..., event_dim={})".format(
msg["name"], event_dim
)
else:
statement = "numpyro.subsample(..., event_dim={})".format(
event_dim
)
raise ValueError(
"Inside numpyro.plate({}, {}, dim={}) invalid shape of {}: {}".format(
self.name, self.size, self.dim, statement, shape
)
)
if self.subsample_size < self.size:
value = msg["value"]
new_value = jnp.take(value, self._indices, dim)
msg["value"] = new_value
[docs]
@contextmanager
def plate_stack(prefix, sizes, rightmost_dim=-1):
"""
Create a contiguous stack of :class:`plate` s with dimensions::
rightmost_dim - len(sizes), ..., rightmost_dim
:param str prefix: Name prefix for plates.
:param iterable sizes: An iterable of plate sizes.
:param int rightmost_dim: The rightmost dim, counting from the right.
"""
assert rightmost_dim < 0
with ExitStack() as stack:
for i, size in enumerate(reversed(sizes)):
plate_i = plate("{}_{}".format(prefix, i), size, dim=rightmost_dim - i)
stack.enter_context(plate_i)
yield
[docs]
def factor(name, log_factor):
"""
Factor statement to add arbitrary log probability factor to a
probabilistic model.
:param str name: Name of the trivial sample.
:param jnp.ndarray log_factor: A possibly batched log probability factor.
"""
unit_dist = numpyro.distributions.distribution.Unit(log_factor)
unit_value = unit_dist.sample(None)
sample(name, unit_dist, obs=unit_value, infer={"is_auxiliary": True})
[docs]
def prng_key():
"""
A statement to draw a pseudo-random number generator key
:func:`~jax.random.PRNGKey` under :class:`~numpyro.handlers.seed` handler.
:return: a PRNG key of shape (2,) and dtype unit32.
"""
if not _PYRO_STACK:
warnings.warn(
"Cannot generate JAX PRNG key outside of `seed` handler.",
stacklevel=find_stack_level(),
)
return
initial_msg = {
"type": "prng_key",
"fn": lambda rng_key: rng_key,
"args": (),
"kwargs": {"rng_key": None},
"value": None,
}
msg = apply_stack(initial_msg)
return msg["value"]
[docs]
def subsample(data, event_dim):
"""
EXPERIMENTAL Subsampling statement to subsample data based on enclosing
:class:`~numpyro.primitives.plate` s.
This is typically called on arguments to ``model()`` when subsampling is
performed automatically by :class:`~numpyro.primitives.plate` s by passing
``subsample_size`` kwarg. For example the following are equivalent::
# Version 1. using indexing
def model(data):
with numpyro.plate("data", len(data), subsample_size=10, dim=-data.dim()) as ind:
data = data[ind]
# ...
# Version 2. using numpyro.subsample()
def model(data):
with numpyro.plate("data", len(data), subsample_size=10, dim=-data.dim()):
data = numpyro.subsample(data, event_dim=0)
# ...
:param jnp.ndarray data: A tensor of batched data.
:param int event_dim: The event dimension of the data tensor. Dimensions to
the left are considered batch dimensions.
:returns: A subsampled version of ``data``
:rtype: ~jnp.ndarray
"""
if not _PYRO_STACK:
return data
assert isinstance(event_dim, int) and event_dim >= 0
initial_msg = {
"type": "subsample",
"value": data,
"kwargs": {"event_dim": event_dim},
}
msg = apply_stack(initial_msg)
return msg["value"]