Runtime Utilities

enable_validation

enable_validation(is_validate=True)[source]

Enable or disable validation checks in NumPyro. Validation checks provide useful warnings and errors, e.g. NaN checks, validating distribution arguments and support values, etc. which is useful for debugging.

Note

This utility does not take effect under JAX’s JIT compilation or vectorized transformation jax.vmap().

Parameters:is_validate (bool) – whether to enable validation checks.

validation_enabled

validation_enabled(is_validate=True)[source]

Context manager that is useful when temporarily enabling/disabling validation checks.

Parameters:is_validate (bool) – whether to enable validation checks.

enable_x64

enable_x64(use_x64=True)[source]

Changes the default array type to use 64 bit precision as in NumPy.

Parameters:use_x64 (bool) – when True, JAX arrays will use 64 bits by default; else 32 bits.

set_platform

set_platform(platform=None)[source]

Changes platform to CPU, GPU, or TPU. This utility only takes effect at the beginning of your program.

Parameters:platform (str) – either ‘cpu’, ‘gpu’, or ‘tpu’.

set_host_device_count

set_host_device_count(n)[source]

By default, XLA considers all CPU cores as one device. This utility tells XLA that there are n host (CPU) devices available to use. As a consequence, this allows parallel mapping in JAX jax.pmap() to work in CPU platform.

Note

This utility only takes effect at the beginning of your program. Under the hood, this sets the environment variable XLA_FLAGS=–xla_force_host_platform_device_count=[num_devices], where [num_device] is the desired number of CPU devices n.

Warning

Our understanding of the side effects of using the xla_force_host_platform_device_count flag in XLA is incomplete. If you observe some strange phenomenon when using this utility, please let us know through our issue or forum page. More information is available in this JAX issue.

Parameters:n (int) – number of CPU devices to use.

Inference Utilities

Predictive

class Predictive(model, posterior_samples=None, guide=None, params=None, num_samples=None, return_sites=None, parallel=False)[source]

Bases: object

This class is used to construct predictive distribution. The predictive distribution is obtained by running model conditioned on latent samples from posterior_samples.

Warning

The interface for the Predictive class is experimental, and might change in the future.

Parameters:
  • model – Python callable containing Pyro primitives.
  • posterior_samples (dict) – dictionary of samples from the posterior.
  • guide (callable) – optional guide to get posterior samples of sites not present in posterior_samples.
  • params (dict) – dictionary of values for param sites of model/guide.
  • num_samples (int) – number of samples
  • return_sites (list) – sites to return; by default only sample sites not present in posterior_samples are returned.
  • parallel (bool) – whether to predict in parallel using JAX vectorized map jax.vmap(). Defaults to False.
Returns:

dict of samples from the predictive distribution.

get_samples(rng_key, *args, **kwargs)[source]

Returns dict of samples from the predictive distribution. By default, only sample sites not contained in posterior_samples are returned. This can be modified by changing the return_sites keyword argument of this Predictive instance.

Parameters:
  • rng_key (jax.random.PRNGKey) – random key to draw samples.
  • args – model arguments.
  • kwargs – model kwargs.

log_density

log_density(model, model_args, model_kwargs, params, skip_dist_transforms=False)[source]

(EXPERIMENTAL INTERFACE) Computes log of joint density for the model given latent values params.

Parameters:
  • model – Python callable containing NumPyro primitives.
  • model_args (tuple) – args provided to the model.
  • model_kwargs (dict) – kwargs provided to the model.
  • params (dict) – dictionary of current parameter values keyed by site name.
  • skip_dist_transforms (bool) – whether to compute log probability of a site (if its prior is a transformed distribution) in its base distribution domain.
Returns:

log of joint density and a corresponding model trace

transform_fn

transform_fn(transforms, params, invert=False)[source]

(EXPERIMENTAL INTERFACE) Callable that applies a transformation from the transforms dict to values in the params dict and returns the transformed values keyed on the same names.

Parameters:
  • transforms – Dictionary of transforms keyed by names. Names in transforms and params should align.
  • params – Dictionary of arrays keyed by names.
  • invert – Whether to apply the inverse of the transforms.
Returns:

dict of transformed params.

constrain_fn

constrain_fn(model, transforms, model_args, model_kwargs, params)[source]

(EXPERIMENTAL INTERFACE) Gets value at each latent site in model given unconstrained parameters params. The transforms is used to transform these unconstrained parameters to base values of the corresponding priors in model. If a prior is a transformed distribution, the corresponding base value lies in the support of base distribution. Otherwise, the base value lies in the support of the distribution.

Parameters:
  • model – a callable containing NumPyro primitives.
  • transforms (dict) – dictionary of transforms keyed by names. Names in transforms and params should align.
  • model_args (tuple) – args provided to the model.
  • model_kwargs (dict) – kwargs provided to the model.
  • params (dict) – dictionary of unconstrained values keyed by site names.
Returns:

dict of transformed params.

potential_energy

potential_energy(model, inv_transforms, model_args, model_kwargs, params)[source]

(EXPERIMENTAL INTERFACE) Computes potential energy of a model given unconstrained params. The inv_transforms is used to transform these unconstrained parameters to base values of the corresponding priors in model. If a prior is a transformed distribution, the corresponding base value lies in the support of base distribution. Otherwise, the base value lies in the support of the distribution.

Parameters:
  • model – a callable containing NumPyro primitives.
  • inv_transforms (dict) – dictionary of transforms keyed by names.
  • model_args (tuple) – args provided to the model.
  • model_kwargs (dict) – kwargs provided to the model.
  • params (dict) – unconstrained parameters of model.
Returns:

potential energy given unconstrained parameters.

log_likelihood

log_likelihood(model, posterior_samples, *args, **kwargs)[source]

(EXPERIMENTAL INTERFACE) Returns log likelihood at observation nodes of model, given samples of all latent variables.

Parameters:
  • model – Python callable containing Pyro primitives.
  • posterior_samples (dict) – dictionary of samples from the posterior.
  • args – model arguments.
  • kwargs – model kwargs.
Returns:

dict of log likelihoods at observation sites.

find_valid_initial_params

find_valid_initial_params(rng_key, model, init_strategy=functools.partial(<function _init_to_uniform>, radius=2), param_as_improper=False, model_args=(), model_kwargs=None)[source]

(EXPERIMENTAL INTERFACE) Given a model with Pyro primitives, returns an initial valid unconstrained value for all the parameters. This function also returns an is_valid flag to say whether the initial parameters are valid. Parameter values are considered valid if the values and the gradients for the log density have finite values.

Parameters:
  • rng_key (jax.random.PRNGKey) – random number generator seed to sample from the prior. The returned init_params will have the batch shape rng_key.shape[:-1].
  • model – Python callable containing Pyro primitives.
  • init_strategy (callable) – a per-site initialization function.
  • param_as_improper (bool) – a flag to decide whether to consider sites with param statement as sites with improper priors.
  • model_args (tuple) – args provided to the model.
  • model_kwargs (dict) – kwargs provided to the model.
Returns:

tuple of (init_params, is_valid).

Initialization Strategies

init_to_median

init_to_median(num_samples=15)[source]

Initialize to the prior median.

Parameters:num_samples (int) – number of prior points to calculate median.

init_to_prior

init_to_prior()[source]

Initialize to a prior sample.

init_to_uniform

init_to_uniform(radius=2)[source]

Initialize to a random point in the area (-radius, radius) of unconstrained domain.

Parameters:radius (float) – specifies the range to draw an initial point in the unconstrained domain.

init_to_feasible

init_to_feasible()[source]

Initialize to an arbitrary feasible point, ignoring distribution parameters.

init_to_value

init_to_value(values)[source]

Initialize to the value specified in values. We defer to init_to_uniform() strategy for sites which do not appear in values.

Parameters:values (dict) – dictionary of initial values keyed by site name.