Distributions

Base Distribution

Distribution

class Distribution(batch_shape=(), event_shape=(), *, validate_args=None)[source]

Bases: object

Base class for probability distributions in NumPyro. The design largely follows from torch.distributions.

Parameters:
  • batch_shape – The batch shape for the distribution. This designates independent (possibly non-identical) dimensions of a sample from the distribution. This is fixed for a distribution instance and is inferred from the shape of the distribution parameters.
  • event_shape – The event shape for the distribution. This designates the dependent dimensions of a sample from the distribution. These are collapsed when we evaluate the log probability density of a batch of samples using .log_prob.
  • validate_args – Whether to enable validation of distribution parameters and arguments to .log_prob method.

As an example:

>>> import jax.numpy as jnp
>>> import numpyro.distributions as dist
>>> d = dist.Dirichlet(jnp.ones((2, 3, 4)))
>>> d.batch_shape
(2, 3)
>>> d.event_shape
(4,)
arg_constraints = {}
support = None
has_enumerate_support = False
reparametrized_params = []
tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]
static set_default_validate_args(value)[source]
batch_shape

Returns the shape over which the distribution parameters are batched.

Returns:batch shape of the distribution.
Return type:tuple
event_shape

Returns the shape of a single sample from the distribution without batching.

Returns:event shape of the distribution.
Return type:tuple
event_dim
Returns:Number of dimensions of individual events.
Return type:int
has_rsample
rsample(key, sample_shape=())[source]
shape(sample_shape=())[source]

The tensor shape of samples from this distribution.

Samples are of shape:

d.shape(sample_shape) == sample_shape + d.batch_shape + d.event_shape
Parameters:sample_shape (tuple) – the size of the iid batch to be drawn from the distribution.
Returns:shape of samples.
Return type:tuple
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

sample_with_intermediates(key, sample_shape=())[source]

Same as sample except that any intermediate computations are returned (useful for TransformedDistribution).

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value)[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:value – A batch of samples from the distribution.
Returns:an array with shape value.shape[:-self.event_shape]
Return type:numpy.ndarray
mean

Mean of the distribution.

variance

Variance of the distribution.

to_event(reinterpreted_batch_ndims=None)[source]

Interpret the rightmost reinterpreted_batch_ndims batch dimensions as dependent event dimensions.

Parameters:reinterpreted_batch_ndims – Number of rightmost batch dims to interpret as event dims.
Returns:An instance of Independent distribution.
Return type:numpyro.distributions.distribution.Independent
enumerate_support(expand=True)[source]

Returns an array with shape len(support) x batch_shape containing all values in the support.

expand(batch_shape)[source]

Returns a new ExpandedDistribution instance with batch dimensions expanded to batch_shape.

Parameters:batch_shape (tuple) – batch shape to expand to.
Returns:an instance of ExpandedDistribution.
Return type:ExpandedDistribution
expand_by(sample_shape)[source]

Expands a distribution by adding sample_shape to the left side of its batch_shape. To expand internal dims of self.batch_shape from 1 to something larger, use expand() instead.

Parameters:sample_shape (tuple) – The size of the iid batch to be drawn from the distribution.
Returns:An expanded version of this distribution.
Return type:ExpandedDistribution
mask(mask)[source]

Masks a distribution by a boolean or boolean-valued array that is broadcastable to the distributions Distribution.batch_shape .

Parameters:mask (bool or jnp.ndarray) – A boolean or boolean valued array (True includes a site, False excludes a site).
Returns:A masked copy of this distribution.
Return type:MaskedDistribution

Example:

>>> from jax import random
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.distributions import constraints
>>> from numpyro.infer import SVI, Trace_ELBO

>>> def model(data, m):
...     f = numpyro.sample("latent_fairness", dist.Beta(1, 1))
...     with numpyro.plate("N", data.shape[0]):
...         # only take into account the values selected by the mask
...         masked_dist = dist.Bernoulli(f).mask(m)
...         numpyro.sample("obs", masked_dist, obs=data)


>>> def guide(data, m):
...     alpha_q = numpyro.param("alpha_q", 5., constraint=constraints.positive)
...     beta_q = numpyro.param("beta_q", 5., constraint=constraints.positive)
...     numpyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q))


>>> data = jnp.concatenate([jnp.ones(5), jnp.zeros(5)])
>>> # select values equal to one
>>> masked_array = jnp.where(data == 1, True, False)
>>> optimizer = numpyro.optim.Adam(step_size=0.05)
>>> svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
>>> svi_result = svi.run(random.PRNGKey(0), 300, data, masked_array)
>>> params = svi_result.params
>>> # inferred_mean is closer to 1
>>> inferred_mean = params["alpha_q"] / (params["alpha_q"] + params["beta_q"])
classmethod infer_shapes(*args, **kwargs)[source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

cdf(value)[source]

The cummulative distribution function of this distribution.

Parameters:value – samples from this distribution.
Returns:output of the cummulative distribution function evaluated at value.
icdf(q)[source]

The inverse cumulative distribution function of this distribution.

Parameters:q – quantile values, should belong to [0, 1].
Returns:the samples whose cdf values equals to q.
is_discrete

ExpandedDistribution

class ExpandedDistribution(base_dist, batch_shape=())[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {}
has_enumerate_support

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

has_rsample
rsample(key, sample_shape=())[source]
support
sample_with_intermediates(key, sample_shape=())[source]

Same as sample except that any intermediate computations are returned (useful for TransformedDistribution).

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value, intermediates=None)[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:value – A batch of samples from the distribution.
Returns:an array with shape value.shape[:-self.event_shape]
Return type:numpy.ndarray
enumerate_support(expand=True)[source]

Returns an array with shape len(support) x batch_shape containing all values in the support.

mean

Mean of the distribution.

variance

Variance of the distribution.

tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]

FoldedDistribution

class FoldedDistribution(base_dist, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.TransformedDistribution

Equivalent to TransformedDistribution(base_dist, AbsTransform()), but additionally supports log_prob() .

Parameters:base_dist (Distribution) – A univariate distribution to reflect.
support = GreaterThan(lower_bound=0.0)
log_prob(*args, **kwargs)
tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]

ImproperUniform

class ImproperUniform(support, batch_shape, event_shape, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

A helper distribution with zero log_prob() over the support domain.

Note

sample method is not implemented for this distribution. In autoguide and mcmc, initial parameters for improper sites are derived from init_to_uniform or init_to_value strategies.

Usage:

>>> from numpyro import sample
>>> from numpyro.distributions import ImproperUniform, Normal, constraints
>>>
>>> def model():
...     # ordered vector with length 10
...     x = sample('x', ImproperUniform(constraints.ordered_vector, (), event_shape=(10,)))
...
...     # real matrix with shape (3, 4)
...     y = sample('y', ImproperUniform(constraints.real, (), event_shape=(3, 4)))
...
...     # a shape-(6, 8) batch of length-5 vectors greater than 3
...     z = sample('z', ImproperUniform(constraints.greater_than(3), (6, 8), event_shape=(5,)))

If you want to set improper prior over all values greater than a, where a is another random variable, you might use

>>> def model():
...     a = sample('a', Normal(0, 1))
...     x = sample('x', ImproperUniform(constraints.greater_than(a), (), event_shape=()))

or if you want to reparameterize it

>>> from numpyro.distributions import TransformedDistribution, transforms
>>> from numpyro.handlers import reparam
>>> from numpyro.infer.reparam import TransformReparam
>>>
>>> def model():
...     a = sample('a', Normal(0, 1))
...     with reparam(config={'x': TransformReparam()}):
...         x = sample('x',
...                    TransformedDistribution(ImproperUniform(constraints.positive, (), ()),
...                                            transforms.AffineTransform(a, 1)))
Parameters:
  • support (Constraint) – the support of this distribution.
  • batch_shape (tuple) – batch shape of this distribution. It is usually safe to set batch_shape=().
  • event_shape (tuple) – event shape of this distribution.
arg_constraints = {}
support = Dependent()
log_prob(*args, **kwargs)
tree_flatten()[source]

Independent

class Independent(base_dist, reinterpreted_batch_ndims, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

Reinterprets batch dimensions of a distribution as event dims by shifting the batch-event dim boundary further to the left.

From a practical standpoint, this is useful when changing the result of log_prob(). For example, a univariate Normal distribution can be interpreted as a multivariate Normal with diagonal covariance:

>>> import numpyro.distributions as dist
>>> normal = dist.Normal(jnp.zeros(3), jnp.ones(3))
>>> [normal.batch_shape, normal.event_shape]
[(3,), ()]
>>> diag_normal = dist.Independent(normal, 1)
>>> [diag_normal.batch_shape, diag_normal.event_shape]
[(), (3,)]
Parameters:
  • base_distribution (numpyro.distribution.Distribution) – a distribution instance.
  • reinterpreted_batch_ndims (int) – the number of batch dims to reinterpret as event dims.
arg_constraints = {}
support
has_enumerate_support

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

reparametrized_params

Built-in mutable sequence.

If no argument is given, the constructor creates a new empty list. The argument must be an iterable if specified.

mean

Mean of the distribution.

variance

Variance of the distribution.

has_rsample
rsample(key, sample_shape=())[source]
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value)[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:value – A batch of samples from the distribution.
Returns:an array with shape value.shape[:-self.event_shape]
Return type:numpy.ndarray
expand(batch_shape)[source]

Returns a new ExpandedDistribution instance with batch dimensions expanded to batch_shape.

Parameters:batch_shape (tuple) – batch shape to expand to.
Returns:an instance of ExpandedDistribution.
Return type:ExpandedDistribution
tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]

MaskedDistribution

class MaskedDistribution(base_dist, mask)[source]

Bases: numpyro.distributions.distribution.Distribution

Masks a distribution by a boolean array that is broadcastable to the distribution’s Distribution.batch_shape. In the special case mask is False, computation of log_prob() , is skipped, and constant zero values are returned instead.

Parameters:mask (jnp.ndarray or bool) – A boolean or boolean-valued array.
arg_constraints = {}
has_enumerate_support

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

has_rsample
rsample(key, sample_shape=())[source]
support
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value)[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:value – A batch of samples from the distribution.
Returns:an array with shape value.shape[:-self.event_shape]
Return type:numpy.ndarray
enumerate_support(expand=True)[source]

Returns an array with shape len(support) x batch_shape containing all values in the support.

mean

Mean of the distribution.

variance

Variance of the distribution.

tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]

TransformedDistribution

class TransformedDistribution(base_distribution, transforms, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

Returns a distribution instance obtained as a result of applying a sequence of transforms to a base distribution. For an example, see LogNormal and HalfNormal.

Parameters:
  • base_distribution – the base distribution over which to apply transforms.
  • transforms – a single transform or a list of transforms.
  • validate_args – Whether to enable validation of distribution parameters and arguments to .log_prob method.
arg_constraints = {}
has_rsample
rsample(key, sample_shape=())[source]
support
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

sample_with_intermediates(key, sample_shape=())[source]

Same as sample except that any intermediate computations are returned (useful for TransformedDistribution).

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

tree_flatten()[source]

Delta

class Delta(v=0.0, log_density=0.0, event_dim=0, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'log_density': Real(), 'v': Dependent()}
reparametrized_params = ['v', 'log_density']
support
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]

Unit

class Unit(log_factor, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

Trivial nonnormalized distribution representing the unit type.

The unit type has a single value with no data, i.e. value.size == 0.

This is used for numpyro.factor() statements.

arg_constraints = {'log_factor': Real()}
support = Real()
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value)[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:value – A batch of samples from the distribution.
Returns:an array with shape value.shape[:-self.event_shape]
Return type:numpy.ndarray

Continuous Distributions

AsymmetricLaplace

class AsymmetricLaplace(loc=0.0, scale=1.0, asymmetry=1.0, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'asymmetry': GreaterThan(lower_bound=0.0), 'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
reparametrized_params = ['loc', 'scale', 'asymmetry']
support = Real()
left_scale[source]
right_scale[source]
log_prob(value)[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:value – A batch of samples from the distribution.
Returns:an array with shape value.shape[:-self.event_shape]
Return type:numpy.ndarray
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

mean

Mean of the distribution.

variance

Variance of the distribution.

cdf(value)[source]

The cummulative distribution function of this distribution.

Parameters:value – samples from this distribution.
Returns:output of the cummulative distribution function evaluated at value.
icdf(value)[source]

The inverse cumulative distribution function of this distribution.

Parameters:q – quantile values, should belong to [0, 1].
Returns:the samples whose cdf values equals to q.

AsymmetricLaplaceQuantile

class AsymmetricLaplaceQuantile(loc=0.0, scale=1.0, quantile=0.5, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

An alternative parameterization of AsymmetricLaplace commonly applied in Bayesian quantile regression.

Instead of the asymmetry parameter employed by AsymmetricLaplace, to define the balance between left- versus right-hand sides of the distribution, this class utilizes a quantile parameter, which describes the proportion of probability density that falls to the left-hand side of the distribution.

The scale parameter is also interpreted slightly differently than in AsymmetricLaplce. When loc=0 and scale=1, AsymmetricLaplace(0,1,1) is equivalent to Laplace(0,1), while AsymmetricLaplaceQuantile(0,1,0.5) is equivalent to Laplace(0,2).

arg_constraints = {'loc': Real(), 'quantile': OpenInterval(lower_bound=0.0, upper_bound=1.0), 'scale': GreaterThan(lower_bound=0.0)}
reparametrized_params = ['loc', 'scale', 'quantile']
support = Real()
log_prob(value)[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:value – A batch of samples from the distribution.
Returns:an array with shape value.shape[:-self.event_shape]
Return type:numpy.ndarray
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

mean

Mean of the distribution.

variance

Variance of the distribution.

cdf(value)[source]

The cummulative distribution function of this distribution.

Parameters:value – samples from this distribution.
Returns:output of the cummulative distribution function evaluated at value.
icdf(value)[source]

The inverse cumulative distribution function of this distribution.

Parameters:q – quantile values, should belong to [0, 1].
Returns:the samples whose cdf values equals to q.

Beta

class Beta(concentration1, concentration0, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'concentration0': GreaterThan(lower_bound=0.0), 'concentration1': GreaterThan(lower_bound=0.0)}
reparametrized_params = ['concentration1', 'concentration0']
support = Interval(lower_bound=0.0, upper_bound=1.0)
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

cdf(value)[source]

The cummulative distribution function of this distribution.

Parameters:value – samples from this distribution.
Returns:output of the cummulative distribution function evaluated at value.
icdf(q)[source]

The inverse cumulative distribution function of this distribution.

Parameters:q – quantile values, should belong to [0, 1].
Returns:the samples whose cdf values equals to q.

BetaProportion

class BetaProportion(mean, concentration, *, validate_args=None)[source]

Bases: numpyro.distributions.continuous.Beta

The BetaProportion distribution is a reparameterization of the conventional Beta distribution in terms of a the variate mean and a precision parameter.

Reference:
Beta regression for modelling rates and proportion, Ferrari Silvia, and
Francisco Cribari-Neto. Journal of Applied Statistics 31.7 (2004): 799-815.
arg_constraints = {'concentration': GreaterThan(lower_bound=0.0), 'mean': OpenInterval(lower_bound=0.0, upper_bound=1.0)}
reparametrized_params = ['mean', 'concentration']
support = Interval(lower_bound=0.0, upper_bound=1.0)

CAR

class CAR(loc, correlation, conditional_precision, adj_matrix, *, is_sparse=False, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

The Conditional Autoregressive (CAR) distribution is a special case of the multivariate normal in which the precision matrix is structured according to the adjacency matrix of sites. The amount of autocorrelation between sites is controlled by correlation. The distribution is a popular prior for areal spatial data.

Parameters:
  • or ndarray loc (float) – mean of the multivariate normal
  • correlation (float) – autoregression parameter. For most cases, the value should lie between 0 (sites are independent, collapses to an iid multivariate normal) and 1 (perfect autocorrelation between sites), but the specification allows for negative correlations.
  • conditional_precision (float) – positive precision for the multivariate normal
  • or scipy.sparse.csr_matrix adj_matrix (ndarray) – symmetric adjacency matrix where 1 indicates adjacency between sites and 0 otherwise. jax.numpy.ndarray adj_matrix is supported but is not recommended over numpy.ndarray or scipy.sparse.spmatrix.
  • is_sparse (bool) – whether to use a sparse form of adj_matrix in calculations (must be True if adj_matrix is a scipy.sparse.spmatrix)
arg_constraints = {'adj_matrix': Dependent(), 'conditional_precision': GreaterThan(lower_bound=0.0), 'correlation': OpenInterval(lower_bound=-1, upper_bound=1), 'loc': IndependentConstraint(Real(), 1)}
support = IndependentConstraint(Real(), 1)
reparametrized_params = ['loc', 'correlation', 'conditional_precision', 'adj_matrix']
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

precision_matrix[source]
tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]
static infer_shapes(loc, correlation, conditional_precision, adj_matrix)[source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

Cauchy

class Cauchy(loc=0.0, scale=1.0, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
support = Real()
reparametrized_params = ['loc', 'scale']
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

cdf(value)[source]

The cummulative distribution function of this distribution.

Parameters:value – samples from this distribution.
Returns:output of the cummulative distribution function evaluated at value.
icdf(q)[source]

The inverse cumulative distribution function of this distribution.

Parameters:q – quantile values, should belong to [0, 1].
Returns:the samples whose cdf values equals to q.

Chi2

class Chi2(df, *, validate_args=None)[source]

Bases: numpyro.distributions.continuous.Gamma

arg_constraints = {'df': GreaterThan(lower_bound=0.0)}
reparametrized_params = ['df']

Dirichlet

class Dirichlet(concentration, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'concentration': IndependentConstraint(GreaterThan(lower_bound=0.0), 1)}
reparametrized_params = ['concentration']
support = Simplex()
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

static infer_shapes(concentration)[source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

Exponential

class Exponential(rate=1.0, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

reparametrized_params = ['rate']
arg_constraints = {'rate': GreaterThan(lower_bound=0.0)}
support = GreaterThan(lower_bound=0.0)
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

cdf(value)[source]

The cummulative distribution function of this distribution.

Parameters:value – samples from this distribution.
Returns:output of the cummulative distribution function evaluated at value.
icdf(q)[source]

The inverse cumulative distribution function of this distribution.

Parameters:q – quantile values, should belong to [0, 1].
Returns:the samples whose cdf values equals to q.

Gamma

class Gamma(concentration, rate=1.0, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'concentration': GreaterThan(lower_bound=0.0), 'rate': GreaterThan(lower_bound=0.0)}
support = GreaterThan(lower_bound=0.0)
reparametrized_params = ['concentration', 'rate']
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

cdf(x)[source]

The cummulative distribution function of this distribution.

Parameters:value – samples from this distribution.
Returns:output of the cummulative distribution function evaluated at value.
icdf(q)[source]

The inverse cumulative distribution function of this distribution.

Parameters:q – quantile values, should belong to [0, 1].
Returns:the samples whose cdf values equals to q.

Gumbel

class Gumbel(loc=0.0, scale=1.0, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
support = Real()
reparametrized_params = ['loc', 'scale']
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

cdf(value)[source]

The cummulative distribution function of this distribution.

Parameters:value – samples from this distribution.
Returns:output of the cummulative distribution function evaluated at value.
icdf(q)[source]

The inverse cumulative distribution function of this distribution.

Parameters:q – quantile values, should belong to [0, 1].
Returns:the samples whose cdf values equals to q.

GaussianRandomWalk

class GaussianRandomWalk(scale=1.0, num_steps=1, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'scale': GreaterThan(lower_bound=0.0)}
support = IndependentConstraint(Real(), 1)
reparametrized_params = ['scale']
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]

HalfCauchy

class HalfCauchy(scale=1.0, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

reparametrized_params = ['scale']
support = GreaterThan(lower_bound=0.0)
arg_constraints = {'scale': GreaterThan(lower_bound=0.0)}
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
cdf(value)[source]

The cummulative distribution function of this distribution.

Parameters:value – samples from this distribution.
Returns:output of the cummulative distribution function evaluated at value.
icdf(q)[source]

The inverse cumulative distribution function of this distribution.

Parameters:q – quantile values, should belong to [0, 1].
Returns:the samples whose cdf values equals to q.
mean

Mean of the distribution.

variance

Variance of the distribution.

HalfNormal

class HalfNormal(scale=1.0, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

reparametrized_params = ['scale']
support = GreaterThan(lower_bound=0.0)
arg_constraints = {'scale': GreaterThan(lower_bound=0.0)}
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
cdf(value)[source]

The cummulative distribution function of this distribution.

Parameters:value – samples from this distribution.
Returns:output of the cummulative distribution function evaluated at value.
icdf(q)[source]

The inverse cumulative distribution function of this distribution.

Parameters:q – quantile values, should belong to [0, 1].
Returns:the samples whose cdf values equals to q.
mean

Mean of the distribution.

variance

Variance of the distribution.

InverseGamma

class InverseGamma(concentration, rate=1.0, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.TransformedDistribution

Note

We keep the same notation rate as in Pyro but it plays the role of scale parameter of InverseGamma in literatures (e.g. wikipedia: https://en.wikipedia.org/wiki/Inverse-gamma_distribution)

arg_constraints = {'concentration': GreaterThan(lower_bound=0.0), 'rate': GreaterThan(lower_bound=0.0)}
reparametrized_params = ['concentration', 'rate']
support = GreaterThan(lower_bound=0.0)
mean

Mean of the distribution.

variance

Variance of the distribution.

tree_flatten()[source]
cdf(x)[source]

The cummulative distribution function of this distribution.

Parameters:value – samples from this distribution.
Returns:output of the cummulative distribution function evaluated at value.

Kumaraswamy

class Kumaraswamy(concentration1, concentration0, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.TransformedDistribution

arg_constraints = {'concentration0': GreaterThan(lower_bound=0.0), 'concentration1': GreaterThan(lower_bound=0.0)}
reparametrized_params = ['concentration1', 'concentration0']
support = Interval(lower_bound=0.0, upper_bound=1.0)
KL_KUMARASWAMY_BETA_TAYLOR_ORDER = 10
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

tree_flatten()[source]

Laplace

class Laplace(loc=0.0, scale=1.0, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
support = Real()
reparametrized_params = ['loc', 'scale']
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

cdf(value)[source]

The cummulative distribution function of this distribution.

Parameters:value – samples from this distribution.
Returns:output of the cummulative distribution function evaluated at value.
icdf(q)[source]

The inverse cumulative distribution function of this distribution.

Parameters:q – quantile values, should belong to [0, 1].
Returns:the samples whose cdf values equals to q.

LKJ

class LKJ(dimension, concentration=1.0, sample_method='onion', *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.TransformedDistribution

LKJ distribution for correlation matrices. The distribution is controlled by concentration parameter \(\eta\) to make the probability of the correlation matrix \(M\) proportional to \(\det(M)^{\eta - 1}\). Because of that, when concentration == 1, we have a uniform distribution over correlation matrices.

When concentration > 1, the distribution favors samples with large large determinent. This is useful when we know a priori that the underlying variables are not correlated.

When concentration < 1, the distribution favors samples with small determinent. This is useful when we know a priori that some underlying variables are correlated.

Sample code for using LKJ in the context of multivariate normal sample:

def model(y):  # y has dimension N x d
    d = y.shape[1]
    N = y.shape[0]
    # Vector of variances for each of the d variables
    theta = numpyro.sample("theta", dist.HalfCauchy(jnp.ones(d)))

    concentration = jnp.ones(1)  # Implies a uniform distribution over correlation matrices
    corr_mat = numpyro.sample("corr_mat", dist.LKJ(d, concentration))
    sigma = jnp.sqrt(theta)
    # we can also use a faster formula `cov_mat = jnp.outer(theta, theta) * corr_mat`
    cov_mat = jnp.matmul(jnp.matmul(jnp.diag(sigma), corr_mat), jnp.diag(sigma))

    # Vector of expectations
    mu = jnp.zeros(d)

    with numpyro.plate("observations", N):
        obs = numpyro.sample("obs", dist.MultivariateNormal(mu, covariance_matrix=cov_mat), obs=y)
    return obs
Parameters:
  • dimension (int) – dimension of the matrices
  • concentration (ndarray) – concentration/shape parameter of the distribution (often referred to as eta)
  • sample_method (str) – Either “cvine” or “onion”. Both methods are proposed in [1] and offer the same distribution over correlation matrices. But they are different in how to generate samples. Defaults to “onion”.

References

[1] Generating random correlation matrices based on vines and extended onion method, Daniel Lewandowski, Dorota Kurowicka, Harry Joe

arg_constraints = {'concentration': GreaterThan(lower_bound=0.0)}
reparametrized_params = ['concentration']
support = CorrMatrix()
mean

Mean of the distribution.

tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]

LKJCholesky

class LKJCholesky(dimension, concentration=1.0, sample_method='onion', *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

LKJ distribution for lower Cholesky factors of correlation matrices. The distribution is controlled by concentration parameter \(\eta\) to make the probability of the correlation matrix \(M\) generated from a Cholesky factor propotional to \(\det(M)^{\eta - 1}\). Because of that, when concentration == 1, we have a uniform distribution over Cholesky factors of correlation matrices.

When concentration > 1, the distribution favors samples with large diagonal entries (hence large determinent). This is useful when we know a priori that the underlying variables are not correlated.

When concentration < 1, the distribution favors samples with small diagonal entries (hence small determinent). This is useful when we know a priori that some underlying variables are correlated.

Sample code for using LKJCholesky in the context of multivariate normal sample:

def model(y):  # y has dimension N x d
    d = y.shape[1]
    N = y.shape[0]
    # Vector of variances for each of the d variables
    theta = numpyro.sample("theta", dist.HalfCauchy(jnp.ones(d)))
    # Lower cholesky factor of a correlation matrix
    concentration = jnp.ones(1)  # Implies a uniform distribution over correlation matrices
    L_omega = numpyro.sample("L_omega", dist.LKJCholesky(d, concentration))
    # Lower cholesky factor of the covariance matrix
    sigma = jnp.sqrt(theta)
    # we can also use a faster formula `L_Omega = sigma[..., None] * L_omega`
    L_Omega = jnp.matmul(jnp.diag(sigma), L_omega)

    # Vector of expectations
    mu = jnp.zeros(d)

    with numpyro.plate("observations", N):
        obs = numpyro.sample("obs", dist.MultivariateNormal(mu, scale_tril=L_Omega), obs=y)
    return obs
Parameters:
  • dimension (int) – dimension of the matrices
  • concentration (ndarray) – concentration/shape parameter of the distribution (often referred to as eta)
  • sample_method (str) – Either “cvine” or “onion”. Both methods are proposed in [1] and offer the same distribution over correlation matrices. But they are different in how to generate samples. Defaults to “onion”.

References

[1] Generating random correlation matrices based on vines and extended onion method, Daniel Lewandowski, Dorota Kurowicka, Harry Joe

arg_constraints = {'concentration': GreaterThan(lower_bound=0.0)}
reparametrized_params = ['concentration']
support = CorrCholesky()
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]

LogNormal

class LogNormal(loc=0.0, scale=1.0, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.TransformedDistribution

arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
support = GreaterThan(lower_bound=0.0)
reparametrized_params = ['loc', 'scale']
mean

Mean of the distribution.

variance

Variance of the distribution.

tree_flatten()[source]
cdf(x)[source]

The cummulative distribution function of this distribution.

Parameters:value – samples from this distribution.
Returns:output of the cummulative distribution function evaluated at value.

LogUniform

class LogUniform(low, high, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.TransformedDistribution

arg_constraints = {'high': GreaterThan(lower_bound=0.0), 'low': GreaterThan(lower_bound=0.0)}
reparametrized_params = ['low', 'high']
support
mean

Mean of the distribution.

variance

Variance of the distribution.

tree_flatten()[source]
cdf(x)[source]

The cummulative distribution function of this distribution.

Parameters:value – samples from this distribution.
Returns:output of the cummulative distribution function evaluated at value.

Logistic

class Logistic(loc=0.0, scale=1.0, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
support = Real()
reparametrized_params = ['loc', 'scale']
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

cdf(value)[source]

The cummulative distribution function of this distribution.

Parameters:value – samples from this distribution.
Returns:output of the cummulative distribution function evaluated at value.
icdf(q)[source]

The inverse cumulative distribution function of this distribution.

Parameters:q – quantile values, should belong to [0, 1].
Returns:the samples whose cdf values equals to q.

MatrixNormal

class MatrixNormal(loc, scale_tril_row, scale_tril_column, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

Matrix variate normal distribution as described in [1] but with a lower_triangular parametrization, i.e. \(U=scale_tril_row @ scale_tril_row^{T}\) and \(V=scale_tril_column @ scale_tril_column^{T}\). The distribution is related to the multivariate normal distribution in the following way. If \(X ~ MN(loc,U,V)\) then \(vec(X) ~ MVN(vec(loc), kron(V,U) )\).

Parameters:
  • loc (array_like) – Location of the distribution.
  • scale_tril_row (array_like) – Lower cholesky of rows correlation matrix.
  • scale_tril_column (array_like) – Lower cholesky of columns correlation matrix.

References

[1] https://en.wikipedia.org/wiki/Matrix_normal_distribution

arg_constraints = {'loc': IndependentConstraint(Real(), 1), 'scale_tril_column': LowerCholesky(), 'scale_tril_row': LowerCholesky()}
support = IndependentConstraint(Real(), 2)
reparametrized_params = ['loc', 'scale_tril_row', 'scale_tril_column']
mean

Mean of the distribution.

sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)

MultivariateNormal

class MultivariateNormal(loc=0.0, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'covariance_matrix': PositiveDefinite(), 'loc': IndependentConstraint(Real(), 1), 'precision_matrix': PositiveDefinite(), 'scale_tril': LowerCholesky()}
support = IndependentConstraint(Real(), 1)
reparametrized_params = ['loc', 'covariance_matrix', 'precision_matrix', 'scale_tril']
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
covariance_matrix[source]
precision_matrix[source]
mean

Mean of the distribution.

variance

Variance of the distribution.

tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]
static infer_shapes(loc=(), covariance_matrix=None, precision_matrix=None, scale_tril=None)[source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

MultivariateStudentT

class MultivariateStudentT(df, loc=0.0, scale_tril=None, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'df': GreaterThan(lower_bound=0.0), 'loc': IndependentConstraint(Real(), 1), 'scale_tril': LowerCholesky()}
support = IndependentConstraint(Real(), 1)
reparametrized_params = ['df', 'loc', 'scale_tril']
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
covariance_matrix[source]
precision_matrix[source]
mean

Mean of the distribution.

variance

Variance of the distribution.

static infer_shapes(df, loc, scale_tril)[source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

LowRankMultivariateNormal

class LowRankMultivariateNormal(loc, cov_factor, cov_diag, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'cov_diag': IndependentConstraint(GreaterThan(lower_bound=0.0), 1), 'cov_factor': IndependentConstraint(Real(), 2), 'loc': IndependentConstraint(Real(), 1)}
support = IndependentConstraint(Real(), 1)
reparametrized_params = ['loc', 'cov_factor', 'cov_diag']
mean

Mean of the distribution.

variance[source]
scale_tril[source]
covariance_matrix[source]
precision_matrix[source]
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
entropy()[source]
static infer_shapes(loc, cov_factor, cov_diag)[source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

Normal

class Normal(loc=0.0, scale=1.0, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
support = Real()
reparametrized_params = ['loc', 'scale']
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
cdf(value)[source]

The cummulative distribution function of this distribution.

Parameters:value – samples from this distribution.
Returns:output of the cummulative distribution function evaluated at value.
icdf(q)[source]

The inverse cumulative distribution function of this distribution.

Parameters:q – quantile values, should belong to [0, 1].
Returns:the samples whose cdf values equals to q.
mean

Mean of the distribution.

variance

Variance of the distribution.

Pareto

class Pareto(scale, alpha, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.TransformedDistribution

arg_constraints = {'alpha': GreaterThan(lower_bound=0.0), 'scale': GreaterThan(lower_bound=0.0)}
reparametrized_params = ['scale', 'alpha']
mean

Mean of the distribution.

variance

Variance of the distribution.

support
cdf(value)[source]

The cummulative distribution function of this distribution.

Parameters:value – samples from this distribution.
Returns:output of the cummulative distribution function evaluated at value.
icdf(q)[source]

The inverse cumulative distribution function of this distribution.

Parameters:q – quantile values, should belong to [0, 1].
Returns:the samples whose cdf values equals to q.
tree_flatten()[source]

RelaxedBernoulli

RelaxedBernoulli(temperature, probs=None, logits=None, *, validate_args=None)[source]

RelaxedBernoulliLogits

class RelaxedBernoulliLogits(temperature, logits, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.TransformedDistribution

arg_constraints = {'logits': Real(), 'temperature': GreaterThan(lower_bound=0.0)}
support = Interval(lower_bound=0.0, upper_bound=1.0)
tree_flatten()[source]

SoftLaplace

class SoftLaplace(loc, scale, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

Smooth distribution with Laplace-like tail behavior.

This distribution corresponds to the log-convex density:

z = (value - loc) / scale
log_prob = log(2 / pi) - log(scale) - logaddexp(z, -z)

Like the Laplace density, this density has the heaviest possible tails (asymptotically) while still being log-convex. Unlike the Laplace distribution, this distribution is infinitely differentiable everywhere, and is thus suitable for HMC and Laplace approximation.

Parameters:
  • loc – Location parameter.
  • scale – Scale parameter.
arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
support = Real()
reparametrized_params = ['loc', 'scale']
log_prob(*args, **kwargs)
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

cdf(value)[source]

The cummulative distribution function of this distribution.

Parameters:value – samples from this distribution.
Returns:output of the cummulative distribution function evaluated at value.
icdf(value)[source]

The inverse cumulative distribution function of this distribution.

Parameters:q – quantile values, should belong to [0, 1].
Returns:the samples whose cdf values equals to q.
mean

Mean of the distribution.

variance

Variance of the distribution.

StudentT

class StudentT(df, loc=0.0, scale=1.0, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'df': GreaterThan(lower_bound=0.0), 'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
support = Real()
reparametrized_params = ['df', 'loc', 'scale']
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

cdf(value)[source]

The cummulative distribution function of this distribution.

Parameters:value – samples from this distribution.
Returns:output of the cummulative distribution function evaluated at value.
icdf(q)[source]

The inverse cumulative distribution function of this distribution.

Parameters:q – quantile values, should belong to [0, 1].
Returns:the samples whose cdf values equals to q.

Uniform

class Uniform(low=0.0, high=1.0, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'high': Dependent(), 'low': Dependent()}
reparametrized_params = ['low', 'high']
support
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
cdf(value)[source]

The cummulative distribution function of this distribution.

Parameters:value – samples from this distribution.
Returns:output of the cummulative distribution function evaluated at value.
icdf(value)[source]

The inverse cumulative distribution function of this distribution.

Parameters:q – quantile values, should belong to [0, 1].
Returns:the samples whose cdf values equals to q.
mean

Mean of the distribution.

variance

Variance of the distribution.

tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]
static infer_shapes(low=(), high=())[source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

Weibull

class Weibull(scale, concentration, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'concentration': GreaterThan(lower_bound=0.0), 'scale': GreaterThan(lower_bound=0.0)}
support = GreaterThan(lower_bound=0.0)
reparametrized_params = ['scale', 'concentration']
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
cdf(value)[source]

The cummulative distribution function of this distribution.

Parameters:value – samples from this distribution.
Returns:output of the cummulative distribution function evaluated at value.
mean

Mean of the distribution.

variance

Variance of the distribution.

Discrete Distributions

Bernoulli

Bernoulli(probs=None, logits=None, *, validate_args=None)[source]

BernoulliLogits

class BernoulliLogits(logits=None, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'logits': Real()}
support = Boolean()
has_enumerate_support = True
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
probs[source]
mean

Mean of the distribution.

variance

Variance of the distribution.

enumerate_support(expand=True)[source]

Returns an array with shape len(support) x batch_shape containing all values in the support.

BernoulliProbs

class BernoulliProbs(probs, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'probs': Interval(lower_bound=0.0, upper_bound=1.0)}
support = Boolean()
has_enumerate_support = True
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
logits[source]
mean

Mean of the distribution.

variance

Variance of the distribution.

enumerate_support(expand=True)[source]

Returns an array with shape len(support) x batch_shape containing all values in the support.

BetaBinomial

class BetaBinomial(concentration1, concentration0, total_count=1, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

Compound distribution comprising of a beta-binomial pair. The probability of success (probs for the Binomial distribution) is unknown and randomly drawn from a Beta distribution prior to a certain number of Bernoulli trials given by total_count.

Parameters:
  • concentration1 (numpy.ndarray) – 1st concentration parameter (alpha) for the Beta distribution.
  • concentration0 (numpy.ndarray) – 2nd concentration parameter (beta) for the Beta distribution.
  • total_count (numpy.ndarray) – number of Bernoulli trials.
arg_constraints = {'concentration0': GreaterThan(lower_bound=0.0), 'concentration1': GreaterThan(lower_bound=0.0), 'total_count': IntegerGreaterThan(lower_bound=0)}
has_enumerate_support = True
enumerate_support(expand=True)

Returns an array with shape len(support) x batch_shape containing all values in the support.

sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

support

Binomial

Binomial(total_count=1, probs=None, logits=None, *, validate_args=None)[source]

BinomialLogits

class BinomialLogits(logits, total_count=1, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'logits': Real(), 'total_count': IntegerGreaterThan(lower_bound=0)}
has_enumerate_support = True
enumerate_support(expand=True)

Returns an array with shape len(support) x batch_shape containing all values in the support.

sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
probs[source]
mean

Mean of the distribution.

variance

Variance of the distribution.

support

BinomialProbs

class BinomialProbs(probs, total_count=1, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'probs': Interval(lower_bound=0.0, upper_bound=1.0), 'total_count': IntegerGreaterThan(lower_bound=0)}
has_enumerate_support = True
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
logits[source]
mean

Mean of the distribution.

variance

Variance of the distribution.

support
enumerate_support(expand=True)[source]

Returns an array with shape len(support) x batch_shape containing all values in the support.

Categorical

Categorical(probs=None, logits=None, *, validate_args=None)[source]

CategoricalLogits

class CategoricalLogits(logits, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'logits': IndependentConstraint(Real(), 1)}
has_enumerate_support = True
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
probs[source]
mean

Mean of the distribution.

variance

Variance of the distribution.

support
enumerate_support(expand=True)[source]

Returns an array with shape len(support) x batch_shape containing all values in the support.

CategoricalProbs

class CategoricalProbs(probs, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'probs': Simplex()}
has_enumerate_support = True
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
logits[source]
mean

Mean of the distribution.

variance

Variance of the distribution.

support
enumerate_support(expand=True)[source]

Returns an array with shape len(support) x batch_shape containing all values in the support.

DirichletMultinomial

class DirichletMultinomial(concentration, total_count=1, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

Compound distribution comprising of a dirichlet-multinomial pair. The probability of classes (probs for the Multinomial distribution) is unknown and randomly drawn from a Dirichlet distribution prior to a certain number of Categorical trials given by total_count.

Parameters:
  • concentration (numpy.ndarray) – concentration parameter (alpha) for the Dirichlet distribution.
  • total_count (numpy.ndarray) – number of Categorical trials.
arg_constraints = {'concentration': IndependentConstraint(GreaterThan(lower_bound=0.0), 1), 'total_count': IntegerGreaterThan(lower_bound=0)}
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

support
static infer_shapes(concentration, total_count=())[source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

DiscreteUniform

class DiscreteUniform(low=0, high=1, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'high': Dependent(), 'low': Dependent()}
has_enumerate_support = True
support
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
cdf(value)[source]

The cummulative distribution function of this distribution.

Parameters:value – samples from this distribution.
Returns:output of the cummulative distribution function evaluated at value.
icdf(value)[source]

The inverse cumulative distribution function of this distribution.

Parameters:q – quantile values, should belong to [0, 1].
Returns:the samples whose cdf values equals to q.
mean

Mean of the distribution.

variance

Variance of the distribution.

enumerate_support(expand=True)[source]

Returns an array with shape len(support) x batch_shape containing all values in the support.

GammaPoisson

class GammaPoisson(concentration, rate=1.0, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

Compound distribution comprising of a gamma-poisson pair, also referred to as a gamma-poisson mixture. The rate parameter for the Poisson distribution is unknown and randomly drawn from a Gamma distribution.

Parameters:
  • concentration (numpy.ndarray) – shape parameter (alpha) of the Gamma distribution.
  • rate (numpy.ndarray) – rate parameter (beta) for the Gamma distribution.
arg_constraints = {'concentration': GreaterThan(lower_bound=0.0), 'rate': GreaterThan(lower_bound=0.0)}
support = IntegerGreaterThan(lower_bound=0)
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

cdf(value)[source]

The cummulative distribution function of this distribution.

Parameters:value – samples from this distribution.
Returns:output of the cummulative distribution function evaluated at value.

Geometric

Geometric(probs=None, logits=None, *, validate_args=None)[source]

GeometricLogits

class GeometricLogits(logits, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'logits': Real()}
support = IntegerGreaterThan(lower_bound=0)
probs[source]
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

GeometricProbs

class GeometricProbs(probs, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'probs': Interval(lower_bound=0.0, upper_bound=1.0)}
support = IntegerGreaterThan(lower_bound=0)
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
logits[source]
mean

Mean of the distribution.

variance

Variance of the distribution.

Multinomial

Multinomial(total_count=1, probs=None, logits=None, *, validate_args=None)[source]

MultinomialLogits

class MultinomialLogits(logits, total_count=1, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'logits': IndependentConstraint(Real(), 1), 'total_count': IntegerGreaterThan(lower_bound=0)}
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
probs[source]
mean

Mean of the distribution.

variance

Variance of the distribution.

support
static infer_shapes(logits, total_count)[source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

MultinomialProbs

class MultinomialProbs(probs, total_count=1, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'probs': Simplex(), 'total_count': IntegerGreaterThan(lower_bound=0)}
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
logits[source]
mean

Mean of the distribution.

variance

Variance of the distribution.

support
static infer_shapes(probs, total_count)[source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

OrderedLogistic

class OrderedLogistic(predictor, cutpoints, *, validate_args=None)[source]

Bases: numpyro.distributions.discrete.CategoricalProbs

A categorical distribution with ordered outcomes.

References:

  1. Stan Functions Reference, v2.20 section 12.6, Stan Development Team
Parameters:
  • predictor (numpy.ndarray) – prediction in real domain; typically this is output of a linear model.
  • cutpoints (numpy.ndarray) – positions in real domain to separate categories.
arg_constraints = {'cutpoints': OrderedVector(), 'predictor': Real()}
static infer_shapes(predictor, cutpoints)[source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

NegativeBinomial

NegativeBinomial(total_count, probs=None, logits=None, *, validate_args=None)[source]

NegativeBinomialLogits

class NegativeBinomialLogits(total_count, logits, *, validate_args=None)[source]

Bases: numpyro.distributions.conjugate.GammaPoisson

arg_constraints = {'logits': Real(), 'total_count': GreaterThan(lower_bound=0.0)}
support = IntegerGreaterThan(lower_bound=0)
log_prob(*args, **kwargs)

NegativeBinomialProbs

class NegativeBinomialProbs(total_count, probs, *, validate_args=None)[source]

Bases: numpyro.distributions.conjugate.GammaPoisson

arg_constraints = {'probs': Interval(lower_bound=0.0, upper_bound=1.0), 'total_count': GreaterThan(lower_bound=0.0)}
support = IntegerGreaterThan(lower_bound=0)

NegativeBinomial2

class NegativeBinomial2(mean, concentration, *, validate_args=None)[source]

Bases: numpyro.distributions.conjugate.GammaPoisson

Another parameterization of GammaPoisson with rate is replaced by mean.

arg_constraints = {'concentration': GreaterThan(lower_bound=0.0), 'mean': GreaterThan(lower_bound=0.0)}
support = IntegerGreaterThan(lower_bound=0)

Poisson

class Poisson(rate, *, is_sparse=False, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

Creates a Poisson distribution parameterized by rate, the rate parameter.

Samples are nonnegative integers, with a pmf given by

\[\mathrm{rate}^k \frac{e^{-\mathrm{rate}}}{k!}\]
Parameters:
  • rate (numpy.ndarray) – The rate parameter
  • is_sparse (bool) – Whether to assume value is mostly zero when computing log_prob(), which can speed up computation when data is sparse.
arg_constraints = {'rate': GreaterThan(lower_bound=0.0)}
support = IntegerGreaterThan(lower_bound=0)
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

cdf(value)[source]

The cummulative distribution function of this distribution.

Parameters:value – samples from this distribution.
Returns:output of the cummulative distribution function evaluated at value.

ZeroInflatedDistribution

ZeroInflatedDistribution(base_dist, *, gate=None, gate_logits=None, validate_args=None)[source]

Generic Zero Inflated distribution.

Parameters:
  • base_dist (Distribution) – the base distribution.
  • gate (numpy.ndarray) – probability of extra zeros given via a Bernoulli distribution.
  • gate_logits (numpy.ndarray) – logits of extra zeros given via a Bernoulli distribution.

ZeroInflatedPoisson

class ZeroInflatedPoisson(gate, rate=1.0, *, validate_args=None)[source]

Bases: numpyro.distributions.discrete.ZeroInflatedProbs

A Zero Inflated Poisson distribution.

Parameters:
arg_constraints = {'gate': Interval(lower_bound=0.0, upper_bound=1.0), 'rate': GreaterThan(lower_bound=0.0)}
support = IntegerGreaterThan(lower_bound=0)

ZeroInflatedNegativeBinomial2

ZeroInflatedNegativeBinomial2(mean, concentration, *, gate=None, gate_logits=None, validate_args=None)[source]

Mixture Distributions

Mixture

Mixture(mixing_distribution, component_distributions, *, validate_args=None)[source]

A marginalized finite mixture of component distributions

The returned distribution will be either a:

  1. MixtureGeneral, when component_distributions is a list, or
  2. MixtureSameFamily, when component_distributions is a single distribution.

and more details can be found in the documentation for each of these classes.

Parameters:
  • mixing_distribution – A Categorical specifying the weights for each mixture components. The size of this distribution specifies the number of components in the mixture, mixture_size.
  • component_distributions – Either a list of component distributions or a single vectorized distribution. When a list is provided, the number of elements must equal mixture_size. Otherwise, the last batch dimension of the distribution must equal mixture_size.
Returns:

The mixture distribution.

MixtureSameFamily

class MixtureSameFamily(mixing_distribution, component_distribution, *, validate_args=None)[source]

Bases: numpyro.distributions.mixtures._MixtureBase

A finite mixture of component distributions from the same family

This mixture only supports a mixture of component distributions that are all of the same family. The different components are specified along the last batch dimension of the input component_distribution. If you need a mixture of distributions from different families, use the more general implementation in MixtureGeneral.

Parameters:
  • mixing_distribution – A Categorical specifying the weights for each mixture components. The size of this distribution specifies the number of components in the mixture, mixture_size.
  • component_distribution – A single vectorized Distribution, whose last batch dimension equals mixture_size as specified by mixing_distribution.

Example

>>> import jax
>>> import jax.numpy as jnp
>>> import numpyro.distributions as dist
>>> mixing_dist = dist.Categorical(probs=jnp.ones(3) / 3.)
>>> component_dist = dist.Normal(loc=jnp.zeros(3), scale=jnp.ones(3))
>>> mixture = dist.MixtureSameFamily(mixing_dist, component_dist)
>>> mixture.sample(jax.random.PRNGKey(42)).shape
()
component_distribution

Return the vectorized distribution of components being mixed.

Returns:Component distribution
Return type:Distribution
support
is_discrete
tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]
component_mean
component_variance
component_cdf(samples)[source]
component_sample(key, sample_shape=())[source]
component_log_probs(value)[source]

MixtureGeneral

class MixtureGeneral(mixing_distribution, component_distributions, *, validate_args=None)[source]

Bases: numpyro.distributions.mixtures._MixtureBase

A finite mixture of component distributions from different families

If all of the component distributions are from the same family, the more specific implementation in MixtureSameFamily will be somewhat more efficient.

Parameters:
  • mixing_distribution – A Categorical specifying the weights for each mixture components. The size of this distribution specifies the number of components in the mixture, mixture_size.
  • component_distributions – A list of mixture_size Distribution objects.

Example

>>> import jax
>>> import jax.numpy as jnp
>>> import numpyro.distributions as dist
>>> mixing_dist = dist.Categorical(probs=jnp.ones(3) / 3.)
>>> component_dists = [
...     dist.Normal(loc=0.0, scale=1.0),
...     dist.Normal(loc=-0.5, scale=0.3),
...     dist.Normal(loc=0.6, scale=1.2),
... ]
>>> mixture = dist.MixtureGeneral(mixing_dist, component_dists)
>>> mixture.sample(jax.random.PRNGKey(42)).shape
()
component_distributions

The list of component distributions in the mixture

Returns:The list of component distributions
Return type:List[Distribution]
support
is_discrete
tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]
component_mean
component_variance
component_cdf(samples)[source]
component_sample(key, sample_shape=())[source]
component_log_probs(value)[source]

Directional Distributions

ProjectedNormal

class ProjectedNormal(concentration, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

Projected isotropic normal distribution of arbitrary dimension.

This distribution over directional data is qualitatively similar to the von Mises and von Mises-Fisher distributions, but permits tractable variational inference via reparametrized gradients.

To use this distribution with autoguides and HMC, use handlers.reparam with a ProjectedNormalReparam reparametrizer in the model, e.g.:

@handlers.reparam(config={"direction": ProjectedNormalReparam()})
def model():
    direction = numpyro.sample("direction",
                               ProjectedNormal(zeros(3)))
    ...

Note

This implements log_prob() only for dimensions {2,3}.

[1] D. Hernandez-Stumpfhauser, F.J. Breidt, M.J. van der Woerd (2017)
“The General Projected Normal Distribution of Arbitrary Dimension: Modeling and Bayesian Inference” https://projecteuclid.org/euclid.ba/1453211962
arg_constraints = {'concentration': IndependentConstraint(Real(), 1)}
reparametrized_params = ['concentration']
support = Sphere()
mean

Note this is the mean in the sense of a centroid in the submanifold that minimizes expected squared geodesic distance.

mode
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value)[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:value – A batch of samples from the distribution.
Returns:an array with shape value.shape[:-self.event_shape]
Return type:numpy.ndarray
static infer_shapes(concentration)[source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

SineBivariateVonMises

class SineBivariateVonMises(phi_loc, psi_loc, phi_concentration, psi_concentration, correlation=None, weighted_correlation=None, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

Unimodal distribution of two dependent angles on the 2-torus (\(S^1 \otimes S^1\)) given by

\[C^{-1}\exp(\kappa_1\cos(x_1-\mu_1) + \kappa_2\cos(x_2 -\mu_2) + \rho\sin(x_1 - \mu_1)\sin(x_2 - \mu_2))\]

and

\[C = (2\pi)^2 \sum_{i=0} {2i \choose i} \left(\frac{\rho^2}{4\kappa_1\kappa_2}\right)^i I_i(\kappa_1)I_i(\kappa_2),\]

where \(I_i(\cdot)\) is the modified bessel function of first kind, mu’s are the locations of the distribution, kappa’s are the concentration and rho gives the correlation between angles \(x_1\) and \(x_2\). This distribution is helpful for modeling coupled angles such as torsion angles in peptide chains.

To infer parameters, use NUTS or HMC with priors that avoid parameterizations where the distribution becomes bimodal; see note below.

Note

Sample efficiency drops as

\[\frac{\rho}{\kappa_1\kappa_2} \rightarrow 1\]

because the distribution becomes increasingly bimodal. To avoid bimodality use the weighted_correlation parameter with a skew away from one (e.g., Beta(1,3)). The weighted_correlation should be in [0,1].

Note

The correlation and weighted_correlation params are mutually exclusive.

Note

In the context of SVI, this distribution can be used as a likelihood but not for latent variables.

** References: **
  1. Probabilistic model for two dependent circular variables Singh, H., Hnizdo, V., and Demchuck, E. (2002)
Parameters:
  • phi_loc (np.ndarray) – location of first angle
  • psi_loc (np.ndarray) – location of second angle
  • phi_concentration (np.ndarray) – concentration of first angle
  • psi_concentration (np.ndarray) – concentration of second angle
  • correlation (np.ndarray) – correlation between the two angles
  • weighted_correlation (np.ndarray) – set correlation to weigthed_corr * sqrt(phi_conc*psi_conc) to avoid bimodality (see note). The weighted_correlation should be in [0,1].
arg_constraints = {'correlation': Real(), 'phi_concentration': GreaterThan(lower_bound=0.0), 'phi_loc': Interval(lower_bound=-3.141592653589793, upper_bound=3.141592653589793), 'psi_concentration': GreaterThan(lower_bound=0.0), 'psi_loc': Interval(lower_bound=-3.141592653589793, upper_bound=3.141592653589793)}
support = IndependentConstraint(Interval(lower_bound=-3.141592653589793, upper_bound=3.141592653589793), 1)
max_sample_iter = 1000
norm_const[source]
log_prob(*args, **kwargs)
sample(key, sample_shape=())[source]
** References: **
  1. A New Unified Approach for the Simulation of a Wide Class of Directional Distributions John T. Kent, Asaad M. Ganeiber & Kanti V. Mardia (2018)
mean

Computes circular mean of distribution. Note: same as location when mapped to support [-pi, pi]

SineSkewed

class SineSkewed(base_dist: numpyro.distributions.distribution.Distribution, skewness, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

Sine-skewing [1] is a procedure for producing a distribution that breaks pointwise symmetry on a torus distribution. The new distribution is called the Sine Skewed X distribution, where X is the name of the (symmetric) base distribution. Torus distributions are distributions with support on products of circles (i.e., \(\otimes S^1\) where \(S^1 = [-pi,pi)\)). So, a 0-torus is a point, the 1-torus is a circle, and the 2-torus is commonly associated with the donut shape.

The sine skewed X distribution is parameterized by a weight parameter for each dimension of the event of X. For example with a von Mises distribution over a circle (1-torus), the sine skewed von Mises distribution has one skew parameter. The skewness parameters can be inferred using HMC or NUTS. For example, the following will produce a prior over skewness for the 2-torus,:

@numpyro.handlers.reparam(config={'phi_loc': CircularReparam(), 'psi_loc': CircularReparam()})
def model(obs):
    # Sine priors
    phi_loc = numpyro.sample('phi_loc', VonMises(pi, 2.))
    psi_loc = numpyro.sample('psi_loc', VonMises(-pi / 2, 2.))
    phi_conc = numpyro.sample('phi_conc', Beta(1., 1.))
    psi_conc = numpyro.sample('psi_conc', Beta(1., 1.))
    corr_scale = numpyro.sample('corr_scale', Beta(2., 5.))

    # Skewing prior
    ball_trans = L1BallTransform()
    skewness = numpyro.sample('skew_phi', Normal(0, 0.5).expand((2,)))
    skewness = ball_trans(skewness)  # constraint sum |skewness_i| <= 1

    with numpyro.plate('obs_plate'):
        sine = SineBivariateVonMises(phi_loc=phi_loc, psi_loc=psi_loc,
                                     phi_concentration=70 * phi_conc,
                                     psi_concentration=70 * psi_conc,
                                     weighted_correlation=corr_scale)
        return numpyro.sample('phi_psi', SineSkewed(sine, skewness), obs=obs)

To ensure the skewing does not alter the normalization constant of the (sine bivariate von Mises) base distribution the skewness parameters are constraint. The constraint requires the sum of the absolute values of skewness to be less than or equal to one. We can use the L1BallTransform to achieve this.

In the context of SVI, this distribution can freely be used as a likelihood, but use as latent variables it will lead to slow inference for 2 and higher dim toruses. This is because the base_dist cannot be reparameterized.

Note

An event in the base distribution must be on a d-torus, so the event_shape must be (d,).

Note

For the skewness parameter, it must hold that the sum of the absolute value of its weights for an event must be less than or equal to one. See eq. 2.1 in [1].

** References: **
  1. Sine-skewed toroidal distributions and their application in protein bioinformatics
    Ameijeiras-Alonso, J., Ley, C. (2019)
Parameters:
  • base_dist (numpyro.distributions.Distribution) – base density on a d-dimensional torus. Supported base distributions include: 1D VonMises, SineBivariateVonMises, 1D ProjectedNormal, and Uniform (-pi, pi).
  • skewness (jax.numpy.array) – skewness of the distribution.
arg_constraints = {'skewness': L1Ball()}
support = IndependentConstraint(Interval(lower_bound=-3.141592653589793, upper_bound=3.141592653589793), 1)
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value)[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:value – A batch of samples from the distribution.
Returns:an array with shape value.shape[:-self.event_shape]
Return type:numpy.ndarray
mean

Mean of the base distribution

VonMises

class VonMises(loc, concentration, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

The von Mises distribution, also known as the circular normal distribution.

This distribution is supported by a circular constraint from -pi to +pi. By default, the circular support behaves like constraints.interval(-math.pi, math.pi). To avoid issues at the boundaries of this interval during sampling, you should reparameterize this distribution using handlers.reparam with a CircularReparam reparametrizer in the model, e.g.:

@handlers.reparam(config={"direction": CircularReparam()})
def model():
    direction = numpyro.sample("direction", VonMises(0.0, 4.0))
    ...
arg_constraints = {'concentration': GreaterThan(lower_bound=0.0), 'loc': Real()}
reparametrized_params = ['loc']
support = Interval(lower_bound=-3.141592653589793, upper_bound=3.141592653589793)
sample(key, sample_shape=())[source]

Generate sample from von Mises distribution

Parameters:
  • key – random number generator key
  • sample_shape – shape of samples
Returns:

samples from von Mises

log_prob(*args, **kwargs)
mean

Computes circular mean of distribution. NOTE: same as location when mapped to support [-pi, pi]

variance

Computes circular variance of distribution

Truncated Distributions

LeftTruncatedDistribution

class LeftTruncatedDistribution(base_dist, low=0.0, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'low': Real()}
reparametrized_params = ['low']
supported_types = (<class 'numpyro.distributions.continuous.Cauchy'>, <class 'numpyro.distributions.continuous.Laplace'>, <class 'numpyro.distributions.continuous.Logistic'>, <class 'numpyro.distributions.continuous.Normal'>, <class 'numpyro.distributions.continuous.SoftLaplace'>, <class 'numpyro.distributions.continuous.StudentT'>)
support
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]
mean

Mean of the distribution.

var

RightTruncatedDistribution

class RightTruncatedDistribution(base_dist, high=0.0, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'high': Real()}
reparametrized_params = ['high']
supported_types = (<class 'numpyro.distributions.continuous.Cauchy'>, <class 'numpyro.distributions.continuous.Laplace'>, <class 'numpyro.distributions.continuous.Logistic'>, <class 'numpyro.distributions.continuous.Normal'>, <class 'numpyro.distributions.continuous.SoftLaplace'>, <class 'numpyro.distributions.continuous.StudentT'>)
support
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]
mean

Mean of the distribution.

var

TruncatedCauchy

class TruncatedCauchy[source]

TruncatedDistribution

TruncatedDistribution(base_dist, low=None, high=None, *, validate_args=None)[source]

A function to generate a truncated distribution.

Parameters:
  • base_dist – The base distribution to be truncated. This should be a univariate distribution. Currently, only the following distributions are supported: Cauchy, Laplace, Logistic, Normal, and StudentT.
  • low – the value which is used to truncate the base distribution from below. Setting this parameter to None to not truncate from below.
  • high – the value which is used to truncate the base distribution from above. Setting this parameter to None to not truncate from above.

TruncatedNormal

class TruncatedNormal[source]

TruncatedPolyaGamma

class TruncatedPolyaGamma(batch_shape=(), *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

truncation_point = 2.5
num_log_prob_terms = 7
num_gamma_variates = 8
arg_constraints = {}
support = Interval(lower_bound=0.0, upper_bound=2.5)
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]

TwoSidedTruncatedDistribution

class TwoSidedTruncatedDistribution(base_dist, low=0.0, high=1.0, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'high': Dependent(), 'low': Dependent()}
reparametrized_params = ['low', 'high']
supported_types = (<class 'numpyro.distributions.continuous.Cauchy'>, <class 'numpyro.distributions.continuous.Laplace'>, <class 'numpyro.distributions.continuous.Logistic'>, <class 'numpyro.distributions.continuous.Normal'>, <class 'numpyro.distributions.continuous.SoftLaplace'>, <class 'numpyro.distributions.continuous.StudentT'>)
support
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]
mean

Mean of the distribution.

var

TensorFlow Distributions

Thin wrappers around TensorFlow Probability (TFP) distributions. For details on the TFP distribution interface, see its Distribution docs.

BijectorConstraint

class BijectorConstraint(bijector)[source]

A constraint which is codomain of a TensorFlow bijector.

Parameters:bijector (Bijector) – a TensorFlow bijector

BijectorTransform

class BijectorTransform(bijector)[source]

A wrapper for TensorFlow bijectors to make them compatible with NumPyro’s transforms.

Parameters:bijector (Bijector) – a TensorFlow bijector

TFPDistribution

class TFPDistribution(batch_shape=(), event_shape=(), *, validate_args=None)[source]

A thin wrapper for TensorFlow Probability (TFP) distributions. The constructor has the same signature as the corresponding TFP distribution.

This class can be used to convert a TFP distribution to a NumPyro-compatible one as follows:

d = TFPDistribution[tfd.Normal](0, 1)

Note that typical use cases do not require explicitly invoking this wrapper, since NumPyro wraps TFP distributions automatically under the hood in model code, e.g.:

from tensorflow_probability.substrates.jax import distributions as tfd

def model():
    numpyro.sample("x", tfd.Normal(0, 1))

Constraints

Constraint

class Constraint[source]

Bases: object

Abstract base class for constraints.

A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.

is_discrete = False
event_dim = 0
check(value)[source]

Returns a byte tensor of sample_shape + batch_shape indicating whether each event in value satisfies this constraint.

feasible_like(prototype)[source]

Get a feasible value which has the same shape as dtype as prototype.

boolean

boolean = Boolean()

circular

circular = Interval(lower_bound=-3.141592653589793, upper_bound=3.141592653589793)

corr_cholesky

corr_cholesky = CorrCholesky()

corr_matrix

corr_matrix = CorrMatrix()

dependent

dependent = Dependent()

Placeholder for variables whose support depends on other variables. These variables obey no simple coordinate-wise constraints.

Parameters:
  • is_discrete (bool) – Optional value of .is_discrete in case this can be computed statically. If not provided, access to the .is_discrete attribute will raise a NotImplementedError.
  • event_dim (int) – Optional value of .event_dim in case this can be computed statically. If not provided, access to the .event_dim attribute will raise a NotImplementedError.

greater_than

greater_than(lower_bound)

Abstract base class for constraints.

A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.

integer_interval

integer_interval(lower_bound, upper_bound)

Abstract base class for constraints.

A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.

integer_greater_than

integer_greater_than(lower_bound)

Abstract base class for constraints.

A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.

interval

interval(lower_bound, upper_bound)

Abstract base class for constraints.

A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.

l1_ball

l1_ball(x)

Constrain to the L1 ball of any dimension.

less_than

less_than(upper_bound)

Abstract base class for constraints.

A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.

lower_cholesky

lower_cholesky = LowerCholesky()

multinomial

multinomial(upper_bound)

Abstract base class for constraints.

A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.

nonnegative_integer

nonnegative_integer = IntegerGreaterThan(lower_bound=0)

ordered_vector

ordered_vector = OrderedVector()

positive

positive = GreaterThan(lower_bound=0.0)

positive_definite

positive_definite = PositiveDefinite()

positive_integer

positive_integer = IntegerGreaterThan(lower_bound=1)

positive_ordered_vector

positive_ordered_vector = PositiveOrderedVector()

Constrains to a positive real-valued tensor where the elements are monotonically increasing along the event_shape dimension.

real

real = Real()

real_vector

real_vector = IndependentConstraint(Real(), 1)

Wraps a constraint by aggregating over reinterpreted_batch_ndims-many dims in check(), so that an event is valid only if all its independent entries are valid.

scaled_unit_lower_cholesky

scaled_unit_lower_cholesky = ScaledUnitLowerCholesky()

softplus_positive

softplus_positive = SoftplusPositive(lower_bound=0.0)

softplus_lower_cholesky

softplus_lower_cholesky = SoftplusLowerCholesky()

simplex

simplex = Simplex()

sphere

sphere = Sphere()

Constrain to the Euclidean sphere of any dimension.

unit_interval

unit_interval = Interval(lower_bound=0.0, upper_bound=1.0)

Transforms

biject_to

biject_to(constraint)

Transform

class Transform[source]

Bases: object

domain = Real()
codomain = Real()
inv
log_abs_det_jacobian(x, y, intermediates=None)[source]
call_with_intermediates(x)[source]
forward_shape(shape)[source]

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

AbsTransform

class AbsTransform[source]

Bases: numpyro.distributions.transforms.Transform

domain = Real()
codomain = GreaterThan(lower_bound=0.0)

AffineTransform

class AffineTransform(loc, scale, domain=Real())[source]

Bases: numpyro.distributions.transforms.Transform

Note

When scale is a JAX tracer, we always assume that scale > 0 when calculating codomain.

codomain
log_abs_det_jacobian(x, y, intermediates=None)[source]
forward_shape(shape)[source]

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

CholeskyTransform

class CholeskyTransform[source]

Bases: numpyro.distributions.transforms.Transform

Transform via the mapping \(y = cholesky(x)\), where x is a positive definite matrix.

domain = PositiveDefinite()
codomain = LowerCholesky()
log_abs_det_jacobian(x, y, intermediates=None)[source]

ComposeTransform

class ComposeTransform(parts)[source]

Bases: numpyro.distributions.transforms.Transform

domain
codomain
log_abs_det_jacobian(x, y, intermediates=None)[source]
call_with_intermediates(x)[source]
forward_shape(shape)[source]

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

CorrCholeskyTransform

class CorrCholeskyTransform[source]

Bases: numpyro.distributions.transforms.Transform

Transforms a uncontrained real vector \(x\) with length \(D*(D-1)/2\) into the Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower triangular matrix with positive diagonals and unit Euclidean norm for each row. The transform is processed as follows:

  1. First we convert \(x\) into a lower triangular matrix with the following order:
\[\begin{split}\begin{bmatrix} 1 & 0 & 0 & 0 \\ x_0 & 1 & 0 & 0 \\ x_1 & x_2 & 1 & 0 \\ x_3 & x_4 & x_5 & 1 \end{bmatrix}\end{split}\]

2. For each row \(X_i\) of the lower triangular part, we apply a signed version of class StickBreakingTransform to transform \(X_i\) into a unit Euclidean length vector using the following steps:

  1. Scales into the interval \((-1, 1)\) domain: \(r_i = \tanh(X_i)\).
  2. Transforms into an unsigned domain: \(z_i = r_i^2\).
  3. Applies \(s_i = StickBreakingTransform(z_i)\).
  4. Transforms back into signed domain: \(y_i = (sign(r_i), 1) * \sqrt{s_i}\).
domain = IndependentConstraint(Real(), 1)
codomain = CorrCholesky()
log_abs_det_jacobian(x, y, intermediates=None)[source]
forward_shape(shape)[source]

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

CorrMatrixCholeskyTransform

class CorrMatrixCholeskyTransform[source]

Bases: numpyro.distributions.transforms.CholeskyTransform

Transform via the mapping \(y = cholesky(x)\), where x is a correlation matrix.

domain = CorrMatrix()
codomain = CorrCholesky()
log_abs_det_jacobian(x, y, intermediates=None)[source]

ExpTransform

class ExpTransform(domain=Real())[source]

Bases: numpyro.distributions.transforms.Transform

codomain
log_abs_det_jacobian(x, y, intermediates=None)[source]

IdentityTransform

class IdentityTransform[source]

Bases: numpyro.distributions.transforms.Transform

log_abs_det_jacobian(x, y, intermediates=None)[source]

L1BallTransform

class L1BallTransform[source]

Bases: numpyro.distributions.transforms.Transform

Transforms a uncontrained real vector \(x\) into the unit L1 ball.

domain = IndependentConstraint(Real(), 1)
codomain = L1Ball()
log_abs_det_jacobian(x, y, intermediates=None)[source]

LowerCholeskyAffine

class LowerCholeskyAffine(loc, scale_tril)[source]

Bases: numpyro.distributions.transforms.Transform

Transform via the mapping \(y = loc + scale\_tril\ @\ x\).

Parameters:
  • loc – a real vector.
  • scale_tril – a lower triangular matrix with positive diagonal.

Example

>>> import jax.numpy as jnp
>>> from numpyro.distributions.transforms import LowerCholeskyAffine
>>> base = jnp.ones(2)
>>> loc = jnp.zeros(2)
>>> scale_tril = jnp.array([[0.3, 0.0], [1.0, 0.5]])
>>> affine = LowerCholeskyAffine(loc=loc, scale_tril=scale_tril)
>>> affine(base)
DeviceArray([0.3, 1.5], dtype=float32)
domain = IndependentConstraint(Real(), 1)
codomain = IndependentConstraint(Real(), 1)
log_abs_det_jacobian(x, y, intermediates=None)[source]
forward_shape(shape)[source]

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

LowerCholeskyTransform

class LowerCholeskyTransform[source]

Bases: numpyro.distributions.transforms.Transform

Transform a real vector to a lower triangular cholesky factor, where the strictly lower triangular submatrix is unconstrained and the diagonal is parameterized with an exponential transform.

domain = IndependentConstraint(Real(), 1)
codomain = LowerCholesky()
log_abs_det_jacobian(x, y, intermediates=None)[source]
forward_shape(shape)[source]

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

OrderedTransform

class OrderedTransform[source]

Bases: numpyro.distributions.transforms.Transform

Transform a real vector to an ordered vector.

References:

  1. Stan Reference Manual v2.20, section 10.6, Stan Development Team

Example

>>> import jax.numpy as jnp
>>> from numpyro.distributions.transforms import OrderedTransform
>>> base = jnp.ones(3)
>>> transform = OrderedTransform()
>>> assert jnp.allclose(transform(base), jnp.array([1., 3.7182817, 6.4365635]), rtol=1e-3, atol=1e-3)
domain = IndependentConstraint(Real(), 1)
codomain = OrderedVector()
log_abs_det_jacobian(x, y, intermediates=None)[source]

PermuteTransform

class PermuteTransform(permutation)[source]

Bases: numpyro.distributions.transforms.Transform

domain = IndependentConstraint(Real(), 1)
codomain = IndependentConstraint(Real(), 1)
log_abs_det_jacobian(x, y, intermediates=None)[source]

PowerTransform

class PowerTransform(exponent)[source]

Bases: numpyro.distributions.transforms.Transform

domain = GreaterThan(lower_bound=0.0)
codomain = GreaterThan(lower_bound=0.0)
log_abs_det_jacobian(x, y, intermediates=None)[source]
forward_shape(shape)[source]

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

ScaledUnitLowerCholeskyTransform

class ScaledUnitLowerCholeskyTransform[source]

Bases: numpyro.distributions.transforms.LowerCholeskyTransform

Like LowerCholeskyTransform this Transform transforms a real vector to a lower triangular cholesky factor. However it does so via a decomposition

\(y = loc + unit\_scale\_tril\ @\ scale\_diag\ @\ x\).

where \(unit\_scale\_tril\) has ones along the diagonal and \(scale\_diag\) is a diagonal matrix with all positive entries that is parameterized with a softplus transform.

domain = IndependentConstraint(Real(), 1)
codomain = ScaledUnitLowerCholesky()
log_abs_det_jacobian(x, y, intermediates=None)[source]

SigmoidTransform

class SigmoidTransform[source]

Bases: numpyro.distributions.transforms.Transform

codomain = Interval(lower_bound=0.0, upper_bound=1.0)
log_abs_det_jacobian(x, y, intermediates=None)[source]

SimplexToOrderedTransform

class SimplexToOrderedTransform(anchor_point=0.0)[source]

Bases: numpyro.distributions.transforms.Transform

Transform a simplex into an ordered vector (via difference in Logistic CDF between cutpoints) Used in [1] to induce a prior on latent cutpoints via transforming ordered category probabilities.

Parameters:anchor_point – Anchor point is a nuisance parameter to improve the identifiability of the transform. For simplicity, we assume it is a scalar value, but it is broadcastable x.shape[:-1]. For more details please refer to Section 2.2 in [1]

References:

  1. Ordinal Regression Case Study, section 2.2, M. Betancourt, https://betanalpha.github.io/assets/case_studies/ordinal_regression.html

Example

>>> import jax.numpy as jnp
>>> from numpyro.distributions.transforms import SimplexToOrderedTransform
>>> base = jnp.array([0.3, 0.1, 0.4, 0.2])
>>> transform = SimplexToOrderedTransform()
>>> assert jnp.allclose(transform(base), jnp.array([-0.8472978, -0.40546507, 1.3862944]), rtol=1e-3, atol=1e-3)
domain = Simplex()
codomain = OrderedVector()
log_abs_det_jacobian(x, y, intermediates=None)[source]

SoftplusLowerCholeskyTransform

class SoftplusLowerCholeskyTransform[source]

Bases: numpyro.distributions.transforms.Transform

Transform from unconstrained vector to lower-triangular matrices with nonnegative diagonal entries. This is useful for parameterizing positive definite matrices in terms of their Cholesky factorization.

domain = IndependentConstraint(Real(), 1)
codomain = SoftplusLowerCholesky()
log_abs_det_jacobian(x, y, intermediates=None)[source]
forward_shape(shape)[source]

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

SoftplusTransform

class SoftplusTransform[source]

Bases: numpyro.distributions.transforms.Transform

Transform from unconstrained space to positive domain via softplus \(y = \log(1 + \exp(x))\). The inverse is computed as \(x = \log(\exp(y) - 1)\).

domain = Real()
codomain = SoftplusPositive(lower_bound=0.0)
log_abs_det_jacobian(x, y, intermediates=None)[source]

StickBreakingTransform

class StickBreakingTransform[source]

Bases: numpyro.distributions.transforms.Transform

domain = IndependentConstraint(Real(), 1)
codomain = Simplex()
log_abs_det_jacobian(x, y, intermediates=None)[source]
forward_shape(shape)[source]

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

Flows

InverseAutoregressiveTransform

class InverseAutoregressiveTransform(autoregressive_nn, log_scale_min_clip=-5.0, log_scale_max_clip=3.0)[source]

Bases: numpyro.distributions.transforms.Transform

An implementation of Inverse Autoregressive Flow, using Eq (10) from Kingma et al., 2016,

\(\mathbf{y} = \mu_t + \sigma_t\odot\mathbf{x}\)

where \(\mathbf{x}\) are the inputs, \(\mathbf{y}\) are the outputs, \(\mu_t,\sigma_t\) are calculated from an autoregressive network on \(\mathbf{x}\), and \(\sigma_t>0\).

References

  1. Improving Variational Inference with Inverse Autoregressive Flow [arXiv:1606.04934], Diederik P. Kingma, Tim Salimans, Rafal Jozefowicz, Xi Chen, Ilya Sutskever, Max Welling
domain = IndependentConstraint(Real(), 1)
codomain = IndependentConstraint(Real(), 1)
call_with_intermediates(x)[source]
log_abs_det_jacobian(x, y, intermediates=None)[source]

Calculates the elementwise determinant of the log jacobian.

Parameters:

BlockNeuralAutoregressiveTransform

class BlockNeuralAutoregressiveTransform(bn_arn)[source]

Bases: numpyro.distributions.transforms.Transform

An implementation of Block Neural Autoregressive flow.

References

  1. Block Neural Autoregressive Flow, Nicola De Cao, Ivan Titov, Wilker Aziz
domain = IndependentConstraint(Real(), 1)
codomain = IndependentConstraint(Real(), 1)
call_with_intermediates(x)[source]
log_abs_det_jacobian(x, y, intermediates=None)[source]

Calculates the elementwise determinant of the log jacobian.

Parameters: