Funsor-based NumPyro¶
Effect handlers¶
-
class
enum(fn=None, first_available_dim=None)[source]¶ Bases:
numpyro.contrib.funsor.enum_messenger.BaseEnumMessengerEnumerates in parallel over discrete sample sites marked
infer={"enumerate": "parallel"}.Parameters: - fn (callable) – Python callable with NumPyro primitives.
- first_available_dim (int) – The first tensor dimension (counting from the right) that is available for parallel enumeration. This dimension and all dimensions left may be used internally by Pyro. This should be a negative integer or None.
-
class
infer_config(fn=None, config_fn=None)[source]¶ Bases:
numpyro.primitives.MessengerGiven a callable fn that contains NumPyro primitive calls and a callable config_fn taking a trace site and returning a dictionary, updates the value of the infer kwarg at a sample site to config_fn(site).
Parameters: - fn – a stochastic function (callable containing NumPyro primitive calls)
- config_fn – a callable taking a site and returning an infer dict
-
markov(fn=None, history=1, keep=False)[source]¶ Markov dependency declaration.
This is a statistical equivalent of a memory management arena.
Parameters: - fn (callable) – Python callable with NumPyro primitives.
- history (int) – The number of previous contexts visible from the
current context. Defaults to 1. If zero, this is similar to
numpyro.primitives.plate. - keep (bool) – If true, frames are replayable. This is important
when branching: if
keep=True, neighboring branches at the same level can depend on each other; ifkeep=False, neighboring branches are independent (conditioned on their shared ancestors).
-
class
plate(name, size, subsample_size=None, dim=None)[source]¶ Bases:
numpyro.contrib.funsor.enum_messenger.GlobalNamedMessengerAn alternative implementation of
numpyro.primitives.plateprimitive. Note that only this version is compatible with enumeration.There is also a context manager
plate_to_enum_plate()which converts numpyro.plate statements to this version.Parameters: - name (str) – Name of the plate.
- size (int) – Size of the plate.
- subsample_size (int) – Optional argument denoting the size of the mini-batch. This can be used to apply a scaling factor by inference algorithms. e.g. when computing ELBO using a mini-batch.
- dim (int) – Optional argument to specify which dimension in the tensor is used as the plate dim. If None (default), the leftmost available dim is allocated.
-
to_data(x, name_to_dim=None, dim_type=<DimType.LOCAL: 0>)[source]¶ A primitive to extract a python object from a
Funsor.Parameters: - x (Funsor) – A funsor object
- name_to_dim (OrderedDict) – An optional inputs hint which maps dimension names from x to dimension positions of the returned value.
- dim_type (int) – Either 0, 1, or 2. This optional argument indicates
a dimension should be treated as ‘local’, ‘global’, or ‘visible’,
which can be used to interact with the global
DimStack.
Returns: A non-funsor equivalent to x.
-
to_funsor(x, output=None, dim_to_name=None, dim_type=<DimType.LOCAL: 0>)[source]¶ A primitive to convert a Python object to a
Funsor.Parameters: - x – An object.
- output (funsor.domains.Domain) – An optional output hint to uniquely convert a data to a Funsor (e.g. when x is a string).
- dim_to_name (OrderedDict) – An optional mapping from negative batch dimensions to name strings.
- dim_type (int) – Either 0, 1, or 2. This optional argument indicates
a dimension should be treated as ‘local’, ‘global’, or ‘visible’,
which can be used to interact with the global
DimStack.
Returns: A Funsor equivalent to x.
Return type: funsor.terms.Funsor
-
class
trace(fn=None)[source]¶ Bases:
numpyro.handlers.traceThis version of
tracehandler records information necessary to do packing after execution.Each sample site is annotated with a “dim_to_name” dictionary, which can be passed directly to
to_funsor().
Inference Utilities¶
-
config_enumerate(fn=None, default='parallel')[source]¶ Configures enumeration for all relevant sites in a NumPyro model.
When configuring for exhaustive enumeration of discrete variables, this configures all sample sites whose distribution satisfies
.has_enumerate_support == True.This can be used as either a function:
model = config_enumerate(model)
or as a decorator:
@config_enumerate def model(*args, **kwargs): ...
Note
Currently, only
default='parallel'is supported.Parameters: - fn (callable) – Python callable with NumPyro primitives.
- default (str) – Which enumerate strategy to use, one of “sequential”, “parallel”, or None. Defaults to “parallel”.
-
infer_discrete(fn=None, first_available_dim=None, temperature=1, rng_key=None)[source]¶ A handler that samples discrete sites marked with
site["infer"]["enumerate"] = "parallel"from the posterior, conditioned on observations.Example:
@infer_discrete(first_available_dim=-1, temperature=0) @config_enumerate def viterbi_decoder(data, hidden_dim=10): transition = 0.3 / hidden_dim + 0.7 * jnp.eye(hidden_dim) means = jnp.arange(float(hidden_dim)) states = [0] for t in markov(range(len(data))): states.append(numpyro.sample("states_{}".format(t), dist.Categorical(transition[states[-1]]))) numpyro.sample("obs_{}".format(t), dist.Normal(means[states[-1]], 1.), obs=data[t]) return states # returns maximum likelihood states
Parameters: - fn – a stochastic function (callable containing NumPyro primitive calls)
- first_available_dim (int) – The first tensor dimension (counting from the right) that is available for parallel enumeration. This dimension and all dimensions left may be used internally by Pyro. This should be a negative integer.
- temperature (int) – Either 1 (sample via forward-filter backward-sample) or 0 (optimize via Viterbi-like MAP inference). Defaults to 1 (sample).
- rng_key (jax.random.PRNGKey) – a random number generator key, to be used in
cases
temperature=1orfirst_available_dim is None.
-
log_density(model, model_args, model_kwargs, params)[source]¶ Similar to
numpyro.infer.util.log_density()but works for models with discrete latent variables. Internally, this usesfunsorto marginalize discrete latent sites and evaluate the joint log probability.Parameters: - model –
Python callable containing NumPyro primitives. Typically, the model has been enumerated by using
enumhandler:def model(*args, **kwargs): ... log_joint = log_density(enum(config_enumerate(model)), args, kwargs, params)
- model_args (tuple) – args provided to the model.
- model_kwargs (dict) – kwargs provided to the model.
- params (dict) – dictionary of current parameter values keyed by site name.
Returns: log of joint density and a corresponding model trace
- model –
-
plate_to_enum_plate()[source]¶ A context manager to replace numpyro.plate statement by a funsor-based
plate.This is useful when doing inference for the usual NumPyro programs with numpyro.plate statements. For example, to get trace of a model whose discrete latent sites are enumerated, we can use:
enum_model = numpyro.contrib.funsor.enum(model) with plate_to_enum_plate(): model_trace = numpyro.contrib.funsor.trace(enum_model).get_trace( *model_args, **model_kwargs)