NumPyro documentation¶
Pyro Primitives¶
param¶
-
param
(name, init_value=None, **kwargs)[source]¶ Annotate the given site as an optimizable parameter for use with
jax.experimental.optimizers
. For an example of how param statements can be used in inference algorithms, refer tosvi()
.Parameters: - name (str) – name of site.
- init_value (numpy.ndarray) – initial value specified by the user. Note that the onus of using this to initialize the optimizer is on the user / inference algorithm, since there is no global parameter store in NumPyro.
Returns: value for the parameter. Unless wrapped inside a handler like
substitute
, this will simply return the initial value.
sample¶
-
sample
(name, fn, obs=None, rng_key=None, sample_shape=())[source]¶ Returns a random sample from the stochastic function fn. This can have additional side effects when wrapped inside effect handlers like
substitute
.Note
By design, sample primitive is meant to be used inside a NumPyro model. Then
seed
handler is used to inject a random state to fn. In those situations, rng_key keyword will take no effect.Parameters: - name (str) – name of the sample site
- fn – Python callable
- obs (numpy.ndarray) – observed value
- rng_key (jax.random.PRNGKey) – an optional random key for fn.
- sample_shape – Shape of samples to be drawn.
Returns: sample from the stochastic fn.
plate¶
-
class
plate
(name, size, subsample_size=None, dim=None)[source]¶ Construct for annotating conditionally independent variables. Within a plate context manager, sample sites will be automatically broadcasted to the size of the plate. Additionally, a scale factor might be applied by certain inference algorithms if subsample_size is specified.
Parameters: - name (str) – Name of the plate.
- size (int) – Size of the plate.
- subsample_size (int) – Optional argument denoting the size of the mini-batch. This can be used to apply a scaling factor by inference algorithms. e.g. when computing ELBO using a mini-batch.
- dim (int) – Optional argument to specify which dimension in the tensor is used as the plate dim. If None (default), the leftmost available dim is allocated.
factor¶
-
factor
(name, log_factor)[source]¶ Factor statement to add arbitrary log probability factor to a probabilistic model.
Parameters: - name (str) – Name of the trivial sample.
- log_factor (numpy.ndarray) – A possibly batched log probability factor.
module¶
-
module
(name, nn, input_shape=None)[source]¶ Declare a
stax
style neural network inside a model so that its parameters are registered for optimization viaparam()
statements.Parameters: Returns: a apply_fn with bound parameters that takes an array as an input and returns the neural network transformed output array.
Effect Handlers¶
This provides a small set of effect handlers in NumPyro that are modeled after Pyro’s poutine module. For a tutorial on effect handlers more generally, readers are encouraged to read Poutine: A Guide to Programming with Effect Handlers in Pyro. These simple effect handlers can be composed together or new ones added to enable implementation of custom inference utilities and algorithms.
Example
As an example, we are using seed
, trace
and substitute
handlers to define the log_likelihood function below.
We first create a logistic regression model and sample from the posterior distribution over
the regression parameters using MCMC()
. The log_likelihood function
uses effect handlers to run the model by substituting sample sites with values from the posterior
distribution and computes the log density for a single data point. The log_predictive_density
function computes the log likelihood for each draw from the joint posterior and aggregates the
results for all the data points, but does so by using JAX’s auto-vectorize transform called
vmap so that we do not need to loop over all the data points.
>>> import jax.numpy as np
>>> from jax import random, vmap
>>> from jax.scipy.special import logsumexp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro import handlers
>>> from numpyro.infer import MCMC, NUTS
>>> N, D = 3000, 3
>>> def logistic_regression(data, labels):
... coefs = numpyro.sample('coefs', dist.Normal(np.zeros(D), np.ones(D)))
... intercept = numpyro.sample('intercept', dist.Normal(0., 10.))
... logits = np.sum(coefs * data + intercept, axis=-1)
... return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)
>>> data = random.normal(random.PRNGKey(0), (N, D))
>>> true_coefs = np.arange(1., D + 1.)
>>> logits = np.sum(true_coefs * data, axis=-1)
>>> labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))
>>> num_warmup, num_samples = 1000, 1000
>>> mcmc = MCMC(NUTS(model=logistic_regression), num_warmup, num_samples)
>>> mcmc.run(random.PRNGKey(2), data, labels)
sample: 100%|██████████| 1000/1000 [00:00<00:00, 1252.39it/s, 1 steps of size 5.83e-01. acc. prob=0.85]
>>> mcmc.print_summary()
mean sd 5.5% 94.5% n_eff Rhat
coefs[0] 0.96 0.07 0.85 1.07 455.35 1.01
coefs[1] 2.05 0.09 1.91 2.20 332.00 1.01
coefs[2] 3.18 0.13 2.96 3.37 320.27 1.00
intercept -0.03 0.02 -0.06 0.00 402.53 1.00
>>> def log_likelihood(rng_key, params, model, *args, **kwargs):
... model = handlers.substitute(handlers.seed(model, rng_key), params)
... model_trace = handlers.trace(model).get_trace(*args, **kwargs)
... obs_node = model_trace['obs']
... return obs_node['fn'].log_prob(obs_node['value'])
>>> def log_predictive_density(rng_key, params, model, *args, **kwargs):
... n = list(params.values())[0].shape[0]
... log_lk_fn = vmap(lambda rng_key, params: log_likelihood(rng_key, params, model, *args, **kwargs))
... log_lk_vals = log_lk_fn(random.split(rng_key, n), params)
... return np.sum(logsumexp(log_lk_vals, 0) - np.log(n))
>>> print(log_predictive_density(random.PRNGKey(2), mcmc.get_samples(),
... logistic_regression, data, labels))
-874.89813
block¶
-
class
block
(fn=None, hide_fn=<function block.<lambda>>)[source]¶ Bases:
numpyro.primitives.Messenger
Given a callable fn, return another callable that selectively hides primitive sites where hide_fn returns True from other effect handlers on the stack.
Parameters: - fn – Python callable with NumPyro primitives.
- hide_fn – function which when given a dictionary containing site-level metadata returns whether it should be blocked.
Example:
>>> from jax import random >>> import numpyro >>> from numpyro.handlers import block, seed, trace >>> import numpyro.distributions as dist >>> def model(): ... a = numpyro.sample('a', dist.Normal(0., 1.)) ... return numpyro.sample('b', dist.Normal(a, 1.)) >>> model = seed(model, random.PRNGKey(0)) >>> block_all = block(model) >>> block_a = block(model, lambda site: site['name'] == 'a') >>> trace_block_all = trace(block_all).get_trace() >>> assert not {'a', 'b'}.intersection(trace_block_all.keys()) >>> trace_block_a = trace(block_a).get_trace() >>> assert 'a' not in trace_block_a >>> assert 'b' in trace_block_a
condition¶
-
class
condition
(fn=None, param_map=None, substitute_fn=None)[source]¶ Bases:
numpyro.primitives.Messenger
Conditions unobserved sample sites to values from param_map or condition_fn. Similar to
substitute
except that it only affects sample sites and changes the is_observed property to True.Parameters: - fn – Python callable with NumPyro primitives.
- param_map (dict) – dictionary of numpy.ndarray values keyed by site names.
- condition_fn – callable that takes in a site dict and returns a numpy array or None (in which case the handler has no side effect).
Example:
>>> from jax import random >>> import numpyro >>> from numpyro.handlers import condition, seed, substitute, trace >>> import numpyro.distributions as dist >>> def model(): ... numpyro.sample('a', dist.Normal(0., 1.)) >>> model = seed(model, random.PRNGKey(0)) >>> exec_trace = trace(condition(model, {'a': -1})).get_trace() >>> assert exec_trace['a']['value'] == -1 >>> assert exec_trace['a']['is_observed']
replay¶
-
class
replay
(fn, guide_trace)[source]¶ Bases:
numpyro.primitives.Messenger
Given a callable fn and an execution trace guide_trace, return a callable which substitutes sample calls in fn with values from the corresponding site names in guide_trace.
Parameters: - fn – Python callable with NumPyro primitives.
- guide_trace – an OrderedDict containing execution metadata.
Example
>>> from jax import random >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.handlers import replay, seed, trace >>> def model(): ... numpyro.sample('a', dist.Normal(0., 1.)) >>> exec_trace = trace(seed(model, random.PRNGKey(0))).get_trace() >>> print(exec_trace['a']['value']) -0.20584235 >>> replayed_trace = trace(replay(model, exec_trace)).get_trace() >>> print(exec_trace['a']['value']) -0.20584235 >>> assert replayed_trace['a']['value'] == exec_trace['a']['value']
scale¶
-
class
scale
(fn=None, scale_factor=1.0)[source]¶ Bases:
numpyro.primitives.Messenger
This messenger rescales the log probability score.
This is typically used for data subsampling or for stratified sampling of data (e.g. in fraud detection where negatives vastly outnumber positives).
Parameters: scale_factor (float) – a positive scaling factor
seed¶
-
class
seed
(fn=None, rng_seed=None, rng=None)[source]¶ Bases:
numpyro.primitives.Messenger
JAX uses a functional pseudo random number generator that requires passing in a seed
PRNGKey()
to every stochastic function. The seed handler allows us to initially seed a stochastic function with aPRNGKey()
. Every call to thesample()
primitive inside the function results in a splitting of this initial seed so that we use a fresh seed for each subsequent call without having to explicitly pass in a PRNGKey to each sample call.Parameters: - fn – Python callable with NumPyro primitives.
- rng_seed (int, np.ndarray scalar, or jax.random.PRNGKey) – a random number generator seed.
Note
Unlike in Pyro, numpyro.sample primitive cannot be used without wrapping it in seed handler since there is no global random state. As such, users need to use seed as a contextmanager to generate samples from distributions or as a decorator for their model callable (See below).
Example:
>>> from jax import random >>> import numpyro >>> import numpyro.handlers >>> import numpyro.distributions as dist >>> # as context manager >>> with handlers.seed(rng_seed=1): ... x = numpyro.sample('x', dist.Normal(0., 1.)) >>> def model(): ... return numpyro.sample('y', dist.Normal(0., 1.)) >>> # as function decorator (/modifier) >>> y = handlers.seed(model, rng_seed=1)() >>> assert x == y
substitute¶
-
class
substitute
(fn=None, param_map=None, base_param_map=None, substitute_fn=None)[source]¶ Bases:
numpyro.primitives.Messenger
Given a callable fn and a dict param_map keyed by site names (alternatively, a callable substitute_fn), return a callable which substitutes all primitive calls in fn with values from param_map whose key matches the site name. If the site name is not present in param_map, there is no side effect.
If a substitute_fn is provided, then the value at the site is replaced by the value returned from the call to substitute_fn for the given site.
Parameters: - fn – Python callable with NumPyro primitives.
- param_map (dict) – dictionary of numpy.ndarray values keyed by site names.
- base_param_map (dict) – similar to param_map but only holds samples from base distributions.
- substitute_fn – callable that takes in a site dict and returns a numpy array or None (in which case the handler has no side effect).
Example:
>>> from jax import random >>> import numpyro >>> from numpyro.handlers import seed, substitute, trace >>> import numpyro.distributions as dist >>> def model(): ... numpyro.sample('a', dist.Normal(0., 1.)) >>> model = seed(model, random.PRNGKey(0)) >>> exec_trace = trace(substitute(model, {'a': -1})).get_trace() >>> assert exec_trace['a']['value'] == -1
trace¶
-
class
trace
(fn=None)[source]¶ Bases:
numpyro.primitives.Messenger
Returns a handler that records the inputs and outputs at primitive calls inside fn.
Example
>>> from jax import random >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.handlers import seed, trace >>> import pprint as pp >>> def model(): ... numpyro.sample('a', dist.Normal(0., 1.)) >>> exec_trace = trace(seed(model, random.PRNGKey(0))).get_trace() >>> pp.pprint(exec_trace) OrderedDict([('a', {'args': (), 'fn': <numpyro.distributions.continuous.Normal object at 0x7f9e689b1eb8>, 'is_observed': False, 'kwargs': {'rng_key': DeviceArray([0, 0], dtype=uint32)}, 'name': 'a', 'type': 'sample', 'value': DeviceArray(-0.20584235, dtype=float32)})])
Base Distribution¶
Distribution¶
-
class
Distribution
(batch_shape=(), event_shape=(), validate_args=None)[source]¶ Bases:
object
Base class for probability distributions in NumPyro. The design largely follows from
torch.distributions
.Parameters: - batch_shape – The batch shape for the distribution. This designates independent (possibly non-identical) dimensions of a sample from the distribution. This is fixed for a distribution instance and is inferred from the shape of the distribution parameters.
- event_shape – The event shape for the distribution. This designates the dependent dimensions of a sample from the distribution. These are collapsed when we evaluate the log probability density of a batch of samples using .log_prob.
- validate_args – Whether to enable validation of distribution parameters and arguments to .log_prob method.
As an example:
>>> import jax.numpy as np >>> import numpyro.distributions as dist >>> d = dist.Dirichlet(np.ones((2, 3, 4))) >>> d.batch_shape (2, 3) >>> d.event_shape (4,)
-
arg_constraints
= {}¶
-
support
= None¶
-
reparametrized_params
= []¶
-
batch_shape
¶ Returns the shape over which the distribution parameters are batched.
Returns: batch shape of the distribution. Return type: tuple
-
event_shape
¶ Returns the shape of a single sample from the distribution without batching.
Returns: event shape of the distribution. Return type: tuple
-
sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
sample_with_intermediates
(key, sample_shape=())[source]¶ Same as
sample
except that any intermediate computations are returned (useful for TransformedDistribution).Parameters: - key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
- sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:
-
log_prob
(value)[source]¶ Evaluates the log probability density for a batch of samples given by value.
Parameters: value – A batch of samples from the distribution. Returns: an array with shape value.shape[:-self.event_shape] Return type: numpy.ndarray
-
mean
¶ Mean of the distribution.
-
variance
¶ Variance of the distribution.
-
to_event
(reinterpreted_batch_ndims=None)[source]¶ Interpret the rightmost reinterpreted_batch_ndims batch dimensions as dependent event dimensions.
Parameters: reinterpreted_batch_ndims – Number of rightmost batch dims to interpret as event dims. Returns: An instance of Independent distribution. Return type: Independent
Independent¶
-
class
Independent
(base_dist, reinterpreted_batch_ndims, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Reinterprets batch dimensions of a distribution as event dims by shifting the batch-event dim boundary further to the left.
From a practical standpoint, this is useful when changing the result of
log_prob()
. For example, a univariate Normal distribution can be interpreted as a multivariate Normal with diagonal covariance:>>> import numpyro.distributions as dist >>> normal = dist.Normal(np.zeros(3), np.ones(3)) >>> [normal.batch_shape, normal.event_shape] [(3,), ()] >>> diag_normal = dist.Independent(normal, 1) >>> [diag_normal.batch_shape, diag_normal.event_shape] [(), (3,)]
Parameters: - base_distribution (numpyro.distribution.Distribution) – a distribution instance.
- reinterpreted_batch_ndims (int) – the number of batch dims to reinterpret as event dims.
-
arg_constraints
= {}¶
-
support
¶
-
reparameterized_params
¶
-
mean
¶
-
variance
¶
TransformedDistribution¶
-
class
TransformedDistribution
(base_distribution, transforms, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Returns a distribution instance obtained as a result of applying a sequence of transforms to a base distribution. For an example, see
LogNormal
andHalfNormal
.Parameters: - base_distribution – the base distribution over which to apply transforms.
- transforms – a single transform or a list of transforms.
- validate_args – Whether to enable validation of distribution parameters and arguments to .log_prob method.
-
arg_constraints
= {}¶
-
support
¶
-
sample
(key, sample_shape=())[source]¶ See
numpyro.distributions.distribution.Distribution.sample()
-
sample_with_intermediates
(key, sample_shape=())[source]¶ See
numpyro.distributions.distribution.Distribution.sample_with_intermediates()
-
log_prob
(*args, **kwargs)¶ See
numpyro.distributions.distribution.Distribution.log_prob()
-
mean
¶
-
variance
¶
Unit¶
-
class
Unit
(log_factor, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Trivial nonnormalized distribution representing the unit type.
The unit type has a single value with no data, i.e.
value.size == 0
.This is used for
numpyro.factor()
statements.-
arg_constraints
= {'log_factor': <numpyro.distributions.constraints._Real object>}¶
-
support
= <numpyro.distributions.constraints._Real object>¶
-
Continuous Distributions¶
Beta¶
-
class
Beta
(concentration1, concentration0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'concentration0': <numpyro.distributions.constraints._GreaterThan object>, 'concentration1': <numpyro.distributions.constraints._GreaterThan object>}¶
-
support
= <numpyro.distributions.constraints._Interval object>¶
-
sample
(key, sample_shape=())[source]¶ See
numpyro.distributions.distribution.Distribution.sample()
-
log_prob
(*args, **kwargs)¶ See
numpyro.distributions.distribution.Distribution.log_prob()
-
Cauchy¶
-
class
Cauchy
(loc=0.0, scale=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
-
support
= <numpyro.distributions.constraints._Real object>¶
-
reparametrized_params
= ['loc', 'scale']¶
-
sample
(key, sample_shape=())[source]¶ See
numpyro.distributions.distribution.Distribution.sample()
-
log_prob
(*args, **kwargs)¶ See
numpyro.distributions.distribution.Distribution.log_prob()
-
Chi2¶
-
class
Chi2
(df, validate_args=None)[source]¶ Bases:
numpyro.distributions.continuous.Gamma
-
arg_constraints
= {'df': <numpyro.distributions.constraints._GreaterThan object>}¶
-
Dirichlet¶
-
class
Dirichlet
(concentration, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'concentration': <numpyro.distributions.constraints._GreaterThan object>}¶
-
support
= <numpyro.distributions.constraints._Simplex object>¶
-
sample
(key, sample_shape=())[source]¶ See
numpyro.distributions.distribution.Distribution.sample()
-
log_prob
(*args, **kwargs)¶ See
numpyro.distributions.distribution.Distribution.log_prob()
-
Exponential¶
-
class
Exponential
(rate=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
reparametrized_params
= ['rate']¶
-
arg_constraints
= {'rate': <numpyro.distributions.constraints._GreaterThan object>}¶
-
support
= <numpyro.distributions.constraints._GreaterThan object>¶
-
sample
(key, sample_shape=())[source]¶ See
numpyro.distributions.distribution.Distribution.sample()
-
log_prob
(*args, **kwargs)¶ See
numpyro.distributions.distribution.Distribution.log_prob()
-
Gamma¶
-
class
Gamma
(concentration, rate=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'concentration': <numpyro.distributions.constraints._GreaterThan object>, 'rate': <numpyro.distributions.constraints._GreaterThan object>}¶
-
support
= <numpyro.distributions.constraints._GreaterThan object>¶
-
reparametrized_params
= ['rate']¶
-
sample
(key, sample_shape=())[source]¶ See
numpyro.distributions.distribution.Distribution.sample()
-
log_prob
(*args, **kwargs)¶ See
numpyro.distributions.distribution.Distribution.log_prob()
-
GaussianRandomWalk¶
-
class
GaussianRandomWalk
(scale=1.0, num_steps=1, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'num_steps': <numpyro.distributions.constraints._IntegerGreaterThan object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
-
support
= <numpyro.distributions.constraints._RealVector object>¶
-
reparametrized_params
= ['scale']¶
-
sample
(key, sample_shape=())[source]¶ See
numpyro.distributions.distribution.Distribution.sample()
-
log_prob
(*args, **kwargs)¶ See
numpyro.distributions.distribution.Distribution.log_prob()
-
HalfCauchy¶
-
class
HalfCauchy
(scale=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
reparametrized_params
= ['scale']¶
-
support
= <numpyro.distributions.constraints._GreaterThan object>¶
-
arg_constraints
= {'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
-
sample
(key, sample_shape=())[source]¶ See
numpyro.distributions.distribution.Distribution.sample()
-
log_prob
(*args, **kwargs)¶ See
numpyro.distributions.distribution.Distribution.log_prob()
-
HalfNormal¶
-
class
HalfNormal
(scale=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
reparametrized_params
= ['scale']¶
-
support
= <numpyro.distributions.constraints._GreaterThan object>¶
-
arg_constraints
= {'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
-
sample
(key, sample_shape=())[source]¶ See
numpyro.distributions.distribution.Distribution.sample()
-
log_prob
(*args, **kwargs)¶ See
numpyro.distributions.distribution.Distribution.log_prob()
-
InverseGamma¶
-
class
InverseGamma
(concentration, rate=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.TransformedDistribution
-
arg_constraints
= {'concentration': <numpyro.distributions.constraints._GreaterThan object>, 'rate': <numpyro.distributions.constraints._GreaterThan object>}¶
-
support
= <numpyro.distributions.constraints._GreaterThan object>¶
-
reparametrized_params
= ['rate']¶
-
LKJ¶
-
class
LKJ
(dimension, concentration=1.0, sample_method='onion', validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.TransformedDistribution
LKJ distribution for correlation matrices. The distribution is controlled by
concentration
parameter \(\eta\) to make the probability of the correlation matrix \(M\) propotional to \(\det(M)^{\eta - 1}\). Because of that, whenconcentration == 1
, we have a uniform distribution over correlation matrices.When
concentration > 1
, the distribution favors samples with large large determinent. This is useful when we know a priori that the underlying variables are not correlated.When
concentration < 1
, the distribution favors samples with small determinent. This is useful when we know a priori that some underlying variables are correlated.Parameters: - dimension (int) – dimension of the matrices
- concentration (ndarray) – concentration/shape parameter of the distribution (often referred to as eta)
- sample_method (str) – Either “cvine” or “onion”. Both methods are proposed in [1] and offer the same distribution over correlation matrices. But they are different in how to generate samples. Defaults to “onion”.
References
[1] Generating random correlation matrices based on vines and extended onion method, Daniel Lewandowski, Dorota Kurowicka, Harry Joe
-
arg_constraints
= {'concentration': <numpyro.distributions.constraints._GreaterThan object>}¶
-
support
= <numpyro.distributions.constraints._CorrMatrix object>¶
LKJCholesky¶
-
class
LKJCholesky
(dimension, concentration=1.0, sample_method='onion', validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
LKJ distribution for lower Cholesky factors of correlation matrices. The distribution is controlled by
concentration
parameter \(\eta\) to make the probability of the correlation matrix \(M\) generated from a Cholesky factor propotional to \(\det(M)^{\eta - 1}\). Because of that, whenconcentration == 1
, we have a uniform distribution over Cholesky factors of correlation matrices.When
concentration > 1
, the distribution favors samples with large diagonal entries (hence large determinent). This is useful when we know a priori that the underlying variables are not correlated.When
concentration < 1
, the distribution favors samples with small diagonal entries (hence small determinent). This is useful when we know a priori that some underlying variables are correlated.Parameters: - dimension (int) – dimension of the matrices
- concentration (ndarray) – concentration/shape parameter of the distribution (often referred to as eta)
- sample_method (str) – Either “cvine” or “onion”. Both methods are proposed in [1] and offer the same distribution over correlation matrices. But they are different in how to generate samples. Defaults to “onion”.
References
[1] Generating random correlation matrices based on vines and extended onion method, Daniel Lewandowski, Dorota Kurowicka, Harry Joe
-
arg_constraints
= {'concentration': <numpyro.distributions.constraints._GreaterThan object>}¶
-
support
= <numpyro.distributions.constraints._CorrCholesky object>¶
-
sample
(key, sample_shape=())[source]¶ See
numpyro.distributions.distribution.Distribution.sample()
-
log_prob
(*args, **kwargs)¶ See
numpyro.distributions.distribution.Distribution.log_prob()
LogNormal¶
-
class
LogNormal
(loc=0.0, scale=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.TransformedDistribution
-
arg_constraints
= {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
-
reparametrized_params
= ['loc', 'scale']¶
-
MultivariateNormal¶
-
class
MultivariateNormal
(loc=0.0, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'covariance_matrix': <numpyro.distributions.constraints._PositiveDefinite object>, 'loc': <numpyro.distributions.constraints._RealVector object>, 'precision_matrix': <numpyro.distributions.constraints._PositiveDefinite object>, 'scale_tril': <numpyro.distributions.constraints._LowerCholesky object>}¶
-
support
= <numpyro.distributions.constraints._RealVector object>¶
-
reparametrized_params
= ['loc', 'covariance_matrix', 'precision_matrix', 'scale_tril']¶
-
sample
(key, sample_shape=())[source]¶ See
numpyro.distributions.distribution.Distribution.sample()
-
log_prob
(*args, **kwargs)¶ See
numpyro.distributions.distribution.Distribution.log_prob()
-
LowRankMultivariateNormal¶
-
class
LowRankMultivariateNormal
(loc, cov_factor, cov_diag, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'cov_diag': <numpyro.distributions.constraints._GreaterThan object>, 'cov_factor': <numpyro.distributions.constraints._Real object>, 'loc': <numpyro.distributions.constraints._RealVector object>}¶
-
support
= <numpyro.distributions.constraints._RealVector object>¶
-
sample
(key, sample_shape=())[source]¶ See
numpyro.distributions.distribution.Distribution.sample()
-
log_prob
(*args, **kwargs)¶ See
numpyro.distributions.distribution.Distribution.log_prob()
-
Normal¶
-
class
Normal
(loc=0.0, scale=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
-
support
= <numpyro.distributions.constraints._Real object>¶
-
reparametrized_params
= ['loc', 'scale']¶
-
sample
(key, sample_shape=())[source]¶ See
numpyro.distributions.distribution.Distribution.sample()
-
log_prob
(*args, **kwargs)¶ See
numpyro.distributions.distribution.Distribution.log_prob()
-
Pareto¶
-
class
Pareto
(alpha, scale=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.TransformedDistribution
-
arg_constraints
= {'alpha': <numpyro.distributions.constraints._GreaterThan object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
-
support
¶
-
StudentT¶
-
class
StudentT
(df, loc=0.0, scale=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'df': <numpyro.distributions.constraints._GreaterThan object>, 'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
-
support
= <numpyro.distributions.constraints._Real object>¶
-
reparametrized_params
= ['loc', 'scale']¶
-
sample
(key, sample_shape=())[source]¶ See
numpyro.distributions.distribution.Distribution.sample()
-
log_prob
(*args, **kwargs)¶ See
numpyro.distributions.distribution.Distribution.log_prob()
-
TruncatedCauchy¶
-
class
TruncatedCauchy
(low=0.0, loc=0.0, scale=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.TransformedDistribution
-
arg_constraints
= {'loc': <numpyro.distributions.constraints._Real object>, 'low': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
-
reparametrized_params
= ['low', 'loc', 'scale']¶
-
TruncatedNormal¶
-
class
TruncatedNormal
(low=0.0, loc=0.0, scale=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.TransformedDistribution
-
arg_constraints
= {'loc': <numpyro.distributions.constraints._Real object>, 'low': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
-
reparametrized_params
= ['low', 'loc', 'scale']¶
-
Uniform¶
-
class
Uniform
(low=0.0, high=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.TransformedDistribution
-
arg_constraints
= {'high': <numpyro.distributions.constraints._Dependent object>, 'low': <numpyro.distributions.constraints._Dependent object>}¶
-
reparametrized_params
= ['low', 'high']¶
-
Discrete Distributions¶
BernoulliLogits¶
-
class
BernoulliLogits
(logits=None, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'logits': <numpyro.distributions.constraints._Real object>}¶
-
support
= <numpyro.distributions.constraints._Boolean object>¶
-
sample
(key, sample_shape=())[source]¶ See
numpyro.distributions.distribution.Distribution.sample()
-
log_prob
(*args, **kwargs)¶ See
numpyro.distributions.distribution.Distribution.log_prob()
-
BernoulliProbs¶
-
class
BernoulliProbs
(probs, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'probs': <numpyro.distributions.constraints._Interval object>}¶
-
support
= <numpyro.distributions.constraints._Boolean object>¶
-
sample
(key, sample_shape=())[source]¶ See
numpyro.distributions.distribution.Distribution.sample()
-
log_prob
(*args, **kwargs)¶ See
numpyro.distributions.distribution.Distribution.log_prob()
-
BetaBinomial¶
-
class
BetaBinomial
(concentration1, concentration0, total_count=1, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Compound distribution comprising of a beta-binomial pair. The probability of success (
probs
for theBinomial
distribution) is unknown and randomly drawn from aBeta
distribution prior to a certain number of Bernoulli trials given bytotal_count
.Parameters: - concentration1 (numpy.ndarray) – 1st concentration parameter (alpha) for the Beta distribution.
- concentration0 (numpy.ndarray) – 2nd concentration parameter (beta) for the Beta distribution.
- total_count (numpy.ndarray) – number of Bernoulli trials.
-
arg_constraints
= {'concentration0': <numpyro.distributions.constraints._GreaterThan object>, 'concentration1': <numpyro.distributions.constraints._GreaterThan object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}¶
-
log_prob
(*args, **kwargs)¶
-
mean
¶
-
variance
¶
-
support
¶
BinomialLogits¶
-
class
BinomialLogits
(logits, total_count=1, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'logits': <numpyro.distributions.constraints._Real object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}¶
-
sample
(key, sample_shape=())[source]¶ See
numpyro.distributions.distribution.Distribution.sample()
-
log_prob
(*args, **kwargs)¶ See
numpyro.distributions.distribution.Distribution.log_prob()
-
support
¶
-
BinomialProbs¶
-
class
BinomialProbs
(probs, total_count=1, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'probs': <numpyro.distributions.constraints._Interval object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}¶
-
sample
(key, sample_shape=())[source]¶ See
numpyro.distributions.distribution.Distribution.sample()
-
log_prob
(*args, **kwargs)¶ See
numpyro.distributions.distribution.Distribution.log_prob()
-
support
¶
-
CategoricalLogits¶
-
class
CategoricalLogits
(logits, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'logits': <numpyro.distributions.constraints._Real object>}¶
-
sample
(key, sample_shape=())[source]¶ See
numpyro.distributions.distribution.Distribution.sample()
-
log_prob
(*args, **kwargs)¶ See
numpyro.distributions.distribution.Distribution.log_prob()
-
support
¶
-
CategoricalProbs¶
-
class
CategoricalProbs
(probs, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'probs': <numpyro.distributions.constraints._Simplex object>}¶
-
sample
(key, sample_shape=())[source]¶ See
numpyro.distributions.distribution.Distribution.sample()
-
log_prob
(*args, **kwargs)¶ See
numpyro.distributions.distribution.Distribution.log_prob()
-
support
¶
-
Delta¶
-
class
Delta
(value=0.0, log_density=0.0, event_ndim=0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'log_density': <numpyro.distributions.constraints._Real object>, 'value': <numpyro.distributions.constraints._Real object>}¶
-
support
= <numpyro.distributions.constraints._Real object>¶
-
sample
(key, sample_shape=())[source]¶ See
numpyro.distributions.distribution.Distribution.sample()
-
log_prob
(*args, **kwargs)¶ See
numpyro.distributions.distribution.Distribution.log_prob()
-
GammaPoisson¶
-
class
GammaPoisson
(concentration, rate=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Compound distribution comprising of a gamma-poisson pair, also referred to as a gamma-poisson mixture. The
rate
parameter for thePoisson
distribution is unknown and randomly drawn from aGamma
distribution.Parameters: - concentration (numpy.ndarray) – shape parameter (alpha) of the Gamma distribution.
- rate (numpy.ndarray) – rate parameter (beta) for the Gamma distribution.
-
arg_constraints
= {'concentration': <numpyro.distributions.constraints._GreaterThan object>, 'rate': <numpyro.distributions.constraints._GreaterThan object>}¶
-
support
= <numpyro.distributions.constraints._IntegerGreaterThan object>¶
-
log_prob
(*args, **kwargs)¶
-
mean
¶
-
variance
¶
MultinomialLogits¶
-
class
MultinomialLogits
(logits, total_count=1, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'logits': <numpyro.distributions.constraints._Real object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}¶
-
sample
(key, sample_shape=())[source]¶ See
numpyro.distributions.distribution.Distribution.sample()
-
log_prob
(*args, **kwargs)¶ See
numpyro.distributions.distribution.Distribution.log_prob()
-
support
¶
-
MultinomialProbs¶
-
class
MultinomialProbs
(probs, total_count=1, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'probs': <numpyro.distributions.constraints._Simplex object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}¶
-
sample
(key, sample_shape=())[source]¶ See
numpyro.distributions.distribution.Distribution.sample()
-
log_prob
(*args, **kwargs)¶ See
numpyro.distributions.distribution.Distribution.log_prob()
-
support
¶
-
OrderedLogistic¶
-
class
OrderedLogistic
(predictor, cutpoints, validate_args=None)[source]¶ Bases:
numpyro.distributions.discrete.CategoricalProbs
A categorical distribution with ordered outcomes.
References:
- Stan Functions Reference, v2.20 section 12.6, Stan Development Team
Parameters: - predictor (numpy.ndarray) – prediction in real domain; typically this is output of a linear model.
- cutpoints (numpy.ndarray) – positions in real domain to separate categories.
-
arg_constraints
= {'cutpoints': <numpyro.distributions.constraints._OrderedVector object>, 'predictor': <numpyro.distributions.constraints._Real object>}¶
Poisson¶
-
class
Poisson
(rate, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
-
arg_constraints
= {'rate': <numpyro.distributions.constraints._GreaterThan object>}¶
-
support
= <numpyro.distributions.constraints._IntegerGreaterThan object>¶
-
sample
(key, sample_shape=())[source]¶ See
numpyro.distributions.distribution.Distribution.sample()
-
log_prob
(*args, **kwargs)¶ See
numpyro.distributions.distribution.Distribution.log_prob()
-
PRNGIdentity¶
ZeroInflatedPoisson¶
-
class
ZeroInflatedPoisson
(gate, rate=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
A Zero Inflated Poisson distribution.
Parameters: - gate (numpy.ndarray) – probability of extra zeros.
- rate (numpy.ndarray) – rate of Poisson distribution.
-
arg_constraints
= {'gate': <numpyro.distributions.constraints._Interval object>, 'rate': <numpyro.distributions.constraints._GreaterThan object>}¶
-
support
= <numpyro.distributions.constraints._IntegerGreaterThan object>¶
-
log_prob
(*args, **kwargs)¶
Constraints¶
nonnegative_integer¶
-
nonnegative_integer
= <numpyro.distributions.constraints._IntegerGreaterThan object>¶
positive_definite¶
-
positive_definite
= <numpyro.distributions.constraints._PositiveDefinite object>¶
Transforms¶
Transform¶
AbsTransform¶
AffineTransform¶
ComposeTransform¶
CorrCholeskyTransform¶
-
class
CorrCholeskyTransform
[source]¶ Bases:
numpyro.distributions.transforms.Transform
Transforms a uncontrained real vector \(x\) with length \(D*(D-1)/2\) into the Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower triangular matrix with positive diagonals and unit Euclidean norm for each row. The transform is processed as follows:
- First we convert \(x\) into a lower triangular matrix with the following order:
\[\begin{split}\begin{bmatrix} 1 & 0 & 0 & 0 \\ x_0 & 1 & 0 & 0 \\ x_1 & x_2 & 1 & 0 \\ x_3 & x_4 & x_5 & 1 \end{bmatrix}\end{split}\]2. For each row \(X_i\) of the lower triangular part, we apply a signed version of class
StickBreakingTransform
to transform \(X_i\) into a unit Euclidean length vector using the following steps:- Scales into the interval \((-1, 1)\) domain: \(r_i = \tanh(X_i)\).
- Transforms into an unsigned domain: \(z_i = r_i^2\).
- Applies \(s_i = StickBreakingTransform(z_i)\).
- Transforms back into signed domain: \(y_i = (sign(r_i), 1) * \sqrt{s_i}\).
-
domain
= <numpyro.distributions.constraints._RealVector object>¶
-
codomain
= <numpyro.distributions.constraints._CorrCholesky object>¶
-
event_dim
= 2¶
ExpTransform¶
IdentityTransform¶
InvCholeskyTransform¶
-
class
InvCholeskyTransform
(domain=<numpyro.distributions.constraints._LowerCholesky object>)[source]¶ Bases:
numpyro.distributions.transforms.Transform
Transform via the mapping \(y = x @ x.T\), where x is a lower triangular matrix with positive diagonal.
-
event_dim
= 2¶
-
codomain
¶
-
LowerCholeskyTransform¶
MultivariateAffineTransform¶
-
class
MultivariateAffineTransform
(loc, scale_tril)[source]¶ Bases:
numpyro.distributions.transforms.Transform
Transform via the mapping \(y = loc + scale\_tril\ @\ x\).
Parameters: - loc – a real vector.
- scale_tril – a lower triangular matrix with positive diagonal.
-
domain
= <numpyro.distributions.constraints._RealVector object>¶
-
codomain
= <numpyro.distributions.constraints._RealVector object>¶
-
event_dim
= 1¶
OrderedTransform¶
-
class
OrderedTransform
[source]¶ Bases:
numpyro.distributions.transforms.Transform
Transform a real vector to an ordered vector.
References:
- Stan Reference Manual v2.20, section 10.6, Stan Development Team
-
domain
= <numpyro.distributions.constraints._RealVector object>¶
-
codomain
= <numpyro.distributions.constraints._OrderedVector object>¶
-
event_dim
= 1¶
PermuteTransform¶
PowerTransform¶
SigmoidTransform¶
Flows¶
InverseAutoregressiveTransform¶
-
class
InverseAutoregressiveTransform
(autoregressive_nn, log_scale_min_clip=-5.0, log_scale_max_clip=3.0)[source]¶ Bases:
numpyro.distributions.transforms.Transform
An implementation of Inverse Autoregressive Flow, using Eq (10) from Kingma et al., 2016,
\(\mathbf{y} = \mu_t + \sigma_t\odot\mathbf{x}\)where \(\mathbf{x}\) are the inputs, \(\mathbf{y}\) are the outputs, \(\mu_t,\sigma_t\) are calculated from an autoregressive network on \(\mathbf{x}\), and \(\sigma_t>0\).
References
- Improving Variational Inference with Inverse Autoregressive Flow [arXiv:1606.04934], Diederik P. Kingma, Tim Salimans, Rafal Jozefowicz, Xi Chen, Ilya Sutskever, Max Welling
-
domain
= <numpyro.distributions.constraints._RealVector object>¶
-
codomain
= <numpyro.distributions.constraints._RealVector object>¶
-
event_dim
= 1¶
-
inv
(y)[source]¶ Parameters: y (numpy.ndarray) – the output of the transform to be inverted
-
log_abs_det_jacobian
(x, y, intermediates=None)[source]¶ Calculates the elementwise determinant of the log jacobian.
Parameters: - x (numpy.ndarray) – the input to the transform
- y (numpy.ndarray) – the output of the transform
Markov Chain Monte Carlo (MCMC)¶
Hamiltonian Monte Carlo¶
-
class
MCMC
(sampler, num_warmup, num_samples, num_chains=1, constrain_fn=None, chain_method='parallel', progress_bar=True, jit_model_args=False)[source]¶ Bases:
object
Provides access to Markov Chain Monte Carlo inference algorithms in NumPyro.
Note
chain_method is an experimental arg, which might be removed in a future version.
Note
Setting progress_bar=False will improve the speed for many cases.
Parameters: - sampler (MCMCKernel) – an instance of
MCMCKernel
that determines the sampler for running MCMC. Currently, onlyHMC
andNUTS
are available. - num_warmup (int) – Number of warmup steps.
- num_samples (int) – Number of samples to generate from the Markov chain.
- num_chains (int) – Number of Number of MCMC chains to run. By default,
chains will be run in parallel using
jax.pmap()
, failing which, chains will be run in sequence. - constrain_fn – Callable that converts a collection of unconstrained sample values returned from the sampler to constrained values that lie within the support of the sample sites.
- chain_method (str) – One of ‘parallel’ (default), ‘sequential’, ‘vectorized’. The method ‘parallel’ is used to execute the drawing process in parallel on XLA devices (CPUs/GPUs/TPUs), If there are not enough devices for ‘parallel’, we fall back to ‘sequential’ method to draw chains sequentially. ‘vectorized’ method is an experimental feature which vectorizes the drawing method, hence allowing us to collect samples in parallel on a single device.
- progress_bar (bool) – Whether to enable progress bar updates. Defaults to
True
. - jit_model_args (bool) – If set to True, this will compile the potential energy computation as a function of model arguments. As such, calling MCMC.run again on a same sized but different dataset will not result in additional compilation cost.
-
warmup
(rng_key, *args, extra_fields=(), collect_warmup=False, init_params=None, **kwargs)[source]¶ Run the MCMC warmup adaptation phase. After this call, the
run()
method will skip the warmup adaptation phase. To run warmup again for the new data, it is required to runwarmup()
again.Parameters: - rng_key (random.PRNGKey) – Random number generator key to be used for the sampling.
- args – Arguments to be provided to the
numpyro.infer.mcmc.MCMCKernel.init()
method. These are typically the arguments needed by the model. - extra_fields (tuple or list) – Extra fields (aside from z, diverging) from
numpyro.infer.mcmc.HMCState
to collect during the MCMC run. - collect_warmup (bool) – Whether to collect samples from the warmup phase. Defaults to False.
- init_params – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
- kwargs – Keyword arguments to be provided to the
numpyro.infer.mcmc.MCMCKernel.init()
method. These are typically the keyword arguments needed by the model.
-
run
(rng_key, *args, extra_fields=(), init_params=None, **kwargs)[source]¶ Run the MCMC samplers and collect samples.
Parameters: - rng_key (random.PRNGKey) – Random number generator key to be used for the sampling. For multi-chains, a batch of num_chains keys can be supplied. If rng_key does not have batch_size, it will be split in to a batch of num_chains keys.
- args – Arguments to be provided to the
numpyro.infer.mcmc.MCMCKernel.init()
method. These are typically the arguments needed by the model. - extra_fields (tuple or list) – Extra fields (aside from z, diverging) from
numpyro.infer.mcmc.HMCState
to collect during the MCMC run. - init_params – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
- kwargs – Keyword arguments to be provided to the
numpyro.infer.mcmc.MCMCKernel.init()
method. These are typically the keyword arguments needed by the model.
-
get_samples
(group_by_chain=False)[source]¶ Get samples from the MCMC run.
Parameters: group_by_chain (bool) – Whether to preserve the chain dimension. If True, all samples will have num_chains as the size of their leading dimension. Returns: Samples having the same data type as init_params. The data type is a dict keyed on site names if a model containing Pyro primitives is used, but can be any jaxlib.pytree()
, more generally (e.g. when defining a potential_fn for HMC that takes list args).
-
get_extra_fields
(group_by_chain=False)[source]¶ Get extra fields from the MCMC run.
Parameters: group_by_chain (bool) – Whether to preserve the chain dimension. If True, all samples will have num_chains as the size of their leading dimension. Returns: Extra fields keyed by field names which are specified in the extra_fields keyword of run()
.
- sampler (MCMCKernel) – an instance of
-
class
HMC
(model=None, potential_fn=None, kinetic_fn=None, step_size=1.0, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, trajectory_length=6.283185307179586, init_strategy=functools.partial(<function _init_to_uniform>, radius=2))[source]¶ Bases:
numpyro.infer.mcmc.MCMCKernel
Hamiltonian Monte Carlo inference, using fixed trajectory length, with provision for step size and mass matrix adaptation.
References:
- MCMC Using Hamiltonian Dynamics, Radford M. Neal
Parameters: - model – Python callable containing Pyro
primitives
. If model is provided, potential_fn will be inferred using the model. - potential_fn – Python callable that computes the potential energy given input parameters. The input parameters to potential_fn can be any python collection type, provided that init_params argument to init_kernel has the same type.
- kinetic_fn – Python callable that returns the kinetic energy given inverse mass matrix and momentum. If not provided, the default is euclidean kinetic energy.
- step_size (float) – Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1.
- adapt_step_size (bool) – A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme.
- adapt_mass_matrix (bool) – A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme.
- dense_mass (bool) – A flag to decide if mass matrix is dense or
diagonal (default when
dense_mass=False
) - target_accept_prob (float) – Target acceptance probability for step size adaptation using Dual Averaging. Increasing this value will lead to a smaller step size, hence the sampling will be slower but more robust. Default to 0.8.
- trajectory_length (float) – Length of a MCMC trajectory for HMC. Default value is \(2\pi\).
- init_strategy (callable) – a per-site initialization function. See Initialization Strategies section for available functions.
-
model
¶
-
sample
(state, model_args, model_kwargs)[source]¶ Run HMC from the given
HMCState
and return the resultingHMCState
.Parameters: - state (HMCState) – Represents the current state.
- model_args – Arguments provided to the model.
- model_kwargs – Keyword arguments provided to the model.
Returns: Next state after running HMC.
-
class
NUTS
(model=None, potential_fn=None, kinetic_fn=None, step_size=1.0, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, trajectory_length=None, max_tree_depth=10, init_strategy=functools.partial(<function _init_to_uniform>, radius=2))[source]¶ Bases:
numpyro.infer.mcmc.HMC
Hamiltonian Monte Carlo inference, using the No U-Turn Sampler (NUTS) with adaptive path length and mass matrix adaptation.References:
- MCMC Using Hamiltonian Dynamics, Radford M. Neal
- The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo, Matthew D. Hoffman, and Andrew Gelman.
- A Conceptual Introduction to Hamiltonian Monte Carlo`, Michael Betancourt
Parameters: - model – Python callable containing Pyro
primitives
. If model is provided, potential_fn will be inferred using the model. - potential_fn – Python callable that computes the potential energy given input parameters. The input parameters to potential_fn can be any python collection type, provided that init_params argument to init_kernel has the same type.
- kinetic_fn – Python callable that returns the kinetic energy given inverse mass matrix and momentum. If not provided, the default is euclidean kinetic energy.
- step_size (float) – Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1.
- adapt_step_size (bool) – A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme.
- adapt_mass_matrix (bool) – A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme.
- dense_mass (bool) – A flag to decide if mass matrix is dense or
diagonal (default when
dense_mass=False
) - target_accept_prob (float) – Target acceptance probability for step size adaptation using Dual Averaging. Increasing this value will lead to a smaller step size, hence the sampling will be slower but more robust. Default to 0.8.
- trajectory_length (float) – Length of a MCMC trajectory for HMC. This arg has no effect in NUTS sampler.
- max_tree_depth (int) – Max depth of the binary tree created during the doubling scheme of NUTS sampler. Defaults to 10.
- init_strategy (callable) – a per-site initialization function. See Initialization Strategies section for available functions.
-
hmc
(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, algo='NUTS')[source]¶ Hamiltonian Monte Carlo inference, using either fixed number of steps or the No U-Turn Sampler (NUTS) with adaptive path length.
References:
- MCMC Using Hamiltonian Dynamics, Radford M. Neal
- The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo, Matthew D. Hoffman, and Andrew Gelman.
- A Conceptual Introduction to Hamiltonian Monte Carlo`, Michael Betancourt
Parameters: - potential_fn – Python callable that computes the potential energy given input parameters. The input parameters to potential_fn can be any python collection type, provided that init_params argument to init_kernel has the same type.
- potential_fn_gen – Python callable that when provided with model arguments / keyword arguments returns potential_fn. This may be provided to do inference on the same model with changing data. If the data shape remains the same, we can compile sample_kernel once, and use the same for multiple inference runs.
- kinetic_fn – Python callable that returns the kinetic energy given inverse mass matrix and momentum. If not provided, the default is euclidean kinetic energy.
- algo (str) – Whether to run
HMC
with fixed number of steps orNUTS
with adaptive path length. Default isNUTS
.
Returns: a tuple of callables (init_kernel, sample_kernel), the first one to initialize the sampler, and the second one to generate samples given an existing one.
Warning
Instead of using this interface directly, we would highly recommend you to use the higher level
numpyro.infer.MCMC
API instead.Example
>>> import jax >>> from jax import random >>> import jax.numpy as np >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.infer.mcmc import hmc >>> from numpyro.infer.util import initialize_model >>> from numpyro.util import fori_collect >>> true_coefs = np.array([1., 2., 3.]) >>> data = random.normal(random.PRNGKey(2), (2000, 3)) >>> dim = 3 >>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(3)) >>> >>> def model(data, labels): ... coefs_mean = np.zeros(dim) ... coefs = numpyro.sample('beta', dist.Normal(coefs_mean, np.ones(3))) ... intercept = numpyro.sample('intercept', dist.Normal(0., 10.)) ... return numpyro.sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)), obs=labels) >>> >>> init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(0), ... model, model_args=(data, labels,)) >>> init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS') >>> hmc_state = init_kernel(init_params, ... trajectory_length=10, ... num_warmup=300) >>> samples = fori_collect(0, 500, sample_kernel, hmc_state, ... transform=lambda state: constrain_fn(state.z)) >>> print(np.mean(samples['beta'], axis=0)) [0.9153987 2.0754058 2.9621222]
-
init_kernel
(init_params, num_warmup, step_size=1.0, inverse_mass_matrix=None, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, trajectory_length=6.283185307179586, max_tree_depth=10, model_args=(), model_kwargs=None, rng_key=DeviceArray([0, 0], dtype=uint32))¶ Initializes the HMC sampler.
Parameters: - init_params – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
- num_warmup (int) – Number of warmup steps; samples generated during warmup are discarded.
- step_size (float) – Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1.
- inverse_mass_matrix (numpy.ndarray) – Initial value for inverse mass matrix. This may be adapted during warmup if adapt_mass_matrix = True. If no value is specified, then it is initialized to the identity matrix.
- adapt_step_size (bool) – A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme.
- adapt_mass_matrix (bool) – A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme.
- dense_mass (bool) – A flag to decide if mass matrix is dense or
diagonal (default when
dense_mass=False
) - target_accept_prob (float) – Target acceptance probability for step size adaptation using Dual Averaging. Increasing this value will lead to a smaller step size, hence the sampling will be slower but more robust. Default to 0.8.
- trajectory_length (float) – Length of a MCMC trajectory for HMC. Default value is \(2\pi\).
- max_tree_depth (int) – Max depth of the binary tree created during the doubling scheme of NUTS sampler. Defaults to 10.
- model_args (tuple) – Model arguments if potential_fn_gen is specified.
- model_kwargs (dict) – Model keyword arguments if potential_fn_gen is specified.
- rng_key (jax.random.PRNGKey) – random key to be used as the source of randomness.
-
sample_kernel
(hmc_state, model_args=(), model_kwargs=None)¶ Given an existing
HMCState
, run HMC with fixed (possibly adapted) step size and return a newHMCState
.Parameters: Returns: new proposed
HMCState
from simulating Hamiltonian dynamics given existing state.
-
HMCState
= <class 'numpyro.infer.mcmc.HMCState'>¶ A
namedtuple()
consisting of the following fields:- i - iteration. This is reset to 0 after warmup.
- z - Python collection representing values (unconstrained samples from the posterior) at latent sites.
- z_grad - Gradient of potential energy w.r.t. latent sample sites.
- potential_energy - Potential energy computed at the given value of
z
. - energy - Sum of potential energy and kinetic energy of the current state.
- num_steps - Number of steps in the Hamiltonian trajectory (for diagnostics).
- accept_prob - Acceptance probability of the proposal. Note that
z
does not correspond to the proposal if it is rejected. - mean_accept_prob - Mean acceptance probability until current iteration during warmup adaptation or sampling (for diagnostics).
- diverging - A boolean value to indicate whether the current trajectory is diverging.
- adapt_state - A
AdaptState
namedtuple which contains adaptation information during warmup:- step_size - Step size to be used by the integrator in the next iteration.
- inverse_mass_matrix - The inverse mass matrix to be used for the next iteration.
- mass_matrix_sqrt - The square root of mass matrix to be used for the next iteration. In case of dense mass, this is the Cholesky factorization of the mass matrix.
- rng_key - random number generator seed used for the iteration.
MCMC Utilities¶
-
initialize_model
(rng_key, model, init_strategy=functools.partial(<function _init_to_uniform>, radius=2), dynamic_args=False, model_args=(), model_kwargs=None)[source]¶ (EXPERIMENTAL INTERFACE) Helper function that calls
get_potential_fn()
andfind_valid_initial_params()
under the hood to return a tuple of (init_params, potential_fn, constrain_fn).Parameters: - rng_key (jax.random.PRNGKey) – random number generator seed to
sample from the prior. The returned init_params will have the
batch shape
rng_key.shape[:-1]
. - model – Python callable containing Pyro primitives.
- init_strategy (callable) – a per-site initialization function. See Initialization Strategies section for available functions.
- dynamic_args (bool) – if True, the potential_fn and constraints_fn are themselves dependent on model arguments. When provided a *model_args, **model_kwargs, they return potential_fn and constraints_fn callables, respectively.
- model_args (tuple) – args provided to the model.
- model_kwargs (dict) – kwargs provided to the model.
Returns: tuple of (init_params, potential_fn, constrain_fn), init_params are values from the prior used to initiate MCMC, constrain_fn is a callable that uses inverse transforms to convert unconstrained HMC samples to constrained values that lie within the site’s support.
- rng_key (jax.random.PRNGKey) – random number generator seed to
sample from the prior. The returned init_params will have the
batch shape
-
fori_collect
(lower, upper, body_fun, init_val, transform=<function identity>, progbar=True, return_last_val=False, collection_size=None, **progbar_opts)[source]¶ This looping construct works like
fori_loop()
but with the additional effect of collecting values from the loop body. In addition, this allows for post-processing of these samples via transform, and progress bar updates. Note that, progbar=False will be faster, especially when collecting a lot of samples. Refer to example usage inhmc()
.Parameters: - lower (int) – the index to start the collective work. In other words, we will skip collecting the first lower values.
- upper (int) – number of times to run the loop body.
- body_fun – a callable that takes a collection of np.ndarray and returns a collection with the same shape and dtype.
- init_val – initial value to pass as argument to body_fun. Can be any Python collection type containing np.ndarray objects.
- transform – a callable to post-process the values returned by body_fn.
- progbar – whether to post progress bar updates.
- return_last_val (bool) – If True, the last value is also returned. This has the same type as init_val.
- collection_size (int) – Size of the returned collection. If not specified,
the size will be
upper - lower
. If the size is larger thanupper - lower
, only the topupper - lower
entries will be non-zero. - **progbar_opts – optional additional progress bar arguments. A diagnostics_fn can be supplied which when passed the current value from body_fun returns a string that is used to update the progress bar postfix. Also a progbar_desc keyword argument can be supplied which is used to label the progress bar.
Returns: collection with the same type as init_val with values collected along the leading axis of np.ndarray objects.
-
consensus
(subposteriors, num_draws=None, diagonal=False, rng_key=None)[source]¶ Merges subposteriors following consensus Monte Carlo algorithm.
References:
- Bayes and big data: The consensus Monte Carlo algorithm, Steven L. Scott, Alexander W. Blocker, Fernando V. Bonassi, Hugh A. Chipman, Edward I. George, Robert E. McCulloch
Parameters: - subposteriors (list) – a list in which each element is a collection of samples.
- num_draws (int) – number of draws from the merged posterior.
- diagonal (bool) – whether to compute weights using variance or covariance, defaults to False (using covariance).
- rng_key (jax.random.PRNGKey) – source of the randomness, defaults to jax.random.PRNGKey(0).
Returns: if num_draws is None, merges subposteriors without resampling; otherwise, returns a collection of num_draws samples with the same data structure as each subposterior.
-
parametric
(subposteriors, diagonal=False)[source]¶ Merges subposteriors following (embarrassingly parallel) parametric Monte Carlo algorithm.
References:
- Asymptotically Exact, Embarrassingly Parallel MCMC, Willie Neiswanger, Chong Wang, Eric Xing
Parameters: Returns: the estimated mean and variance/covariance parameters of the joined posterior
-
parametric_draws
(subposteriors, num_draws, diagonal=False, rng_key=None)[source]¶ Merges subposteriors following (embarrassingly parallel) parametric Monte Carlo algorithm.
References:
- Asymptotically Exact, Embarrassingly Parallel MCMC, Willie Neiswanger, Chong Wang, Eric Xing
Parameters: - subposteriors (list) – a list in which each element is a collection of samples.
- num_draws (int) – number of draws from the merged posterior.
- diagonal (bool) – whether to compute weights using variance or covariance, defaults to False (using covariance).
- rng_key (jax.random.PRNGKey) – source of the randomness, defaults to jax.random.PRNGKey(0).
Returns: a collection of num_draws samples with the same data structure as each subposterior.
Stochastic Variational Inference (SVI)¶
-
class
SVI
(model, guide, optim, loss, **static_kwargs)[source]¶ Bases:
object
Stochastic Variational Inference given an ELBO loss objective.
Parameters: - model – Python callable with Pyro primitives for the model.
- guide – Python callable with Pyro primitives for the guide (recognition network).
- optim – an instance of
_NumpyroOptim
. - loss – ELBO loss, i.e. negative Evidence Lower Bound, to minimize.
- static_kwargs – static arguments for the model / guide, i.e. arguments that remain constant during fitting.
Returns: tuple of (init_fn, update_fn, evaluate).
-
init
(rng_key, *args, **kwargs)[source]¶ Parameters: - rng_key (jax.random.PRNGKey) – random number generator seed.
- args – arguments to the model / guide (these can possibly vary during the course of fitting).
- kwargs – keyword arguments to the model / guide (these can possibly vary during the course of fitting).
Returns: tuple containing initial
SVIState
, and get_params, a callable that transforms unconstrained parameter values from the optimizer to the specified constrained domain
-
get_params
(svi_state)[source]¶ Gets values at param sites of the model and guide.
Parameters: svi_state – current state of the optimizer.
-
update
(svi_state, *args, **kwargs)[source]¶ Take a single step of SVI (possibly on a batch / minibatch of data), using the optimizer.
Parameters: - svi_state – current state of SVI.
- args – arguments to the model / guide (these can possibly vary during the course of fitting).
- kwargs – keyword arguments to the model / guide (these can possibly vary during the course of fitting).
Returns: tuple of (svi_state, loss).
-
evaluate
(svi_state, *args, **kwargs)[source]¶ Take a single step of SVI (possibly on a batch / minibatch of data).
Parameters: - svi_state – current state of SVI.
- args – arguments to the model / guide (these can possibly vary during the course of fitting).
- kwargs – keyword arguments to the model / guide.
Returns: evaluate ELBO loss given the current parameter values (held within svi_state.optim_state).
ELBO¶
-
class
ELBO
(num_particles=1)[source]¶ Bases:
object
A trace implementation of ELBO-based SVI. The estimator is constructed along the lines of references [1] and [2]. There are no restrictions on the dependency structure of the model or the guide.
This is the most basic implementation of the Evidence Lower Bound, which is the fundamental objective in Variational Inference. This implementation has various limitations (for example it only supports random variables with reparameterized samplers) but can be used as a template to build more sophisticated loss objectives.
For more details, refer to http://pyro.ai/examples/svi_part_i.html.
References:
- Automated Variational Inference in Probabilistic Programming, David Wingate, Theo Weber
- Black Box Variational Inference, Rajesh Ranganath, Sean Gerrish, David M. Blei
Parameters: num_particles – The number of particles/samples used to form the ELBO (gradient) estimators. -
loss
(rng_key, param_map, model, guide, *args, **kwargs)[source]¶ Evaluates the ELBO with an estimator that uses num_particles many samples/particles.
Parameters: - rng_key (jax.random.PRNGKey) – random number generator seed.
- param_map (dict) – dictionary of current parameter values keyed by site name.
- model – Python callable with NumPyro primitives for the model.
- guide – Python callable with NumPyro primitives for the guide.
- args – arguments to the model / guide (these can possibly vary during the course of fitting).
- kwargs – keyword arguments to the model / guide (these can possibly vary during the course of fitting).
Returns: negative of the Evidence Lower Bound (ELBO) to be minimized.
RenyiELBO¶
-
class
RenyiELBO
(alpha=0, num_particles=2)[source]¶ Bases:
numpyro.infer.elbo.ELBO
An implementation of Renyi’s \(\alpha\)-divergence variational inference following reference [1]. In order for the objective to be a strict lower bound, we require \(\alpha \ge 0\). Note, however, that according to reference [1], depending on the dataset \(\alpha < 0\) might give better results. In the special case \(\alpha = 0\), the objective function is that of the important weighted autoencoder derived in reference [2].
Note
Setting \(\alpha < 1\) gives a better bound than the usual ELBO.
Parameters: - alpha (float) – The order of \(\alpha\)-divergence. Here \(\alpha \neq 1\). Default is 0.
- num_particles – The number of particles/samples used to form the objective (gradient) estimator. Default is 2.
References:
- Renyi Divergence Variational Inference, Yingzhen Li, Richard E. Turner
- Importance Weighted Autoencoders, Yuri Burda, Roger Grosse, Ruslan Salakhutdinov
-
loss
(rng_key, param_map, model, guide, *args, **kwargs)[source]¶ Evaluates the Renyi ELBO with an estimator that uses num_particles many samples/particles.
Parameters: - rng_key (jax.random.PRNGKey) – random number generator seed.
- param_map (dict) – dictionary of current parameter values keyed by site name.
- model – Python callable with NumPyro primitives for the model.
- guide – Python callable with NumPyro primitives for the guide.
- args – arguments to the model / guide (these can possibly vary during the course of fitting).
- kwargs – keyword arguments to the model / guide (these can possibly vary during the course of fitting).
Returns: negative of the Renyi Evidence Lower Bound (ELBO) to be minimized.
Automatic Guide Generation¶
Warning
The interface for the contrib.autoguide module is experimental, and subject to frequent revisions.
AutoContinuous¶
-
class
AutoContinuous
(model, prefix='auto', init_strategy=functools.partial(<function _init_to_uniform>, radius=2))[source]¶ Bases:
numpyro.contrib.autoguide.AutoGuide
Base class for implementations of continuous-valued Automatic Differentiation Variational Inference [1].
Each derived class implements its own
_get_transform()
method.Assumes model structure and latent dimension are fixed, and all latent variables are continuous.
Note
We recommend using
AutoContinuousELBO
as the objective function loss inSVI
. In addition, we recommend usingsample_posterior()
method for drawing posterior samples from the autoguide as it will automatically do any unpacking and transformations required to constrain the samples to the support of the latent sites.Reference:
- Automatic Differentiation Variational Inference, Alp Kucukelbir, Dustin Tran, Rajesh Ranganath, Andrew Gelman, David M. Blei
Parameters: - model (callable) – A NumPyro model.
- prefix (str) – a prefix that will be prefixed to all param internal sites.
- init_strategy (callable) – A per-site initialization function. See Initialization Strategies section for available functions.
-
base_dist
¶ Base distribution of the posterior. By default, it is standard normal.
-
get_transform
(params)[source]¶ Returns the transformation learned by the guide to generate samples from the unconstrained (approximate) posterior.
Parameters: params (dict) – Current parameters of model and autoguide. Returns: the transform of posterior distribution Return type: Transform
-
sample_posterior
(rng_key, params, sample_shape=())[source]¶ Get samples from the learned posterior.
Parameters: Returns: a dict containing samples drawn the this guide.
Return type:
-
median
(params)[source]¶ Returns the posterior median value of each latent variable.
Parameters: params (dict) – A dict containing parameter values. Returns: A dict mapping sample site name to median tensor. Return type: dict
AutoDiagonalNormal¶
-
class
AutoDiagonalNormal
(model, prefix='auto', init_strategy=functools.partial(<function _init_to_uniform>, radius=2))[source]¶ Bases:
numpyro.contrib.autoguide.AutoContinuous
This implementation of
AutoContinuous
uses a Normal distribution with a diagonal covariance matrix to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Usage:
guide = AutoDiagonalNormal(model, ...) svi = SVI(model, guide, ...)
AutoMultivariateNormal¶
-
class
AutoMultivariateNormal
(model, prefix='auto', init_strategy=functools.partial(<function _init_to_uniform>, radius=2))[source]¶ Bases:
numpyro.contrib.autoguide.AutoContinuous
This implementation of
AutoContinuous
uses a MultivariateNormal distribution to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Usage:
guide = AutoMultivariateNormal(model, ...) svi = SVI(model, guide, ...)
AutoIAFNormal¶
-
class
AutoIAFNormal
(model, prefix='auto', init_strategy=functools.partial(<function _init_to_uniform>, radius=2), num_flows=3, **arn_kwargs)[source]¶ Bases:
numpyro.contrib.autoguide.AutoContinuous
This implementation of
AutoContinuous
uses a Diagonal Normal distribution transformed via aInverseAutoregressiveTransform
to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Usage:
guide = AutoIAFNormal(model, hidden_dims=[20], skip_connections=True, ...) svi = SVI(model, guide, ...)
Parameters: - rng_key (jax.random.PRNGKey) – random key to be used as the source of randomness to initialize the guide.
- model (callable) – a generative model.
- prefix (str) – a prefix that will be prefixed to all param internal sites.
- init_strategy (callable) – A per-site initialization function.
- num_flows (int) – the number of flows to be used, defaults to 3.
- **arn_kwargs –
keywords for constructing autoregressive neural networks, which includes:
- hidden_dims (
list[int]
) - the dimensionality of the hidden units per layer. Defaults to[latent_size, latent_size]
. - skip_connections (
bool
) - whether to add skip connections from the input to the output of each flow. Defaults to False. - nonlinearity (
callable
) - the nonlinearity to use in the feedforward network. Defaults tojax.experimental.stax.Relu()
.
- hidden_dims (
AutoLaplaceApproximation¶
-
class
AutoLaplaceApproximation
(model, prefix='auto', init_strategy=functools.partial(<function _init_to_uniform>, radius=2))[source]¶ Bases:
numpyro.contrib.autoguide.AutoContinuous
Laplace approximation (quadratic approximation) approximates the posterior \(\log p(z | x)\) by a multivariate normal distribution in the unconstrained space. Under the hood, it uses Delta distributions to construct a MAP guide over the entire (unconstrained) latent space. Its covariance is given by the inverse of the hessian of \(-\log p(x, z)\) at the MAP point of z.
Usage:
guide = AutoLaplaceApproximation(model, ...) svi = SVI(model, guide, ...)
AutoContinuousELBO¶
-
class
AutoContinuousELBO
(num_particles=1)[source]¶ Bases:
numpyro.infer.elbo.ELBO
An ELBO implementation specific to
AutoContinuous
guides. In those guide, the latent variables of the model are transformed to unconstrained domains. This class provides ELBO of the “transformed” model (i.e. the corresponding model with unconstrained variables) and the guide.
Optimizers¶
Optimizer classes defined here are light wrappers over the corresponding optimizers
sourced from jax.experimental.optimizers
with an interface that is better
suited for working with NumPyro inference algorithms.
Adam¶
-
class
Adam
(*args, **kwargs)[source]¶ Wrapper class for the JAX optimizer:
adam()
-
get_params
(state: Tuple[int, _OptState]) → _Params¶ Get current parameter values.
Parameters: state – current optimizer state. Returns: collection with current value for parameters.
-
init
(params: _Params) → Tuple[int, _OptState]¶ Initialize the optimizer with parameters designated to be optimized.
Parameters: params – a collection of numpy arrays. Returns: initial optimizer state.
-
update
(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Gradient update for the optimizer.
Parameters: - g – gradient information for parameters.
- state – current optimizer state.
Returns: new optimizer state after the update.
-
Adagrad¶
-
class
Adagrad
(*args, **kwargs)[source]¶ Wrapper class for the JAX optimizer:
adagrad()
-
get_params
(state: Tuple[int, _OptState]) → _Params¶ Get current parameter values.
Parameters: state – current optimizer state. Returns: collection with current value for parameters.
-
init
(params: _Params) → Tuple[int, _OptState]¶ Initialize the optimizer with parameters designated to be optimized.
Parameters: params – a collection of numpy arrays. Returns: initial optimizer state.
-
update
(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Gradient update for the optimizer.
Parameters: - g – gradient information for parameters.
- state – current optimizer state.
Returns: new optimizer state after the update.
-
ClippedAdam¶
-
class
ClippedAdam
(*args, clip_norm=10.0, **kwargs)[source]¶ Adam
optimizer with gradient clipping.Parameters: clip_norm (float) – All gradient values will be clipped between [-clip_norm, clip_norm]. Reference:
A Method for Stochastic Optimization, Diederik P. Kingma, Jimmy Ba https://arxiv.org/abs/1412.6980
-
get_params
(state: Tuple[int, _OptState]) → _Params¶ Get current parameter values.
Parameters: state – current optimizer state. Returns: collection with current value for parameters.
-
init
(params: _Params) → Tuple[int, _OptState]¶ Initialize the optimizer with parameters designated to be optimized.
Parameters: params – a collection of numpy arrays. Returns: initial optimizer state.
-
Momentum¶
-
class
Momentum
(*args, **kwargs)[source]¶ Wrapper class for the JAX optimizer:
momentum()
-
get_params
(state: Tuple[int, _OptState]) → _Params¶ Get current parameter values.
Parameters: state – current optimizer state. Returns: collection with current value for parameters.
-
init
(params: _Params) → Tuple[int, _OptState]¶ Initialize the optimizer with parameters designated to be optimized.
Parameters: params – a collection of numpy arrays. Returns: initial optimizer state.
-
update
(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Gradient update for the optimizer.
Parameters: - g – gradient information for parameters.
- state – current optimizer state.
Returns: new optimizer state after the update.
-
RMSProp¶
-
class
RMSProp
(*args, **kwargs)[source]¶ Wrapper class for the JAX optimizer:
rmsprop()
-
get_params
(state: Tuple[int, _OptState]) → _Params¶ Get current parameter values.
Parameters: state – current optimizer state. Returns: collection with current value for parameters.
-
init
(params: _Params) → Tuple[int, _OptState]¶ Initialize the optimizer with parameters designated to be optimized.
Parameters: params – a collection of numpy arrays. Returns: initial optimizer state.
-
update
(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Gradient update for the optimizer.
Parameters: - g – gradient information for parameters.
- state – current optimizer state.
Returns: new optimizer state after the update.
-
RMSPropMomentum¶
-
class
RMSPropMomentum
(*args, **kwargs)[source]¶ Wrapper class for the JAX optimizer:
rmsprop_momentum()
-
get_params
(state: Tuple[int, _OptState]) → _Params¶ Get current parameter values.
Parameters: state – current optimizer state. Returns: collection with current value for parameters.
-
init
(params: _Params) → Tuple[int, _OptState]¶ Initialize the optimizer with parameters designated to be optimized.
Parameters: params – a collection of numpy arrays. Returns: initial optimizer state.
-
update
(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Gradient update for the optimizer.
Parameters: - g – gradient information for parameters.
- state – current optimizer state.
Returns: new optimizer state after the update.
-
SGD¶
-
class
SGD
(*args, **kwargs)[source]¶ Wrapper class for the JAX optimizer:
sgd()
-
get_params
(state: Tuple[int, _OptState]) → _Params¶ Get current parameter values.
Parameters: state – current optimizer state. Returns: collection with current value for parameters.
-
init
(params: _Params) → Tuple[int, _OptState]¶ Initialize the optimizer with parameters designated to be optimized.
Parameters: params – a collection of numpy arrays. Returns: initial optimizer state.
-
update
(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Gradient update for the optimizer.
Parameters: - g – gradient information for parameters.
- state – current optimizer state.
Returns: new optimizer state after the update.
-
SM3¶
-
class
SM3
(*args, **kwargs)[source]¶ Wrapper class for the JAX optimizer:
sm3()
-
get_params
(state: Tuple[int, _OptState]) → _Params¶ Get current parameter values.
Parameters: state – current optimizer state. Returns: collection with current value for parameters.
-
init
(params: _Params) → Tuple[int, _OptState]¶ Initialize the optimizer with parameters designated to be optimized.
Parameters: params – a collection of numpy arrays. Returns: initial optimizer state.
-
update
(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Gradient update for the optimizer.
Parameters: - g – gradient information for parameters.
- state – current optimizer state.
Returns: new optimizer state after the update.
-
Diagnostics¶
This provides a small set of utilities in NumPyro that are used to diagnose posterior samples.
Autocorrelation¶
-
autocorrelation
(x, axis=0)[source]¶ Computes the autocorrelation of samples at dimension
axis
.Parameters: - x (numpy.ndarray) – the input array.
- axis (int) – the dimension to calculate autocorrelation.
Returns: autocorrelation of
x
.Return type:
Autocovariance¶
-
autocovariance
(x, axis=0)[source]¶ Computes the autocovariance of samples at dimension
axis
.Parameters: - x (numpy.ndarray) – the input array.
- axis (int) – the dimension to calculate autocovariance.
Returns: autocovariance of
x
.Return type:
Effective Sample Size¶
-
effective_sample_size
(x)[source]¶ Computes effective sample size of input
x
, where the first dimension ofx
is chain dimension and the second dimension ofx
is draw dimension.References:
- Introduction to Markov Chain Monte Carlo, Charles J. Geyer
- Stan Reference Manual version 2.18, Stan Development Team
Parameters: x (numpy.ndarray) – the input array. Returns: effective sample size of x
.Return type: numpy.ndarray
Gelman Rubin¶
-
gelman_rubin
(x)[source]¶ Computes R-hat over chains of samples
x
, where the first dimension ofx
is chain dimension and the second dimension ofx
is draw dimension. It is required thatx.shape[0] >= 2
andx.shape[1] >= 2
.Parameters: x (numpy.ndarray) – the input array. Returns: R-hat of x
.Return type: numpy.ndarray
Split Gelman Rubin¶
-
split_gelman_rubin
(x)[source]¶ Computes split R-hat over chains of samples
x
, where the first dimension ofx
is chain dimension and the second dimension ofx
is draw dimension. It is required thatx.shape[1] >= 4
.Parameters: x (numpy.ndarray) – the input array. Returns: split R-hat of x
.Return type: numpy.ndarray
HPDI¶
-
hpdi
(x, prob=0.9, axis=0)[source]¶ Computes “highest posterior density interval” (HPDI) which is the narrowest interval with probability mass
prob
.Parameters: - x (numpy.ndarray) – the input array.
- prob (float) – the probability mass of samples within the interval.
- axis (int) – the dimension to calculate hpdi.
Returns: quantiles of
x
at(1 - prob) / 2
and(1 + prob) / 2
.Return type:
Summary¶
-
summary
(samples, prob=0.9, group_by_chain=True)[source]¶ Returns a summary table displaying diagnostics of
samples
from the posterior. The diagnostics displayed are mean, standard deviation, median, the 90% Credibility Intervalhpdi()
,effective_sample_size()
, andsplit_gelman_rubin()
.Parameters: - samples (dict or numpy.ndarray) – a collection of input samples with left most dimension is chain dimension and second to left most dimension is draw dimension.
- prob (float) – the probability mass of samples within the HPDI interval.
- group_by_chain (bool) – If True, each variable in samples will be treated as having shape num_chains x num_samples x sample_shape. Otherwise, the corresponding shape will be num_samples x sample_shape (i.e. without chain dimension).
-
print_summary
(samples, prob=0.9, group_by_chain=True)[source]¶ Prints a summary table displaying diagnostics of
samples
from the posterior. The diagnostics displayed are mean, standard deviation, median, the 90% Credibility Intervalhpdi()
,effective_sample_size()
, andsplit_gelman_rubin()
.Parameters: - samples (dict or numpy.ndarray) – a collection of input samples with left most dimension is chain dimension and second to left most dimension is draw dimension.
- prob (float) – the probability mass of samples within the HPDI interval.
- group_by_chain (bool) – If True, each variable in samples will be treated as having shape num_chains x num_samples x sample_shape. Otherwise, the corresponding shape will be num_samples x sample_shape (i.e. without chain dimension).
Runtime Utilities¶
enable_validation¶
-
enable_validation
(is_validate=True)[source]¶ Enable or disable validation checks in NumPyro. Validation checks provide useful warnings and errors, e.g. NaN checks, validating distribution arguments and support values, etc. which is useful for debugging.
Note
This utility does not take effect under JAX’s JIT compilation or vectorized transformation
jax.vmap()
.Parameters: is_validate (bool) – whether to enable validation checks.
validation_enabled¶
enable_x64¶
set_platform¶
set_host_device_count¶
-
set_host_device_count
(n)[source]¶ By default, XLA considers all CPU cores as one device. This utility tells XLA that there are n host (CPU) devices available to use. As a consequence, this allows parallel mapping in JAX
jax.pmap()
to work in CPU platform.Note
This utility only takes effect at the beginning of your program. Under the hood, this sets the environment variable XLA_FLAGS=–xla_force_host_platform_device_count=[num_devices], where [num_device] is the desired number of CPU devices n.
Warning
Our understanding of the side effects of using the xla_force_host_platform_device_count flag in XLA is incomplete. If you observe some strange phenomenon when using this utility, please let us know through our issue or forum page. More information is available in this JAX issue.
Parameters: n (int) – number of CPU devices to use.
Inference Utilities¶
Predictive¶
-
class
Predictive
(model, posterior_samples=None, guide=None, params=None, num_samples=None, return_sites=None, parallel=False)[source]¶ Bases:
object
This class is used to construct predictive distribution. The predictive distribution is obtained by running model conditioned on latent samples from posterior_samples.
Warning
The interface for the Predictive class is experimental, and might change in the future.
Parameters: - model – Python callable containing Pyro primitives.
- posterior_samples (dict) – dictionary of samples from the posterior.
- guide (callable) – optional guide to get posterior samples of sites not present in posterior_samples.
- params (dict) – dictionary of values for param sites of model/guide.
- num_samples (int) – number of samples
- return_sites (list) – sites to return; by default only sample sites not present in posterior_samples are returned.
- parallel (bool) – whether to predict in parallel using JAX vectorized map
jax.vmap()
. Defaults to False.
Returns: dict of samples from the predictive distribution.
-
get_samples
(rng_key, *args, **kwargs)[source]¶ Returns dict of samples from the predictive distribution. By default, only sample sites not contained in posterior_samples are returned. This can be modified by changing the return_sites keyword argument of this
Predictive
instance.Parameters: - rng_key (jax.random.PRNGKey) – random key to draw samples.
- args – model arguments.
- kwargs – model kwargs.
log_density¶
-
log_density
(model, model_args, model_kwargs, params, skip_dist_transforms=False)[source]¶ (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given latent values
params
.Parameters: - model – Python callable containing NumPyro primitives.
- model_args (tuple) – args provided to the model.
- model_kwargs (dict) – kwargs provided to the model.
- params (dict) – dictionary of current parameter values keyed by site name.
- skip_dist_transforms (bool) – whether to compute log probability of a site (if its prior is a transformed distribution) in its base distribution domain.
Returns: log of joint density and a corresponding model trace
transform_fn¶
-
transform_fn
(transforms, params, invert=False)[source]¶ (EXPERIMENTAL INTERFACE) Callable that applies a transformation from the transforms dict to values in the params dict and returns the transformed values keyed on the same names.
Parameters: - transforms – Dictionary of transforms keyed by names. Names in transforms and params should align.
- params – Dictionary of arrays keyed by names.
- invert – Whether to apply the inverse of the transforms.
Returns: dict of transformed params.
constrain_fn¶
-
constrain_fn
(model, transforms, model_args, model_kwargs, params)[source]¶ (EXPERIMENTAL INTERFACE) Gets value at each latent site in model given unconstrained parameters params. The transforms is used to transform these unconstrained parameters to base values of the corresponding priors in model. If a prior is a transformed distribution, the corresponding base value lies in the support of base distribution. Otherwise, the base value lies in the support of the distribution.
Parameters: - model – a callable containing NumPyro primitives.
- transforms (dict) – dictionary of transforms keyed by names. Names in transforms and params should align.
- model_args (tuple) – args provided to the model.
- model_kwargs (dict) – kwargs provided to the model.
- params (dict) – dictionary of unconstrained values keyed by site names.
Returns: dict of transformed params.
potential_energy¶
-
potential_energy
(model, inv_transforms, model_args, model_kwargs, params)[source]¶ (EXPERIMENTAL INTERFACE) Computes potential energy of a model given unconstrained params. The inv_transforms is used to transform these unconstrained parameters to base values of the corresponding priors in model. If a prior is a transformed distribution, the corresponding base value lies in the support of base distribution. Otherwise, the base value lies in the support of the distribution.
Parameters: Returns: potential energy given unconstrained parameters.
log_likelihood¶
-
log_likelihood
(model, posterior_samples, *args, **kwargs)[source]¶ (EXPERIMENTAL INTERFACE) Returns log likelihood at observation nodes of model, given samples of all latent variables.
Parameters: - model – Python callable containing Pyro primitives.
- posterior_samples (dict) – dictionary of samples from the posterior.
- args – model arguments.
- kwargs – model kwargs.
Returns: dict of log likelihoods at observation sites.
find_valid_initial_params¶
-
find_valid_initial_params
(rng_key, model, init_strategy=functools.partial(<function _init_to_uniform>, radius=2), param_as_improper=False, model_args=(), model_kwargs=None)[source]¶ (EXPERIMENTAL INTERFACE) Given a model with Pyro primitives, returns an initial valid unconstrained value for all the parameters. This function also returns an is_valid flag to say whether the initial parameters are valid. Parameter values are considered valid if the values and the gradients for the log density have finite values.
Parameters: - rng_key (jax.random.PRNGKey) – random number generator seed to
sample from the prior. The returned init_params will have the
batch shape
rng_key.shape[:-1]
. - model – Python callable containing Pyro primitives.
- init_strategy (callable) – a per-site initialization function.
- param_as_improper (bool) – a flag to decide whether to consider sites with param statement as sites with improper priors.
- model_args (tuple) – args provided to the model.
- model_kwargs (dict) – kwargs provided to the model.
Returns: tuple of (init_params, is_valid).
- rng_key (jax.random.PRNGKey) – random number generator seed to
sample from the prior. The returned init_params will have the
batch shape
Initialization Strategies¶
init_to_median¶
init_to_uniform¶
init_to_feasible¶
init_to_value¶
-
init_to_value
(values)[source]¶ Initialize to the value specified in values. We defer to
init_to_uniform()
strategy for sites which do not appear in values.Parameters: values (dict) – dictionary of initial values keyed by site name.