Runtime Utilities¶
enable_validation¶
-
enable_validation
(is_validate=True)[source]¶ Enable or disable validation checks in NumPyro. Validation checks provide useful warnings and errors, e.g. NaN checks, validating distribution arguments and support values, etc. which is useful for debugging.
Note
This utility does not take effect under JAX’s JIT compilation or vectorized transformation
jax.vmap()
.Parameters: is_validate (bool) – whether to enable validation checks.
validation_enabled¶
enable_x64¶
set_platform¶
set_host_device_count¶
-
set_host_device_count
(n)[source]¶ By default, XLA considers all CPU cores as one device. This utility tells XLA that there are n host (CPU) devices available to use. As a consequence, this allows parallel mapping in JAX
jax.pmap()
to work in CPU platform.Note
This utility only takes effect at the beginning of your program. Under the hood, this sets the environment variable XLA_FLAGS=–xla_force_host_platform_device_count=[num_devices], where [num_device] is the desired number of CPU devices n.
Warning
Our understanding of the side effects of using the xla_force_host_platform_device_count flag in XLA is incomplete. If you observe some strange phenomenon when using this utility, please let us know through our issue or forum page. More information is available in this JAX issue.
Parameters: n (int) – number of CPU devices to use.
Inference Utilities¶
Predictive¶
-
class
Predictive
(model, posterior_samples=None, guide=None, params=None, num_samples=None, return_sites=None, parallel=False, batch_ndims=1)[source]¶ Bases:
object
This class is used to construct predictive distribution. The predictive distribution is obtained by running model conditioned on latent samples from posterior_samples.
Warning
The interface for the Predictive class is experimental, and might change in the future.
Parameters: - model – Python callable containing Pyro primitives.
- posterior_samples (dict) – dictionary of samples from the posterior.
- guide (callable) – optional guide to get posterior samples of sites not present in posterior_samples.
- params (dict) – dictionary of values for param sites of model/guide.
- num_samples (int) – number of samples
- return_sites (list) – sites to return; by default only sample sites not present in posterior_samples are returned.
- parallel (bool) – whether to predict in parallel using JAX vectorized map
jax.vmap()
. Defaults to False. - batch_ndims –
the number of batch dimensions in posterior samples. Some usages:
- set batch_ndims=0 to get prediction for 1 single sample
- set batch_ndims=1 to get prediction for posterior_samples with shapes (num_samples x …)
- set batch_ndims=2 to get prediction for posterior_samples with shapes (num_chains x N x …). Note that if num_samples argument is not None, its value should be equal to num_chains x N.
Returns: dict of samples from the predictive distribution.
log_density¶
transform_fn¶
-
transform_fn
(transforms, params, invert=False)[source]¶ (EXPERIMENTAL INTERFACE) Callable that applies a transformation from the transforms dict to values in the params dict and returns the transformed values keyed on the same names.
Parameters: - transforms – Dictionary of transforms keyed by names. Names in transforms and params should align.
- params – Dictionary of arrays keyed by names.
- invert – Whether to apply the inverse of the transforms.
Returns: dict of transformed params.
constrain_fn¶
-
constrain_fn
(model, model_args, model_kwargs, params, return_deterministic=False)[source]¶ (EXPERIMENTAL INTERFACE) Gets value at each latent site in model given unconstrained parameters params. The transforms is used to transform these unconstrained parameters to base values of the corresponding priors in model. If a prior is a transformed distribution, the corresponding base value lies in the support of base distribution. Otherwise, the base value lies in the support of the distribution.
Parameters: - model – a callable containing NumPyro primitives.
- model_args (tuple) – args provided to the model.
- model_kwargs (dict) – kwargs provided to the model.
- params (dict) – dictionary of unconstrained values keyed by site names.
- return_deterministic (bool) – whether to return the value of deterministic sites from the model. Defaults to False.
Returns: dict of transformed params.
potential_energy¶
-
potential_energy
(model, model_args, model_kwargs, params, enum=False)[source]¶ (EXPERIMENTAL INTERFACE) Computes potential energy of a model given unconstrained params. Under the hood, we will transform these unconstrained parameters to the values belong to the supports of the corresponding priors in model.
Parameters: Returns: potential energy given unconstrained parameters.
log_likelihood¶
-
log_likelihood
(model, posterior_samples, *args, parallel=False, batch_ndims=1, **kwargs)[source]¶ (EXPERIMENTAL INTERFACE) Returns log likelihood at observation nodes of model, given samples of all latent variables.
Parameters: - model – Python callable containing Pyro primitives.
- posterior_samples (dict) – dictionary of samples from the posterior.
- args – model arguments.
- batch_ndims –
the number of batch dimensions in posterior samples. Some usages:
- set batch_ndims=0 to get log likelihoods for 1 single sample
- set batch_ndims=1 to get log likelihoods for posterior_samples with shapes (num_samples x …)
- set batch_ndims=2 to get log likelihoods for posterior_samples with shapes (num_chains x num_samples x …)
- kwargs – model kwargs.
Returns: dict of log likelihoods at observation sites.
find_valid_initial_params¶
-
find_valid_initial_params
(rng_key, model, init_strategy=<function init_to_uniform>, enum=False, model_args=(), model_kwargs=None, prototype_params=None, forward_mode_differentiation=False)[source]¶ (EXPERIMENTAL INTERFACE) Given a model with Pyro primitives, returns an initial valid unconstrained value for all the parameters. This function also returns the corresponding potential energy, the gradients, and an is_valid flag to say whether the initial parameters are valid. Parameter values are considered valid if the values and the gradients for the log density have finite values.
Parameters: - rng_key (jax.random.PRNGKey) – random number generator seed to
sample from the prior. The returned init_params will have the
batch shape
rng_key.shape[:-1]
. - model – Python callable containing Pyro primitives.
- init_strategy (callable) – a per-site initialization function.
- enum (bool) – whether to enumerate over discrete latent sites.
- model_args (tuple) – args provided to the model.
- model_kwargs (dict) – kwargs provided to the model.
- prototype_params (dict) – an optional prototype parameters, which is used to define the shape for initial parameters.
Returns: tuple of init_params_info and is_valid, where init_params_info is the tuple containing the initial params, their potential energy, and their gradients.
- rng_key (jax.random.PRNGKey) – random number generator seed to
sample from the prior. The returned init_params will have the
batch shape
Initialization Strategies¶
init_to_feasible¶
init_to_median¶
-
init_to_median
(site=None, num_samples=15)[source]¶ Initialize to the prior median. For priors with no .sample method implemented, we defer to the
init_to_uniform()
strategy.Parameters: num_samples (int) – number of prior points to calculate median.
init_to_sample¶
-
init_to_sample
(site=None)[source]¶ Initialize to a prior sample. For priors with no .sample method implemented, we defer to the
init_to_uniform()
strategy.
init_to_uniform¶
init_to_value¶
-
init_to_value
(site=None, values={})[source]¶ Initialize to the value specified in values. We defer to
init_to_uniform()
strategy for sites which do not appear in values.Parameters: values (dict) – dictionary of initial values keyed by site name.
Tensor Indexing¶
-
vindex
(tensor, args)[source]¶ Vectorized advanced indexing with broadcasting semantics.
See also the convenience wrapper
Vindex
.This is useful for writing indexing code that is compatible with batching and enumeration, especially for selecting mixture components with discrete random variables.
For example suppose
x
is a parameter withlen(x.shape) == 3
and we wish to generalize the expressionx[i, :, j]
from integeri,j
to tensorsi,j
with batch dims and enum dims (but no event dims). Then we can write the generalize version usingVindex
xij = Vindex(x)[i, :, j] batch_shape = broadcast_shape(i.shape, j.shape) event_shape = (x.size(1),) assert xij.shape == batch_shape + event_shape
To handle the case when
x
may also contain batch dimensions (e.g. ifx
was sampled in a plated context as when using vectorized particles),vindex()
uses the special convention thatEllipsis
denotes batch dimensions (hence...
can appear only on the left, never in the middle or in the right). Supposex
has event dim 3. Then we can write:old_batch_shape = x.shape[:-3] old_event_shape = x.shape[-3:] xij = Vindex(x)[..., i, :, j] # The ... denotes unknown batch shape. new_batch_shape = broadcast_shape(old_batch_shape, i.shape, j.shape) new_event_shape = (x.size(1),) assert xij.shape = new_batch_shape + new_event_shape
Note that this special handling of
Ellipsis
differs from the NEP [1].Formally, this function assumes:
- Each arg is either
Ellipsis
,slice(None)
, an integer, or a batched integer tensor (i.e. with empty event shape). This function does not support Nontrivial slices or boolean tensor masks.Ellipsis
can only appear on the left asargs[0]
. - If
args[0] is not Ellipsis
thentensor
is not batched, and its event dim is equal tolen(args)
. - If
args[0] is Ellipsis
thentensor
is batched and its event dim is equal tolen(args[1:])
. Dims oftensor
to the left of the event dims are considered batch dims and will be broadcasted with dims of tensor args.
Note that if none of the args is a tensor with
len(shape) > 0
, then this function behaves like standard indexing:if not any(isinstance(a, jnp.ndarray) and len(a.shape) > 0 for a in args): assert Vindex(x)[args] == x[args]
References
- [1] https://www.numpy.org/neps/nep-0021-advanced-indexing.html
- introduces
vindex
as a helper for vectorized indexing. This implementation is similar to the proposed notationx.vindex[]
except for slightly different handling ofEllipsis
.
Parameters: - tensor (jnp.ndarray) – A tensor to be indexed.
- args (tuple) – An index, as args to
__getitem__
.
Returns: A nonstandard interpetation of
tensor[args]
.Return type: jnp.ndarray
- Each arg is either