Base Distribution¶
Distribution¶

class
Distribution
(batch_shape=(), event_shape=(), validate_args=None)[source]¶ Bases:
object
Base class for probability distributions in NumPyro. The design largely follows from
torch.distributions
.Parameters:  batch_shape – The batch shape for the distribution. This designates independent (possibly nonidentical) dimensions of a sample from the distribution. This is fixed for a distribution instance and is inferred from the shape of the distribution parameters.
 event_shape – The event shape for the distribution. This designates the dependent dimensions of a sample from the distribution. These are collapsed when we evaluate the log probability density of a batch of samples using .log_prob.
 validate_args – Whether to enable validation of distribution parameters and arguments to .log_prob method.
As an example:
>>> import jax.numpy as jnp >>> import numpyro.distributions as dist >>> d = dist.Dirichlet(jnp.ones((2, 3, 4))) >>> d.batch_shape (2, 3) >>> d.event_shape (4,)

arg_constraints
= {}¶

support
= None¶

has_enumerate_support
= False¶

is_discrete
= False¶

reparametrized_params
= []¶

batch_shape
¶ Returns the shape over which the distribution parameters are batched.
Returns: batch shape of the distribution. Return type: tuple

event_shape
¶ Returns the shape of a single sample from the distribution without batching.
Returns: event shape of the distribution. Return type: tuple

shape
(sample_shape=())[source]¶ The tensor shape of samples from this distribution.
Samples are of shape:
d.shape(sample_shape) == sample_shape + d.batch_shape + d.event_shape
Parameters: sample_shape (tuple) – the size of the iid batch to be drawn from the distribution. Returns: shape of samples. Return type: tuple

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

sample_with_intermediates
(key, sample_shape=())[source]¶ Same as
sample
except that any intermediate computations are returned (useful for TransformedDistribution).Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(value)[source]¶ Evaluates the log probability density for a batch of samples given by value.
Parameters: value – A batch of samples from the distribution. Returns: an array with shape value.shape[:self.event_shape] Return type: numpy.ndarray

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

to_event
(reinterpreted_batch_ndims=None)[source]¶ Interpret the rightmost reinterpreted_batch_ndims batch dimensions as dependent event dimensions.
Parameters: reinterpreted_batch_ndims – Number of rightmost batch dims to interpret as event dims. Returns: An instance of Independent distribution. Return type: Independent

enumerate_support
(expand=True)[source]¶ Returns an array with shape len(support) x batch_shape containing all values in the support.

expand
(batch_shape)[source]¶ Returns a new
ExpandedDistribution
instance with batch dimensions expanded to batch_shape.Parameters: batch_shape (tuple) – batch shape to expand to. Returns: an instance of ExpandedDistribution. Return type: ExpandedDistribution

expand_by
(sample_shape)[source]¶ Expands a distribution by adding
sample_shape
to the left side of itsbatch_shape
. To expand internal dims ofself.batch_shape
from 1 to something larger, useexpand()
instead.Parameters: sample_shape (tuple) – The size of the iid batch to be drawn from the distribution. Returns: An expanded version of this distribution. Return type: ExpandedDistribution

mask
(mask)[source]¶ Masks a distribution by a boolean or booleanvalued array that is broadcastable to the distributions
Distribution.batch_shape
.Parameters: mask (bool or jnp.ndarray) – A boolean or boolean valued array (True includes a site, False excludes a site). Returns: A masked copy of this distribution. Return type: MaskedDistribution
ExpandedDistribution¶

class
ExpandedDistribution
(base_dist, batch_shape=())[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {}¶

expand
(batch_shape)[source]¶ Returns a new
ExpandedDistribution
instance with batch dimensions expanded to batch_shape.Parameters: batch_shape (tuple) – batch shape to expand to. Returns: an instance of ExpandedDistribution. Return type: ExpandedDistribution

has_enumerate_support
¶ bool(x) > bool
Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

is_discrete
¶ bool(x) > bool
Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

support
¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(value)[source]¶ Evaluates the log probability density for a batch of samples given by value.
Parameters: value – A batch of samples from the distribution. Returns: an array with shape value.shape[:self.event_shape] Return type: numpy.ndarray

enumerate_support
(expand=True)[source]¶ Returns an array with shape len(support) x batch_shape containing all values in the support.

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

ImproperUniform¶

class
ImproperUniform
(support, batch_shape, event_shape, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
A helper distribution with zero
log_prob()
over the support domain.Note
sample method is not implemented for this distribution. In autoguide and mcmc, initial parameters for improper sites are derived from init_to_uniform or init_to_value strategies.
Usage:
>>> from numpyro import sample >>> from numpyro.distributions import ImproperUniform, Normal, constraints >>> >>> def model(): ... # ordered vector with length 10 ... x = sample('x', ImproperUniform(constraints.ordered_vector, (), event_shape=(10,))) ... ... # real matrix with shape (3, 4) ... y = sample('y', ImproperUniform(constraints.real, (), event_shape=(3, 4))) ... ... # a shape(6, 8) batch of length5 vectors greater than 3 ... z = sample('z', ImproperUniform(constraints.greater_than(3), (6, 8), event_shape=(5,)))
If you want to set improper prior over all values greater than a, where a is another random variable, you might use
>>> def model(): ... a = sample('a', Normal(0, 1)) ... x = sample('x', ImproperUniform(constraints.greater_than(a), (), event_shape=()))
or if you want to reparameterize it
>>> from numpyro.distributions import TransformedDistribution, transforms >>> from numpyro.handlers import reparam >>> from numpyro.infer.reparam import TransformReparam >>> >>> def model(): ... a = sample('a', Normal(0, 1)) ... with reparam(config={'x': TransformReparam()}): ... x = sample('x', ... TransformedDistribution(ImproperUniform(constraints.positive, (), ()), ... transforms.AffineTransform(a, 1)))
Parameters:  support (Constraint) – the support of this distribution.
 batch_shape (tuple) – batch shape of this distribution. It is usually safe to set batch_shape=().
 event_shape (tuple) – event shape of this distribution.

arg_constraints
= {}¶

log_prob
(*args, **kwargs)¶
Independent¶

class
Independent
(base_dist, reinterpreted_batch_ndims, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Reinterprets batch dimensions of a distribution as event dims by shifting the batchevent dim boundary further to the left.
From a practical standpoint, this is useful when changing the result of
log_prob()
. For example, a univariate Normal distribution can be interpreted as a multivariate Normal with diagonal covariance:>>> import numpyro.distributions as dist >>> normal = dist.Normal(jnp.zeros(3), jnp.ones(3)) >>> [normal.batch_shape, normal.event_shape] [(3,), ()] >>> diag_normal = dist.Independent(normal, 1) >>> [diag_normal.batch_shape, diag_normal.event_shape] [(), (3,)]
Parameters:  base_distribution (numpyro.distribution.Distribution) – a distribution instance.
 reinterpreted_batch_ndims (int) – the number of batch dims to reinterpret as event dims.

arg_constraints
= {}¶

support
¶

has_enumerate_support
¶ bool(x) > bool
Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

is_discrete
¶ bool(x) > bool
Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

reparameterized_params
¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(value)[source]¶ Evaluates the log probability density for a batch of samples given by value.
Parameters: value – A batch of samples from the distribution. Returns: an array with shape value.shape[:self.event_shape] Return type: numpy.ndarray

expand
(batch_shape)[source]¶ Returns a new
ExpandedDistribution
instance with batch dimensions expanded to batch_shape.Parameters: batch_shape (tuple) – batch shape to expand to. Returns: an instance of ExpandedDistribution. Return type: ExpandedDistribution
MaskedDistribution¶

class
MaskedDistribution
(base_dist, mask)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Masks a distribution by a boolean array that is broadcastable to the distribution’s
Distribution.batch_shape
. In the special casemask is False
, computation oflog_prob()
, is skipped, and constant zero values are returned instead.Parameters: mask (jnp.ndarray or bool) – A boolean or booleanvalued array. 
arg_constraints
= {}¶

has_enumerate_support
¶ bool(x) > bool
Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

is_discrete
¶ bool(x) > bool
Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

support
¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(value)[source]¶ Evaluates the log probability density for a batch of samples given by value.
Parameters: value – A batch of samples from the distribution. Returns: an array with shape value.shape[:self.event_shape] Return type: numpy.ndarray

enumerate_support
(expand=True)[source]¶ Returns an array with shape len(support) x batch_shape containing all values in the support.

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

TransformedDistribution¶

class
TransformedDistribution
(base_distribution, transforms, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Returns a distribution instance obtained as a result of applying a sequence of transforms to a base distribution. For an example, see
LogNormal
andHalfNormal
.Parameters:  base_distribution – the base distribution over which to apply transforms.
 transforms – a single transform or a list of transforms.
 validate_args – Whether to enable validation of distribution parameters and arguments to .log_prob method.

arg_constraints
= {}¶

support
¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

sample_with_intermediates
(key, sample_shape=())[source]¶ Same as
sample
except that any intermediate computations are returned (useful for TransformedDistribution).Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.
Unit¶

class
Unit
(log_factor, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Trivial nonnormalized distribution representing the unit type.
The unit type has a single value with no data, i.e.
value.size == 0
.This is used for
numpyro.factor()
statements.
arg_constraints
= {'log_factor': <numpyro.distributions.constraints._Real object>}¶

support
= <numpyro.distributions.constraints._Real object>¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(value)[source]¶ Evaluates the log probability density for a batch of samples given by value.
Parameters: value – A batch of samples from the distribution. Returns: an array with shape value.shape[:self.event_shape] Return type: numpy.ndarray

Continuous Distributions¶
Beta¶

class
Beta
(concentration1, concentration0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'concentration0': <numpyro.distributions.constraints._GreaterThan object>, 'concentration1': <numpyro.distributions.constraints._GreaterThan object>}¶

support
= <numpyro.distributions.constraints._Interval object>¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

Cauchy¶

class
Cauchy
(loc=0.0, scale=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶

support
= <numpyro.distributions.constraints._Real object>¶

reparametrized_params
= ['loc', 'scale']¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

Chi2¶

class
Chi2
(df, validate_args=None)[source]¶ Bases:
numpyro.distributions.continuous.Gamma

arg_constraints
= {'df': <numpyro.distributions.constraints._GreaterThan object>}¶

Dirichlet¶

class
Dirichlet
(concentration, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'concentration': <numpyro.distributions.constraints._GreaterThan object>}¶

support
= <numpyro.distributions.constraints._Simplex object>¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

Exponential¶

class
Exponential
(rate=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

reparametrized_params
= ['rate']¶

arg_constraints
= {'rate': <numpyro.distributions.constraints._GreaterThan object>}¶

support
= <numpyro.distributions.constraints._GreaterThan object>¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

Gamma¶

class
Gamma
(concentration, rate=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'concentration': <numpyro.distributions.constraints._GreaterThan object>, 'rate': <numpyro.distributions.constraints._GreaterThan object>}¶

support
= <numpyro.distributions.constraints._GreaterThan object>¶

reparametrized_params
= ['rate']¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

Gumbel¶

class
Gumbel
(loc=0.0, scale=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶

support
= <numpyro.distributions.constraints._Real object>¶

reparametrized_params
= ['loc', 'scale']¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

GaussianRandomWalk¶

class
GaussianRandomWalk
(scale=1.0, num_steps=1, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'num_steps': <numpyro.distributions.constraints._IntegerGreaterThan object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶

support
= <numpyro.distributions.constraints._RealVector object>¶

reparametrized_params
= ['scale']¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

HalfCauchy¶

class
HalfCauchy
(scale=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

reparametrized_params
= ['scale']¶

support
= <numpyro.distributions.constraints._GreaterThan object>¶

arg_constraints
= {'scale': <numpyro.distributions.constraints._GreaterThan object>}¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

HalfNormal¶

class
HalfNormal
(scale=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

reparametrized_params
= ['scale']¶

support
= <numpyro.distributions.constraints._GreaterThan object>¶

arg_constraints
= {'scale': <numpyro.distributions.constraints._GreaterThan object>}¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

InverseGamma¶

class
InverseGamma
(concentration, rate=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.TransformedDistribution

arg_constraints
= {'concentration': <numpyro.distributions.constraints._GreaterThan object>, 'rate': <numpyro.distributions.constraints._GreaterThan object>}¶

support
= <numpyro.distributions.constraints._GreaterThan object>¶

reparametrized_params
= ['rate']¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

Laplace¶

class
Laplace
(loc=0.0, scale=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶

support
= <numpyro.distributions.constraints._Real object>¶

reparametrized_params
= ['loc', 'scale']¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

LKJ¶

class
LKJ
(dimension, concentration=1.0, sample_method='onion', validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.TransformedDistribution
LKJ distribution for correlation matrices. The distribution is controlled by
concentration
parameter \(\eta\) to make the probability of the correlation matrix \(M\) propotional to \(\det(M)^{\eta  1}\). Because of that, whenconcentration == 1
, we have a uniform distribution over correlation matrices.When
concentration > 1
, the distribution favors samples with large large determinent. This is useful when we know a priori that the underlying variables are not correlated.When
concentration < 1
, the distribution favors samples with small determinent. This is useful when we know a priori that some underlying variables are correlated.Parameters:  dimension (int) – dimension of the matrices
 concentration (ndarray) – concentration/shape parameter of the distribution (often referred to as eta)
 sample_method (str) – Either “cvine” or “onion”. Both methods are proposed in [1] and offer the same distribution over correlation matrices. But they are different in how to generate samples. Defaults to “onion”.
References
[1] Generating random correlation matrices based on vines and extended onion method, Daniel Lewandowski, Dorota Kurowicka, Harry Joe

arg_constraints
= {'concentration': <numpyro.distributions.constraints._GreaterThan object>}¶

support
= <numpyro.distributions.constraints._CorrMatrix object>¶

mean
¶ Mean of the distribution.
LKJCholesky¶

class
LKJCholesky
(dimension, concentration=1.0, sample_method='onion', validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
LKJ distribution for lower Cholesky factors of correlation matrices. The distribution is controlled by
concentration
parameter \(\eta\) to make the probability of the correlation matrix \(M\) generated from a Cholesky factor propotional to \(\det(M)^{\eta  1}\). Because of that, whenconcentration == 1
, we have a uniform distribution over Cholesky factors of correlation matrices.When
concentration > 1
, the distribution favors samples with large diagonal entries (hence large determinent). This is useful when we know a priori that the underlying variables are not correlated.When
concentration < 1
, the distribution favors samples with small diagonal entries (hence small determinent). This is useful when we know a priori that some underlying variables are correlated.Parameters:  dimension (int) – dimension of the matrices
 concentration (ndarray) – concentration/shape parameter of the distribution (often referred to as eta)
 sample_method (str) – Either “cvine” or “onion”. Both methods are proposed in [1] and offer the same distribution over correlation matrices. But they are different in how to generate samples. Defaults to “onion”.
References
[1] Generating random correlation matrices based on vines and extended onion method, Daniel Lewandowski, Dorota Kurowicka, Harry Joe

arg_constraints
= {'concentration': <numpyro.distributions.constraints._GreaterThan object>}¶

support
= <numpyro.distributions.constraints._CorrCholesky object>¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶
LogNormal¶

class
LogNormal
(loc=0.0, scale=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.TransformedDistribution

arg_constraints
= {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶

reparametrized_params
= ['loc', 'scale']¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

Logistic¶

class
Logistic
(loc=0.0, scale=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶

support
= <numpyro.distributions.constraints._Real object>¶

reparametrized_params
= ['loc', 'real']¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

MultivariateNormal¶

class
MultivariateNormal
(loc=0.0, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'covariance_matrix': <numpyro.distributions.constraints._PositiveDefinite object>, 'loc': <numpyro.distributions.constraints._RealVector object>, 'precision_matrix': <numpyro.distributions.constraints._PositiveDefinite object>, 'scale_tril': <numpyro.distributions.constraints._LowerCholesky object>}¶

support
= <numpyro.distributions.constraints._RealVector object>¶

reparametrized_params
= ['loc', 'covariance_matrix', 'precision_matrix', 'scale_tril']¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

LowRankMultivariateNormal¶

class
LowRankMultivariateNormal
(loc, cov_factor, cov_diag, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'cov_diag': <numpyro.distributions.constraints._GreaterThan object>, 'cov_factor': <numpyro.distributions.constraints._Real object>, 'loc': <numpyro.distributions.constraints._RealVector object>}¶

support
= <numpyro.distributions.constraints._RealVector object>¶

mean
¶ Mean of the distribution.

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

Normal¶

class
Normal
(loc=0.0, scale=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶

support
= <numpyro.distributions.constraints._Real object>¶

reparametrized_params
= ['loc', 'scale']¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

Pareto¶

class
Pareto
(scale, alpha, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.TransformedDistribution

arg_constraints
= {'alpha': <numpyro.distributions.constraints._GreaterThan object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

support
¶

StudentT¶

class
StudentT
(df, loc=0.0, scale=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'df': <numpyro.distributions.constraints._GreaterThan object>, 'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶

support
= <numpyro.distributions.constraints._Real object>¶

reparametrized_params
= ['loc', 'scale']¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

TruncatedCauchy¶

class
TruncatedCauchy
(low=0.0, loc=0.0, scale=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.TransformedDistribution

arg_constraints
= {'loc': <numpyro.distributions.constraints._Real object>, 'low': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶

reparametrized_params
= ['low', 'loc', 'scale']¶

support
¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

TruncatedNormal¶

class
TruncatedNormal
(low=0.0, loc=0.0, scale=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.TransformedDistribution

arg_constraints
= {'loc': <numpyro.distributions.constraints._Real object>, 'low': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶

reparametrized_params
= ['low', 'loc', 'scale']¶

support
¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

TruncatedPolyaGamma¶

class
TruncatedPolyaGamma
(batch_shape=(), validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

truncation_point
= 2.5¶

num_log_prob_terms
= 7¶

num_gamma_variates
= 8¶

arg_constraints
= {}¶

support
= <numpyro.distributions.constraints._Interval object>¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

Uniform¶

class
Uniform
(low=0.0, high=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.TransformedDistribution

arg_constraints
= {'high': <numpyro.distributions.constraints._Dependent object>, 'low': <numpyro.distributions.constraints._Dependent object>}¶

reparametrized_params
= ['low', 'high']¶

support
¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

Discrete Distributions¶
BernoulliLogits¶

class
BernoulliLogits
(logits=None, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'logits': <numpyro.distributions.constraints._Real object>}¶

support
= <numpyro.distributions.constraints._Boolean object>¶

has_enumerate_support
= True¶

is_discrete
= True¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

BernoulliProbs¶

class
BernoulliProbs
(probs, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'probs': <numpyro.distributions.constraints._Interval object>}¶

support
= <numpyro.distributions.constraints._Boolean object>¶

has_enumerate_support
= True¶

is_discrete
= True¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

BetaBinomial¶

class
BetaBinomial
(concentration1, concentration0, total_count=1, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Compound distribution comprising of a betabinomial pair. The probability of success (
probs
for theBinomial
distribution) is unknown and randomly drawn from aBeta
distribution prior to a certain number of Bernoulli trials given bytotal_count
.Parameters:  concentration1 (numpy.ndarray) – 1st concentration parameter (alpha) for the Beta distribution.
 concentration0 (numpy.ndarray) – 2nd concentration parameter (beta) for the Beta distribution.
 total_count (numpy.ndarray) – number of Bernoulli trials.

arg_constraints
= {'concentration0': <numpyro.distributions.constraints._GreaterThan object>, 'concentration1': <numpyro.distributions.constraints._GreaterThan object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}¶

has_enumerate_support
= True¶

is_discrete
= True¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

support
¶
BinomialLogits¶

class
BinomialLogits
(logits, total_count=1, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'logits': <numpyro.distributions.constraints._Real object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}¶

has_enumerate_support
= True¶

is_discrete
= True¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

support
¶

BinomialProbs¶

class
BinomialProbs
(probs, total_count=1, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'probs': <numpyro.distributions.constraints._Interval object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}¶

has_enumerate_support
= True¶

is_discrete
= True¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

support
¶

CategoricalLogits¶

class
CategoricalLogits
(logits, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'logits': <numpyro.distributions.constraints._RealVector object>}¶

has_enumerate_support
= True¶

is_discrete
= True¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

support
¶

CategoricalProbs¶

class
CategoricalProbs
(probs, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'probs': <numpyro.distributions.constraints._Simplex object>}¶

has_enumerate_support
= True¶

is_discrete
= True¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

support
¶

Delta¶

class
Delta
(value=0.0, log_density=0.0, event_dim=0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'log_density': <numpyro.distributions.constraints._Real object>, 'value': <numpyro.distributions.constraints._Real object>}¶

support
= <numpyro.distributions.constraints._Real object>¶

is_discrete
= True¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

GammaPoisson¶

class
GammaPoisson
(concentration, rate=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Compound distribution comprising of a gammapoisson pair, also referred to as a gammapoisson mixture. The
rate
parameter for thePoisson
distribution is unknown and randomly drawn from aGamma
distribution.Parameters:  concentration (numpy.ndarray) – shape parameter (alpha) of the Gamma distribution.
 rate (numpy.ndarray) – rate parameter (beta) for the Gamma distribution.

arg_constraints
= {'concentration': <numpyro.distributions.constraints._GreaterThan object>, 'rate': <numpyro.distributions.constraints._GreaterThan object>}¶

support
= <numpyro.distributions.constraints._IntegerGreaterThan object>¶

is_discrete
= True¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.
GeometricLogits¶

class
GeometricLogits
(logits, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'logits': <numpyro.distributions.constraints._Real object>}¶

support
= <numpyro.distributions.constraints._IntegerGreaterThan object>¶

is_discrete
= True¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

GeometricProbs¶

class
GeometricProbs
(probs, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'probs': <numpyro.distributions.constraints._Interval object>}¶

support
= <numpyro.distributions.constraints._IntegerGreaterThan object>¶

is_discrete
= True¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

MultinomialLogits¶

class
MultinomialLogits
(logits, total_count=1, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'logits': <numpyro.distributions.constraints._RealVector object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}¶

is_discrete
= True¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

support
¶

MultinomialProbs¶

class
MultinomialProbs
(probs, total_count=1, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'probs': <numpyro.distributions.constraints._Simplex object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}¶

is_discrete
= True¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

support
¶

OrderedLogistic¶

class
OrderedLogistic
(predictor, cutpoints, validate_args=None)[source]¶ Bases:
numpyro.distributions.discrete.CategoricalProbs
A categorical distribution with ordered outcomes.
References:
 Stan Functions Reference, v2.20 section 12.6, Stan Development Team
Parameters:  predictor (numpy.ndarray) – prediction in real domain; typically this is output of a linear model.
 cutpoints (numpy.ndarray) – positions in real domain to separate categories.

arg_constraints
= {'cutpoints': <numpyro.distributions.constraints._OrderedVector object>, 'predictor': <numpyro.distributions.constraints._Real object>}¶
Poisson¶

class
Poisson
(rate, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'rate': <numpyro.distributions.constraints._GreaterThan object>}¶

support
= <numpyro.distributions.constraints._IntegerGreaterThan object>¶

is_discrete
= True¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

PRNGIdentity¶

class
PRNGIdentity
[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Distribution over
PRNGKey()
. This can be used to draw a batch ofPRNGKey()
using theseed
handler. Only sample method is supported.
is_discrete
= True¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

ZeroInflatedPoisson¶

class
ZeroInflatedPoisson
(gate, rate=1.0, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
A Zero Inflated Poisson distribution.
Parameters:  gate (numpy.ndarray) – probability of extra zeros.
 rate (numpy.ndarray) – rate of Poisson distribution.

arg_constraints
= {'gate': <numpyro.distributions.constraints._Interval object>, 'rate': <numpyro.distributions.constraints._GreaterThan object>}¶

support
= <numpyro.distributions.constraints._IntegerGreaterThan object>¶

is_discrete
= True¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶
Directional Distributions¶
VonMises¶

class
VonMises
(loc, concentration, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'concentration': <numpyro.distributions.constraints._GreaterThan object>, 'loc': <numpyro.distributions.constraints._Real object>}¶

support
= <numpyro.distributions.constraints._Interval object>¶

sample
(key, sample_shape=())[source]¶ Generate sample from von Mises distribution
Parameters:  sample_shape – shape of samples
 key – random number generator key
Returns: samples from von Mises

log_prob
(*args, **kwargs)¶

mean
¶ Computes circular mean of distribution. NOTE: same as location when mapped to support [pi, pi]

variance
¶ Computes circular variance of distribution

Constraints¶
Constraint¶
greater_than¶

greater_than
(lower_bound)¶ Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.
integer_interval¶

integer_interval
(lower_bound, upper_bound)¶ Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.
integer_greater_than¶

integer_greater_than
(lower_bound)¶ Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.
interval¶

interval
(lower_bound, upper_bound)¶ Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.
less_than¶

less_than
(upper_bound)¶ Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.
multinomial¶

multinomial
(upper_bound)¶ Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.
nonnegative_integer¶

nonnegative_integer
= <numpyro.distributions.constraints._IntegerGreaterThan object>¶
positive_definite¶

positive_definite
= <numpyro.distributions.constraints._PositiveDefinite object>¶
Transforms¶
Transform¶
AbsTransform¶
AffineTransform¶

class
AffineTransform
(loc, scale, domain=<numpyro.distributions.constraints._Real object>)[source]¶ Bases:
numpyro.distributions.transforms.Transform
Note
When scale is a JAX tracer, we always assume that scale > 0 when calculating codomain.

codomain
¶

event_dim
¶ int([x]) > integer int(x, base=10) > integer
Convert a number or string to an integer, or return 0 if no arguments are given. If x is a number, return x.__int__(). For floating point numbers, this truncates towards zero.
If x is not a number or if base is given, then x must be a string, bytes, or bytearray instance representing an integer literal in the given base. The literal can be preceded by ‘+’ or ‘‘ and be surrounded by whitespace. The base defaults to 10. Valid bases are 0 and 236. Base 0 means to interpret the base from the string as an integer literal. >>> int(‘0b100’, base=0) 4

ComposeTransform¶

class
ComposeTransform
(parts)[source]¶ Bases:
numpyro.distributions.transforms.Transform

domain
¶

codomain
¶

event_dim
¶ int([x]) > integer int(x, base=10) > integer
Convert a number or string to an integer, or return 0 if no arguments are given. If x is a number, return x.__int__(). For floating point numbers, this truncates towards zero.
If x is not a number or if base is given, then x must be a string, bytes, or bytearray instance representing an integer literal in the given base. The literal can be preceded by ‘+’ or ‘‘ and be surrounded by whitespace. The base defaults to 10. Valid bases are 0 and 236. Base 0 means to interpret the base from the string as an integer literal. >>> int(‘0b100’, base=0) 4

CorrCholeskyTransform¶

class
CorrCholeskyTransform
[source]¶ Bases:
numpyro.distributions.transforms.Transform
Transforms a uncontrained real vector \(x\) with length \(D*(D1)/2\) into the Cholesky factor of a Ddimension correlation matrix. This Cholesky factor is a lower triangular matrix with positive diagonals and unit Euclidean norm for each row. The transform is processed as follows:
 First we convert \(x\) into a lower triangular matrix with the following order:
\[\begin{split}\begin{bmatrix} 1 & 0 & 0 & 0 \\ x_0 & 1 & 0 & 0 \\ x_1 & x_2 & 1 & 0 \\ x_3 & x_4 & x_5 & 1 \end{bmatrix}\end{split}\]2. For each row \(X_i\) of the lower triangular part, we apply a signed version of class
StickBreakingTransform
to transform \(X_i\) into a unit Euclidean length vector using the following steps: Scales into the interval \((1, 1)\) domain: \(r_i = \tanh(X_i)\).
 Transforms into an unsigned domain: \(z_i = r_i^2\).
 Applies \(s_i = StickBreakingTransform(z_i)\).
 Transforms back into signed domain: \(y_i = (sign(r_i), 1) * \sqrt{s_i}\).

domain
= <numpyro.distributions.constraints._RealVector object>¶

codomain
= <numpyro.distributions.constraints._CorrCholesky object>¶

event_dim
= 2¶
ExpTransform¶
IdentityTransform¶
InvCholeskyTransform¶

class
InvCholeskyTransform
(domain=<numpyro.distributions.constraints._LowerCholesky object>)[source]¶ Bases:
numpyro.distributions.transforms.Transform
Transform via the mapping \(y = x @ x.T\), where x is a lower triangular matrix with positive diagonal.

event_dim
= 2¶

codomain
¶

LowerCholeskyAffine¶

class
LowerCholeskyAffine
(loc, scale_tril)[source]¶ Bases:
numpyro.distributions.transforms.Transform
Transform via the mapping \(y = loc + scale\_tril\ @\ x\).
Parameters:  loc – a real vector.
 scale_tril – a lower triangular matrix with positive diagonal.

domain
= <numpyro.distributions.constraints._RealVector object>¶

codomain
= <numpyro.distributions.constraints._RealVector object>¶

event_dim
= 1¶
LowerCholeskyTransform¶
OrderedTransform¶

class
OrderedTransform
[source]¶ Bases:
numpyro.distributions.transforms.Transform
Transform a real vector to an ordered vector.
References:
 Stan Reference Manual v2.20, section 10.6, Stan Development Team

domain
= <numpyro.distributions.constraints._RealVector object>¶

codomain
= <numpyro.distributions.constraints._OrderedVector object>¶

event_dim
= 1¶
PermuteTransform¶
PowerTransform¶
SigmoidTransform¶
Flows¶
InverseAutoregressiveTransform¶

class
InverseAutoregressiveTransform
(autoregressive_nn, log_scale_min_clip=5.0, log_scale_max_clip=3.0)[source]¶ Bases:
numpyro.distributions.transforms.Transform
An implementation of Inverse Autoregressive Flow, using Eq (10) from Kingma et al., 2016,
\(\mathbf{y} = \mu_t + \sigma_t\odot\mathbf{x}\)where \(\mathbf{x}\) are the inputs, \(\mathbf{y}\) are the outputs, \(\mu_t,\sigma_t\) are calculated from an autoregressive network on \(\mathbf{x}\), and \(\sigma_t>0\).
References
 Improving Variational Inference with Inverse Autoregressive Flow [arXiv:1606.04934], Diederik P. Kingma, Tim Salimans, Rafal Jozefowicz, Xi Chen, Ilya Sutskever, Max Welling

domain
= <numpyro.distributions.constraints._RealVector object>¶

codomain
= <numpyro.distributions.constraints._RealVector object>¶

event_dim
= 1¶

inv
(y)[source]¶ Parameters: y (numpy.ndarray) – the output of the transform to be inverted

log_abs_det_jacobian
(x, y, intermediates=None)[source]¶ Calculates the elementwise determinant of the log jacobian.
Parameters:  x (numpy.ndarray) – the input to the transform
 y (numpy.ndarray) – the output of the transform
BlockNeuralAutoregressiveTransform¶

class
BlockNeuralAutoregressiveTransform
(bn_arn)[source]¶ Bases:
numpyro.distributions.transforms.Transform
An implementation of Block Neural Autoregressive flow.
References
 Block Neural Autoregressive Flow, Nicola De Cao, Ivan Titov, Wilker Aziz

event_dim
= 1¶

log_abs_det_jacobian
(x, y, intermediates=None)[source]¶ Calculates the elementwise determinant of the log jacobian.
Parameters:  x (numpy.ndarray) – the input to the transform
 y (numpy.ndarray) – the output of the transform