# Base Distribution¶

## Distribution¶

class Distribution(batch_shape=(), event_shape=(), validate_args=None)[source]

Bases: object

Base class for probability distributions in NumPyro. The design largely follows from torch.distributions.

Parameters: batch_shape – The batch shape for the distribution. This designates independent (possibly non-identical) dimensions of a sample from the distribution. This is fixed for a distribution instance and is inferred from the shape of the distribution parameters. event_shape – The event shape for the distribution. This designates the dependent dimensions of a sample from the distribution. These are collapsed when we evaluate the log probability density of a batch of samples using .log_prob. validate_args – Whether to enable validation of distribution parameters and arguments to .log_prob method.

As an example:

>>> import jax.numpy as jnp
>>> import numpyro.distributions as dist
>>> d = dist.Dirichlet(jnp.ones((2, 3, 4)))
>>> d.batch_shape
(2, 3)
>>> d.event_shape
(4,)

arg_constraints = {}
support = None
has_enumerate_support = False
is_discrete = False
reparametrized_params = []
tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]
static set_default_validate_args(value)[source]
batch_shape

Returns the shape over which the distribution parameters are batched.

Returns: batch shape of the distribution. tuple
event_shape

Returns the shape of a single sample from the distribution without batching.

Returns: event shape of the distribution. tuple
event_dim
Returns: Number of dimensions of individual events. int
shape(sample_shape=())[source]

The tensor shape of samples from this distribution.

Samples are of shape:

d.shape(sample_shape) == sample_shape + d.batch_shape + d.event_shape

Parameters: sample_shape (tuple) – the size of the iid batch to be drawn from the distribution. shape of samples. tuple
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
sample_with_intermediates(key, sample_shape=())[source]

Same as sample except that any intermediate computations are returned (useful for TransformedDistribution).

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(value)[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters: value – A batch of samples from the distribution. an array with shape value.shape[:-self.event_shape] numpy.ndarray
mean

Mean of the distribution.

variance

Variance of the distribution.

to_event(reinterpreted_batch_ndims=None)[source]

Interpret the rightmost reinterpreted_batch_ndims batch dimensions as dependent event dimensions.

Parameters: reinterpreted_batch_ndims – Number of rightmost batch dims to interpret as event dims. An instance of Independent distribution. Independent
enumerate_support(expand=True)[source]

Returns an array with shape len(support) x batch_shape containing all values in the support.

expand(batch_shape)[source]

Returns a new ExpandedDistribution instance with batch dimensions expanded to batch_shape.

Parameters: batch_shape (tuple) – batch shape to expand to. an instance of ExpandedDistribution. ExpandedDistribution
expand_by(sample_shape)[source]

Expands a distribution by adding sample_shape to the left side of its batch_shape. To expand internal dims of self.batch_shape from 1 to something larger, use expand() instead.

Parameters: sample_shape (tuple) – The size of the iid batch to be drawn from the distribution. An expanded version of this distribution. ExpandedDistribution
mask(mask)[source]

Masks a distribution by a boolean or boolean-valued array that is broadcastable to the distributions Distribution.batch_shape .

Parameters: mask (bool or jnp.ndarray) – A boolean or boolean valued array (True includes a site, False excludes a site). A masked copy of this distribution. MaskedDistribution

## ExpandedDistribution¶

class ExpandedDistribution(base_dist, batch_shape=())[source]
arg_constraints = {}
expand(batch_shape)[source]

Returns a new ExpandedDistribution instance with batch dimensions expanded to batch_shape.

Parameters: batch_shape (tuple) – batch shape to expand to. an instance of ExpandedDistribution. ExpandedDistribution
has_enumerate_support

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

is_discrete

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

support
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(value)[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters: value – A batch of samples from the distribution. an array with shape value.shape[:-self.event_shape] numpy.ndarray
enumerate_support(expand=True)[source]

Returns an array with shape len(support) x batch_shape containing all values in the support.

mean

Mean of the distribution.

variance

Variance of the distribution.

tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]

## ImproperUniform¶

class ImproperUniform(support, batch_shape, event_shape, validate_args=None)[source]

A helper distribution with zero log_prob() over the support domain.

Note

sample method is not implemented for this distribution. In autoguide and mcmc, initial parameters for improper sites are derived from init_to_uniform or init_to_value strategies.

Usage:

>>> from numpyro import sample
>>> from numpyro.distributions import ImproperUniform, Normal, constraints
>>>
>>> def model():
...     # ordered vector with length 10
...     x = sample('x', ImproperUniform(constraints.ordered_vector, (), event_shape=(10,)))
...
...     # real matrix with shape (3, 4)
...     y = sample('y', ImproperUniform(constraints.real, (), event_shape=(3, 4)))
...
...     # a shape-(6, 8) batch of length-5 vectors greater than 3
...     z = sample('z', ImproperUniform(constraints.greater_than(3), (6, 8), event_shape=(5,)))


If you want to set improper prior over all values greater than a, where a is another random variable, you might use

>>> def model():
...     a = sample('a', Normal(0, 1))
...     x = sample('x', ImproperUniform(constraints.greater_than(a), (), event_shape=()))


or if you want to reparameterize it

>>> from numpyro.distributions import TransformedDistribution, transforms
>>> from numpyro.handlers import reparam
>>> from numpyro.infer.reparam import TransformReparam
>>>
>>> def model():
...     a = sample('a', Normal(0, 1))
...     with reparam(config={'x': TransformReparam()}):
...         x = sample('x',
...                    TransformedDistribution(ImproperUniform(constraints.positive, (), ()),
...                                            transforms.AffineTransform(a, 1)))

Parameters: support (Constraint) – the support of this distribution. batch_shape (tuple) – batch shape of this distribution. It is usually safe to set batch_shape=(). event_shape (tuple) – event shape of this distribution.
arg_constraints = {}
log_prob(*args, **kwargs)
tree_flatten()[source]

## Independent¶

class Independent(base_dist, reinterpreted_batch_ndims, validate_args=None)[source]

Reinterprets batch dimensions of a distribution as event dims by shifting the batch-event dim boundary further to the left.

From a practical standpoint, this is useful when changing the result of log_prob(). For example, a univariate Normal distribution can be interpreted as a multivariate Normal with diagonal covariance:

>>> import numpyro.distributions as dist
>>> normal = dist.Normal(jnp.zeros(3), jnp.ones(3))
>>> [normal.batch_shape, normal.event_shape]
[(3,), ()]
>>> diag_normal = dist.Independent(normal, 1)
>>> [diag_normal.batch_shape, diag_normal.event_shape]
[(), (3,)]

Parameters: base_distribution (numpyro.distribution.Distribution) – a distribution instance. reinterpreted_batch_ndims (int) – the number of batch dims to reinterpret as event dims.
arg_constraints = {}
support
has_enumerate_support

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

is_discrete

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

reparameterized_params
mean

Mean of the distribution.

variance

Variance of the distribution.

sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(value)[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters: value – A batch of samples from the distribution. an array with shape value.shape[:-self.event_shape] numpy.ndarray
expand(batch_shape)[source]

Returns a new ExpandedDistribution instance with batch dimensions expanded to batch_shape.

Parameters: batch_shape (tuple) – batch shape to expand to. an instance of ExpandedDistribution. ExpandedDistribution
tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]

class MaskedDistribution(base_dist, mask)[source]

Masks a distribution by a boolean array that is broadcastable to the distribution’s Distribution.batch_shape. In the special case mask is False, computation of log_prob() , is skipped, and constant zero values are returned instead.

Parameters: mask (jnp.ndarray or bool) – A boolean or boolean-valued array.
arg_constraints = {}
has_enumerate_support

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

is_discrete

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

support
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(value)[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters: value – A batch of samples from the distribution. an array with shape value.shape[:-self.event_shape] numpy.ndarray
enumerate_support(expand=True)[source]

Returns an array with shape len(support) x batch_shape containing all values in the support.

mean

Mean of the distribution.

variance

Variance of the distribution.

tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]

## TransformedDistribution¶

class TransformedDistribution(base_distribution, transforms, validate_args=None)[source]

Returns a distribution instance obtained as a result of applying a sequence of transforms to a base distribution. For an example, see LogNormal and HalfNormal.

Parameters: base_distribution – the base distribution over which to apply transforms. transforms – a single transform or a list of transforms. validate_args – Whether to enable validation of distribution parameters and arguments to .log_prob method.
arg_constraints = {}
support
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
sample_with_intermediates(key, sample_shape=())[source]

Same as sample except that any intermediate computations are returned (useful for TransformedDistribution).

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

tree_flatten()[source]

## Unit¶

class Unit(log_factor, validate_args=None)[source]

Trivial nonnormalized distribution representing the unit type.

The unit type has a single value with no data, i.e. value.size == 0.

This is used for numpyro.factor() statements.

arg_constraints = {'log_factor': <numpyro.distributions.constraints._Real object>}
support = <numpyro.distributions.constraints._Real object>
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(value)[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters: value – A batch of samples from the distribution. an array with shape value.shape[:-self.event_shape] numpy.ndarray

# Continuous Distributions¶

## Beta¶

class Beta(concentration1, concentration0, validate_args=None)[source]
arg_constraints = {'concentration0': <numpyro.distributions.constraints._GreaterThan object>, 'concentration1': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._Interval object>
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

## Cauchy¶

class Cauchy(loc=0.0, scale=1.0, validate_args=None)[source]
arg_constraints = {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._Real object>
reparametrized_params = ['loc', 'scale']
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

## Chi2¶

class Chi2(df, validate_args=None)[source]
arg_constraints = {'df': <numpyro.distributions.constraints._GreaterThan object>}

## Dirichlet¶

class Dirichlet(concentration, validate_args=None)[source]
arg_constraints = {'concentration': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._Simplex object>
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

## Exponential¶

class Exponential(rate=1.0, validate_args=None)[source]
reparametrized_params = ['rate']
arg_constraints = {'rate': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._GreaterThan object>
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

## Gamma¶

class Gamma(concentration, rate=1.0, validate_args=None)[source]
arg_constraints = {'concentration': <numpyro.distributions.constraints._GreaterThan object>, 'rate': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._GreaterThan object>
reparametrized_params = ['rate']
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

## Gumbel¶

class Gumbel(loc=0.0, scale=1.0, validate_args=None)[source]
arg_constraints = {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._Real object>
reparametrized_params = ['loc', 'scale']
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

## GaussianRandomWalk¶

class GaussianRandomWalk(scale=1.0, num_steps=1, validate_args=None)[source]
arg_constraints = {'num_steps': <numpyro.distributions.constraints._IntegerGreaterThan object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._RealVector object>
reparametrized_params = ['scale']
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]

## HalfCauchy¶

class HalfCauchy(scale=1.0, validate_args=None)[source]
reparametrized_params = ['scale']
support = <numpyro.distributions.constraints._GreaterThan object>
arg_constraints = {'scale': <numpyro.distributions.constraints._GreaterThan object>}
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

## HalfNormal¶

class HalfNormal(scale=1.0, validate_args=None)[source]
reparametrized_params = ['scale']
support = <numpyro.distributions.constraints._GreaterThan object>
arg_constraints = {'scale': <numpyro.distributions.constraints._GreaterThan object>}
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

## InverseGamma¶

class InverseGamma(concentration, rate=1.0, validate_args=None)[source]
arg_constraints = {'concentration': <numpyro.distributions.constraints._GreaterThan object>, 'rate': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._GreaterThan object>
reparametrized_params = ['rate']
mean

Mean of the distribution.

variance

Variance of the distribution.

tree_flatten()[source]

## Laplace¶

class Laplace(loc=0.0, scale=1.0, validate_args=None)[source]
arg_constraints = {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._Real object>
reparametrized_params = ['loc', 'scale']
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

## LKJ¶

class LKJ(dimension, concentration=1.0, sample_method='onion', validate_args=None)[source]

LKJ distribution for correlation matrices. The distribution is controlled by concentration parameter $$\eta$$ to make the probability of the correlation matrix $$M$$ propotional to $$\det(M)^{\eta - 1}$$. Because of that, when concentration == 1, we have a uniform distribution over correlation matrices.

When concentration > 1, the distribution favors samples with large large determinent. This is useful when we know a priori that the underlying variables are not correlated.

When concentration < 1, the distribution favors samples with small determinent. This is useful when we know a priori that some underlying variables are correlated.

Parameters: dimension (int) – dimension of the matrices concentration (ndarray) – concentration/shape parameter of the distribution (often referred to as eta) sample_method (str) – Either “cvine” or “onion”. Both methods are proposed in [1] and offer the same distribution over correlation matrices. But they are different in how to generate samples. Defaults to “onion”.

References

[1] Generating random correlation matrices based on vines and extended onion method, Daniel Lewandowski, Dorota Kurowicka, Harry Joe

arg_constraints = {'concentration': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._CorrMatrix object>
mean

Mean of the distribution.

tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]

## LKJCholesky¶

class LKJCholesky(dimension, concentration=1.0, sample_method='onion', validate_args=None)[source]

LKJ distribution for lower Cholesky factors of correlation matrices. The distribution is controlled by concentration parameter $$\eta$$ to make the probability of the correlation matrix $$M$$ generated from a Cholesky factor propotional to $$\det(M)^{\eta - 1}$$. Because of that, when concentration == 1, we have a uniform distribution over Cholesky factors of correlation matrices.

When concentration > 1, the distribution favors samples with large diagonal entries (hence large determinent). This is useful when we know a priori that the underlying variables are not correlated.

When concentration < 1, the distribution favors samples with small diagonal entries (hence small determinent). This is useful when we know a priori that some underlying variables are correlated.

Parameters: dimension (int) – dimension of the matrices concentration (ndarray) – concentration/shape parameter of the distribution (often referred to as eta) sample_method (str) – Either “cvine” or “onion”. Both methods are proposed in [1] and offer the same distribution over correlation matrices. But they are different in how to generate samples. Defaults to “onion”.

References

[1] Generating random correlation matrices based on vines and extended onion method, Daniel Lewandowski, Dorota Kurowicka, Harry Joe

arg_constraints = {'concentration': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._CorrCholesky object>
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(*args, **kwargs)
tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]

## LogNormal¶

class LogNormal(loc=0.0, scale=1.0, validate_args=None)[source]
arg_constraints = {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
reparametrized_params = ['loc', 'scale']
mean

Mean of the distribution.

variance

Variance of the distribution.

tree_flatten()[source]

## Logistic¶

class Logistic(loc=0.0, scale=1.0, validate_args=None)[source]
arg_constraints = {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._Real object>
reparametrized_params = ['loc', 'real']
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

## MultivariateNormal¶

class MultivariateNormal(loc=0.0, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None)[source]
arg_constraints = {'covariance_matrix': <numpyro.distributions.constraints._PositiveDefinite object>, 'loc': <numpyro.distributions.constraints._RealVector object>, 'precision_matrix': <numpyro.distributions.constraints._PositiveDefinite object>, 'scale_tril': <numpyro.distributions.constraints._LowerCholesky object>}
support = <numpyro.distributions.constraints._RealVector object>
reparametrized_params = ['loc', 'covariance_matrix', 'precision_matrix', 'scale_tril']
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(*args, **kwargs)
covariance_matrix[source]
precision_matrix[source]
mean

Mean of the distribution.

variance

Variance of the distribution.

tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]

## LowRankMultivariateNormal¶

class LowRankMultivariateNormal(loc, cov_factor, cov_diag, validate_args=None)[source]
arg_constraints = {'cov_diag': <numpyro.distributions.constraints._GreaterThan object>, 'cov_factor': <numpyro.distributions.constraints._Real object>, 'loc': <numpyro.distributions.constraints._RealVector object>}
support = <numpyro.distributions.constraints._RealVector object>
mean

Mean of the distribution.

variance[source]
scale_tril[source]
covariance_matrix[source]
precision_matrix[source]
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(*args, **kwargs)
entropy()[source]

## Normal¶

class Normal(loc=0.0, scale=1.0, validate_args=None)[source]
arg_constraints = {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._Real object>
reparametrized_params = ['loc', 'scale']
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(*args, **kwargs)
icdf(q)[source]
mean

Mean of the distribution.

variance

Variance of the distribution.

## Pareto¶

class Pareto(scale, alpha, validate_args=None)[source]
arg_constraints = {'alpha': <numpyro.distributions.constraints._GreaterThan object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
mean

Mean of the distribution.

variance

Variance of the distribution.

support
tree_flatten()[source]

## StudentT¶

class StudentT(df, loc=0.0, scale=1.0, validate_args=None)[source]
arg_constraints = {'df': <numpyro.distributions.constraints._GreaterThan object>, 'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._Real object>
reparametrized_params = ['loc', 'scale']
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

## TruncatedCauchy¶

class TruncatedCauchy(low=0.0, loc=0.0, scale=1.0, validate_args=None)[source]
arg_constraints = {'loc': <numpyro.distributions.constraints._Real object>, 'low': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
reparametrized_params = ['low', 'loc', 'scale']
support
mean

Mean of the distribution.

variance

Variance of the distribution.

tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]

## TruncatedNormal¶

class TruncatedNormal(low=0.0, loc=0.0, scale=1.0, validate_args=None)[source]
arg_constraints = {'loc': <numpyro.distributions.constraints._Real object>, 'low': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
reparametrized_params = ['low', 'loc', 'scale']
support
mean

Mean of the distribution.

variance

Variance of the distribution.

tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]

## TruncatedPolyaGamma¶

class TruncatedPolyaGamma(batch_shape=(), validate_args=None)[source]
truncation_point = 2.5
num_log_prob_terms = 7
num_gamma_variates = 8
arg_constraints = {}
support = <numpyro.distributions.constraints._Interval object>
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(*args, **kwargs)
tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]

## Uniform¶

class Uniform(low=0.0, high=1.0, validate_args=None)[source]
arg_constraints = {'high': <numpyro.distributions.constraints._Dependent object>, 'low': <numpyro.distributions.constraints._Dependent object>}
reparametrized_params = ['low', 'high']
support
mean

Mean of the distribution.

variance

Variance of the distribution.

tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]

# Discrete Distributions¶

## Bernoulli¶

Bernoulli(probs=None, logits=None, validate_args=None)[source]

## BernoulliLogits¶

class BernoulliLogits(logits=None, validate_args=None)[source]
arg_constraints = {'logits': <numpyro.distributions.constraints._Real object>}
support = <numpyro.distributions.constraints._Boolean object>
has_enumerate_support = True
is_discrete = True
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(*args, **kwargs)
probs[source]
mean

Mean of the distribution.

variance

Variance of the distribution.

enumerate_support(expand=True)[source]

Returns an array with shape len(support) x batch_shape containing all values in the support.

## BernoulliProbs¶

class BernoulliProbs(probs, validate_args=None)[source]
arg_constraints = {'probs': <numpyro.distributions.constraints._Interval object>}
support = <numpyro.distributions.constraints._Boolean object>
has_enumerate_support = True
is_discrete = True
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

enumerate_support(expand=True)[source]

Returns an array with shape len(support) x batch_shape containing all values in the support.

## BetaBinomial¶

class BetaBinomial(concentration1, concentration0, total_count=1, validate_args=None)[source]

Compound distribution comprising of a beta-binomial pair. The probability of success (probs for the Binomial distribution) is unknown and randomly drawn from a Beta distribution prior to a certain number of Bernoulli trials given by total_count.

Parameters: concentration1 (numpy.ndarray) – 1st concentration parameter (alpha) for the Beta distribution. concentration0 (numpy.ndarray) – 2nd concentration parameter (beta) for the Beta distribution. total_count (numpy.ndarray) – number of Bernoulli trials.
arg_constraints = {'concentration0': <numpyro.distributions.constraints._GreaterThan object>, 'concentration1': <numpyro.distributions.constraints._GreaterThan object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}
has_enumerate_support = True
is_discrete = True
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

support
enumerate_support(expand=True)[source]

Returns an array with shape len(support) x batch_shape containing all values in the support.

## Binomial¶

Binomial(total_count=1, probs=None, logits=None, validate_args=None)[source]

## BinomialLogits¶

class BinomialLogits(logits, total_count=1, validate_args=None)[source]
arg_constraints = {'logits': <numpyro.distributions.constraints._Real object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}
has_enumerate_support = True
is_discrete = True
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(*args, **kwargs)
probs[source]
mean

Mean of the distribution.

variance

Variance of the distribution.

support
enumerate_support(expand=True)[source]

Returns an array with shape len(support) x batch_shape containing all values in the support.

## BinomialProbs¶

class BinomialProbs(probs, total_count=1, validate_args=None)[source]
arg_constraints = {'probs': <numpyro.distributions.constraints._Interval object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}
has_enumerate_support = True
is_discrete = True
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

support
enumerate_support(expand=True)[source]

Returns an array with shape len(support) x batch_shape containing all values in the support.

## Categorical¶

Categorical(probs=None, logits=None, validate_args=None)[source]

## CategoricalLogits¶

class CategoricalLogits(logits, validate_args=None)[source]
arg_constraints = {'logits': <numpyro.distributions.constraints._RealVector object>}
has_enumerate_support = True
is_discrete = True
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(*args, **kwargs)
probs[source]
mean

Mean of the distribution.

variance

Variance of the distribution.

support
enumerate_support(expand=True)[source]

Returns an array with shape len(support) x batch_shape containing all values in the support.

## CategoricalProbs¶

class CategoricalProbs(probs, validate_args=None)[source]
arg_constraints = {'probs': <numpyro.distributions.constraints._Simplex object>}
has_enumerate_support = True
is_discrete = True
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

support
enumerate_support(expand=True)[source]

Returns an array with shape len(support) x batch_shape containing all values in the support.

## Delta¶

class Delta(value=0.0, log_density=0.0, event_dim=0, validate_args=None)[source]
arg_constraints = {'log_density': <numpyro.distributions.constraints._Real object>, 'value': <numpyro.distributions.constraints._Real object>}
support = <numpyro.distributions.constraints._Real object>
is_discrete = True
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]

## GammaPoisson¶

class GammaPoisson(concentration, rate=1.0, validate_args=None)[source]

Compound distribution comprising of a gamma-poisson pair, also referred to as a gamma-poisson mixture. The rate parameter for the Poisson distribution is unknown and randomly drawn from a Gamma distribution.

Parameters: concentration (numpy.ndarray) – shape parameter (alpha) of the Gamma distribution. rate (numpy.ndarray) – rate parameter (beta) for the Gamma distribution.
arg_constraints = {'concentration': <numpyro.distributions.constraints._GreaterThan object>, 'rate': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._IntegerGreaterThan object>
is_discrete = True
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

## Geometric¶

Geometric(probs=None, logits=None, validate_args=None)[source]

## GeometricLogits¶

class GeometricLogits(logits, validate_args=None)[source]
arg_constraints = {'logits': <numpyro.distributions.constraints._Real object>}
support = <numpyro.distributions.constraints._IntegerGreaterThan object>
is_discrete = True
probs[source]
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

## GeometricProbs¶

class GeometricProbs(probs, validate_args=None)[source]
arg_constraints = {'probs': <numpyro.distributions.constraints._Interval object>}
support = <numpyro.distributions.constraints._IntegerGreaterThan object>
is_discrete = True
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

## Multinomial¶

Multinomial(total_count=1, probs=None, logits=None, validate_args=None)[source]

## MultinomialLogits¶

class MultinomialLogits(logits, total_count=1, validate_args=None)[source]
arg_constraints = {'logits': <numpyro.distributions.constraints._RealVector object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}
is_discrete = True
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(*args, **kwargs)
probs[source]
mean

Mean of the distribution.

variance

Variance of the distribution.

support

## MultinomialProbs¶

class MultinomialProbs(probs, total_count=1, validate_args=None)[source]
arg_constraints = {'probs': <numpyro.distributions.constraints._Simplex object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}
is_discrete = True
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

support

## OrderedLogistic¶

class OrderedLogistic(predictor, cutpoints, validate_args=None)[source]

A categorical distribution with ordered outcomes.

References:

1. Stan Functions Reference, v2.20 section 12.6, Stan Development Team
Parameters: predictor (numpy.ndarray) – prediction in real domain; typically this is output of a linear model. cutpoints (numpy.ndarray) – positions in real domain to separate categories.
arg_constraints = {'cutpoints': <numpyro.distributions.constraints._OrderedVector object>, 'predictor': <numpyro.distributions.constraints._Real object>}

## Poisson¶

class Poisson(rate, validate_args=None)[source]
arg_constraints = {'rate': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._IntegerGreaterThan object>
is_discrete = True
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

## PRNGIdentity¶

class PRNGIdentity[source]

Distribution over PRNGKey(). This can be used to draw a batch of PRNGKey() using the seed handler. Only sample method is supported.

is_discrete = True
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray

## ZeroInflatedPoisson¶

class ZeroInflatedPoisson(gate, rate=1.0, validate_args=None)[source]

A Zero Inflated Poisson distribution.

Parameters: gate (numpy.ndarray) – probability of extra zeros. rate (numpy.ndarray) – rate of Poisson distribution.
arg_constraints = {'gate': <numpyro.distributions.constraints._Interval object>, 'rate': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._IntegerGreaterThan object>
is_discrete = True
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters: key (jax.random.PRNGKey) – the rng_key key to be used for the distribution. sample_shape (tuple) – the sample shape for the distribution. an array of shape sample_shape + batch_shape + event_shape numpy.ndarray
log_prob(*args, **kwargs)
mean[source]
variance[source]

# Directional Distributions¶

## VonMises¶

class VonMises(loc, concentration, validate_args=None)[source]
arg_constraints = {'concentration': <numpyro.distributions.constraints._GreaterThan object>, 'loc': <numpyro.distributions.constraints._Real object>}
support = <numpyro.distributions.constraints._Interval object>
sample(key, sample_shape=())[source]

Generate sample from von Mises distribution

Parameters: sample_shape – shape of samples key – random number generator key samples from von Mises
log_prob(*args, **kwargs)
mean

Computes circular mean of distribution. NOTE: same as location when mapped to support [-pi, pi]

variance

Computes circular variance of distribution

# Constraints¶

## Constraint¶

class Constraint[source]

Bases: object

Abstract base class for constraints.

A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.

check(value)[source]

Returns a byte tensor of sample_shape + batch_shape indicating whether each event in value satisfies this constraint.

## boolean¶

boolean = <numpyro.distributions.constraints._Boolean object>

## corr_cholesky¶

corr_cholesky = <numpyro.distributions.constraints._CorrCholesky object>

## corr_matrix¶

corr_matrix = <numpyro.distributions.constraints._CorrMatrix object>

## dependent¶

dependent = <numpyro.distributions.constraints._Dependent object>

## greater_than¶

greater_than(lower_bound)

Abstract base class for constraints.

A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.

## integer_interval¶

integer_interval(lower_bound, upper_bound)

Abstract base class for constraints.

A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.

## integer_greater_than¶

integer_greater_than(lower_bound)

Abstract base class for constraints.

A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.

## interval¶

interval(lower_bound, upper_bound)

Abstract base class for constraints.

A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.

## less_than¶

less_than(upper_bound)

Abstract base class for constraints.

A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.

## lower_cholesky¶

lower_cholesky = <numpyro.distributions.constraints._LowerCholesky object>

## multinomial¶

multinomial(upper_bound)

Abstract base class for constraints.

A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.

## nonnegative_integer¶

nonnegative_integer = <numpyro.distributions.constraints._IntegerGreaterThan object>

## ordered_vector¶

ordered_vector = <numpyro.distributions.constraints._OrderedVector object>

## positive¶

positive = <numpyro.distributions.constraints._GreaterThan object>

## positive_definite¶

positive_definite = <numpyro.distributions.constraints._PositiveDefinite object>

## positive_integer¶

positive_integer = <numpyro.distributions.constraints._IntegerGreaterThan object>

## real¶

real = <numpyro.distributions.constraints._Real object>

## real_vector¶

real_vector = <numpyro.distributions.constraints._RealVector object>

## simplex¶

simplex = <numpyro.distributions.constraints._Simplex object>

## unit_interval¶

unit_interval = <numpyro.distributions.constraints._Interval object>

# Transforms¶

## biject_to¶

biject_to(constraint)

## Transform¶

class Transform[source]

Bases: object

domain = <numpyro.distributions.constraints._Real object>
codomain = <numpyro.distributions.constraints._Real object>
event_dim = 0
inv(y)[source]
log_abs_det_jacobian(x, y, intermediates=None)[source]
call_with_intermediates(x)[source]

## AbsTransform¶

class AbsTransform[source]
domain = <numpyro.distributions.constraints._Real object>
codomain = <numpyro.distributions.constraints._GreaterThan object>
inv(y)[source]

## AffineTransform¶

class AffineTransform(loc, scale, domain=<numpyro.distributions.constraints._Real object>)[source]

Note

When scale is a JAX tracer, we always assume that scale > 0 when calculating codomain.

codomain
event_dim

int([x]) -> integer int(x, base=10) -> integer

Convert a number or string to an integer, or return 0 if no arguments are given. If x is a number, return x.__int__(). For floating point numbers, this truncates towards zero.

If x is not a number or if base is given, then x must be a string, bytes, or bytearray instance representing an integer literal in the given base. The literal can be preceded by ‘+’ or ‘-‘ and be surrounded by whitespace. The base defaults to 10. Valid bases are 0 and 2-36. Base 0 means to interpret the base from the string as an integer literal. >>> int(‘0b100’, base=0) 4

inv(y)[source]
log_abs_det_jacobian(x, y, intermediates=None)[source]

## ComposeTransform¶

class ComposeTransform(parts)[source]
domain
codomain
event_dim

int([x]) -> integer int(x, base=10) -> integer

Convert a number or string to an integer, or return 0 if no arguments are given. If x is a number, return x.__int__(). For floating point numbers, this truncates towards zero.

If x is not a number or if base is given, then x must be a string, bytes, or bytearray instance representing an integer literal in the given base. The literal can be preceded by ‘+’ or ‘-‘ and be surrounded by whitespace. The base defaults to 10. Valid bases are 0 and 2-36. Base 0 means to interpret the base from the string as an integer literal. >>> int(‘0b100’, base=0) 4

inv(y)[source]
log_abs_det_jacobian(x, y, intermediates=None)[source]
call_with_intermediates(x)[source]

## CorrCholeskyTransform¶

class CorrCholeskyTransform[source]

Transforms a uncontrained real vector $$x$$ with length $$D*(D-1)/2$$ into the Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower triangular matrix with positive diagonals and unit Euclidean norm for each row. The transform is processed as follows:

1. First we convert $$x$$ into a lower triangular matrix with the following order:
$\begin{split}\begin{bmatrix} 1 & 0 & 0 & 0 \\ x_0 & 1 & 0 & 0 \\ x_1 & x_2 & 1 & 0 \\ x_3 & x_4 & x_5 & 1 \end{bmatrix}\end{split}$

2. For each row $$X_i$$ of the lower triangular part, we apply a signed version of class StickBreakingTransform to transform $$X_i$$ into a unit Euclidean length vector using the following steps:

1. Scales into the interval $$(-1, 1)$$ domain: $$r_i = \tanh(X_i)$$.
2. Transforms into an unsigned domain: $$z_i = r_i^2$$.
3. Applies $$s_i = StickBreakingTransform(z_i)$$.
4. Transforms back into signed domain: $$y_i = (sign(r_i), 1) * \sqrt{s_i}$$.
domain = <numpyro.distributions.constraints._RealVector object>
codomain = <numpyro.distributions.constraints._CorrCholesky object>
event_dim = 2
inv(y)[source]
log_abs_det_jacobian(x, y, intermediates=None)[source]

## ExpTransform¶

class ExpTransform(domain=<numpyro.distributions.constraints._Real object>)[source]
codomain
inv(y)[source]
log_abs_det_jacobian(x, y, intermediates=None)[source]

## IdentityTransform¶

class IdentityTransform(event_dim=0)[source]
inv(y)[source]
log_abs_det_jacobian(x, y, intermediates=None)[source]

## InvCholeskyTransform¶

class InvCholeskyTransform(domain=<numpyro.distributions.constraints._LowerCholesky object>)[source]

Transform via the mapping $$y = x @ x.T$$, where x is a lower triangular matrix with positive diagonal.

event_dim = 2
codomain
inv(y)[source]
log_abs_det_jacobian(x, y, intermediates=None)[source]

## LowerCholeskyAffine¶

class LowerCholeskyAffine(loc, scale_tril)[source]

Transform via the mapping $$y = loc + scale\_tril\ @\ x$$.

Parameters: loc – a real vector. scale_tril – a lower triangular matrix with positive diagonal.
domain = <numpyro.distributions.constraints._RealVector object>
codomain = <numpyro.distributions.constraints._RealVector object>
event_dim = 1
inv(y)[source]
log_abs_det_jacobian(x, y, intermediates=None)[source]

## LowerCholeskyTransform¶

class LowerCholeskyTransform[source]
domain = <numpyro.distributions.constraints._RealVector object>
codomain = <numpyro.distributions.constraints._LowerCholesky object>
event_dim = 2
inv(y)[source]
log_abs_det_jacobian(x, y, intermediates=None)[source]

## OrderedTransform¶

class OrderedTransform[source]

Transform a real vector to an ordered vector.

References:

1. Stan Reference Manual v2.20, section 10.6, Stan Development Team
domain = <numpyro.distributions.constraints._RealVector object>
codomain = <numpyro.distributions.constraints._OrderedVector object>
event_dim = 1
inv(y)[source]
log_abs_det_jacobian(x, y, intermediates=None)[source]

## PermuteTransform¶

class PermuteTransform(permutation)[source]
domain = <numpyro.distributions.constraints._RealVector object>
codomain = <numpyro.distributions.constraints._RealVector object>
event_dim = 1
inv(y)[source]
log_abs_det_jacobian(x, y, intermediates=None)[source]

## PowerTransform¶

class PowerTransform(exponent)[source]
domain = <numpyro.distributions.constraints._GreaterThan object>
codomain = <numpyro.distributions.constraints._GreaterThan object>
inv(y)[source]
log_abs_det_jacobian(x, y, intermediates=None)[source]

## SigmoidTransform¶

class SigmoidTransform[source]
codomain = <numpyro.distributions.constraints._Interval object>
inv(y)[source]
log_abs_det_jacobian(x, y, intermediates=None)[source]

## StickBreakingTransform¶

class StickBreakingTransform[source]
domain = <numpyro.distributions.constraints._RealVector object>
codomain = <numpyro.distributions.constraints._Simplex object>
event_dim = 1
inv(y)[source]
log_abs_det_jacobian(x, y, intermediates=None)[source]

# Flows¶

## InverseAutoregressiveTransform¶

class InverseAutoregressiveTransform(autoregressive_nn, log_scale_min_clip=-5.0, log_scale_max_clip=3.0)[source]

An implementation of Inverse Autoregressive Flow, using Eq (10) from Kingma et al., 2016,

$$\mathbf{y} = \mu_t + \sigma_t\odot\mathbf{x}$$

where $$\mathbf{x}$$ are the inputs, $$\mathbf{y}$$ are the outputs, $$\mu_t,\sigma_t$$ are calculated from an autoregressive network on $$\mathbf{x}$$, and $$\sigma_t>0$$.

References

1. Improving Variational Inference with Inverse Autoregressive Flow [arXiv:1606.04934], Diederik P. Kingma, Tim Salimans, Rafal Jozefowicz, Xi Chen, Ilya Sutskever, Max Welling
domain = <numpyro.distributions.constraints._RealVector object>
codomain = <numpyro.distributions.constraints._RealVector object>
event_dim = 1
call_with_intermediates(x)[source]
inv(y)[source]
Parameters: y (numpy.ndarray) – the output of the transform to be inverted
log_abs_det_jacobian(x, y, intermediates=None)[source]

Calculates the elementwise determinant of the log jacobian.

Parameters: x (numpy.ndarray) – the input to the transform y (numpy.ndarray) – the output of the transform

## BlockNeuralAutoregressiveTransform¶

class BlockNeuralAutoregressiveTransform(bn_arn)[source]

An implementation of Block Neural Autoregressive flow.

References

1. Block Neural Autoregressive Flow, Nicola De Cao, Ivan Titov, Wilker Aziz
event_dim = 1
call_with_intermediates(x)[source]
inv(y)[source]
log_abs_det_jacobian(x, y, intermediates=None)[source]

Calculates the elementwise determinant of the log jacobian.

Parameters: x (numpy.ndarray) – the input to the transform y (numpy.ndarray) – the output of the transform