# Funsor-based NumPyro¶

## Effect handlers¶

class enum(fn=None, first_available_dim=None)[source]

Bases: numpyro.contrib.funsor.enum_messenger.BaseEnumMessenger

Enumerates in parallel over discrete sample sites marked infer={"enumerate": "parallel"}.

Parameters: fn (callable) – Python callable with NumPyro primitives. first_available_dim (int) – The first tensor dimension (counting from the right) that is available for parallel enumeration. This dimension and all dimensions left may be used internally by Pyro. This should be a negative integer or None.
process_message(msg)[source]
class infer_config(fn, config_fn)[source]

Bases: numpyro.primitives.Messenger

Given a callable fn that contains Pyro primitive calls and a callable config_fn taking a trace site and returning a dictionary, updates the value of the infer kwarg at a sample site to config_fn(site).

Parameters: fn – a stochastic function (callable containing Pyro primitive calls) config_fn – a callable taking a site and returning an infer dict
process_message(msg)[source]
class markov[source]

Markov dependency declaration.

This is a statistical equivalent of a memory management arena.

Parameters: fn (callable) – Python callable with NumPyro primitives. history (int) – The number of previous contexts visible from the current context. Defaults to 1. If zero, this is similar to numpyro.primitives.plate. keep (bool) – If true, frames are replayable. This is important when branching: if keep=True, neighboring branches at the same level can depend on each other; if keep=False, neighboring branches are independent (conditioned on their shared ancestors).
class plate(name, size, subsample_size=None, dim=None)[source]

Bases: numpyro.contrib.funsor.enum_messenger.GlobalNamedMessenger

An alternative implementation of numpyro.primitives.plate primitive. Note that only this version is compatible with enumeration.

There is also a context manager plate_to_enum_plate() which converts numpyro.plate statements to this version.

Parameters: name (str) – Name of the plate. size (int) – Size of the plate. subsample_size (int) – Optional argument denoting the size of the mini-batch. This can be used to apply a scaling factor by inference algorithms. e.g. when computing ELBO using a mini-batch. dim (int) – Optional argument to specify which dimension in the tensor is used as the plate dim. If None (default), the leftmost available dim is allocated.
process_message(msg)[source]
postprocess_message(msg)[source]
class to_data[source]

A primitive to extract a python object from a Funsor.

Parameters: x (Funsor) – A funsor object name_to_dim (OrderedDict) – An optional inputs hint which maps dimension names from x to dimension positions of the returned value. dim_type (int) – Either 0, 1, or 2. This optional argument indicates a dimension should be treated as ‘local’, ‘global’, or ‘visible’, which can be used to interact with the global DimStack. A non-funsor equivalent to x.
class to_funsor[source]

A primitive to convert a Python object to a Funsor.

Parameters: x – An object. output (funsor.domains.Domain) – An optional output hint to uniquely convert a data to a Funsor (e.g. when x is a string). dim_to_name (OrderedDict) – An optional mapping from negative batch dimensions to name strings. dim_type (int) – Either 0, 1, or 2. This optional argument indicates a dimension should be treated as ‘local’, ‘global’, or ‘visible’, which can be used to interact with the global DimStack. A Funsor equivalent to x. funsor.terms.Funsor
class trace(fn=None)[source]

This version of trace handler records information necessary to do packing after execution.

Each sample site is annotated with a “dim_to_name” dictionary, which can be passed directly to to_funsor().

postprocess_message(msg)[source]

## Inference Utilities¶

config_enumerate(fn, default='parallel')[source]

Configures enumeration for all relevant sites in a NumPyro model.

When configuring for exhaustive enumeration of discrete variables, this configures all sample sites whose distribution satisfies .has_enumerate_support == True.

This can be used as either a function:

model = config_enumerate(model)


or as a decorator:

@config_enumerate
def model(*args, **kwargs):
...


Note

Currently, only default='parallel' is supported.

Parameters: fn (callable) – Python callable with NumPyro primitives. default (str) – Which enumerate strategy to use, one of “sequential”, “parallel”, or None. Defaults to “parallel”.
log_density(model, model_args, model_kwargs, params)[source]

Similar to numpyro.infer.util.log_density() but works for models with discrete latent variables. Internally, this uses funsor to marginalize discrete latent sites and evaluate the joint log probability.

Parameters: model – Python callable containing NumPyro primitives. Typically, the model has been enumerated by using enum handler: def model(*args, **kwargs): ... log_joint = log_density(enum(config_enumerate(model)), args, kwargs, params)  model_args (tuple) – args provided to the model. model_kwargs (dict) – kwargs provided to the model. params (dict) – dictionary of current parameter values keyed by site name. log of joint density and a corresponding model trace
plate_to_enum_plate()[source]

A context manager to replace numpyro.plate statement by a funsor-based plate.

This is useful when doing inference for the usual NumPyro programs with numpyro.plate statements. For example, to get trace of a model whose discrete latent sites are enumerated, we can use:

enum_model = numpyro.contrib.funsor.enum(model)
with plate_to_enum_plate():
model_trace = numpyro.contrib.funsor.trace(enum_model).get_trace(
*model_args, **model_kwargs)