Source code for numpyro.contrib.funsor.enum_messenger

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import OrderedDict, namedtuple
from contextlib import ExitStack  # python 3
from enum import Enum

from jax import lax
import jax.numpy as jnp

import funsor
from numpyro.handlers import trace as OrigTraceMessenger
from numpyro.primitives import CondIndepStackFrame, Messenger, apply_stack
from numpyro.primitives import plate as OrigPlateMessenger

funsor.set_backend("jax")


__all__ = [
    "enum",
    "infer_config",
    "markov",
    "plate",
    "to_data",
    "to_funsor",
    "trace",
]


##################################
# DimStack to store global state
##################################

# name_to_dim : dict, dim_to_name : dict, parents : tuple, iter_parents : tuple
class StackFrame(namedtuple('StackFrame', [
            'name_to_dim',
            'dim_to_name',
            'parents',
            'iter_parents',
            'keep',
        ])):

    def read(self, name, dim):
        found_name = self.dim_to_name.get(dim, name)
        found_dim = self.name_to_dim.get(name, dim)
        found = name in self.name_to_dim or dim in self.dim_to_name
        return found_name, found_dim, found

    def write(self, name, dim):
        assert name is not None and dim is not None
        self.dim_to_name[dim] = name
        self.name_to_dim[name] = dim

    def free(self, name, dim):
        self.dim_to_name.pop(dim, None)
        self.name_to_dim.pop(name, None)
        return name, dim


class DimType(Enum):
    """Enumerates the possible types of dimensions to allocate"""
    LOCAL = 0
    GLOBAL = 1
    VISIBLE = 2


DimRequest = namedtuple('DimRequest', ['dim', 'dim_type'])
DimRequest.__new__.__defaults__ = (None, DimType.LOCAL)
NameRequest = namedtuple('NameRequest', ['name', 'dim_type'])
NameRequest.__new__.__defaults__ = (None, DimType.LOCAL)


class DimStack:
    """
    Single piece of global state to keep track of the mapping between names and dimensions.

    Replaces the plate DimAllocator, the enum EnumAllocator, the stack in MarkovMessenger,
    _param_dims and _value_dims in EnumMessenger, and dim_to_symbol in msg['infer']
    """
    def __init__(self):
        self._stack = [StackFrame(
            name_to_dim=OrderedDict(), dim_to_name=OrderedDict(),
            parents=(), iter_parents=(), keep=False,
        )]
        self._first_available_dim = -1
        self.outermost = None

    MAX_DIM = -25

    def set_first_available_dim(self, dim):
        assert dim is None or (self.MAX_DIM < dim < 0)
        old_dim, self._first_available_dim = self._first_available_dim, dim
        return old_dim

    def push(self, frame):
        self._stack.append(frame)

    def pop(self):
        assert len(self._stack) > 1, "cannot pop the global frame"
        return self._stack.pop()

    @property
    def current_frame(self):
        return self._stack[-1]

    @property
    def global_frame(self):
        return self._stack[0]

    def _gendim(self, name_request, dim_request):
        assert isinstance(name_request, NameRequest) and isinstance(dim_request, DimRequest)
        dim_type = dim_request.dim_type

        if name_request.name is None:
            fresh_name = f"_pyro_dim_{-dim_request.dim}"
        else:
            fresh_name = name_request.name

        conflict_frames = (self.current_frame, self.global_frame) + \
            self.current_frame.parents + self.current_frame.iter_parents
        if dim_request.dim is None:
            fresh_dim = self._first_available_dim if dim_type != DimType.VISIBLE else -1
            fresh_dim = -1 if fresh_dim is None else fresh_dim
            while any(fresh_dim in p.dim_to_name for p in conflict_frames):
                fresh_dim -= 1
        else:
            fresh_dim = dim_request.dim

        if fresh_dim < self.MAX_DIM or \
                any(fresh_dim in p.dim_to_name for p in conflict_frames) or \
                (dim_type == DimType.VISIBLE and fresh_dim <= self._first_available_dim):
            raise ValueError(f"Ran out of free dims during allocation for {fresh_name}")

        return fresh_name, fresh_dim

    def request(self, name, dim):
        assert isinstance(name, NameRequest) ^ isinstance(dim, DimRequest)
        if isinstance(dim, DimRequest):
            dim, dim_type = dim.dim, dim.dim_type
        elif isinstance(name, NameRequest):
            name, dim_type = name.name, name.dim_type

        read_frames = (self.global_frame,) if dim_type != DimType.LOCAL else \
            (self.current_frame,) + self.current_frame.parents + self.current_frame.iter_parents + (self.global_frame,)

        # read dimension
        for frame in read_frames:
            name, dim, found = frame.read(name, dim)
            if found:
                break

        # generate fresh name or dimension
        if not found:
            name, dim = self._gendim(NameRequest(name, dim_type), DimRequest(dim, dim_type))

            write_frames = (self.global_frame,) if dim_type != DimType.LOCAL else \
                (self.current_frame,) + (self.current_frame.parents if self.current_frame.keep else ())

            # store the fresh dimension
            for frame in write_frames:
                frame.write(name, dim)

        return name, dim


_DIM_STACK = DimStack()  # only one global instance


#################################################
# Messengers that implement guts of enumeration
#################################################

class ReentrantMessenger(Messenger):
    def __init__(self, fn=None):
        self._ref_count = 0
        super().__init__(fn)

    # def __call__(self, fn):
    #     return functools.wraps(fn)(super().__call__(fn))

    def __enter__(self):
        self._ref_count += 1
        if self._ref_count == 1:
            super().__enter__()
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self._ref_count -= 1
        if self._ref_count == 0:
            super().__exit__(exc_type, exc_value, traceback)


class DimStackCleanupMessenger(ReentrantMessenger):

    def __init__(self, fn=None):
        self._saved_dims = ()
        return super().__init__(fn)

    def __enter__(self):
        if self._ref_count == 0 and _DIM_STACK.outermost is None:
            _DIM_STACK.outermost = self
            for name, dim in self._saved_dims:
                _DIM_STACK.global_frame.write(name, dim)
            self._saved_dims = ()
        return super().__enter__()

    def __exit__(self, *args, **kwargs):
        if self._ref_count == 1 and _DIM_STACK.outermost is self:
            _DIM_STACK.outermost = None
            for name, dim in reversed(tuple(_DIM_STACK.global_frame.name_to_dim.items())):
                self._saved_dims += (_DIM_STACK.global_frame.free(name, dim),)
        return super().__exit__(*args, **kwargs)


class NamedMessenger(DimStackCleanupMessenger):

    def process_message(self, msg):
        if msg["type"] == "to_funsor":
            self._pyro_to_funsor(msg)
        elif msg["type"] == "to_data":
            self._pyro_to_data(msg)

    @staticmethod
    def _get_name_to_dim(batch_names, name_to_dim=None, dim_type=DimType.LOCAL):
        name_to_dim = OrderedDict() if name_to_dim is None else name_to_dim.copy()

        # interpret all names/dims as requests since we only run this function once
        for name in batch_names:
            dim = name_to_dim.get(name, None)
            name_to_dim[name] = dim if isinstance(dim, DimRequest) else DimRequest(dim, dim_type)

        # read dimensions and allocate fresh dimensions as necessary
        for name, dim_request in name_to_dim.items():
            name_to_dim[name] = _DIM_STACK.request(name, dim_request)[1]

        return name_to_dim

    @classmethod  # only depends on the global _DIM_STACK state, not self
    def _pyro_to_data(cls, msg):

        funsor_value, = msg["args"]
        name_to_dim = msg["kwargs"].setdefault("name_to_dim", OrderedDict())
        dim_type = msg["kwargs"].setdefault("dim_type", DimType.LOCAL)

        batch_names = tuple(funsor_value.inputs.keys())

        name_to_dim.update(cls._get_name_to_dim(batch_names, name_to_dim=name_to_dim, dim_type=dim_type))
        msg["stop"] = True  # only need to run this once per to_data call

    @staticmethod
    def _get_dim_to_name(batch_shape, dim_to_name=None, dim_type=DimType.LOCAL):
        dim_to_name = OrderedDict() if dim_to_name is None else dim_to_name.copy()
        batch_dim = len(batch_shape)

        # interpret all names/dims as requests since we only run this function once
        for dim in range(-batch_dim, 0):
            name = dim_to_name.get(dim, None)
            # the time dimension on the left sometimes necessitates empty dimensions appearing
            # before they have been assigned a name
            if batch_shape[dim] == 1 and name is None:
                continue
            dim_to_name[dim] = name if isinstance(name, NameRequest) else NameRequest(name, dim_type)

        for dim, name_request in dim_to_name.items():
            dim_to_name[dim] = _DIM_STACK.request(name_request, dim)[0]

        return dim_to_name

    @classmethod  # only depends on the global _DIM_STACK state, not self
    def _pyro_to_funsor(cls, msg):

        if len(msg["args"]) == 2:
            raw_value, output = msg["args"]
        else:
            raw_value = msg["args"][0]
            output = msg["kwargs"].setdefault("output", None)
        dim_to_name = msg["kwargs"].setdefault("dim_to_name", OrderedDict())
        dim_type = msg["kwargs"].setdefault("dim_type", DimType.LOCAL)

        event_dim = len(output.shape) if output else 0
        try:
            batch_shape = raw_value.batch_shape  # TODO make make this more robust
        except AttributeError:
            batch_shape = raw_value.shape[:len(raw_value.shape) - event_dim]

        dim_to_name.update(cls._get_dim_to_name(batch_shape, dim_to_name=dim_to_name, dim_type=dim_type))
        msg["stop"] = True  # only need to run this once per to_funsor call


class LocalNamedMessenger(NamedMessenger):
    """
    Handler for converting to/from funsors consistent with Pyro's positional batch dimensions.

    :param int history: The number of previous contexts visible from the
        current context. Defaults to 1. If zero, this is similar to
        :class:`pyro.plate`.
    :param bool keep: If true, frames are replayable. This is important
        when branching: if ``keep=True``, neighboring branches at the same
        level can depend on each other; if ``keep=False``, neighboring branches
        are independent (conditioned on their shared ancestors).
    """
    def __init__(self, fn=None, history=1, keep=False):
        self.history = history
        self.keep = keep
        self._iterable = None
        self._saved_frames = []
        self._iter_parents = ()
        super().__init__(fn)

    def generator(self, iterable):
        self._iterable = iterable
        return self

    def _get_iter_parents(self, frame):
        iter_parents = [frame]
        frontier = (frame,)
        while frontier:
            frontier = sum([p.iter_parents for p in frontier], ())
            iter_parents += frontier
        return tuple(iter_parents)

    def __iter__(self):
        assert self._iterable is not None
        self._iter_parents = self._get_iter_parents(_DIM_STACK.current_frame)
        with ExitStack() as stack:
            for value in self._iterable:
                stack.enter_context(self)
                yield value

    def __enter__(self):
        if self.keep and self._saved_frames:
            saved_frame = self._saved_frames.pop()
            name_to_dim, dim_to_name = saved_frame.name_to_dim, saved_frame.dim_to_name
        else:
            name_to_dim, dim_to_name = OrderedDict(), OrderedDict()

        frame = StackFrame(
            name_to_dim=name_to_dim, dim_to_name=dim_to_name,
            parents=tuple(reversed(_DIM_STACK._stack[len(_DIM_STACK._stack) - self.history:])),
            iter_parents=tuple(self._iter_parents),
            keep=self.keep
        )

        _DIM_STACK.push(frame)
        return super().__enter__()

    def __exit__(self, *args, **kwargs):
        if self.keep:
            # don't keep around references to other frames
            old_frame = _DIM_STACK.pop()
            saved_frame = StackFrame(
                name_to_dim=old_frame.name_to_dim, dim_to_name=old_frame.dim_to_name,
                parents=(), iter_parents=(), keep=self.keep
            )
            self._saved_frames.append(saved_frame)
        else:
            _DIM_STACK.pop()
        return super().__exit__(*args, **kwargs)


class GlobalNamedMessenger(NamedMessenger):

    def __init__(self, fn=None):
        self._saved_globals = ()
        super().__init__(fn)

    def __enter__(self):
        if self._ref_count == 0:
            for name, dim in self._saved_globals:
                _DIM_STACK.global_frame.write(name, dim)
            self._saved_globals = ()
        return super().__enter__()

    def __exit__(self, *args, **kwargs):
        if self._ref_count == 1:
            for name, dim in self._saved_globals:
                _DIM_STACK.global_frame.free(name, dim)
        return super().__exit__(*args, **kwargs)

    def postprocess_message(self, msg):
        if msg["type"] == "to_funsor":
            self._pyro_post_to_funsor(msg)
        elif msg["type"] == "to_data":
            self._pyro_post_to_data(msg)

    def _pyro_post_to_funsor(self, msg):
        if msg["kwargs"]["dim_type"] in (DimType.GLOBAL, DimType.VISIBLE):
            for name in msg["value"].inputs:
                self._saved_globals += ((name, _DIM_STACK.global_frame.name_to_dim[name]),)

    def _pyro_post_to_data(self, msg):
        if msg["kwargs"]["dim_type"] in (DimType.GLOBAL, DimType.VISIBLE):
            for name in msg["args"][0].inputs:
                self._saved_globals += ((name, _DIM_STACK.global_frame.name_to_dim[name]),)


class BaseEnumMessenger(NamedMessenger):
    """
    Handles first_available_dim management, enum effects should inherit from this
    """
    def __init__(self, fn=None, first_available_dim=None):
        assert first_available_dim is None or first_available_dim < 0, first_available_dim
        self.first_available_dim = first_available_dim
        super().__init__(fn)

    def __enter__(self):
        if self._ref_count == 0 and self.first_available_dim is not None:
            self._prev_first_dim = _DIM_STACK.set_first_available_dim(self.first_available_dim)
        return super().__enter__()

    def __exit__(self, *args, **kwargs):
        if self._ref_count == 1 and self.first_available_dim is not None:
            _DIM_STACK.set_first_available_dim(self._prev_first_dim)
        return super().__exit__(*args, **kwargs)


##########################################
# User-facing handler implementations
##########################################

[docs]class plate(GlobalNamedMessenger): """ An alternative implementation of :class:`numpyro.primitives.plate` primitive. Note that only this version is compatible with enumeration. There is also a context manager :func:`~numpyro.contrib.funsor.infer_util.plate_to_enum_plate` which converts `numpyro.plate` statements to this version. :param str name: Name of the plate. :param int size: Size of the plate. :param int subsample_size: Optional argument denoting the size of the mini-batch. This can be used to apply a scaling factor by inference algorithms. e.g. when computing ELBO using a mini-batch. :param int dim: Optional argument to specify which dimension in the tensor is used as the plate dim. If `None` (default), the leftmost available dim is allocated. """ def __init__(self, name, size, subsample_size=None, dim=None): self.name = name self.size = size if dim is not None and dim >= 0: raise ValueError('dim arg must be negative.') self.dim, indices = OrigPlateMessenger._subsample(self.name, self.size, subsample_size, dim) self.subsample_size = indices.shape[0] self._indices = funsor.Tensor( indices, OrderedDict([(self.name, funsor.bint(self.subsample_size))]), self.subsample_size ) super(plate, self).__init__(None) def __enter__(self): super().__enter__() # do this first to take care of globals recycling name_to_dim = OrderedDict([(self.name, self.dim)]) if self.dim is not None else OrderedDict() indices = to_data(self._indices, name_to_dim=name_to_dim, dim_type=DimType.VISIBLE) # extract the dimension allocated by to_data to match plate's current behavior self.dim, self.indices = -len(indices.shape), indices.squeeze() return self.indices @staticmethod def _get_batch_shape(cond_indep_stack): n_dims = max(-f.dim for f in cond_indep_stack) batch_shape = [1] * n_dims for f in cond_indep_stack: batch_shape[f.dim] = f.size return tuple(batch_shape)
[docs] def process_message(self, msg): # copied almost verbatim from plate if msg['type'] not in ('sample',): return super().process_message(msg) cond_indep_stack = msg['cond_indep_stack'] frame = CondIndepStackFrame(self.name, self.dim, self.subsample_size) cond_indep_stack.append(frame) expected_shape = self._get_batch_shape(cond_indep_stack) if msg['type'] == 'sample': dist_batch_shape = msg['fn'].batch_shape if 'sample_shape' in msg['kwargs']: dist_batch_shape = msg['kwargs']['sample_shape'] + dist_batch_shape msg['kwargs']['sample_shape'] = () overlap_idx = max(len(expected_shape) - len(dist_batch_shape), 0) trailing_shape = expected_shape[overlap_idx:] broadcast_shape = lax.broadcast_shapes(trailing_shape, tuple(dist_batch_shape)) batch_shape = expected_shape[:overlap_idx] + broadcast_shape msg['fn'] = msg['fn'].expand(batch_shape) if self.size != self.subsample_size: scale = 1. if msg['scale'] is None else msg['scale'] msg['scale'] = scale * self.size / self.subsample_size
[docs] def postprocess_message(self, msg): if msg["type"] in ("subsample", "param") and self.dim is not None: event_dim = msg["kwargs"].get("event_dim") if event_dim is not None: assert event_dim >= 0 dim = self.dim - event_dim shape = msg["value"].shape if len(shape) >= -dim and shape[dim] != 1: if shape[dim] != self.size: if msg["type"] == "param": statement = "numpyro.param({}, ..., event_dim={})".format(msg["name"], event_dim) else: statement = "numpyro.subsample(..., event_dim={})".format(event_dim) raise ValueError( "Inside numpyro.plate({}, {}, dim={}) invalid shape of {}: {}" .format(self.name, self.size, self.dim, statement, shape)) if self.subsample_size < self.size: value = msg["value"] new_value = jnp.take(value, self._indices, dim) msg["value"] = new_value
[docs]class enum(BaseEnumMessenger): """ Enumerates in parallel over discrete sample sites marked ``infer={"enumerate": "parallel"}``. :param callable fn: Python callable with NumPyro primitives. :param int first_available_dim: The first tensor dimension (counting from the right) that is available for parallel enumeration. This dimension and all dimensions left may be used internally by Pyro. This should be a negative integer or None. """
[docs] def process_message(self, msg): if msg["type"] != "sample" or \ msg.get("done", False) or msg["is_observed"] or msg["infer"].get("expand", False) or \ msg["infer"].get("enumerate") != "parallel" or (not msg["fn"].has_enumerate_support): if msg["type"] == "control_flow": msg["kwargs"]["enum"] = True return super().process_message(msg) if msg["infer"].get("num_samples", None) is not None: raise NotImplementedError("TODO implement multiple sampling") if msg["infer"].get("expand", False): raise NotImplementedError("expand=True not implemented") size = msg["fn"].enumerate_support(expand=False).shape[0] raw_value = jnp.arange(0, size) funsor_value = funsor.Tensor( raw_value, OrderedDict([(msg["name"], funsor.bint(size))]), size ) msg["value"] = to_data(funsor_value) msg["done"] = True
[docs]class trace(OrigTraceMessenger): """ This version of :class:`~numpyro.handlers.trace` handler records information necessary to do packing after execution. Each sample site is annotated with a "dim_to_name" dictionary, which can be passed directly to :func:`to_funsor`. """
[docs] def postprocess_message(self, msg): if msg["type"] == "sample": total_batch_shape = lax.broadcast_shapes( tuple(msg["fn"].batch_shape), msg["value"].shape[:len(msg["value"].shape)-len(msg["fn"].event_shape)] ) msg["infer"]["dim_to_name"] = NamedMessenger._get_dim_to_name(total_batch_shape) if msg["type"] in ("sample", "param"): super().postprocess_message(msg)
[docs]def markov(fn=None, history=1, keep=False): """ Markov dependency declaration. This is a statistical equivalent of a memory management arena. :param callable fn: Python callable with NumPyro primitives. :param int history: The number of previous contexts visible from the current context. Defaults to 1. If zero, this is similar to :class:`numpyro.primitives.plate`. :param bool keep: If true, frames are replayable. This is important when branching: if ``keep=True``, neighboring branches at the same level can depend on each other; if ``keep=False``, neighboring branches are independent (conditioned on their shared ancestors). """ if fn is not None and not callable(fn): # Used as a generator return LocalNamedMessenger(fn=None, history=history, keep=keep).generator(iterable=fn) return LocalNamedMessenger(fn, history=history, keep=keep)
[docs]class infer_config(Messenger): """ Given a callable `fn` that contains Pyro primitive calls and a callable `config_fn` taking a trace site and returning a dictionary, updates the value of the infer kwarg at a sample site to config_fn(site). :param fn: a stochastic function (callable containing Pyro primitive calls) :param config_fn: a callable taking a site and returning an infer dict """ def __init__(self, fn, config_fn): super().__init__(fn) self.config_fn = config_fn
[docs] def process_message(self, msg): if msg["type"] in ("sample", "param"): msg["infer"].update(self.config_fn(msg))
#################### # New primitives ####################
[docs]def to_funsor(x, output=None, dim_to_name=None, dim_type=DimType.LOCAL): """ A primitive to convert a Python object to a :class:`~funsor.terms.Funsor`. :param x: An object. :param funsor.domains.Domain output: An optional output hint to uniquely convert a data to a Funsor (e.g. when `x` is a string). :param OrderedDict dim_to_name: An optional mapping from negative batch dimensions to name strings. :param int dim_type: Either 0, 1, or 2. This optional argument indicates a dimension should be treated as 'local', 'global', or 'visible', which can be used to interact with the global :class:`DimStack`. :return: A Funsor equivalent to `x`. :rtype: funsor.terms.Funsor """ dim_to_name = OrderedDict() if dim_to_name is None else dim_to_name initial_msg = { 'type': 'to_funsor', 'fn': lambda x, output, dim_to_name, dim_type: funsor.to_funsor(x, output=output, dim_to_name=dim_to_name), 'args': (x,), 'kwargs': {"output": output, "dim_to_name": dim_to_name, "dim_type": dim_type}, 'value': None, 'mask': None, } msg = apply_stack(initial_msg) return msg['value']
[docs]def to_data(x, name_to_dim=None, dim_type=DimType.LOCAL): """ A primitive to extract a python object from a :class:`~funsor.terms.Funsor`. :param ~funsor.terms.Funsor x: A funsor object :param OrderedDict name_to_dim: An optional inputs hint which maps dimension names from `x` to dimension positions of the returned value. :param int dim_type: Either 0, 1, or 2. This optional argument indicates a dimension should be treated as 'local', 'global', or 'visible', which can be used to interact with the global :class:`DimStack`. :return: A non-funsor equivalent to `x`. """ name_to_dim = OrderedDict() if name_to_dim is None else name_to_dim initial_msg = { 'type': 'to_data', 'fn': lambda x, name_to_dim, dim_type: funsor.to_data(x, name_to_dim=name_to_dim), 'args': (x,), 'kwargs': {"name_to_dim": name_to_dim, "dim_type": dim_type}, 'value': None, 'mask': None, } msg = apply_stack(initial_msg) return msg['value']