Runtime Utilities

enable_validation

enable_validation(is_validate=True)[source]

Enable or disable validation checks in NumPyro. Validation checks provide useful warnings and errors, e.g. NaN checks, validating distribution arguments and support values, etc. which is useful for debugging.

Note

This utility does not take effect under JAX’s JIT compilation or vectorized transformation jax.vmap().

Parameters:

is_validate (bool) – whether to enable validation checks.

validation_enabled

validation_enabled(is_validate=True)[source]

Context manager that is useful when temporarily enabling/disabling validation checks.

Parameters:

is_validate (bool) – whether to enable validation checks.

enable_x64

enable_x64(use_x64=True)[source]

Changes the default array type to use 64 bit precision as in NumPy.

Parameters:

use_x64 (bool) – when True, JAX arrays will use 64 bits by default; else 32 bits.

set_platform

set_platform(platform=None)[source]

Changes platform to CPU, GPU, or TPU. This utility only takes effect at the beginning of your program.

Parameters:

platform (str) – either ‘cpu’, ‘gpu’, or ‘tpu’.

set_host_device_count

set_host_device_count(n)[source]

By default, XLA considers all CPU cores as one device. This utility tells XLA that there are n host (CPU) devices available to use. As a consequence, this allows parallel mapping in JAX jax.pmap() to work in CPU platform.

Note

This utility only takes effect at the beginning of your program. Under the hood, this sets the environment variable XLA_FLAGS=–xla_force_host_platform_device_count=[num_devices], where [num_device] is the desired number of CPU devices n.

Warning

Our understanding of the side effects of using the xla_force_host_platform_device_count flag in XLA is incomplete. If you observe some strange phenomenon when using this utility, please let us know through our issue or forum page. More information is available in this JAX issue.

Parameters:

n (int) – number of CPU devices to use.

Inference Utilities

Predictive

class Predictive(model: Callable, posterior_samples: dict | None = None, *, guide: Callable | None = None, params: dict | None = None, num_samples: int | None = None, return_sites: list[str] | None = None, infer_discrete: bool = False, parallel: bool = False, batch_ndims: int | None = None)[source]

Bases: object

This class is used to construct predictive distribution. The predictive distribution is obtained by running model conditioned on latent samples from posterior_samples.

Warning

The interface for the Predictive class is experimental, and might change in the future.

Parameters:
  • model – Python callable containing Pyro primitives.

  • posterior_samples (dict) – dictionary of samples from the posterior.

  • guide (callable) – optional guide to get posterior samples of sites not present in posterior_samples.

  • params (dict) – dictionary of values for param sites of model/guide.

  • num_samples (int) – number of samples

  • return_sites (list) – sites to return; by default only sample sites not present in posterior_samples are returned.

  • infer_discrete (bool) – whether or not to sample discrete sites from the posterior, conditioned on observations and other latent values in posterior_samples. Under the hood, those sites will be marked with site["infer"]["enumerate"] = "parallel". See how infer_discrete works at the Pyro enumeration tutorial. Note that this requires funsor installation.

  • parallel (bool) – whether to predict in parallel using JAX vectorized map jax.vmap(). Defaults to False.

  • batch_ndims

    the number of batch dimensions in posterior samples or parameters. If None defaults to 0 if guide is set (i.e. not None) and 1 otherwise. Usages for batched posterior samples:

    • set batch_ndims=0 to get prediction for 1 single sample

    • set batch_ndims=1 to get prediction for posterior_samples with shapes (num_samples x …) (same as`batch_ndims=None` with guide=None)

    • set batch_ndims=2 to get prediction for posterior_samples with shapes (num_chains x N x …). Note that if num_samples argument is not None, its value should be equal to num_chains x N.

    Usages for batched parameters:

    • set batch_ndims=0 to get 1 sample from the guide and parameters (same as batch_ndims=None with guide)

    • set batch_ndims=1 to get predictions from a one dimensional batch of the guide and parameters with shapes (num_samples x batch_size x …)

Returns:

dict of samples from the predictive distribution.

Example:

Given a model:

def model(X, y=None):
    ...
    return numpyro.sample("obs", likelihood, obs=y)

you can sample from the prior predictive:

predictive = Predictive(model, num_samples=1000)
y_pred = predictive(rng_key, X)["obs"]

If you also have posterior samples, you can sample from the posterior predictive:

predictive = Predictive(model, posterior_samples=posterior_samples)
y_pred = predictive(rng_key, X)["obs"]

See docstrings for SVI and MCMCKernel to see example code of this in context.

log_density

log_density(model, model_args, model_kwargs, params)[source]

(EXPERIMENTAL INTERFACE) Computes log of joint density for the model given latent values params.

Parameters:
  • model – Python callable containing NumPyro primitives.

  • model_args (tuple) – args provided to the model.

  • model_kwargs (dict) – kwargs provided to the model.

  • params (dict) – dictionary of current parameter values keyed by site name.

Returns:

log of joint density and a corresponding model trace

get_transforms

get_transforms(model, model_args, model_kwargs, params)[source]

(EXPERIMENTAL INTERFACE) Retrieve (inverse) transforms via biject_to() given a NumPyro model. This function supports ‘param’ sites. NB: Parameter values are only used to retrieve the model trace.

Parameters:
  • model – a callable containing NumPyro primitives.

  • model_args (tuple) – args provided to the model.

  • model_kwargs (dict) – kwargs provided to the model.

  • params (dict) – dictionary of values keyed by site names.

Returns:

dict of transformation keyed by site names.

transform_fn

transform_fn(transforms, params, invert=False)[source]

(EXPERIMENTAL INTERFACE) Callable that applies a transformation from the transforms dict to values in the params dict and returns the transformed values keyed on the same names.

Parameters:
  • transforms – Dictionary of transforms keyed by names. Names in transforms and params should align.

  • params – Dictionary of arrays keyed by names.

  • invert – Whether to apply the inverse of the transforms.

Returns:

dict of transformed params.

constrain_fn

constrain_fn(model, model_args, model_kwargs, params, return_deterministic=False)[source]

(EXPERIMENTAL INTERFACE) Gets value at each latent site in model given unconstrained parameters params. The transforms is used to transform these unconstrained parameters to base values of the corresponding priors in model. If a prior is a transformed distribution, the corresponding base value lies in the support of base distribution. Otherwise, the base value lies in the support of the distribution.

Parameters:
  • model – a callable containing NumPyro primitives.

  • model_args (tuple) – args provided to the model.

  • model_kwargs (dict) – kwargs provided to the model.

  • params (dict) – dictionary of unconstrained values keyed by site names.

  • return_deterministic (bool) – whether to return the value of deterministic sites from the model. Defaults to False.

Returns:

dict of transformed params.

unconstrain_fn

unconstrain_fn(model, model_args, model_kwargs, params)[source]

(EXPERIMENTAL INTERFACE) Given a NumPyro model and a dict of parameters, this function applies the right transformation to convert parameter values from constrained space to unconstrained space.

Parameters:
  • model – a callable containing NumPyro primitives.

  • model_args (tuple) – args provided to the model.

  • model_kwargs (dict) – kwargs provided to the model.

  • params (dict) – dictionary of constrained values keyed by site names.

Returns:

dict of transformation keyed by site names.

potential_energy

potential_energy(model, model_args, model_kwargs, params, enum=False)[source]

(EXPERIMENTAL INTERFACE) Computes potential energy of a model given unconstrained params. Under the hood, we will transform these unconstrained parameters to the values belong to the supports of the corresponding priors in model.

Parameters:
  • model – a callable containing NumPyro primitives.

  • model_args (tuple) – args provided to the model.

  • model_kwargs (dict) – kwargs provided to the model.

  • params (dict) – unconstrained parameters of model.

  • enum (bool) – whether to enumerate over discrete latent sites.

Returns:

potential energy given unconstrained parameters.

log_likelihood

log_likelihood(model, posterior_samples, *args, parallel=False, batch_ndims=1, **kwargs)[source]

(EXPERIMENTAL INTERFACE) Returns log likelihood at observation nodes of model, given samples of all latent variables.

Parameters:
  • model – Python callable containing Pyro primitives.

  • posterior_samples (dict) – dictionary of samples from the posterior.

  • args – model arguments.

  • batch_ndims

    the number of batch dimensions in posterior samples. Some usages:

    • set batch_ndims=0 to get log likelihoods for 1 single sample

    • set batch_ndims=1 to get log likelihoods for posterior_samples with shapes (num_samples x …)

    • set batch_ndims=2 to get log likelihoods for posterior_samples with shapes (num_chains x num_samples x …)

  • kwargs – model kwargs.

Returns:

dict of log likelihoods at observation sites.

find_valid_initial_params

find_valid_initial_params(rng_key, model, *, init_strategy=<function init_to_uniform>, enum=False, model_args=(), model_kwargs=None, prototype_params=None, forward_mode_differentiation=False, validate_grad=True)[source]

(EXPERIMENTAL INTERFACE) Given a model with Pyro primitives, returns an initial valid unconstrained value for all the parameters. This function also returns the corresponding potential energy, the gradients, and an is_valid flag to say whether the initial parameters are valid. Parameter values are considered valid if the values and the gradients for the log density have finite values.

Parameters:
  • rng_key (jax.random.PRNGKey) – random number generator seed to sample from the prior. The returned init_params will have the batch shape rng_key.shape[:-1].

  • model – Python callable containing Pyro primitives.

  • init_strategy (callable) – a per-site initialization function.

  • enum (bool) – whether to enumerate over discrete latent sites.

  • model_args (tuple) – args provided to the model.

  • model_kwargs (dict) – kwargs provided to the model.

  • prototype_params (dict) – an optional prototype parameters, which is used to define the shape for initial parameters.

  • forward_mode_differentiation (bool) – whether to use forward-mode differentiation or reverse-mode differentiation. Defaults to False.

  • validate_grad (bool) – whether to validate gradient of the initial params. Defaults to True.

Returns:

tuple of init_params_info and is_valid, where init_params_info is the tuple containing the initial params, their potential energy, and their gradients.

Initialization Strategies

init_to_feasible

init_to_feasible(site=None)[source]

Initialize to an arbitrary feasible point, ignoring distribution parameters.

init_to_mean

init_to_mean(site=None)[source]

Initialize to the prior mean. For priors with no .mean property implemented, we defer to the init_to_median() strategy.

init_to_median

init_to_median(site=None, num_samples=15)[source]

Initialize to the prior median. For priors with no .sample method implemented, we defer to the init_to_uniform() strategy.

Parameters:

num_samples (int) – number of prior points to calculate median.

init_to_sample

init_to_sample(site=None)[source]

Initialize to a prior sample. For priors with no .sample method implemented, we defer to the init_to_uniform() strategy.

init_to_uniform

init_to_uniform(site=None, radius=2)[source]

Initialize to a random point in the area (-radius, radius) of unconstrained domain.

Parameters:

radius (float) – specifies the range to draw an initial point in the unconstrained domain.

init_to_value

init_to_value(site=None, values={})[source]

Initialize to the value specified in values. We defer to init_to_uniform() strategy for sites which do not appear in values.

Parameters:

values (dict) – dictionary of initial values keyed by site name.

Tensor Indexing

vindex(tensor, args)[source]

Vectorized advanced indexing with broadcasting semantics.

See also the convenience wrapper Vindex.

This is useful for writing indexing code that is compatible with batching and enumeration, especially for selecting mixture components with discrete random variables.

For example suppose x is a parameter with len(x.shape) == 3 and we wish to generalize the expression x[i, :, j] from integer i,j to tensors i,j with batch dims and enum dims (but no event dims). Then we can write the generalize version using Vindex

xij = Vindex(x)[i, :, j]

batch_shape = broadcast_shape(i.shape, j.shape)
event_shape = (x.size(1),)
assert xij.shape == batch_shape + event_shape

To handle the case when x may also contain batch dimensions (e.g. if x was sampled in a plated context as when using vectorized particles), vindex() uses the special convention that Ellipsis denotes batch dimensions (hence ... can appear only on the left, never in the middle or in the right). Suppose x has event dim 3. Then we can write:

old_batch_shape = x.shape[:-3]
old_event_shape = x.shape[-3:]

xij = Vindex(x)[..., i, :, j]   # The ... denotes unknown batch shape.

new_batch_shape = broadcast_shape(old_batch_shape, i.shape, j.shape)
new_event_shape = (x.size(1),)
assert xij.shape = new_batch_shape + new_event_shape

Note that this special handling of Ellipsis differs from the NEP [1].

Formally, this function assumes:

  1. Each arg is either Ellipsis, slice(None), an integer, or a batched integer tensor (i.e. with empty event shape). This function does not support Nontrivial slices or boolean tensor masks. Ellipsis can only appear on the left as args[0].

  2. If args[0] is not Ellipsis then tensor is not batched, and its event dim is equal to len(args).

  3. If args[0] is Ellipsis then tensor is batched and its event dim is equal to len(args[1:]). Dims of tensor to the left of the event dims are considered batch dims and will be broadcasted with dims of tensor args.

Note that if none of the args is a tensor with len(shape) > 0, then this function behaves like standard indexing:

if not any(isinstance(a, jnp.ndarray) and len(a.shape) > 0 for a in args):
    assert Vindex(x)[args] == x[args]

References

[1] https://www.numpy.org/neps/nep-0021-advanced-indexing.html

introduces vindex as a helper for vectorized indexing. This implementation is similar to the proposed notation x.vindex[] except for slightly different handling of Ellipsis.

Parameters:
  • tensor (jnp.ndarray) – A tensor to be indexed.

  • args (tuple) – An index, as args to __getitem__.

Returns:

A nonstandard interpretation of tensor[args].

Return type:

jnp.ndarray

class Vindex(tensor)[source]

Bases: object

Convenience wrapper around vindex().

The following are equivalent:

Vindex(x)[..., i, j, :]
vindex(x, (Ellipsis, i, j, slice(None)))
Parameters:

tensor (jnp.ndarray) – A tensor to be indexed.

Returns:

An object with a special __getitem__() method.

Model Inspection

get_dependencies

get_dependencies(model: Callable, model_args: tuple | None = None, model_kwargs: dict | None = None) dict[str, object][source]

Infers dependency structure about a conditioned model.

This returns a nested dictionary with structure like:

{
    "prior_dependencies": {
        "variable1": {"variable1": set()},
        "variable2": {"variable1": set(), "variable2": set()},
        ...
    },
    "posterior_dependencies": {
        "variable1": {"variable1": {"plate1"}, "variable2": set()},
        ...
    },
}

where

  • prior_dependencies is a dict mapping downstream latent and observed variables to dictionaries mapping upstream latent variables on which they depend to sets of plates inducing full dependencies. That is, included plates introduce quadratically many dependencies as in complete-bipartite graphs, whereas excluded plates introduce only linearly many dependencies as in independent sets of parallel edges. Prior dependencies follow the original model order.

  • posterior_dependencies is a similar dict, but mapping latent variables to the latent or observed sits on which they depend in the posterior. Posterior dependencies are reversed from the model order.

Dependencies elide numpyro.deterministic sites and numpyro.sample(..., Delta(...)) sites.

Examples

Here is a simple example with no plates. We see every node depends on itself, and only the latent variables appear in the posterior:

def model_1():
    a = numpyro.sample("a", dist.Normal(0, 1))
    numpyro.sample("b", dist.Normal(a, 1), obs=0.0)

assert get_dependencies(model_1) == {
    "prior_dependencies": {
        "a": {"a": set()},
        "b": {"a": set(), "b": set()},
    },
    "posterior_dependencies": {
        "a": {"a": set(), "b": set()},
    },
}

Here is an example where two variables a and b start out conditionally independent in the prior, but become conditionally dependent in the posterior do the so-called collider variable c on which they both depend. This is called “moralization” in the graphical model literature:

def model_2():
    a = numpyro.sample("a", dist.Normal(0, 1))
    b = numpyro.sample("b", dist.LogNormal(0, 1))
    c = numpyro.sample("c", dist.Normal(a, b))
    numpyro.sample("d", dist.Normal(c, 1), obs=0.)

assert get_dependencies(model_2) == {
    "prior_dependencies": {
        "a": {"a": set()},
        "b": {"b": set()},
        "c": {"a": set(), "b": set(), "c": set()},
        "d": {"c": set(), "d": set()},
    },
    "posterior_dependencies": {
        "a": {"a": set(), "b": set(), "c": set()},
        "b": {"b": set(), "c": set()},
        "c": {"c": set(), "d": set()},
    },
}

Dependencies can be more complex in the presence of plates. So far all the dict values have been empty sets of plates, but in the following posterior we see that c depends on itself across the plate p. This means that, among the elements of c, e.g. c[0] depends on c[1] (this is why we explicitly allow variables to depend on themselves):

def model_3():
    with numpyro.plate("p", 5):
        a = numpyro.sample("a", dist.Normal(0, 1))
    numpyro.sample("b", dist.Normal(a.sum(), 1), obs=0.0)

assert get_dependencies(model_3) == {
    "prior_dependencies": {
        "a": {"a": set()},
        "b": {"a": set(), "b": set()},
    },
    "posterior_dependencies": {
        "a": {"a": {"p"}, "b": set()},
    },
}
[1] S.Webb, A.Goliński, R.Zinkov, N.Siddharth, T.Rainforth, Y.W.Teh, F.Wood (2018)

“Faithful inversion of generative models for effective amortized inference” https://dl.acm.org/doi/10.5555/3327144.3327229

Parameters:
  • model (callable) – A model.

  • model_args (tuple) – Optional tuple of model args.

  • model_kwargs (dict) – Optional dict of model kwargs.

Returns:

A dictionary of metadata (see above).

Return type:

dict

get_model_relations

get_model_relations(model, model_args=None, model_kwargs=None)[source]

Infer relations of RVs and plates from given model and optionally data. See https://github.com/pyro-ppl/numpyro/issues/949 for more details.

This returns a dictionary with keys:

  • “sample_sample” map each downstream sample site to a list of the upstream sample sites on which it depend;

  • “sample_param” map each downstream sample site to a list of the upstream param sites on which it depend;

  • “sample_dist” maps each sample site to the name of the distribution at that site;

  • “param_constraint” maps each param site to the name of the constraints at that site;

  • “plate_sample” maps each plate name to a lists of the sample sites within that plate; and

  • “observe” is a list of observed sample sites.

For example for the model:

def model(data):
    m = numpyro.sample('m', dist.Normal(0, 1))
    sd = numpyro.sample('sd', dist.LogNormal(m, 1))
    with numpyro.plate('N', len(data)):
        numpyro.sample('obs', dist.Normal(m, sd), obs=data)

the relation is:

{'sample_sample': {'m': [], 'sd': ['m'], 'obs': ['m', 'sd']},
 'sample_dist': {'m': 'Normal', 'sd': 'LogNormal', 'obs': 'Normal'},
 'plate_sample': {'N': ['obs']},
 'observed': ['obs']}
Parameters:
  • model (callable) – A model to inspect.

  • model_args – Optional tuple of model args.

  • model_kwargs – Optional dict of model kwargs.

Return type:

dict

Visualization Utilities

render_model

render_model(model, model_args=None, model_kwargs=None, filename=None, render_distributions=False, render_params=False)[source]

Wrap all functions needed to automatically render a model.

Warning

This utility does not support the scan() primitive. If you want to render a time-series model, you can try to rewrite the code using Python for loop.

Parameters:
  • model – Model to render.

  • model_args – Positional arguments to pass to the model.

  • model_kwargs – Keyword arguments to pass to the model.

  • filename (str) – File to save rendered model in.

  • render_distributions (bool) – Whether to include RV distribution annotations in the plot.

  • render_params (bool) – Whether to show params in the plot.

Trace Inspection

format_shapes(trace, *, compute_log_prob=False, title='Trace Shapes:', last_site=None)[source]

Given the trace of a function, returns a string showing a table of the shapes of all sites in the trace.

Use trace handler (or funsor trace handler for enumeration) to produce the trace.

Parameters:
  • trace (dict) – The model trace to format.

  • compute_log_prob – Compute log probabilities and display the shapes in the table. Accepts True / False or a function which when given a dictionary containing site-level metadata returns whether the log probability should be calculated and included in the table.

  • title (str) – Title for the table of shapes.

  • last_site (str) – Name of a site in the model. If supplied, subsequent sites are not displayed in the table.

Usage:

def model(*args, **kwargs):
    ...

with numpyro.handlers.seed(rng_seed=1):
    trace = numpyro.handlers.trace(model).get_trace(*args, **kwargs)
print(numpyro.util.format_shapes(trace))