Source code for numpyro.contrib.funsor.infer_util

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import defaultdict
from contextlib import contextmanager

import funsor
import numpyro
from numpyro.contrib.funsor.enum_messenger import infer_config
from numpyro.contrib.funsor.enum_messenger import plate as enum_plate
from numpyro.contrib.funsor.enum_messenger import trace as packed_trace
from numpyro.distributions.util import is_identically_one
from numpyro.handlers import substitute

funsor.set_backend("jax")


[docs]@contextmanager def plate_to_enum_plate(): """ A context manager to replace `numpyro.plate` statement by a funsor-based :class:`~numpyro.contrib.funsor.enum_messenger.plate`. This is useful when doing inference for the usual NumPyro programs with `numpyro.plate` statements. For example, to get trace of a `model` whose discrete latent sites are enumerated, we can use:: enum_model = numpyro.contrib.funsor.enum(model) with plate_to_enum_plate(): model_trace = numpyro.contrib.funsor.trace(enum_model).get_trace( *model_args, **model_kwargs) """ try: numpyro.plate.__new__ = lambda cls, *args, **kwargs: enum_plate(*args, **kwargs) yield finally: numpyro.plate.__new__ = lambda *args, **kwargs: object.__new__(numpyro.plate)
[docs]def config_enumerate(fn, default='parallel'): """ Configures enumeration for all relevant sites in a NumPyro model. When configuring for exhaustive enumeration of discrete variables, this configures all sample sites whose distribution satisfies ``.has_enumerate_support == True``. This can be used as either a function:: model = config_enumerate(model) or as a decorator:: @config_enumerate def model(*args, **kwargs): ... .. note:: Currently, only ``default='parallel'`` is supported. :param callable fn: Python callable with NumPyro primitives. :param str default: Which enumerate strategy to use, one of "sequential", "parallel", or None. Defaults to "parallel". """ def config_fn(site): if site['type'] == 'sample' and (not site['is_observed']) \ and site['fn'].has_enumerate_support: return {'enumerate': site['infer'].get('enumerate', default)} return {} return infer_config(fn, config_fn)
def compute_markov_factors(time_to_factors, time_to_init_vars, time_to_markov_dims, sum_vars, prod_vars): """ :param dict time_to_factors: a map from time variable to the log prob factors. :param dict time_to_init_vars: a map from time variable to init discrete sites. :param dict time_to_markov_dims: a map from time variable to dimensions at markov sites (discrete sites that depend on previous steps). :param frozenset sum_vars: all plate and enum dimensions in the trace. :param frozenset prod_vars: all plate dimensions in the trace. :returns: a list of factors after eliminate time dimensions """ markov_factors = [] for time_var, log_factors in time_to_factors.items(): prev_vars = time_to_init_vars[time_var] # remove `_init/` prefix to convert prev to curr prev_to_curr = {k: "/".join(k.split("/")[1:]) for k in prev_vars} # we eliminate all plate and enum dimensions not available at markov sites. eliminate_vars = (sum_vars | prod_vars) - time_to_markov_dims[time_var] with funsor.interpreter.interpretation(funsor.terms.lazy): lazy_result = funsor.sum_product.sum_product( funsor.ops.logaddexp, funsor.ops.add, log_factors, eliminate=eliminate_vars, plates=prod_vars) trans = funsor.optimizer.apply_optimizer(lazy_result) markov_factors.append(funsor.sum_product.sequential_sum_product( funsor.ops.logaddexp, funsor.ops.add, trans, time_var, prev_to_curr)) return markov_factors
[docs]def log_density(model, model_args, model_kwargs, params): """ Similar to :func:`numpyro.infer.util.log_density` but works for models with discrete latent variables. Internally, this uses :mod:`funsor` to marginalize discrete latent sites and evaluate the joint log probability. :param model: Python callable containing NumPyro primitives. Typically, the model has been enumerated by using :class:`~numpyro.contrib.funsor.enum_messenger.enum` handler:: def model(*args, **kwargs): ... log_joint = log_density(enum(config_enumerate(model)), args, kwargs, params) :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :param dict params: dictionary of current parameter values keyed by site name. :return: log of joint density and a corresponding model trace """ model = substitute(model, data=params) with plate_to_enum_plate(): model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs) log_factors = [] time_to_factors = defaultdict(list) # log prob factors time_to_init_vars = defaultdict(frozenset) # _init/... variables time_to_markov_dims = defaultdict(frozenset) # dimensions at markov sites sum_vars, prod_vars = frozenset(), frozenset() for site in model_trace.values(): if site['type'] == 'sample': value = site['value'] intermediates = site['intermediates'] scale = site['scale'] if intermediates: log_prob = site['fn'].log_prob(value, intermediates) else: log_prob = site['fn'].log_prob(value) if (scale is not None) and (not is_identically_one(scale)): log_prob = scale * log_prob dim_to_name = site["infer"]["dim_to_name"] log_prob = funsor.to_funsor(log_prob, output=funsor.reals(), dim_to_name=dim_to_name) time_dim = None for dim, name in dim_to_name.items(): if name.startswith("_time"): time_dim = funsor.Variable(name, funsor.domains.bint(site["value"].shape[dim])) time_to_factors[time_dim].append(log_prob) time_to_init_vars[time_dim] |= frozenset( s for s in dim_to_name.values() if s.startswith("_init")) break if time_dim is None: log_factors.append(log_prob) if not site['is_observed']: sum_vars |= frozenset({site['name']}) prod_vars |= frozenset(f.name for f in site['cond_indep_stack'] if f.dim is not None) for time_dim, init_vars in time_to_init_vars.items(): for var in init_vars: curr_var = "/".join(var.split("/")[1:]) dim_to_name = model_trace[curr_var]["infer"]["dim_to_name"] if var in dim_to_name.values(): # i.e. _init (i.e. prev) in dim_to_name time_to_markov_dims[time_dim] |= frozenset(name for name in dim_to_name.values()) if len(time_to_factors) > 0: markov_factors = compute_markov_factors(time_to_factors, time_to_init_vars, time_to_markov_dims, sum_vars, prod_vars) log_factors = log_factors + markov_factors with funsor.interpreter.interpretation(funsor.terms.lazy): lazy_result = funsor.sum_product.sum_product( funsor.ops.logaddexp, funsor.ops.add, log_factors, eliminate=sum_vars | prod_vars, plates=prod_vars) result = funsor.optimizer.apply_optimizer(lazy_result) return result.data, model_trace