Source code for numpyro.contrib.funsor.infer_util

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import OrderedDict, defaultdict
from contextlib import contextmanager
import functools
import re

import funsor
import numpyro
from numpyro.contrib.funsor.enum_messenger import (
    infer_config,
    plate as enum_plate,
    trace as packed_trace,
)
from numpyro.distributions.util import is_identically_one
from numpyro.handlers import substitute

funsor.set_backend("jax")


[docs] @contextmanager def plate_to_enum_plate(): """ A context manager to replace `numpyro.plate` statement by a funsor-based :class:`~numpyro.contrib.funsor.enum_messenger.plate`. This is useful when doing inference for the usual NumPyro programs with `numpyro.plate` statements. For example, to get trace of a `model` whose discrete latent sites are enumerated, we can use:: enum_model = numpyro.contrib.funsor.enum(model) with plate_to_enum_plate(): model_trace = numpyro.contrib.funsor.trace(enum_model).get_trace( *model_args, **model_kwargs) """ try: numpyro.plate.__new__ = lambda cls, *args, **kwargs: enum_plate(*args, **kwargs) yield finally: numpyro.plate.__new__ = lambda *args, **kwargs: object.__new__(numpyro.plate)
def _config_enumerate_fn(site, default): """helper function used internally in config_enumerate""" if ( site["type"] == "sample" and (not site["is_observed"]) and site["fn"].has_enumerate_support ): return {"enumerate": site["infer"].get("enumerate", default)} return {}
[docs] def config_enumerate(fn=None, default="parallel"): """ Configures enumeration for all relevant sites in a NumPyro model. When configuring for exhaustive enumeration of discrete variables, this configures all sample sites whose distribution satisfies ``.has_enumerate_support == True``. This can be used as either a function:: model = config_enumerate(model) or as a decorator:: @config_enumerate def model(*args, **kwargs): ... .. note:: Currently, only ``default='parallel'`` is supported. :param callable fn: Python callable with NumPyro primitives. :param str default: Which enumerate strategy to use, one of "sequential", "parallel", or None. Defaults to "parallel". """ if fn is None: # support use as a decorator return functools.partial(config_enumerate, default=default) return infer_config(fn, functools.partial(_config_enumerate_fn, default=default))
def _config_kl_fn(site, sites): """helper function used internally in config_kl""" if ( site["type"] == "sample" and (not site["is_observed"]) and (sites is None or site["name"] in sites) ): return {"kl": site["infer"].get("kl", "analytic")} return {} def config_kl(fn=None, sites=None): """ Configures the ``kl`` flag in the ``infer`` dict for all sample sites in ``sites`` in a NumPyro model. If ``kl`` is ``analytic``, and ``TraceEnum_ELBO`` is being used, an attempt is made to analytically compute the KL divergence in the ELBO at the corresponding site. If ``analytic`` is specified and it is not possible to analytically compute the KL divergence then an error is raised. This can be used as either a function:: model = config_kl(model) or as a decorator:: @config_kl def model(*args, **kwargs): ... :param callable fn: Python callable with NumPyro primitives. :param set sites: Sites for which to use analytic KL solution. If ``None`` all sites are set to analytic. """ if fn is None: # support use as a decorator return functools.partial(config_kl, sites=sites) return infer_config(fn, functools.partial(_config_kl_fn, sites=sites)) def _get_shift(name): """helper function used internally in sarkka_bilmes_product""" return len(re.search(r"^(_PREV_)*", name).group(0)) // 6 def _shift_name(name, t): """helper function used internally in sarkka_bilmes_product""" if t >= 0: return t * "_PREV_" + name return name.replace("_PREV_" * -t, "", 1) def compute_markov_factors( time_to_factors, time_to_init_vars, time_to_markov_dims, sum_vars, prod_vars, history, sum_op, prod_op, ): """ :param dict time_to_factors: a map from time variable to the log prob factors. :param dict time_to_init_vars: a map from time variable to init discrete sites. :param dict time_to_markov_dims: a map from time variable to dimensions at markov sites (discrete sites that depend on previous steps). :param frozenset sum_vars: all plate and enum dimensions in the trace. :param frozenset prod_vars: all plate dimensions in the trace. :param int history: The number of previous contexts visible from the current context. :returns: a list of factors after eliminate time dimensions """ markov_factors = [] for time_var, log_factors in time_to_factors.items(): prev_vars = time_to_init_vars[time_var] # we eliminate all plate and enum dimensions not available at markov sites. eliminate_vars = (sum_vars | prod_vars) - time_to_markov_dims[time_var] with funsor.interpretations.lazy: lazy_result = funsor.sum_product.sum_product( sum_op, prod_op, log_factors, eliminate=eliminate_vars, plates=prod_vars, ) trans = funsor.optimizer.apply_optimizer(lazy_result) if history > 1: global_vars = frozenset( set(trans.inputs) - {time_var.name} - prev_vars - {_shift_name(k, -_get_shift(k)) for k in prev_vars} ) markov_factors.append( funsor.sum_product.sarkka_bilmes_product( sum_op, prod_op, trans, time_var, global_vars ) ) else: # remove `_PREV_` prefix to convert prev to curr prev_to_curr = {k: _shift_name(k, -_get_shift(k)) for k in prev_vars} markov_factors.append( funsor.sum_product.sequential_sum_product( sum_op, prod_op, trans, time_var, prev_to_curr ) ) return markov_factors def _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op): """Helper function to compute elbo and extract its components from execution traces.""" model = substitute(model, data=params) with plate_to_enum_plate(): model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs) log_factors = [] time_to_factors = defaultdict(list) # log prob factors time_to_init_vars = defaultdict(frozenset) # PP... variables time_to_markov_dims = defaultdict(frozenset) # dimensions at markov sites sum_vars, prod_vars = frozenset(), frozenset() history = 0 log_measures = {} for site in model_trace.values(): if site["type"] == "sample": value = site["value"] intermediates = site["intermediates"] scale = site["scale"] if intermediates: log_prob = site["fn"].log_prob(value, intermediates) else: log_prob = site["fn"].log_prob(value) if (scale is not None) and (not is_identically_one(scale)): log_prob = scale * log_prob dim_to_name = site["infer"]["dim_to_name"] if all(dim == 1 for dim in log_prob.shape) and dim_to_name == OrderedDict(): log_prob = log_prob.squeeze() log_prob_factor = funsor.to_funsor( log_prob, output=funsor.Real, dim_to_name=dim_to_name ) time_dim = None for dim, name in dim_to_name.items(): if name.startswith("_time"): time_dim = funsor.Variable(name, funsor.Bint[log_prob.shape[dim]]) history = max( history, max(_get_shift(s) for s in dim_to_name.values()), ) if history == 0: log_factors.append(log_prob_factor) prod_vars |= frozenset({name}) else: time_to_factors[time_dim].append(log_prob_factor) time_to_init_vars[time_dim] |= frozenset( s for s in dim_to_name.values() if s.startswith("_PREV_") ) break if time_dim is None: log_factors.append(log_prob_factor) if not site["is_observed"]: log_measures[site["name"]] = log_prob_factor sum_vars |= frozenset({site["name"]}) prod_vars |= frozenset( f.name for f in site["cond_indep_stack"] if f.dim is not None ) for time_dim, init_vars in time_to_init_vars.items(): for var in init_vars: curr_var = _shift_name(var, -_get_shift(var)) dim_to_name = model_trace[curr_var]["infer"]["dim_to_name"] if var in dim_to_name.values(): # i.e. _PREV_* (i.e. prev) in dim_to_name time_to_markov_dims[time_dim] |= frozenset( name for name in dim_to_name.values() ) if len(time_to_factors) > 0: markov_factors = compute_markov_factors( time_to_factors, time_to_init_vars, time_to_markov_dims, sum_vars, prod_vars, history, sum_op, prod_op, ) log_factors = log_factors + markov_factors with funsor.interpretations.lazy: lazy_result = funsor.sum_product.sum_product( sum_op, prod_op, log_factors, eliminate=sum_vars | prod_vars, plates=prod_vars, ) result = funsor.optimizer.apply_optimizer(lazy_result) if len(result.inputs) > 0: raise ValueError( "Expected the joint log density is a scalar, but got {}. " "There seems to be something wrong at the following sites: {}.".format( result.data.shape, {k.split("__BOUND")[0] for k in result.inputs}, ) ) return result, model_trace, log_measures
[docs] def log_density(model, model_args, model_kwargs, params): """ Similar to :func:`numpyro.infer.util.log_density` but works for models with discrete latent variables. Internally, this uses :mod:`funsor` to marginalize discrete latent sites and evaluate the joint log probability. :param model: Python callable containing NumPyro primitives. Typically, the model has been enumerated by using :class:`~numpyro.contrib.funsor.enum_messenger.enum` handler:: def model(*args, **kwargs): ... log_joint = log_density(enum(config_enumerate(model)), args, kwargs, params) :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :param dict params: dictionary of current parameter values keyed by site name. :return: log of joint density and a corresponding model trace """ result, model_trace, _ = _enum_log_density( model, model_args, model_kwargs, params, funsor.ops.logaddexp, funsor.ops.add, ) return result.data, model_trace