Source code for numpyro.contrib.control_flow.scan

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import OrderedDict
from functools import partial

from jax import lax, random, tree_flatten, tree_map, tree_multimap, tree_unflatten
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class

from numpyro import handlers
from numpyro.primitives import _PYRO_STACK, Messenger, apply_stack
from numpyro.util import not_jax_tracer

class PytreeTrace:
    def __init__(self, trace):
        self.trace = trace

    def tree_flatten(self):
        trace, aux_trace = {}, {}
        for name, site in self.trace.items():
            if site['type'] in ['sample', 'deterministic']:
                trace[name], aux_trace[name] = {}, {'_control_flow_done': True}
                for key in site:
                    if key in ['fn', 'args', 'value', 'intermediates']:
                        trace[name][key] = site[key]
                    # scanned sites have stop field because we trace them inside a block handler
                    elif key != 'stop':
                        aux_trace[name][key] = site[key]
        return (trace,), aux_trace

    def tree_unflatten(cls, aux_data, children):
        trace, = children
        for name, site in trace.items():
        return cls(trace)

def _subs_wrapper(subs_map, i, length, site):
    value = None
    if isinstance(subs_map, dict) and site['name'] in subs_map:
        value = subs_map[site['name']]
    elif callable(subs_map):
        rng_key = site['kwargs'].get('rng_key')
        subs_map = handlers.seed(subs_map, rng_seed=rng_key) if rng_key is not None else subs_map
        value = subs_map(site)

    if value is not None:
        value_ndim = jnp.ndim(value)
        sample_shape = site['kwargs']['sample_shape']
        fn_ndim = len(sample_shape + site['fn'].shape())
        if value_ndim == fn_ndim:
            # this branch happens when substitute_fn is init_strategy,
            # where we apply init_strategy to each element in the scanned series
            return value
        elif value_ndim == fn_ndim + 1:
            # this branch happens when we substitute a series of values
            shape = jnp.shape(value)
            if shape[0] == length:
                return value[i]
            elif shape[0] < length:
                rng_key = site['kwargs']['rng_key']
                assert rng_key is not None
                # we use the substituted values if i < shape[0]
                # and generate a new sample otherwise
                return lax.cond(i < shape[0],
                                (value, i),
                                lambda val: val[0][val[1]],
                                lambda val: site['fn'](rng_key=val, sample_shape=sample_shape))
                raise RuntimeError(f"Substituted value for site {site['name']} "
                                   "requires length less than or equal to scan length."
                                   f" Expected length <= {length}, but got {shape[0]}.")
            raise RuntimeError(f"Something goes wrong. Expected ndim = {fn_ndim} or {fn_ndim+1},"
                               f" but got {value_ndim}. This might happen when you use nested scan,"
                               " which is currently not supported. Please report the issue to us!")

class promote_shapes(Messenger):
    # a helper messenger to promote shapes of `fn` and `value`
    #   + msg: fn.batch_shape = (2, 3), value.shape = (3,) + fn.event_shape
    #     process_message(msg): promote value so that value.shape = (1, 3) + fn.event_shape
    #   + msg: fn.batch_shape = (3,), value.shape = (2, 3) + fn.event_shape
    #     process_message(msg): promote fn so that fn.batch_shape = (1, 3).
    def process_message(self, msg):
        if msg["type"] == "sample" and msg["value"] is not None:
            fn, value = msg["fn"], msg["value"]
            value_batch_ndims = jnp.ndim(value) - fn.event_dim
            fn_batch_ndim = len(fn.batch_shape)
            prepend_shapes = (1,) * abs(fn_batch_ndim - value_batch_ndims)
            if fn_batch_ndim > value_batch_ndims:
                msg["value"] = jnp.reshape(value, prepend_shapes + jnp.shape(value))
            elif fn_batch_ndim < value_batch_ndims:
                msg["fn"] = tree_map(lambda x: jnp.reshape(x, prepend_shapes + jnp.shape(x)), fn)

def scan_enum(f, init, xs, length, reverse, rng_key=None, substitute_stack=None):
    from numpyro.contrib.funsor import enum, config_enumerate, markov, trace as packed_trace

    # XXX: This implementation only works for history size=1 but can be
    # extended to history size > 1 by running `f` `history_size` times
    # for initialization. However, `sequential_sum_product` does not
    # support history size > 1, so we skip supporting it here.
    # Note that `funsor.sum_product.sarkka_bilmes_product` does support history > 1.
    if reverse:
        x0 = tree_map(lambda x: x[-1], xs)
        xs_ = tree_map(lambda x: x[:-1], xs)
        x0 = tree_map(lambda x: x[0], xs)
        xs_ = tree_map(lambda x: x[1:], xs)

    carry_shape_at_t1 = None

    def body_fn(wrapped_carry, x, prefix=None):
        i, rng_key, carry = wrapped_carry
        init = True if (not_jax_tracer(i) and i == 0) else False
        rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None)

        seeded_fn = handlers.seed(f, subkey) if subkey is not None else f
        for subs_type, subs_map in substitute_stack:
            subs_fn = partial(_subs_wrapper, subs_map, i, length)
            if subs_type == 'condition':
                seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn)
            elif subs_type == 'substitute':
                seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn)

        if init:
            with handlers.scope(prefix="_init"):
                new_carry, y = seeded_fn(carry, x)
                trace = {}
            with handlers.block(), packed_trace() as trace, promote_shapes(), enum(), markov():
                # Like scan_wrapper, we collect the trace of scan's transition function
                # `seeded_fn` here. To put time dimension to the correct position, we need to
                # promote shapes to make `fn` and `value`
                # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`,
                # and value's batch_shape is (3,), then we promote shape of
                # value so that its batch shape is (1, 3)).
                new_carry, y = config_enumerate(seeded_fn)(carry, x)

            # store shape of new_carry at a global variable
            nonlocal carry_shape_at_t1
            carry_shape_at_t1 = [jnp.shape(x) for x in tree_flatten(new_carry)[0]]
            # make new_carry have the same shape as carry
            # FIXME: is this rigorous?
            new_carry = tree_multimap(lambda a, b: jnp.reshape(a, jnp.shape(b)),
                                      new_carry, carry)
        return (i + jnp.array(1), rng_key, new_carry), (PytreeTrace(trace), y)

    with markov():
        wrapped_carry = (0, rng_key, init)
        wrapped_carry, (_, y0) = body_fn(wrapped_carry, x0)
        if length == 1:
            ys = tree_map(lambda x: jnp.expand_dims(x, 0), y0)
            return wrapped_carry, (PytreeTrace({}), ys)
        wrapped_carry, (pytree_trace, ys) = lax.scan(body_fn, wrapped_carry, xs_, length - 1, reverse)

    first_var = None
    for name, site in pytree_trace.trace.items():
        # add `time` dimension, the name will be '_time_{first variable in the trace}'
        if first_var is None:
            first_var = name
        leftmost_dim = min(site['infer']['dim_to_name'])
        site['infer']['dim_to_name'][leftmost_dim - 1] = '_time_{}'.format(first_var)

    # similar to carry, we need to reshape due to shape alternating in markov
    ys = tree_multimap(lambda z0, z: jnp.reshape(z, z.shape[:1] + jnp.shape(z0)), y0, ys)
    # we also need to reshape `carry` to match sequential behavior
    if length % 2 == 0:
        t, rng_key, carry = wrapped_carry
        flatten_carry, treedef = tree_flatten(carry)
        flatten_carry = [jnp.reshape(x, t1_shape)
                         for x, t1_shape in zip(flatten_carry, carry_shape_at_t1)]
        carry = tree_unflatten(treedef, flatten_carry)
        wrapped_carry = (t, rng_key, carry)
    return wrapped_carry, (pytree_trace, ys)

def scan_wrapper(f, init, xs, length, reverse, rng_key=None, substitute_stack=[], enum=False):
    if length is None:
        length = tree_flatten(xs)[0][0].shape[0]

    if enum:
        return scan_enum(f, init, xs, length, reverse, rng_key, substitute_stack)

    def body_fn(wrapped_carry, x):
        i, rng_key, carry = wrapped_carry
        rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None)

        with handlers.block():
            seeded_fn = handlers.seed(f, subkey) if subkey is not None else f
            for subs_type, subs_map in substitute_stack:
                subs_fn = partial(_subs_wrapper, subs_map, i, length)
                if subs_type == 'condition':
                    seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn)
                elif subs_type == 'substitute':
                    seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn)

            with handlers.trace() as trace:
                carry, y = seeded_fn(carry, x)

        return (i + 1, rng_key, carry), (PytreeTrace(trace), y)

    return lax.scan(body_fn, (jnp.array(0), rng_key, init), xs, length=length, reverse=reverse)

[docs]def scan(f, init, xs, length=None, reverse=False): """ This primitive scans a function over the leading array axes of `xs` while carrying along state. See :func:`jax.lax.scan` for more information. **Usage**: .. doctest:: >>> import numpy as np >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.contrib.control_flow import scan >>> >>> def gaussian_hmm(y=None, T=10): ... def transition(x_prev, y_curr): ... x_curr = numpyro.sample('x', dist.Normal(x_prev, 1)) ... y_curr = numpyro.sample('y', dist.Normal(x_curr, 1), obs=y_curr) ... return x_curr, (x_curr, y_curr) ... ... x0 = numpyro.sample('x_0', dist.Normal(0, 1)) ... _, (x, y) = scan(transition, x0, y, length=T) ... return (x, y) >>> >>> # here we do some quick tests >>> with numpyro.handlers.seed(rng_seed=0): ... x, y = gaussian_hmm(np.arange(10.)) >>> assert x.shape == (10,) and y.shape == (10,) >>> assert np.all(y == np.arange(10)) >>> >>> with numpyro.handlers.seed(rng_seed=0): # generative ... x, y = gaussian_hmm() >>> assert x.shape == (10,) and y.shape == (10,) .. warning:: This is an experimental utility function that allows users to use JAX control flow with NumPyro's effect handlers. Currently, `sample` and `deterministic` sites within the scan body `f` are supported. If you notice that any effect handlers or distributions are unsupported, please file an issue. .. note:: It is ambiguous to align `scan` dimension inside a `plate` context. So the following pattern won't be supported .. code-block:: python with numpyro.plate('N', 10): last, ys = scan(f, init, xs) All `plate` statements should be put inside `f`. For example, the corresponding working code is .. code-block:: python def g(*args, **kwargs): with numpyro.plate('N', 10): return f(*arg, **kwargs) last, ys = scan(g, init, xs) .. note:: Nested scan is currently not supported. .. note:: We can scan over discrete latent variables in `f`. The joint density is evaluated using parallel-scan (reference [1]) over time dimension, which reduces parallel complexity to `O(log(length))`. Currently, only the equivalence to :class:`~numpyro.contrib.funsor.enum_messenger.markov(history_size=1)` is supported. A :class:`~numpyro.handlers.trace` of `scan` with discrete latent variables will contain the following sites: + init sites: those sites belong to the first trace of `f`. Each of them will have name prefixed with `_init/`. + scanned sites: those sites collect the values of the remaining scan loop over `f`. An addition time dimension `_time_foo` will be added to those sites, where `foo` is the name of the first site appeared in `f`. Not all transition functions `f` are supported. All of the restrictions from Pyro's enumeration tutorial [2] still apply here. In addition, there should not have any site outside of `scan` depend on the first output of `scan` (the last carry value). ** References ** 1. *Temporal Parallelization of Bayesian Smoothers*, Simo Sarkka, Angel F. Garcia-Fernandez ( 2. *Inference with Discrete Latent Variables* ( :param callable f: a function to be scanned. :param init: the initial carrying state :param xs: the values over which we scan along the leading axis. This can be any JAX pytree (e.g. list/dict of arrays). :param length: optional value specifying the length of `xs` but can be used when `xs` is an empty pytree (e.g. None) :param bool reverse: optional boolean specifying whether to run the scan iteration forward (the default) or in reverse :return: output of scan, quoted from :func:`jax.lax.scan` docs: "pair of type (c, [b]) where the first element represents the final loop carry value and the second element represents the stacked outputs of the second output of f when scanned over the leading axis of the inputs". """ # if there are no active Messengers, we just run and return it as expected: if not _PYRO_STACK: (length, rng_key, carry), (pytree_trace, ys) = scan_wrapper( f, init, xs, length=length, reverse=reverse) else: # Otherwise, we initialize a message... initial_msg = { 'type': 'control_flow', 'fn': scan_wrapper, 'args': (f, init, xs, length, reverse), 'kwargs': {'rng_key': None, 'substitute_stack': []}, 'value': None, } # ...and use apply_stack to send it to the Messengers msg = apply_stack(initial_msg) (length, rng_key, carry), (pytree_trace, ys) = msg['value'] if not msg["kwargs"].get("enum", False): for msg in pytree_trace.trace.values(): apply_stack(msg) else: from numpyro.contrib.funsor import to_funsor from numpyro.contrib.funsor.enum_messenger import LocalNamedMessenger for msg in pytree_trace.trace.values(): with LocalNamedMessenger(): dim_to_name = msg["infer"].get("dim_to_name") to_funsor(msg["value"], dim_to_name=OrderedDict([(k, dim_to_name[k]) for k in sorted(dim_to_name)])) apply_stack(msg) return carry, ys