# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from collections import OrderedDict, defaultdict
import functools
from jax import random
import jax.numpy as jnp
import funsor
from numpyro.contrib.funsor.enum_messenger import enum
from numpyro.contrib.funsor.infer_util import _enum_log_density, _get_shift, _shift_name
from numpyro.handlers import block, seed, substitute, trace
from numpyro.infer.util import _guess_max_plate_nesting
@functools.singledispatch
def _get_support_value(funsor_dist, name, **kwargs):
raise ValueError(
"Could not extract point from {} at name {}".format(funsor_dist, name)
)
@_get_support_value.register(funsor.cnf.Contraction)
def _get_support_value_contraction(funsor_dist, name, **kwargs):
delta_terms = [
v
for v in funsor_dist.terms
if isinstance(v, funsor.delta.Delta) and name in v.fresh
]
assert len(delta_terms) == 1
return _get_support_value(delta_terms[0], name, **kwargs)
@_get_support_value.register(funsor.delta.Delta)
def _get_support_value_delta(funsor_dist, name, **kwargs):
assert name in funsor_dist.fresh
return OrderedDict(funsor_dist.terms)[name][0]
def _sample_posterior(
model, first_available_dim, temperature, rng_key, *args, **kwargs
):
if temperature == 0:
sum_op, prod_op = funsor.ops.max, funsor.ops.add
approx = funsor.approximations.argmax_approximate
elif temperature == 1:
sum_op, prod_op = funsor.ops.logaddexp, funsor.ops.add
rng_key, sub_key = random.split(rng_key)
approx = funsor.montecarlo.MonteCarlo(rng_key=sub_key)
else:
raise ValueError("temperature must be 0 (map) or 1 (sample) for now")
if first_available_dim is None:
with block():
model_trace = trace(seed(model, rng_key)).get_trace(*args, **kwargs)
first_available_dim = -_guess_max_plate_nesting(model_trace) - 1
with funsor.adjoint.AdjointTape() as tape:
with block(), enum(first_available_dim=first_available_dim):
log_prob, model_tr, log_measures = _enum_log_density(
model, args, kwargs, {}, sum_op, prod_op
)
with approx:
approx_factors = tape.adjoint(sum_op, prod_op, log_prob)
# construct a result trace to replay against the model
sample_tr = model_tr.copy()
for name, node in sample_tr.items():
if node["type"] != "sample":
continue
if node["infer"].get("enumerate") == "parallel":
log_measure = approx_factors[log_measures[name]]
value = _get_support_value(log_measure, name)
node["value"] = funsor.to_data(
value, name_to_dim=node["infer"]["name_to_dim"]
)
data = {
name: site["value"]
for name, site in sample_tr.items()
if site["type"] == "sample"
}
# concatenate _PREV_foo to foo
time_vars = defaultdict(list)
for name in data:
if name.startswith("_PREV_"):
root_name = _shift_name(name, -_get_shift(name))
time_vars[root_name].append(name)
for name in time_vars:
if name in data:
time_vars[name].append(name)
time_vars[name] = sorted(time_vars[name], key=len, reverse=True)
for root_name, vars in time_vars.items():
prototype_shape = model_trace[root_name]["value"].shape
values = [data.pop(name) for name in vars]
if len(values) == 1:
data[root_name] = values[0].reshape(prototype_shape)
else:
assert len(prototype_shape) >= 1
values = [v.reshape((-1,) + prototype_shape[1:]) for v in values]
data[root_name] = jnp.concatenate(values)
return data
[docs]def infer_discrete(fn=None, first_available_dim=None, temperature=1, rng_key=None):
"""
A handler that samples discrete sites marked with
``site["infer"]["enumerate"] = "parallel"`` from the posterior,
conditioned on observations.
Example::
@infer_discrete(first_available_dim=-1, temperature=0)
@config_enumerate
def viterbi_decoder(data, hidden_dim=10):
transition = 0.3 / hidden_dim + 0.7 * jnp.eye(hidden_dim)
means = jnp.arange(float(hidden_dim))
states = [0]
for t in markov(range(len(data))):
states.append(numpyro.sample("states_{}".format(t),
dist.Categorical(transition[states[-1]])))
numpyro.sample("obs_{}".format(t),
dist.Normal(means[states[-1]], 1.),
obs=data[t])
return states # returns maximum likelihood states
.. warning: This does not yet support :func:`numpyro.contrib.control_flow.scan`
primitive.
.. warning: The ``log_prob``s of the inferred model's trace are not
meaningful, and may be changed in a future release.
:param fn: a stochastic function (callable containing NumPyro primitive calls)
:param int first_available_dim: The first tensor dimension (counting
from the right) that is available for parallel enumeration. This
dimension and all dimensions left may be used internally by Pyro.
This should be a negative integer.
:param int temperature: Either 1 (sample via forward-filter backward-sample)
or 0 (optimize via Viterbi-like MAP inference). Defaults to 1 (sample).
:param jax.random.PRNGKey rng_key: a random number generator key, to be used in
cases ``temperature=1`` or ``first_available_dim is None``.
"""
if temperature == 1 or first_available_dim is None:
assert rng_key is not None
if fn is None: # support use as a decorator
return functools.partial(
infer_discrete,
first_available_dim=first_available_dim,
temperature=temperature,
rng_key=rng_key,
)
def wrap_fn(*args, **kwargs):
samples = _sample_posterior(
fn, first_available_dim, temperature, rng_key, *args, **kwargs
)
with substitute(data=samples):
return fn(*args, **kwargs)
return wrap_fn