Source code for numpyro.primitives

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import namedtuple
from contextlib import ExitStack, contextmanager
import functools
import warnings

import jax
from jax import lax, random
import jax.numpy as jnp

import numpyro
from numpyro.util import find_stack_level, identity

_PYRO_STACK = []

CondIndepStackFrame = namedtuple("CondIndepStackFrame", ["name", "dim", "size"])


def default_process_message(msg):
    if msg["value"] is None:
        if msg["type"] == "sample":
            msg["value"], msg["intermediates"] = msg["fn"](
                *msg["args"], sample_intermediates=True, **msg["kwargs"]
            )
        else:
            msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])


def apply_stack(msg):
    """
    Execute the effect stack at a single site according to the following scheme:

        1. For each ``Messenger`` in the stack from bottom to top,
           execute ``Messenger.process_message`` with the message;
           if the message field "stop" is True, stop;
           otherwise, continue
        2. Apply default behavior (``default_process_message``) to finish remaining
           site execution
        3. For each ``Messenger`` in the stack from top to bottom,
           execute ``Messenger.postprocess_message`` to update the message
           and internal messenger state with the site results
    """
    pointer = 0
    for pointer, handler in enumerate(reversed(_PYRO_STACK)):
        handler.process_message(msg)
        # When a Messenger sets the "stop" field of a message,
        # it prevents any Messengers above it on the stack from being applied.
        if msg.get("stop"):
            break

    default_process_message(msg)

    # A Messenger that sets msg["stop"] == True also prevents application
    # of postprocess_message by Messengers above it on the stack
    # via the pointer variable from the process_message loop
    for handler in _PYRO_STACK[-pointer - 1 :]:
        handler.postprocess_message(msg)
    return msg


class Messenger(object):
    def __init__(self, fn=None):
        if fn is not None and not callable(fn):
            raise ValueError(
                "Expected `fn` to be a Python callable object; "
                "instead found type(fn) = {}.".format(type(fn))
            )
        self.fn = fn
        functools.update_wrapper(self, fn, updated=[])

    def __enter__(self):
        _PYRO_STACK.append(self)

    def __exit__(self, exc_type, exc_value, traceback):
        if exc_type is None:
            assert _PYRO_STACK[-1] is self
            _PYRO_STACK.pop()
        else:
            # NB: this mimics Pyro exception handling
            # the wrapped function or block raised an exception
            # handler exception handling:
            # when the callee or enclosed block raises an exception,
            # find this handler's position in the stack,
            # then remove it and everything below it in the stack.
            if self in _PYRO_STACK:
                loc = _PYRO_STACK.index(self)
                for i in range(loc, len(_PYRO_STACK)):
                    _PYRO_STACK.pop()

    def process_message(self, msg):
        pass

    def postprocess_message(self, msg):
        pass

    def __call__(self, *args, **kwargs):
        if self.fn is None:
            # Assume self is being used as a decorator.
            assert len(args) == 1 and not kwargs
            self.fn = args[0]
            return self
        with self:
            return self.fn(*args, **kwargs)


def _masked_observe(name, fn, obs, obs_mask, **kwargs):
    # Split into two auxiliary sample sites.
    with numpyro.handlers.mask(mask=obs_mask):
        observed = sample(f"{name}_observed", fn, **kwargs, obs=obs)
    with numpyro.handlers.mask(mask=(obs_mask ^ True)):
        unobserved = sample(f"{name}_unobserved", fn, **kwargs)

    # Interleave observed and unobserved events.
    shape = jnp.shape(obs_mask) + (1,) * fn.event_dim
    batch_mask = jnp.reshape(obs_mask, shape)
    value = jnp.where(batch_mask, observed, unobserved)
    return deterministic(name, value)


[docs] def sample( name, fn, obs=None, rng_key=None, sample_shape=(), infer=None, obs_mask=None ): """ Returns a random sample from the stochastic function `fn`. This can have additional side effects when wrapped inside effect handlers like :class:`~numpyro.handlers.substitute`. .. note:: By design, `sample` primitive is meant to be used inside a NumPyro model. Then :class:`~numpyro.handlers.seed` handler is used to inject a random state to `fn`. In those situations, `rng_key` keyword will take no effect. :param str name: name of the sample site. :param fn: a stochastic function that returns a sample. :param jnp.ndarray obs: observed value :param jax.random.PRNGKey rng_key: an optional random key for `fn`. :param sample_shape: Shape of samples to be drawn. :param dict infer: an optional dictionary containing additional information for inference algorithms. For example, if `fn` is a discrete distribution, setting `infer={'enumerate': 'parallel'}` to tell MCMC marginalize this discrete latent site. :param jnp.ndarray obs_mask: Optional boolean array mask of shape broadcastable with ``fn.batch_shape``. If provided, events with mask=True will be conditioned on ``obs`` and remaining events will be imputed by sampling. This introduces a latent sample site named ``name + "_unobserved"`` which should be used by guides in SVI. Note that this argument is not intended to be used with MCMC. :return: sample from the stochastic `fn`. """ assert isinstance( sample_shape, tuple ), "sample_shape needs to be a tuple of integers" if not isinstance(fn, numpyro.distributions.Distribution): type_error = TypeError( "It looks like you tried to use a fn that isn't an instance of " "numpyro.distributions.Distribution, funsor.Funsor or " "tensorflow_probability.distributions.Distribution. If you're using " "funsor or tensorflow_probability, make sure they are correctly installed." ) # fn can be a funsor.Funsor, but this won't be installed for all users try: from funsor import Funsor except ImportError: Funsor = None # if Funsor import failed, or fn is not a Funsor it's also possible fn could be # a tensorflow_probability distribution if Funsor is None or not isinstance(fn, Funsor): try: from tensorflow_probability.substrates.jax import distributions as tfd from numpyro.contrib.tfp.distributions import TFPDistribution except ImportError: # if tensorflow_probability fails to import here, then fn is not a # numpyro Distribution or a Funsor, and it can't have been a tfp # distribution either, so raising TypeError is ok raise type_error if isinstance(fn, tfd.Distribution): with warnings.catch_warnings(): # ignore FutureWarnings when instantiating TFPDistribution warnings.simplefilter("ignore", category=FutureWarning) # if fn is a tfp distribution we need to wrap it fn = TFPDistribution[fn.__class__](**fn.parameters) else: # if tensorflow_probability imported, but fn is not tfd.Distribution we # still need to raise a type error raise type_error # if no active Messengers, draw a sample or return obs as expected: if not _PYRO_STACK: if obs is None: return fn(rng_key=rng_key, sample_shape=sample_shape) else: return obs if obs_mask is not None: return _masked_observe( name, fn, obs, obs_mask, rng_key=rng_key, sample_shape=(), infer=infer ) # Otherwise, we initialize a message... initial_msg = { "type": "sample", "name": name, "fn": fn, "args": (), "kwargs": {"rng_key": rng_key, "sample_shape": sample_shape}, "value": obs, "scale": None, "is_observed": obs is not None, "intermediates": [], "cond_indep_stack": [], "infer": {} if infer is None else infer, } # ...and use apply_stack to send it to the Messengers msg = apply_stack(initial_msg) return msg["value"]
[docs] def param(name, init_value=None, **kwargs): """ Annotate the given site as an optimizable parameter for use with :mod:`jax.example_libraries.optimizers`. For an example of how `param` statements can be used in inference algorithms, refer to :class:`~numpyro.infer.SVI`. :param str name: name of site. :param init_value: initial value specified by the user or a lazy callable that accepts a JAX random PRNGKey and returns an array. Note that the onus of using this to initialize the optimizer is on the user inference algorithm, since there is no global parameter store in NumPyro. :type init_value: jnp.ndarray or callable :param constraint: NumPyro constraint, defaults to ``constraints.real``. :type constraint: numpyro.distributions.constraints.Constraint :param int event_dim: (optional) number of rightmost dimensions unrelated to batching. Dimension to the left of this will be considered batch dimensions; if the param statement is inside a subsampled plate, then corresponding batch dimensions of the parameter will be correspondingly subsampled. If unspecified, all dimensions will be considered event dims and no subsampling will be performed. :return: value for the parameter. Unless wrapped inside a handler like :class:`~numpyro.handlers.substitute`, this will simply return the initial value. """ # if there are no active Messengers, we just draw a sample and return it as expected: if not _PYRO_STACK: assert not callable( init_value ), "A callable init_value needs to be put inside a numpyro.handlers.seed handler." return init_value if callable(init_value): def fn(init_fn, *args, **kwargs): return init_fn(prng_key()) else: fn = identity # Otherwise, we initialize a message... initial_msg = { "type": "param", "name": name, "fn": fn, "args": (init_value,), "kwargs": kwargs, "value": None, "scale": None, "cond_indep_stack": [], } # ...and use apply_stack to send it to the Messengers msg = apply_stack(initial_msg) return msg["value"]
[docs] def deterministic(name, value): """ Used to designate deterministic sites in the model. Note that most effect handlers will not operate on deterministic sites (except :func:`~numpyro.handlers.trace`), so deterministic sites should be side-effect free. The use case for deterministic nodes is to record any values in the model execution trace. :param str name: name of the deterministic site. :param jnp.ndarray value: deterministic value to record in the trace. """ if not _PYRO_STACK: return value initial_msg = { "type": "deterministic", "name": name, "value": value, "cond_indep_stack": [], } # ...and use apply_stack to send it to the Messengers msg = apply_stack(initial_msg) return msg["value"]
def mutable(name, init_value=None): """ This primitive is used to store a mutable value that can be changed during model execution:: a = numpyro.mutable("a", {"value": 1.}) a["value"] = 2. assert numpyro.mutable("a")["value"] == 2. For example, this can be used to store and update information like running mean/variance in a neural network batch normalization layer. :param str name: name of the mutable site. :param init_value: mutable value to record in the trace. """ if not _PYRO_STACK: return init_value initial_msg = { "type": "mutable", "name": name, "fn": identity, "args": (init_value,), "kwargs": {}, "value": init_value, } # ...and use apply_stack to send it to the Messengers msg = apply_stack(initial_msg) return msg["value"] def _inspect(): """ EXPERIMENTAL Inspect the Pyro stack. .. warning:: The format of the returned message may change at any time and does not guarantee backwards compatibility. :returns: A message with mask effects applied. :rtype: dict """ # NB: this is different from Pyro that in Pyro, all effects applied. # Here, we only apply mask effect handler. msg = { "type": "inspect", "fn": lambda: True, "args": (), "kwargs": {}, "value": None, "mask": None, } apply_stack(msg) return msg
[docs] def get_mask(): """ Records the effects of enclosing ``handlers.mask`` handlers. This is useful for avoiding expensive ``numpyro.factor()`` computations during prediction, when the log density need not be computed, e.g.:: def model(): # ... if numpyro.get_mask() is not False: log_density = my_expensive_computation() numpyro.factor("foo", log_density) # ... :returns: The mask. :rtype: None, bool, or jnp.ndarray """ return _inspect()["mask"]
[docs] def module(name, nn, input_shape=None): """ Declare a :mod:`~jax.example_libraries.stax` style neural network inside a model so that its parameters are registered for optimization via :func:`~numpyro.primitives.param` statements. :param str name: name of the module to be registered. :param tuple nn: a tuple of `(init_fn, apply_fn)` obtained by a :mod:`~jax.example_libraries.stax` constructor function. :param tuple input_shape: shape of the input taken by the neural network. :return: a `apply_fn` with bound parameters that takes an array as an input and returns the neural network transformed output array. """ module_key = name + "$params" nn_init, nn_apply = nn nn_params = param(module_key) if nn_params is None: if input_shape is None: raise ValueError("Valid value for `input_shape` needed to initialize.") rng_key = prng_key() _, nn_params = nn_init(rng_key, input_shape) param(module_key, nn_params) return functools.partial(nn_apply, nn_params)
def _subsample_fn(size, subsample_size, rng_key=None): if rng_key is None: raise ValueError( "Missing random key to generate subsample indices." " Algorithms like HMC/NUTS do not support subsampling." " You might want to use SVI or HMCECS instead." ) if jax.default_backend() == "cpu": # ref: https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm rng_keys = random.split(rng_key, subsample_size) def body_fn(val, idx): i_p1 = size - idx i = i_p1 - 1 j = random.randint(rng_keys[idx], (), 0, i_p1) val = val.at[jnp.array([i, j])].set(val[jnp.array([j, i])]) return val, None val, _ = lax.scan(body_fn, jnp.arange(size), jnp.arange(subsample_size)) return val[-subsample_size:] else: return random.choice(rng_key, size, (subsample_size,), replace=False)
[docs] class plate(Messenger): """ Construct for annotating conditionally independent variables. Within a `plate` context manager, `sample` sites will be automatically broadcasted to the size of the plate. Additionally, a scale factor might be applied by certain inference algorithms if `subsample_size` is specified. .. note:: This can be used to subsample minibatches of data: .. code-block:: python with plate("data", len(data), subsample_size=100) as ind: batch = data[ind] assert len(batch) == 100 :param str name: Name of the plate. :param int size: Size of the plate. :param int subsample_size: Optional argument denoting the size of the mini-batch. This can be used to apply a scaling factor by inference algorithms. e.g. when computing ELBO using a mini-batch. :param int dim: Optional argument to specify which dimension in the tensor is used as the plate dim. If `None` (default), the rightmost available dim is allocated. """ def __init__(self, name, size, subsample_size=None, dim=None): self.name = name assert size > 0, "size of plate should be positive" self.size = size if dim is not None and dim >= 0: raise ValueError("dim arg must be negative.") self.dim, self._indices = self._subsample( self.name, self.size, subsample_size, dim ) self.subsample_size = self._indices.shape[0] super(plate, self).__init__() # XXX: different from Pyro, this method returns dim and indices @staticmethod def _subsample(name, size, subsample_size, dim): msg = { "type": "plate", "fn": _subsample_fn, "name": name, "args": (size, subsample_size), "kwargs": {"rng_key": None}, "value": ( None if (subsample_size is not None and size != subsample_size) else jnp.arange(size) ), "scale": 1.0, "cond_indep_stack": [], } apply_stack(msg) subsample = msg["value"] subsample_size = msg["args"][1] if subsample_size is not None and subsample_size != subsample.shape[0]: warnings.warn( "subsample_size does not match len(subsample), {} vs {}.".format( subsample_size, len(subsample) ) + " Did you accidentally use different subsample_size in the model and guide?", stacklevel=find_stack_level(), ) cond_indep_stack = msg["cond_indep_stack"] occupied_dims = {f.dim for f in cond_indep_stack} if dim is None: new_dim = -1 while new_dim in occupied_dims: new_dim -= 1 dim = new_dim else: assert dim not in occupied_dims return dim, subsample def __enter__(self): super().__enter__() return self._indices @staticmethod def _get_batch_shape(cond_indep_stack): n_dims = max(-f.dim for f in cond_indep_stack) batch_shape = [1] * n_dims for f in cond_indep_stack: batch_shape[f.dim] = f.size return tuple(batch_shape) def process_message(self, msg): if msg["type"] not in ("param", "sample", "plate", "deterministic"): if msg["type"] == "control_flow": raise NotImplementedError( "Cannot use control flow primitive under a `plate` primitive." " Please move those `plate` statements into the control flow" " body function. See `scan` documentation for more information." ) return if ( "block_plates" in msg.get("infer", {}) and self.name in msg["infer"]["block_plates"] ): return cond_indep_stack = msg["cond_indep_stack"] frame = CondIndepStackFrame(self.name, self.dim, self.subsample_size) cond_indep_stack.append(frame) if msg["type"] == "deterministic": return if msg["type"] == "sample": expected_shape = self._get_batch_shape(cond_indep_stack) dist_batch_shape = msg["fn"].batch_shape if "sample_shape" in msg["kwargs"]: dist_batch_shape = msg["kwargs"]["sample_shape"] + dist_batch_shape msg["kwargs"]["sample_shape"] = () overlap_idx = max(len(expected_shape) - len(dist_batch_shape), 0) trailing_shape = expected_shape[overlap_idx:] broadcast_shape = lax.broadcast_shapes( trailing_shape, tuple(dist_batch_shape) ) batch_shape = expected_shape[:overlap_idx] + broadcast_shape msg["fn"] = msg["fn"].expand(batch_shape) if self.size != self.subsample_size: scale = 1.0 if msg["scale"] is None else msg["scale"] msg["scale"] = scale * ( self.size / self.subsample_size if self.subsample_size else 1 ) def postprocess_message(self, msg): if msg["type"] in ("subsample", "param") and self.dim is not None: event_dim = msg["kwargs"].get("event_dim") if event_dim is not None: assert event_dim >= 0 dim = self.dim - event_dim shape = jnp.shape(msg["value"]) if len(shape) >= -dim and shape[dim] != 1: if shape[dim] != self.size: if msg["type"] == "param": statement = "numpyro.param({}, ..., event_dim={})".format( msg["name"], event_dim ) else: statement = "numpyro.subsample(..., event_dim={})".format( event_dim ) raise ValueError( "Inside numpyro.plate({}, {}, dim={}) invalid shape of {}: {}".format( self.name, self.size, self.dim, statement, shape ) ) if self.subsample_size < self.size: value = msg["value"] new_value = jnp.take(value, self._indices, dim) msg["value"] = new_value
[docs] @contextmanager def plate_stack(prefix, sizes, rightmost_dim=-1): """ Create a contiguous stack of :class:`plate` s with dimensions:: rightmost_dim - len(sizes), ..., rightmost_dim :param str prefix: Name prefix for plates. :param iterable sizes: An iterable of plate sizes. :param int rightmost_dim: The rightmost dim, counting from the right. """ assert rightmost_dim < 0 with ExitStack() as stack: for i, size in enumerate(reversed(sizes)): plate_i = plate("{}_{}".format(prefix, i), size, dim=rightmost_dim - i) stack.enter_context(plate_i) yield
[docs] def factor(name, log_factor): """ Factor statement to add arbitrary log probability factor to a probabilistic model. :param str name: Name of the trivial sample. :param jnp.ndarray log_factor: A possibly batched log probability factor. """ unit_dist = numpyro.distributions.distribution.Unit(log_factor) unit_value = unit_dist.sample(None) sample(name, unit_dist, obs=unit_value, infer={"is_auxiliary": True})
[docs] def prng_key(): """ A statement to draw a pseudo-random number generator key :func:`~jax.random.PRNGKey` under :class:`~numpyro.handlers.seed` handler. :return: a PRNG key of shape (2,) and dtype unit32. """ if not _PYRO_STACK: warnings.warn( "Cannot generate JAX PRNG key outside of `seed` handler.", stacklevel=find_stack_level(), ) return initial_msg = { "type": "prng_key", "fn": lambda rng_key: rng_key, "args": (), "kwargs": {"rng_key": None}, "value": None, } msg = apply_stack(initial_msg) return msg["value"]
[docs] def subsample(data, event_dim): """ EXPERIMENTAL Subsampling statement to subsample data based on enclosing :class:`~numpyro.primitives.plate` s. This is typically called on arguments to ``model()`` when subsampling is performed automatically by :class:`~numpyro.primitives.plate` s by passing ``subsample_size`` kwarg. For example the following are equivalent:: # Version 1. using indexing def model(data): with numpyro.plate("data", len(data), subsample_size=10, dim=-data.dim()) as ind: data = data[ind] # ... # Version 2. using numpyro.subsample() def model(data): with numpyro.plate("data", len(data), subsample_size=10, dim=-data.dim()): data = numpyro.subsample(data, event_dim=0) # ... :param jnp.ndarray data: A tensor of batched data. :param int event_dim: The event dimension of the data tensor. Dimensions to the left are considered batch dimensions. :returns: A subsampled version of ``data`` :rtype: ~jnp.ndarray """ if not _PYRO_STACK: return data assert isinstance(event_dim, int) and event_dim >= 0 initial_msg = { "type": "subsample", "value": data, "kwargs": {"event_dim": event_dim}, } msg = apply_stack(initial_msg) return msg["value"]