Runtime Utilities
enable_validation
- enable_validation(is_validate=True)[source]
Enable or disable validation checks in NumPyro. Validation checks provide useful warnings and errors, e.g. NaN checks, validating distribution arguments and support values, etc. which is useful for debugging.
Note
This utility does not take effect under JAX’s JIT compilation or vectorized transformation
jax.vmap()
.- Parameters:
is_validate (bool) – whether to enable validation checks.
validation_enabled
enable_x64
set_platform
set_host_device_count
- set_host_device_count(n)[source]
By default, XLA considers all CPU cores as one device. This utility tells XLA that there are n host (CPU) devices available to use. As a consequence, this allows parallel mapping in JAX
jax.pmap()
to work in CPU platform.Note
This utility only takes effect at the beginning of your program. Under the hood, this sets the environment variable XLA_FLAGS=–xla_force_host_platform_device_count=[num_devices], where [num_device] is the desired number of CPU devices n.
Warning
Our understanding of the side effects of using the xla_force_host_platform_device_count flag in XLA is incomplete. If you observe some strange phenomenon when using this utility, please let us know through our issue or forum page. More information is available in this JAX issue.
- Parameters:
n (int) – number of CPU devices to use.
Inference Utilities
Predictive
- class Predictive(model: Callable, posterior_samples: dict | None = None, *, guide: Callable | None = None, params: dict | None = None, num_samples: int | None = None, return_sites: list[str] | None = None, infer_discrete: bool = False, parallel: bool = False, batch_ndims: int | None = None, exclude_deterministic: bool = True)[source]
Bases:
object
This class is used to construct predictive distribution. The predictive distribution is obtained by running model conditioned on latent samples from posterior_samples.
Warning
The interface for the Predictive class is experimental, and might change in the future.
Note that for the predictive distribution to be returned as intended, observed variables in the model (constraining the likelihood term) must be set to None (see Example).
- Parameters:
model – Python callable containing Pyro primitives.
posterior_samples (dict) – dictionary of samples from the posterior.
guide (callable) – optional guide to get posterior samples of sites not present in posterior_samples.
params (dict) – dictionary of values for param sites of model/guide.
num_samples (int) – number of samples
return_sites (list) – sites to return; by default only sample sites not present in posterior_samples are returned.
infer_discrete (bool) – whether or not to sample discrete sites from the posterior, conditioned on observations and other latent values in
posterior_samples
. Under the hood, those sites will be marked withsite["infer"]["enumerate"] = "parallel"
. See how infer_discrete works at the Pyro enumeration tutorial. Note that this requiresfunsor
installation.parallel (bool) – whether to predict in parallel using JAX vectorized map
jax.vmap()
. Defaults to False.batch_ndims –
the number of batch dimensions in posterior samples or parameters. If None defaults to 0 if guide is set (i.e. not None) and 1 otherwise. Usages for batched posterior samples:
set batch_ndims=0 to get prediction for 1 single sample
set batch_ndims=1 to get prediction for posterior_samples with shapes (num_samples x …) (same as`batch_ndims=None` with guide=None)
set batch_ndims=2 to get prediction for posterior_samples with shapes (num_chains x N x …). Note that if num_samples argument is not None, its value should be equal to num_chains x N.
Usages for batched parameters:
set batch_ndims=0 to get 1 sample from the guide and parameters (same as batch_ndims=None with guide)
set batch_ndims=1 to get predictions from a one dimensional batch of the guide and parameters with shapes (num_samples x batch_size x …)
exclude_deterministic – indicates whether to ignore deterministic sites from the posterior samples.
- Returns:
dict of samples from the predictive distribution.
Example:
Given a model:
def model(X, y=None): ... return numpyro.sample("obs", likelihood, obs=y)
you can sample from the prior predictive:
predictive = Predictive(model, num_samples=1000) y_pred = predictive(rng_key, X)["obs"]
Note how above, no value for y is passed to predictive, resulting in y being set to None. Setting the observed variable(s) to None when using Predictive is required for the method to function as expected.
If you also have posterior samples, you can sample from the posterior predictive:
predictive = Predictive(model, posterior_samples=posterior_samples) y_pred = predictive(rng_key, X)["obs"]
See docstrings for
SVI
andMCMCKernel
to see example code of this in context.
log_density
get_transforms
transform_fn
- transform_fn(transforms, params, invert=False)[source]
(EXPERIMENTAL INTERFACE) Callable that applies a transformation from the transforms dict to values in the params dict and returns the transformed values keyed on the same names.
- Parameters:
transforms – Dictionary of transforms keyed by names. Names in transforms and params should align.
params – Dictionary of arrays keyed by names.
invert – Whether to apply the inverse of the transforms.
- Returns:
dict of transformed params.
constrain_fn
- constrain_fn(model, model_args, model_kwargs, params, return_deterministic=False)[source]
(EXPERIMENTAL INTERFACE) Gets value at each latent site in model given unconstrained parameters params. The transforms is used to transform these unconstrained parameters to base values of the corresponding priors in model. If a prior is a transformed distribution, the corresponding base value lies in the support of base distribution. Otherwise, the base value lies in the support of the distribution.
- Parameters:
model – a callable containing NumPyro primitives.
model_args (tuple) – args provided to the model.
model_kwargs (dict) – kwargs provided to the model.
params (dict) – dictionary of unconstrained values keyed by site names.
return_deterministic (bool) – whether to return the value of deterministic sites from the model. Defaults to False.
- Returns:
dict of transformed params.
unconstrain_fn
potential_energy
- potential_energy(model, model_args, model_kwargs, params, enum=False)[source]
(EXPERIMENTAL INTERFACE) Computes potential energy of a model given unconstrained params. Under the hood, we will transform these unconstrained parameters to the values belong to the supports of the corresponding priors in model.
- Parameters:
- Returns:
potential energy given unconstrained parameters.
log_likelihood
- log_likelihood(model, posterior_samples, *args, parallel=False, batch_ndims=1, **kwargs)[source]
(EXPERIMENTAL INTERFACE) Returns log likelihood at observation nodes of model, given samples of all latent variables.
- Parameters:
model – Python callable containing Pyro primitives.
posterior_samples (dict) – dictionary of samples from the posterior.
args – model arguments.
batch_ndims –
the number of batch dimensions in posterior samples. Some usages:
set batch_ndims=0 to get log likelihoods for 1 single sample
set batch_ndims=1 to get log likelihoods for posterior_samples with shapes (num_samples x …)
set batch_ndims=2 to get log likelihoods for posterior_samples with shapes (num_chains x num_samples x …)
kwargs – model kwargs.
- Returns:
dict of log likelihoods at observation sites.
find_valid_initial_params
- find_valid_initial_params(rng_key, model, *, init_strategy=<function init_to_uniform>, enum=False, model_args=(), model_kwargs=None, prototype_params=None, forward_mode_differentiation=False, validate_grad=True)[source]
(EXPERIMENTAL INTERFACE) Given a model with Pyro primitives, returns an initial valid unconstrained value for all the parameters. This function also returns the corresponding potential energy, the gradients, and an is_valid flag to say whether the initial parameters are valid. Parameter values are considered valid if the values and the gradients for the log density have finite values.
- Parameters:
rng_key (jax.random.PRNGKey) – random number generator seed to sample from the prior. The returned init_params will have the batch shape
rng_key.shape[:-1]
.model – Python callable containing Pyro primitives.
init_strategy (callable) – a per-site initialization function.
enum (bool) – whether to enumerate over discrete latent sites.
model_args (tuple) – args provided to the model.
model_kwargs (dict) – kwargs provided to the model.
prototype_params (dict) – an optional prototype parameters, which is used to define the shape for initial parameters.
forward_mode_differentiation (bool) – whether to use forward-mode differentiation or reverse-mode differentiation. Defaults to False.
validate_grad (bool) – whether to validate gradient of the initial params. Defaults to True.
- Returns:
tuple of init_params_info and is_valid, where init_params_info is the tuple containing the initial params, their potential energy, and their gradients.
Initialization Strategies
init_to_feasible
init_to_mean
- init_to_mean(site=None)[source]
Initialize to the prior mean. For priors with no .mean property implemented, we defer to the
init_to_median()
strategy.
init_to_median
- init_to_median(site=None, num_samples=15)[source]
Initialize to the prior median. For priors with no .sample method implemented, we defer to the
init_to_uniform()
strategy.- Parameters:
num_samples (int) – number of prior points to calculate median.
init_to_sample
- init_to_sample(site=None)[source]
Initialize to a prior sample. For priors with no .sample method implemented, we defer to the
init_to_uniform()
strategy.
init_to_uniform
init_to_value
- init_to_value(site=None, values={})[source]
Initialize to the value specified in values. We defer to
init_to_uniform()
strategy for sites which do not appear in values.- Parameters:
values (dict) – dictionary of initial values keyed by site name.
Tensor Indexing
- vindex(tensor, args)[source]
Vectorized advanced indexing with broadcasting semantics.
See also the convenience wrapper
Vindex
.This is useful for writing indexing code that is compatible with batching and enumeration, especially for selecting mixture components with discrete random variables.
For example suppose
x
is a parameter withlen(x.shape) == 3
and we wish to generalize the expressionx[i, :, j]
from integeri,j
to tensorsi,j
with batch dims and enum dims (but no event dims). Then we can write the generalize version usingVindex
xij = Vindex(x)[i, :, j] batch_shape = broadcast_shape(i.shape, j.shape) event_shape = (x.size(1),) assert xij.shape == batch_shape + event_shape
To handle the case when
x
may also contain batch dimensions (e.g. ifx
was sampled in a plated context as when using vectorized particles),vindex()
uses the special convention thatEllipsis
denotes batch dimensions (hence...
can appear only on the left, never in the middle or in the right). Supposex
has event dim 3. Then we can write:old_batch_shape = x.shape[:-3] old_event_shape = x.shape[-3:] xij = Vindex(x)[..., i, :, j] # The ... denotes unknown batch shape. new_batch_shape = broadcast_shape(old_batch_shape, i.shape, j.shape) new_event_shape = (x.size(1),) assert xij.shape = new_batch_shape + new_event_shape
Note that this special handling of
Ellipsis
differs from the NEP [1].Formally, this function assumes:
Each arg is either
Ellipsis
,slice(None)
, an integer, or a batched integer tensor (i.e. with empty event shape). This function does not support Nontrivial slices or boolean tensor masks.Ellipsis
can only appear on the left asargs[0]
.If
args[0] is not Ellipsis
thentensor
is not batched, and its event dim is equal tolen(args)
.If
args[0] is Ellipsis
thentensor
is batched and its event dim is equal tolen(args[1:])
. Dims oftensor
to the left of the event dims are considered batch dims and will be broadcasted with dims of tensor args.
Note that if none of the args is a tensor with
len(shape) > 0
, then this function behaves like standard indexing:if not any(isinstance(a, jnp.ndarray) and len(a.shape) > 0 for a in args): assert Vindex(x)[args] == x[args]
References
- [1] https://www.numpy.org/neps/nep-0021-advanced-indexing.html
introduces
vindex
as a helper for vectorized indexing. This implementation is similar to the proposed notationx.vindex[]
except for slightly different handling ofEllipsis
.
- Parameters:
tensor (jnp.ndarray) – A tensor to be indexed.
args (tuple) – An index, as args to
__getitem__
.
- Returns:
A nonstandard interpretation of
tensor[args]
.- Return type:
jnp.ndarray
Model Inspection
get_dependencies
- get_dependencies(model: Callable, model_args: tuple | None = None, model_kwargs: dict | None = None) dict[str, object] [source]
Infers dependency structure about a conditioned model.
This returns a nested dictionary with structure like:
{ "prior_dependencies": { "variable1": {"variable1": set()}, "variable2": {"variable1": set(), "variable2": set()}, ... }, "posterior_dependencies": { "variable1": {"variable1": {"plate1"}, "variable2": set()}, ... }, }
where
prior_dependencies is a dict mapping downstream latent and observed variables to dictionaries mapping upstream latent variables on which they depend to sets of plates inducing full dependencies. That is, included plates introduce quadratically many dependencies as in complete-bipartite graphs, whereas excluded plates introduce only linearly many dependencies as in independent sets of parallel edges. Prior dependencies follow the original model order.
posterior_dependencies is a similar dict, but mapping latent variables to the latent or observed sits on which they depend in the posterior. Posterior dependencies are reversed from the model order.
Dependencies elide
numpyro.deterministic
sites andnumpyro.sample(..., Delta(...))
sites.Examples
Here is a simple example with no plates. We see every node depends on itself, and only the latent variables appear in the posterior:
def model_1(): a = numpyro.sample("a", dist.Normal(0, 1)) numpyro.sample("b", dist.Normal(a, 1), obs=0.0) assert get_dependencies(model_1) == { "prior_dependencies": { "a": {"a": set()}, "b": {"a": set(), "b": set()}, }, "posterior_dependencies": { "a": {"a": set(), "b": set()}, }, }
Here is an example where two variables
a
andb
start out conditionally independent in the prior, but become conditionally dependent in the posterior do the so-called collider variablec
on which they both depend. This is called “moralization” in the graphical model literature:def model_2(): a = numpyro.sample("a", dist.Normal(0, 1)) b = numpyro.sample("b", dist.LogNormal(0, 1)) c = numpyro.sample("c", dist.Normal(a, b)) numpyro.sample("d", dist.Normal(c, 1), obs=0.) assert get_dependencies(model_2) == { "prior_dependencies": { "a": {"a": set()}, "b": {"b": set()}, "c": {"a": set(), "b": set(), "c": set()}, "d": {"c": set(), "d": set()}, }, "posterior_dependencies": { "a": {"a": set(), "b": set(), "c": set()}, "b": {"b": set(), "c": set()}, "c": {"c": set(), "d": set()}, }, }
Dependencies can be more complex in the presence of plates. So far all the dict values have been empty sets of plates, but in the following posterior we see that
c
depends on itself across the platep
. This means that, among the elements ofc
, e.g.c[0]
depends onc[1]
(this is why we explicitly allow variables to depend on themselves):def model_3(): with numpyro.plate("p", 5): a = numpyro.sample("a", dist.Normal(0, 1)) numpyro.sample("b", dist.Normal(a.sum(), 1), obs=0.0) assert get_dependencies(model_3) == { "prior_dependencies": { "a": {"a": set()}, "b": {"a": set(), "b": set()}, }, "posterior_dependencies": { "a": {"a": {"p"}, "b": set()}, }, }
- [1] S.Webb, A.Goliński, R.Zinkov, N.Siddharth, T.Rainforth, Y.W.Teh, F.Wood (2018)
“Faithful inversion of generative models for effective amortized inference” https://dl.acm.org/doi/10.5555/3327144.3327229
get_model_relations
- get_model_relations(model, model_args=None, model_kwargs=None)[source]
Infer relations of RVs and plates from given model and optionally data. See https://github.com/pyro-ppl/numpyro/issues/949 for more details.
This returns a dictionary with keys:
“sample_sample” map each downstream sample site to a list of the upstream sample sites on which it depend;
“sample_param” map each downstream sample site to a list of the upstream param sites on which it depend;
“sample_dist” maps each sample site to the name of the distribution at that site;
“param_constraint” maps each param site to the name of the constraints at that site;
“plate_sample” maps each plate name to a lists of the sample sites within that plate; and
“observe” is a list of observed sample sites.
For example for the model:
def model(data): m = numpyro.sample('m', dist.Normal(0, 1)) sd = numpyro.sample('sd', dist.LogNormal(m, 1)) with numpyro.plate('N', len(data)): numpyro.sample('obs', dist.Normal(m, sd), obs=data)
the relation is:
{'sample_sample': {'m': [], 'sd': ['m'], 'obs': ['m', 'sd']}, 'sample_dist': {'m': 'Normal', 'sd': 'LogNormal', 'obs': 'Normal'}, 'plate_sample': {'N': ['obs']}, 'observed': ['obs']}
- Parameters:
model (callable) – A model to inspect.
model_args – Optional tuple of model args.
model_kwargs – Optional dict of model kwargs.
- Return type:
Visualization Utilities
render_model
- render_model(model, model_args=None, model_kwargs=None, filename=None, render_distributions=False, render_params=False)[source]
Wrap all functions needed to automatically render a model.
Warning
This utility does not support the
scan()
primitive. If you want to render a time-series model, you can try to rewrite the code using Python for loop.- Parameters:
model – Model to render.
model_args – Positional arguments to pass to the model.
model_kwargs – Keyword arguments to pass to the model.
filename (str) – File to save rendered model in.
render_distributions (bool) – Whether to include RV distribution annotations in the plot.
render_params (bool) – Whether to show params in the plot.
Trace Inspection
- format_shapes(trace, *, compute_log_prob=False, title='Trace Shapes:', last_site=None)[source]
Given the trace of a function, returns a string showing a table of the shapes of all sites in the trace.
Use
trace
handler (or funsortrace
handler for enumeration) to produce the trace.- Parameters:
trace (dict) – The model trace to format.
compute_log_prob – Compute log probabilities and display the shapes in the table. Accepts True / False or a function which when given a dictionary containing site-level metadata returns whether the log probability should be calculated and included in the table.
title (str) – Title for the table of shapes.
last_site (str) – Name of a site in the model. If supplied, subsequent sites are not displayed in the table.
Usage:
def model(*args, **kwargs): ... with numpyro.handlers.seed(rng_seed=1): trace = numpyro.handlers.trace(model).get_trace(*args, **kwargs) print(numpyro.util.format_shapes(trace))