Source code for numpyro.contrib.control_flow.cond

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from functools import partial

from jax import device_put, lax

from numpyro import handlers
from numpyro.ops.pytree import PytreeTrace
from numpyro.primitives import _PYRO_STACK, apply_stack


def _subs_wrapper(subs_map, site):
    if isinstance(subs_map, dict) and site["name"] in subs_map:
        return subs_map[site["name"]]
    elif callable(subs_map):
        if site["type"] == "deterministic":
            return subs_map(site)
        rng_key = site["kwargs"].get("rng_key")
        subs_map = (
            handlers.seed(subs_map, rng_seed=rng_key)
            if rng_key is not None
            else subs_map
        )
        return subs_map(site)
    return None


def wrap_fn(fn, substitute_stack):
    def wrapper(wrapped_operand):
        rng_key, operand = wrapped_operand

        with handlers.block():
            seeded_fn = handlers.seed(fn, rng_key) if rng_key is not None else fn
            for subs_type, subs_map in substitute_stack:
                subs_fn = partial(_subs_wrapper, subs_map)
                if subs_type == "condition":
                    seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn)
                elif subs_type == "substitute":
                    seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn)

            with handlers.trace() as trace:
                value = seeded_fn(operand)

        return value, PytreeTrace(trace)

    return wrapper


def cond_wrapper(
    pred,
    true_fun,
    false_fun,
    operand,
    rng_key=None,
    substitute_stack=None,
    enum=False,
    first_available_dim=None,
):
    if enum:
        # TODO: support enumeration. note that pred passed to lax.cond must be scalar
        # which means that even simple conditions like `x == 0` can get complicated if
        # x is an enumerated discrete random variable
        raise RuntimeError("The cond primitive does not currently support enumeration")

    if substitute_stack is None:
        substitute_stack = []

    wrapped_true_fun = wrap_fn(true_fun, substitute_stack)
    wrapped_false_fun = wrap_fn(false_fun, substitute_stack)
    wrapped_operand = device_put((rng_key, operand))
    return lax.cond(pred, wrapped_true_fun, wrapped_false_fun, wrapped_operand)


[docs] def cond(pred, true_fun, false_fun, operand): """ This primitive conditionally applies ``true_fun`` or ``false_fun``. See :func:`jax.lax.cond` for more information. **Usage**: .. doctest:: >>> import numpyro >>> import numpyro.distributions as dist >>> from jax import random >>> from numpyro.contrib.control_flow import cond >>> from numpyro.infer import SVI, Trace_ELBO >>> >>> def model(): ... def true_fun(_): ... return numpyro.sample("x", dist.Normal(20.0)) ... ... def false_fun(_): ... return numpyro.sample("x", dist.Normal(0.0)) ... ... cluster = numpyro.sample("cluster", dist.Normal()) ... return cond(cluster > 0, true_fun, false_fun, None) >>> >>> def guide(): ... m1 = numpyro.param("m1", 10.0) ... s1 = numpyro.param("s1", 0.1, constraint=dist.constraints.positive) ... m2 = numpyro.param("m2", 10.0) ... s2 = numpyro.param("s2", 0.1, constraint=dist.constraints.positive) ... ... def true_fun(_): ... return numpyro.sample("x", dist.Normal(m1, s1)) ... ... def false_fun(_): ... return numpyro.sample("x", dist.Normal(m2, s2)) ... ... cluster = numpyro.sample("cluster", dist.Normal()) ... return cond(cluster > 0, true_fun, false_fun, None) >>> >>> svi = SVI(model, guide, numpyro.optim.Adam(1e-2), Trace_ELBO(num_particles=100)) >>> svi_result = svi.run(random.PRNGKey(0), num_steps=2500) .. warning:: This is an experimental utility function that allows users to use JAX control flow with NumPyro's effect handlers. Currently, `sample` and `deterministic` sites within `true_fun` and `false_fun` are supported. If you notice that any effect handlers or distributions are unsupported, please file an issue. .. warning:: The ``cond`` primitive does not currently support enumeration and can not be used inside a ``numpyro.plate`` context. .. note:: All ``sample`` sites must belong to the same distribution class. For example the following is not supported .. code-block:: python cond( True, lambda _: numpyro.sample("x", dist.Normal()), lambda _: numpyro.sample("x", dist.Laplace()), None, ) :param bool pred: Boolean scalar type indicating which branch function to apply :param callable true_fun: A function to be applied if ``pred`` is true. :param callable false_fun: A function to be applied if ``pred`` is false. :param operand: Operand input to either branch depending on ``pred``. This can be any JAX PyTree (e.g. list / dict of arrays). :return: Output of the applied branch function. """ if not _PYRO_STACK: value, _ = cond_wrapper(pred, true_fun, false_fun, operand) return value initial_msg = { "type": "control_flow", "fn": cond_wrapper, "args": (pred, true_fun, false_fun, operand), "kwargs": {"rng_key": None, "substitute_stack": []}, "value": None, } msg = apply_stack(initial_msg) value, pytree_trace = msg["value"] for msg in pytree_trace.trace.values(): if msg["type"] == "plate": continue apply_stack(msg) return value