Distributions

Base Distribution

Distribution

class Distribution(batch_shape: tuple[int, ...] = (), event_shape: tuple[int, ...] = (), *, validate_args: bool | None = None)[source]

Bases: object

Base class for probability distributions in NumPyro. The design largely follows from torch.distributions.

Parameters:
  • batch_shape – The batch shape for the distribution. This designates independent (possibly non-identical) dimensions of a sample from the distribution. This is fixed for a distribution instance and is inferred from the shape of the distribution parameters.

  • event_shape – The event shape for the distribution. This designates the dependent dimensions of a sample from the distribution. These are collapsed when we evaluate the log probability density of a batch of samples using .log_prob.

  • validate_args – Whether to enable validation of distribution parameters and arguments to .log_prob method.

As an example:

>>> import jax.numpy as jnp
>>> import numpyro.distributions as dist
>>> d = dist.Dirichlet(jnp.ones((2, 3, 4)))
>>> d.batch_shape
(2, 3)
>>> d.event_shape
(4,)
arg_constraints: dict[str, Any] = {}
pytree_data_fields: tuple[str, ...] = ()
pytree_aux_fields: tuple[str, ...] = ('_batch_shape', '_event_shape')
property support: Constraint | None

The support of this distribution. Subclasses can override this as a class attribute or as a property.

has_enumerate_support: bool = False
reparametrized_params: list[str] = []
classmethod gather_pytree_data_fields() tuple[str, ...][source]
classmethod gather_pytree_aux_fields() tuple[str, ...][source]
tree_flatten() tuple[tuple[Any, ...], tuple[Any, ...]][source]
classmethod tree_unflatten(aux_data: tuple[Any, ...], params: tuple[Any, ...]) Distribution[source]
static set_default_validate_args(value: bool) None[source]
get_args() dict[str, Any][source]

Get arguments of the distribution.

validate_args(strict: bool = True) None[source]

Validate the arguments of the distribution.

Parameters:

strict – Require strict validation, raising an error if the function is called inside jitted code.

property batch_shape: tuple[int, ...]

Returns the shape over which the distribution parameters are batched.

Returns:

batch shape of the distribution.

Return type:

tuple[int, …]

property event_shape: tuple[int, ...]

Returns the shape of a single sample from the distribution without batching.

Returns:

event shape of the distribution.

Return type:

tuple[int, …]

property event_dim: int
Returns:

Number of dimensions of individual events.

Return type:

int

property has_rsample: bool
rsample(key: Array | None, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]
shape(sample_shape: tuple[int, ...] = ()) tuple[int, ...][source]

The tensor shape of samples from this distribution.

Samples are of shape:

d.shape(sample_shape) == sample_shape + d.batch_shape + d.event_shape
Parameters:

sample_shape (tuple) – the size of the iid batch to be drawn from the distribution.

Returns:

shape of samples.

Return type:

tuple

sample(key: Array | None, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

sample_with_intermediates(key: Array | None, sample_shape: tuple[int, ...] = ()) tuple[Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, list[Any]][source]

Same as sample except that any intermediate computations are returned (useful for TransformedDistribution).

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.

to_event(reinterpreted_batch_ndims: int | None = None) Distribution[source]

Interpret the rightmost reinterpreted_batch_ndims batch dimensions as dependent event dimensions.

Parameters:

reinterpreted_batch_ndims – Number of rightmost batch dims to interpret as event dims.

Returns:

An instance of Independent distribution.

Return type:

numpyro.distributions.distribution.Independent

enumerate_support(expand: bool = True) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns an array with shape len(support) x batch_shape containing all values in the support.

entropy() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns the entropy of the distribution.

expand(batch_shape: Sequence[int]) Distribution[source]

Returns a new ExpandedDistribution instance with batch dimensions expanded to batch_shape.

Parameters:

batch_shape (tuple) – batch shape to expand to.

Returns:

an instance of ExpandedDistribution.

Return type:

ExpandedDistribution

expand_by(sample_shape: Sequence[int]) Distribution[source]

Expands a distribution by adding sample_shape to the left side of its batch_shape. To expand internal dims of self.batch_shape from 1 to something larger, use expand() instead.

Parameters:

sample_shape (tuple) – The size of the iid batch to be drawn from the distribution.

Returns:

An expanded version of this distribution.

Return type:

ExpandedDistribution

mask(mask: Array) MaskedDistribution[source]

Masks a distribution by a boolean or boolean-valued array that is broadcastable to the distributions Distribution.batch_shape .

Parameters:

mask (bool or jnp.ndarray) – A boolean or boolean valued array (True includes a site, False excludes a site).

Returns:

A masked copy of this distribution.

Return type:

MaskedDistribution

Example:

>>> from jax import random
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.distributions import constraints
>>> from numpyro.infer import SVI, Trace_ELBO

>>> def model(data, m):
...     f = numpyro.sample("latent_fairness", dist.Beta(1, 1))
...     with numpyro.plate("N", data.shape[0]):
...         # only take into account the values selected by the mask
...         masked_dist = dist.Bernoulli(f).mask(m)
...         numpyro.sample("obs", masked_dist, obs=data)


>>> def guide(data, m):
...     alpha_q = numpyro.param("alpha_q", 5., constraint=constraints.positive)
...     beta_q = numpyro.param("beta_q", 5., constraint=constraints.positive)
...     numpyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q))


>>> data = jnp.concatenate([jnp.ones(5), jnp.zeros(5)])
>>> # select values equal to one
>>> masked_array = jnp.where(data == 1, True, False)
>>> optimizer = numpyro.optim.Adam(step_size=0.05)
>>> svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
>>> svi_result = svi.run(random.key(0), 300, data, masked_array)
>>> params = svi_result.params
>>> # inferred_mean is closer to 1
>>> inferred_mean = params["alpha_q"] / (params["alpha_q"] + params["beta_q"])
classmethod infer_shapes(*args: Any, **kwargs: Any) tuple[tuple[int, ...], tuple[int, ...]][source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.

  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.

Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

cdf(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The cumulative distribution function of this distribution.

Parameters:

value – samples from this distribution.

Returns:

output of the cumulative distribution function evaluated at value.

icdf(q: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The inverse cumulative distribution function of this distribution.

Parameters:

q – quantile values, should belong to [0, 1].

Returns:

the samples whose cdf values equals to q.

property is_discrete: bool

ExpandedDistribution

class ExpandedDistribution(base_dist: Distribution, batch_shape: tuple[int, ...] = ())[source]

Bases: Distribution

arg_constraints: dict[str, Any] = {}
pytree_data_fields: tuple[str, ...] = ('base_dist',)
pytree_aux_fields: tuple[str, ...] = ('_expanded_sizes', '_interstitial_sizes', 'has_enumerate_support')
base_dist: Distribution
property has_rsample: bool
rsample(key: Array | None, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]
property support: Constraint

The support of this distribution. Subclasses can override this as a class attribute or as a property.

sample_with_intermediates(key: Array | None, sample_shape: tuple[int, ...] = ()) tuple[Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, list[Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray]][source]

Same as sample except that any intermediate computations are returned (useful for TransformedDistribution).

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

sample(key: Array | None, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, intermediates: list[Any] | None = None) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

enumerate_support(expand: bool = True) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns an array with shape len(support) x batch_shape containing all values in the support.

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.

entropy() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns the entropy of the distribution.

FoldedDistribution

class FoldedDistribution(base_dist: Distribution, *, validate_args: bool | None = None)[source]

Bases: TransformedDistribution

Equivalent to TransformedDistribution(base_dist, AbsTransform()), but additionally supports log_prob() .

Parameters:

base_dist (Distribution) – A univariate distribution to reflect.

support = Positive(lower_bound=0.0)
log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

ImproperUniform

class ImproperUniform(support: constraints.Constraint, batch_shape: tuple[int, ...], event_shape: tuple[int, ...], *, validate_args: bool | None = None)[source]

Bases: Distribution

A helper distribution with zero log_prob() over the support domain.

Note

sample method is not implemented for this distribution. In autoguide and mcmc, initial parameters for improper sites are derived from init_to_uniform or init_to_value strategies.

Usage:

>>> from numpyro import sample
>>> from numpyro.distributions import ImproperUniform, Normal, constraints
>>>
>>> def model():
...     # ordered vector with length 10
...     x = sample('x', ImproperUniform(constraints.ordered_vector, (), event_shape=(10,)))
...
...     # real matrix with shape (3, 4)
...     y = sample('y', ImproperUniform(constraints.real, (), event_shape=(3, 4)))
...
...     # a shape-(6, 8) batch of length-5 vectors greater than 3
...     z = sample('z', ImproperUniform(constraints.greater_than(3), (6, 8), event_shape=(5,)))

If you want to set improper prior over all values greater than a, where a is another random variable, you might use

>>> def model():
...     a = sample('a', Normal(0, 1))
...     x = sample('x', ImproperUniform(constraints.greater_than(a), (), event_shape=()))

or if you want to reparameterize it

>>> from numpyro.distributions import TransformedDistribution, transforms
>>> from numpyro.handlers import reparam
>>> from numpyro.infer.reparam import TransformReparam
>>>
>>> def model():
...     a = sample('a', Normal(0, 1))
...     with reparam(config={'x': TransformReparam()}):
...         x = sample('x',
...                    TransformedDistribution(ImproperUniform(constraints.positive, (), ()),
...                                            transforms.AffineTransform(a, 1)))
Parameters:
  • support (Constraint) – the support of this distribution.

  • batch_shape (tuple) – batch shape of this distribution. It is usually safe to set batch_shape=().

  • event_shape (tuple) – event shape of this distribution.

arg_constraints: dict[str, Any] = {}
pytree_data_fields: tuple[str, ...] = ('support',)
support: Constraint = Dependent()
log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

Independent

class Independent(base_dist: Distribution, reinterpreted_batch_ndims: int, *, validate_args: bool | None = None)[source]

Bases: Distribution

Reinterprets batch dimensions of a distribution as event dims by shifting the batch-event dim boundary further to the left.

From a practical standpoint, this is useful when changing the result of log_prob(). For example, a univariate Normal distribution can be interpreted as a multivariate Normal with diagonal covariance:

>>> import numpyro.distributions as dist
>>> normal = dist.Normal(jnp.zeros(3), jnp.ones(3))
>>> [normal.batch_shape, normal.event_shape]
[(3,), ()]
>>> diag_normal = dist.Independent(normal, 1)
>>> [diag_normal.batch_shape, diag_normal.event_shape]
[(), (3,)]
Parameters:
  • base_distribution (numpyro.distribution.Distribution) – a distribution instance.

  • reinterpreted_batch_ndims (int) – the number of batch dims to reinterpret as event dims.

arg_constraints: dict[str, Any] = {}
pytree_data_fields: tuple[str, ...] = ('base_dist',)
pytree_aux_fields: tuple[str, ...] = ('reinterpreted_batch_ndims', 'has_enumerate_support')
base_dist: Distribution
property support: Constraint

The support of this distribution. Subclasses can override this as a class attribute or as a property.

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.

property has_rsample: bool
rsample(key: Array | None, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]
sample(key: Array | None, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

expand(batch_shape: Sequence[int]) Distribution[source]

Returns a new ExpandedDistribution instance with batch dimensions expanded to batch_shape.

Parameters:

batch_shape (tuple) – batch shape to expand to.

Returns:

an instance of ExpandedDistribution.

Return type:

ExpandedDistribution

entropy() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns the entropy of the distribution.

MaskedDistribution

class MaskedDistribution(base_dist: Distribution, mask: bool | Array)[source]

Bases: Distribution

Masks a distribution by a boolean array that is broadcastable to the distribution’s Distribution.batch_shape. In the special case mask is False, computation of log_prob() , is skipped, and constant zero values are returned instead.

Parameters:

mask (jnp.ndarray or bool) – A boolean or boolean-valued array.

arg_constraints: dict[str, Any] = {}
pytree_data_fields: tuple[str, ...] = ('base_dist', '_mask')
pytree_aux_fields: tuple[str, ...] = ('_mask', 'has_enumerate_support')
base_dist: Distribution
property has_rsample: bool
rsample(key: Array | None, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]
property support: Constraint

The support of this distribution. Subclasses can override this as a class attribute or as a property.

sample(key: Array | None, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

enumerate_support(expand: bool = True) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns an array with shape len(support) x batch_shape containing all values in the support.

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.

tree_flatten() tuple[tuple[Any, ...], tuple[Any, ...]][source]
classmethod tree_unflatten(aux_data: tuple[Any, ...], params: tuple[Any, ...]) MaskedDistribution[source]

TransformedDistribution

class TransformedDistribution(base_distribution: Distribution, transforms: Transform | list[Transform], *, validate_args: bool | None = None)[source]

Bases: Distribution

Returns a distribution instance obtained as a result of applying a sequence of transforms to a base distribution. For an example, see LogNormal and HalfNormal.

Parameters:
  • base_distribution – the base distribution over which to apply transforms.

  • transforms – a single transform or a list of transforms.

  • validate_args – Whether to enable validation of distribution parameters and arguments to .log_prob method.

arg_constraints: dict[str, Any] = {}
pytree_data_fields: tuple[str, ...] = ('base_dist', 'transforms')
transforms: list[Transform]
base_dist: Distribution
property has_rsample: bool
rsample(key: Array | None, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]
property support: Constraint

The support of this distribution. Subclasses can override this as a class attribute or as a property.

sample(key: Array | None, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

sample_with_intermediates(key: Array | None, sample_shape: tuple[int, ...] = ()) tuple[Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, list[Any]][source]

Same as sample except that any intermediate computations are returned (useful for TransformedDistribution).

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, intermediates: list[Any] | None = None) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.

cdf(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The cumulative distribution function of this distribution.

Parameters:

value – samples from this distribution.

Returns:

output of the cumulative distribution function evaluated at value.

icdf(q: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The inverse cumulative distribution function of this distribution.

Parameters:

q – quantile values, should belong to [0, 1].

Returns:

the samples whose cdf values equals to q.

Delta

class Delta(v: ArrayLike = 0.0, log_density: ArrayLike = 0.0, event_dim: int = 0, *, validate_args: bool | None = None)[source]

Bases: Distribution

arg_constraints: dict[str, Any] = {'log_density': Real(), 'v': Dependent()}
reparametrized_params: list[str] = ['v', 'log_density']
property support: Constraint

The support of this distribution. Subclasses can override this as a class attribute or as a property.

sample(key: Array | None, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.

entropy() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns the entropy of the distribution.

Unit

class Unit(log_factor: ArrayLike, *, validate_args: bool = False)[source]

Bases: Distribution

Trivial nonnormalized distribution representing the unit type.

The unit type has a single value with no data, i.e. value.size == 0.

This is used for numpyro.factor() statements.

arg_constraints: dict[str, Any] = {'log_factor': Real()}
support = Real()
sample(key: Array | None, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

Continuous Distributions

AsymmetricLaplace

class AsymmetricLaplace(loc: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 0.0, scale: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1.0, asymmetry: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1.0, *, validate_args: bool | None = None)[source]

Bases: Distribution

arg_constraints: dict[str, Any] = {'asymmetry': Positive(lower_bound=0.0), 'loc': Real(), 'scale': Positive(lower_bound=0.0)}
reparametrized_params: list[str] = ['loc', 'scale', 'asymmetry']
support = Real()
left_scale()[source]
right_scale()[source]
log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.

cdf(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The cumulative distribution function of this distribution.

Parameters:

value – samples from this distribution.

Returns:

output of the cumulative distribution function evaluated at value.

icdf(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The inverse cumulative distribution function of this distribution.

Parameters:

q – quantile values, should belong to [0, 1].

Returns:

the samples whose cdf values equals to q.

AsymmetricLaplaceQuantile

class AsymmetricLaplaceQuantile(loc: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 0.0, scale: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1.0, quantile: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 0.5, *, validate_args: bool | None = None)[source]

Bases: Distribution

An alternative parameterization of AsymmetricLaplace commonly applied in Bayesian quantile regression.

Instead of the asymmetry parameter employed by AsymmetricLaplace, to define the balance between left- versus right-hand sides of the distribution, this class utilizes a quantile parameter, which describes the proportion of probability density that falls to the left-hand side of the distribution.

The scale parameter is also interpreted slightly differently than in AsymmetricLaplace. When loc=0 and scale=1, AsymmetricLaplace(0,1,1) is equivalent to Laplace(0,1), while AsymmetricLaplaceQuantile(0,1,0.5) is equivalent to Laplace(0,2).

arg_constraints: dict[str, Any] = {'loc': Real(), 'quantile': OpenInterval(lower_bound=0.0, upper_bound=1.0), 'scale': Positive(lower_bound=0.0)}
reparametrized_params: list[str] = ['loc', 'scale', 'quantile']
support = Real()
pytree_data_fields: tuple[str, ...] = ('loc', 'scale', 'quantile', '_ald')
log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.

cdf(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The cumulative distribution function of this distribution.

Parameters:

value – samples from this distribution.

Returns:

output of the cumulative distribution function evaluated at value.

icdf(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The inverse cumulative distribution function of this distribution.

Parameters:

q – quantile values, should belong to [0, 1].

Returns:

the samples whose cdf values equals to q.

Beta

class Beta(concentration1: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, concentration0: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, validate_args: bool | None = None)[source]

Bases: Distribution

Beta distribution parameterized by concentration parameters alpha (concentration1) and beta (concentration0), on the unit interval \([0,1]\).

The probability density function (PDF) is defined as:

\[f(x; \alpha, \beta) = \frac{x^{\alpha - 1} (1 - x)^{\beta - 1}}{\text{B}(\alpha, \beta)}\]

Where, \(x \in [0, 1]\), \(\alpha > 0\), \(\beta > 0\), and \(\text{B}(\alpha, \beta)\) is the Beta function.

Parameters:
  • concentration1 (ArrayLike) – Alpha parameter (1st shape parameter).

  • concentration0 (ArrayLike) – Beta parameter (2nd shape parameter).

  • validate_args (bool, optional) – Whether to validate input constraints, defaults to None.

arg_constraints: dict[str, Any] = {'concentration0': Positive(lower_bound=0.0), 'concentration1': Positive(lower_bound=0.0)}
reparametrized_params: list[str] = ['concentration1', 'concentration0']
support = UnitInterval(lower_bound=0.0, upper_bound=1.0)
pytree_data_fields: tuple[str, ...] = ('concentration0', 'concentration1', '_dirichlet')
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Generates samples from the distribution using the underlying Dirichlet implementation.

Since a \(\mathrm{Beta}(\alpha, \beta)\) distribution is equivalent to a 2-category \(\mathrm{Dirichlet}(\alpha, \beta)\), this method samples from the Dirichlet and slices the first component.

Parameters:
  • key (jax.Array) – JAX PRNGKey for reproducibility.

  • sample_shape (tuple[int, ...]) – The shape of the samples to be generated.

Returns:

Samples from the Beta distribution of shape sample_shape + batch_shape.

Return type:

ArrayLike

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Calculates the log of the probability density function.

To avoid NaN gradients at the boundaries \(x=0\) or \(x=1\), this implementation masks boundary values with a safe constant (0.5) during the differentiation path. The forward pass value is then corrected using stop_gradient() to ensure numerical stability without sacrificing accuracy.

Parameters:

value (ArrayLike) – Values at which to evaluate the log density.

Returns:

Log probability density.

Return type:

ArrayLike

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Calculates the analytical mean.

\[E[X] = \frac{\alpha}{\alpha + \beta}\]
property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Calculates the analytical variance.

\[Var(X) = \frac{\alpha \beta}{(\alpha + \beta)^2 (\alpha + \beta + 1)}\]
cdf(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Cumulative distribution function using the regularized incomplete beta function.

\[I_x(\alpha, \beta) = \frac{\text{B}(x; \alpha, \beta)}{\text{B}(\alpha, \beta)}\]
Parameters:

value (ArrayLike) – Value to evaluate.

icdf(q: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Inverse cumulative distribution function (Quantile function).

Parameters:

q (ArrayLike) – Probability value in \([0,1]\).

entropy() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Entropy of the Beta distribution.

\[H(X) = \ln \text{B}(\alpha, \beta) - (\alpha - 1)\psi(\alpha) - (\beta - 1)\psi(\beta) + (\alpha + \beta - 2)\psi(\alpha + \beta)\]

where \(\psi\) is the digamma function.

BetaProportion

class BetaProportion(mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, concentration: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, validate_args: bool | None = None)[source]

Bases: Beta

Beta distribution reparameterized in terms of a mean (mean) and a precision (concentration).

Given mean \(\mu\) and precision \(\phi\), the standard Beta parameters are derived as:

\[\alpha = \mu \phi, \quad \beta = (1 - \mu) \phi\]

The resulting PDF is:

\[f(x; \mu, \phi) = \frac{x^{\mu\phi - 1} (1 - x)^{(1 - \mu)\phi - 1}}{\text{B}(\mu\phi, (1 - \mu)\phi)}\]

Reference

Ferrari, Silvia, and Francisco Cribari-Neto. “Beta regression for modelling rates and proportions.” Journal of Applied Statistics 31.7 (2004): 799-815.

Parameters:
  • mean (ArrayLike) – Mean of the distribution, restricted to the open interval (0, 1).

  • concentration (ArrayLike) – Precision parameter (\(\phi\)), must be positive.

  • validate_args (bool, optional) – Whether to validate input constraints, defaults to None.

arg_constraints: dict[str, Any] = {'concentration': Positive(lower_bound=0.0), 'mean': OpenInterval(lower_bound=0.0, upper_bound=1.0)}
reparametrized_params: list[str] = ['mean', 'concentration']
support = UnitInterval(lower_bound=0.0, upper_bound=1.0)
pytree_data_fields: tuple[str, ...] = ('concentration',)

CAR

class CAR(loc: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, correlation: Array, conditional_precision: Array, adj_matrix: Array, *, is_sparse: bool = False, validate_args: bool | None = None)[source]

Bases: Distribution

The Conditional Autoregressive (CAR) distribution is a special case of the multivariate normal in which the precision matrix is structured according to the adjacency matrix of sites. The amount of autocorrelation between sites is controlled by correlation. The distribution is a popular prior for areal spatial data.

Parameters:
  • loc (float or ndarray) – mean of the multivariate normal

  • correlation (float) – autoregression parameter. For most cases, the value should lie between 0 (sites are independent, collapses to an iid multivariate normal) and 1 (perfect autocorrelation between sites), but the specification allows for negative correlations.

  • conditional_precision (float) – positive precision for the multivariate normal

  • adj_matrix (ndarray or scipy.sparse.csr_matrix) – symmetric adjacency matrix where 1 indicates adjacency between sites and 0 otherwise. jax.numpy.ndarray adj_matrix is supported but is not recommended over numpy.ndarray or scipy.sparse.spmatrix.

  • is_sparse (bool) – whether to use a sparse form of adj_matrix in calculations (must be True if adj_matrix is a scipy.sparse.spmatrix)

arg_constraints: dict[str, Any] = {'adj_matrix': Dependent(), 'conditional_precision': Positive(lower_bound=0.0), 'correlation': OpenInterval(lower_bound=-1, upper_bound=1), 'loc': RealVector(Real(), 1)}
support = RealVector(Real(), 1)
reparametrized_params: list[str] = ['loc', 'correlation', 'conditional_precision', 'adj_matrix']
pytree_aux_fields: tuple[str, ...] = ('is_sparse', 'adj_matrix')
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

precision_matrix()[source]
static infer_shapes(loc, correlation, conditional_precision, adj_matrix)[source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.

  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.

Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]

Cauchy

class Cauchy(loc: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 0.0, scale: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1.0, *, validate_args: bool | None = None)[source]

Bases: Distribution

Cauchy distribution parameterized by location (loc) and scale (scale).

The probability density function (PDF) is defined as:

\[f(x; x_0, \gamma) = \frac{1}{\pi \gamma \left[1 + \left(\frac{x - x_0}{\gamma}\right)^2\right]}\]

where \(x \in \mathbb{R}\), \(x_0 \in \mathbb{R}\) is the location, and \(\gamma > 0\) is the scale. The Cauchy distribution has no finite mean or variance.

Parameters:
  • loc (ArrayLike) – Location parameter (\(x_0\)).

  • scale (ArrayLike) – Scale parameter (\(\gamma\)).

  • validate_args (bool, optional) – Whether to validate input constraints, defaults to None.

arg_constraints: dict[str, Any] = {'loc': Real(), 'scale': Positive(lower_bound=0.0)}
support = Real()
reparametrized_params: list[str] = ['loc', 'scale']
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Generates samples using the inverse CDF method via cauchy().

Parameters:
  • key (jax.Array) – JAX PRNGKey for reproducibility.

  • sample_shape (tuple[int, ...]) – The shape of the samples to be generated.

Returns:

Samples from the Cauchy distribution of shape sample_shape + batch_shape.

Return type:

ArrayLike

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Calculates the log of the probability density function.

\[\log f(x; x_0, \gamma) = -\log(\pi) - \log(\gamma) - \log\!\left[1 + \left(\frac{x - x_0}{\gamma}\right)^2\right]\]
Parameters:

value (ArrayLike) – Values at which to evaluate the log density.

Returns:

Log probability density.

Return type:

ArrayLike

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

The mean of the Cauchy distribution is undefined.

Returns NaN for all batch elements.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

The variance of the Cauchy distribution is undefined.

Returns NaN for all batch elements.

cdf(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Cumulative distribution function.

\[F(x; x_0, \gamma) = \frac{1}{\pi}\arctan\!\left(\frac{x - x_0}{\gamma}\right) + \frac{1}{2}\]
Parameters:

value (ArrayLike) – Value to evaluate.

icdf(q: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Inverse cumulative distribution function (Quantile function).

\[F^{-1}(q; x_0, \gamma) = x_0 + \gamma \tan\!\left[\pi\!\left(q - \frac{1}{2}\right)\right]\]
Parameters:

q (ArrayLike) – Probability value in \([0,1]\).

entropy() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Entropy of the Cauchy distribution.

\[H(X) = \log(4\pi\gamma)\]

Chi2

class Chi2(df: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, validate_args: bool | None = None)[source]

Bases: Gamma

A chi-square continuous random variable, parameterized by the degrees of freedom \(k\).

The Probability Density Function (PDF) of the chi-square distribution with \(k\) degrees of freedom is defined as:

\[f(x; k) = \frac{x^{k/2 - 1} e^{-x/2}}{2^{k/2}\,\Gamma(k/2)}, \quad x > 0\]

Where, \(k\) represents the degrees of freedom (df), \(\Gamma(\cdot)\) is the gamma function, and \(x\) is the observed value. The support domain is \(x \in (0, \infty)\).

The chi-square distribution is a special case of the Gamma distribution:

\[\chi^2(k) \equiv \mathrm{Gamma}(k/2,\; 1/2)\]

so this class inherits sampling, log-probability, mean, variance, and entropy implementations from Gamma.

Parameters:
  • df – Degrees of freedom parameter \(k > 0\) (df).

  • validate_args – If True, enforce domain constraints during initialization.

arg_constraints: dict[str, Any] = {'df': Positive(lower_bound=0.0)}
reparametrized_params: list[str] = ['df']

CirculantNormal

class CirculantNormal(loc: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, covariance_row: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, covariance_rfft: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, *, validate_args: bool | None = None)[source]

Bases: TransformedDistribution

Multivariate normal distribution with covariance matrix \(\mathbf{C}\) that is positive-definite and circulant [1], i.e., has periodic boundary conditions. The density of a sample \(\mathbf{x}\in\mathbb{R}^n\) is the standard multivariate normal density

\[p\left(\mathbf{x}\mid\boldsymbol{\mu},\mathbf{C}\right) = \frac{\left(\mathrm{det}\,\mathbf{C}\right)^{-1/2}}{\left(2\pi\right)^{n / 2}} \exp\left(-\frac{1}{2}\left(\mathbf{x}-\boldsymbol{\mu}\right)^\intercal \mathbf{C}^{-1}\left(\mathbf{x}-\boldsymbol{\mu}\right)\right),\]

where \(\mathrm{det}\) denotes the determinant and \(^\intercal\) the transpose. Circulant matrices can be diagnolized efficiently using the discrete Fourier transform [1], allowing the log likelihood to be evaluated in \(n \log n\) time for \(n\) observations [2].

Parameters:
  • loc – Mean of the distribution \(\boldsymbol{\mu}\).

  • covariance_row – First row of the circulant covariance matrix \(\boldsymbol{C}\). Because of periodic boundary conditions, the covariance matrix is fully determined by its first row (see jax.scipy.linalg.toeplitz() for further details).

  • covariance_rfft – Real part of the real fast Fourier transform of covariance_row, the first row of the circulant covariance matrix \(\boldsymbol{C}\).

References:

  1. Wikipedia. (n.d.). Circulant matrix. Retrieved March 6, 2025, from https://en.wikipedia.org/wiki/Circulant_matrix

  2. Wood, A. T. A., & Chan, G. (1994). Simulation of Stationary Gaussian Processes in \(\left[0, 1\right]^d\). Journal of Computational and Graphical Statistics, 3(4), 409–432. https://doi.org/10.1080/10618600.1994.10474655

arg_constraints: dict[str, Any] = {'covariance_rfft': IndependentConstraint(Positive(lower_bound=0.0), 1), 'covariance_row': PositiveDefiniteCirculantVector(), 'loc': RealVector(Real(), 1)}
support = RealVector(Real(), 1)
property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

covariance_row() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]
covariance_matrix() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]
variance() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Variance of the distribution.

static infer_shapes(loc: tuple[int, ...] = (), covariance_row: tuple[int, ...] | None = None, covariance_rfft: tuple[int, ...] | None = None)[source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.

  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.

Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

entropy() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns the entropy of the distribution.

Dagum

class Dagum(concentration: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, sharpness: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, scale: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, validate_args: bool | None = None)[source]

Bases: Distribution

The Dagum distribution (or Mielke Beta-Kappa distribution) is a continuous probability distribution defined over positive real numbers. If \(p\), \(a\) and \(b\) are concentration, sharpness and scale values respectively, then Dagum distribution is defined as,

\[f(x\mid p,a,b):=\frac{ap}{x} \left(\frac{(x/b)^{ap}}{\left((x/b)^{a}+1\right)^{p+1}}\right)\]

References:

  1. Wikipedia. (n.d.). Dagum distribution. Retrieved March 31, 2025, from https://en.wikipedia.org/wiki/Dagum_distribution

arg_constraints: dict[str, Any] = {'concentration': Positive(lower_bound=0.0), 'scale': Positive(lower_bound=0.0), 'sharpness': Positive(lower_bound=0.0)}
support = Positive(lower_bound=0.0)
reparametrized_params: list[str] = ['concentration', 'sharpness', 'scale']
log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

cdf(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The cumulative distribution function of this distribution.

Parameters:

value – samples from this distribution.

Returns:

output of the cumulative distribution function evaluated at value.

icdf(q: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The inverse cumulative distribution function of this distribution.

Parameters:

q – quantile values, should belong to [0, 1].

Returns:

the samples whose cdf values equals to q.

sample(key: Array, sample_shape: tuple[int, ...] = ()) Array[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.

Dirichlet

class Dirichlet(concentration: Array, *, validate_args: bool | None = None)[source]

Bases: Distribution

arg_constraints: dict[str, Any] = {'concentration': IndependentConstraint(Positive(lower_bound=0.0), 1)}
reparametrized_params: list[str] = ['concentration']
support = Simplex()
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.

static infer_shapes(concentration)[source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.

  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.

Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

entropy() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns the entropy of the distribution.

EulerMaruyama

class EulerMaruyama(t: Array, sde_fn: Callable[[Array, Array], tuple[Array, Array]], init_dist: Distribution, *, validate_args: bool | None = None)[source]

Bases: Distribution

Euler–Maruyama method is a method for the approximate numerical solution of a stochastic differential equation (SDE)

Parameters:
  • t (ndarray) – discretized time

  • sde_fn (callable) – function returning the drift and diffusion coefficients of SDE

  • init_dist (Distribution) – Distribution for initial values.

References

[1] https://en.wikipedia.org/wiki/Euler-Maruyama_method

arg_constraints: dict[str, Any] = {'t': OrderedVector()}
pytree_data_fields: tuple[str, ...] = ('t', 'init_dist')
pytree_aux_fields: tuple[str, ...] = ('sde_fn',)
property support: Constraint

The support of this distribution. Subclasses can override this as a class attribute or as a property.

sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

Exponential

class Exponential(rate: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1.0, *, validate_args: bool | None = None)[source]

Bases: Distribution

Exponential distribution parameterized by rate (rate).

The probability density function (PDF) is defined as:

\[f(x; \lambda) = \lambda e^{-\lambda x}\]

where \(x \geq 0\) and \(\lambda > 0\) is the rate parameter.

Parameters:
  • rate (ArrayLike) – Rate parameter (\(\lambda\)), the inverse of the mean.

  • validate_args (bool, optional) – Whether to validate input constraints, defaults to None.

reparametrized_params: list[str] = ['rate']
arg_constraints: dict[str, Any] = {'rate': Positive(lower_bound=0.0)}
support = Positive(lower_bound=0.0)
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Generates samples by scaling standard exponential draws by the inverse rate: \(X = E / \lambda\), where \(E \sim \mathrm{Exp}(1)\).

Parameters:
  • key (jax.Array) – JAX PRNGKey for reproducibility.

  • sample_shape (tuple[int, ...]) – The shape of the samples to be generated.

Returns:

Samples from the Exponential distribution of shape sample_shape + batch_shape.

Return type:

ArrayLike

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Calculates the log of the probability density function.

\[\log f(x; \lambda) = \log \lambda - \lambda x\]
Parameters:

value (ArrayLike) – Values at which to evaluate the log density.

Returns:

Log probability density.

Return type:

ArrayLike

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Calculates the analytical mean.

\[E[X] = \frac{1}{\lambda}\]
property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Calculates the analytical variance.

\[\mathrm{Var}(X) = \frac{1}{\lambda^2}\]
cdf(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Cumulative distribution function.

\[F(x; \lambda) = 1 - e^{-\lambda x}\]
Parameters:

value (ArrayLike) – Value to evaluate.

icdf(q: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Inverse cumulative distribution function (Quantile function).

\[F^{-1}(q; \lambda) = -\frac{\ln(1 - q)}{\lambda}\]
Parameters:

q (ArrayLike) – Probability value in \([0,1]\).

entropy() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Entropy of the Exponential distribution.

\[H(X) = 1 - \ln \lambda\]

Gamma

class Gamma(concentration: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, rate: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1.0, *, validate_args: bool | None = None)[source]

Bases: Distribution

Implementation of the Gamma distribution, \(\mathrm{Gamma}(\alpha, \lambda)\), where, \(\alpha\) is the concentration and \(\lambda\) is the rate.

Parameters:
  • concentration (ArrayLike) – concentration parameter \(\alpha\) (also known as shape parameter).

  • rate (ArrayLike) – rate parameter \(\lambda\) (inverse scale parameter).

arg_constraints: dict[str, Any] = {'concentration': Positive(lower_bound=0.0), 'rate': Positive(lower_bound=0.0)}
support = Positive(lower_bound=0.0)
reparametrized_params: list[str] = ['concentration', 'rate']
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Method to generate samples \(X \sim \mathrm{Gamma}(\alpha, \lambda)\). It uses gamma() under the hood to generate samples.

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

If \(X \sim \mathrm{Gamma}(\alpha, \lambda)\), then

\[f_X(x\mid \alpha, \lambda) = \frac{\lambda^{\alpha} x^{\alpha - 1} e^{-\lambda x}}{\Gamma(\alpha)}, \quad x > 0\]

It uses gammaln() to compute the logarithm of the gamma function.

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

If \(X \sim \mathrm{Gamma}(\alpha, \lambda)\), then

\[\mathbb{E}[X] = \frac{\alpha}{\lambda}\]
property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

If \(X \sim \mathrm{Gamma}(\alpha, \lambda)\), then

\[\mathrm{Var}(X) = \frac{\alpha}{\lambda^2}\]
cdf(x)[source]

If \(X \sim \mathrm{Gamma}(\alpha, \lambda)\), then

\[F_X(x \mid \alpha, \lambda) = \frac{1}{\Gamma(\alpha)} \gamma\left(\alpha, \lambda x\right)\]

where, \(\gamma(\cdot,\cdot)\) is the lower incomplete gamma function. This method uses regularized incomplete gamma function, which is implemented as gammainc().

icdf(q: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

If \(X \sim \mathrm{Gamma}(\alpha, \lambda)\), then

\[F^{-1}_X(q \mid \alpha, \lambda) = \frac{1}{\lambda} \gamma^{-1}\left(\alpha, q \Gamma(\alpha)\right)\]

where, \(\gamma^{-1}(\cdot,\cdot)\) is the inverse of the lower incomplete gamma function. This method uses regularized incomplete gamma inverse function, which is implemented as gammaincinv().

entropy() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

If \(X \sim \mathrm{Gamma}(\alpha, \lambda)\), then

\[H[X] = \alpha - \ln(\lambda) + \ln\Gamma(\alpha) + (1 - \alpha) \psi(\alpha)\]

where, \(\psi(\cdot)\) is the digamma function. This methods uses which is implemented as digamma().

GaussianCopula

class GaussianCopula(marginal_dist: Distribution, correlation_matrix: Array | None = None, correlation_cholesky: Array | None = None, *, validate_args: bool | None = None)[source]

Bases: Distribution

A distribution that links the batch_shape[:-1] of marginal distribution marginal_dist with a multivariate Gaussian copula modelling the correlation between the axes.

Parameters:
  • marginal_dist (Distribution) – Distribution whose last batch axis is to be coupled.

  • correlation_matrix (array_like) – Correlation matrix of coupling multivariate normal distribution.

  • correlation_cholesky (array_like) – Correlation Cholesky factor of coupling multivariate normal distribution.

arg_constraints: dict[str, Any] = {'correlation_cholesky': CorrCholesky(), 'correlation_matrix': CorrMatrix()}
reparametrized_params: list[str] = ['correlation_matrix', 'correlation_cholesky']
pytree_data_fields: tuple[str, ...] = ('marginal_dist', 'base_dist')
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.

property support: Constraint

The support of this distribution. Subclasses can override this as a class attribute or as a property.

correlation_matrix() Array[source]
correlation_cholesky() Array[source]

GaussianCopulaBeta

class GaussianCopulaBeta(concentration1: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, concentration0: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, correlation_matrix: Array | None = None, correlation_cholesky: Array | None = None, *, validate_args: bool = False)[source]

Bases: GaussianCopula

arg_constraints: dict[str, Any] = {'concentration0': Positive(lower_bound=0.0), 'concentration1': Positive(lower_bound=0.0), 'correlation_cholesky': CorrCholesky(), 'correlation_matrix': CorrMatrix()}
support = IndependentConstraint(UnitInterval(lower_bound=0.0, upper_bound=1.0), 1)
pytree_data_fields: tuple[str, ...] = ('concentration1', 'concentration0')

GaussianRandomWalk

class GaussianRandomWalk(scale: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1.0, num_steps: int = 1, *, validate_args: bool | None = None)[source]

Bases: Distribution

arg_constraints: dict[str, Any] = {'scale': Positive(lower_bound=0.0)}
support = RealVector(Real(), 1)
reparametrized_params: list[str] = ['scale']
pytree_aux_fields: tuple[str, ...] = ('num_steps',)
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.

GaussianStateSpace

class GaussianStateSpace(num_steps: int, transition_matrix: Array, covariance_matrix: Array | None = None, precision_matrix: Array | None = None, scale_tril: Array | None = None, initial_value: Array | None = None, *, validate_args: bool | None = None)[source]

Bases: TransformedDistribution

Gaussian state space model.

\[\begin{split}\mathbf{z}_{t} &= \mathbf{A} \mathbf{z}_{t - 1} + \boldsymbol{\epsilon}_t\\ &= \mathbf{A}^t \mathbf{z}_0 + \sum_{k=1}^{t} \mathbf{A}^{t-k} \boldsymbol{\epsilon}_k,\end{split}\]

where \(\mathbf{z}_t\) is the state vector at step \(t\), \(\mathbf{A}\) is the transition matrix, \(\mathbf{z}_0\) is the initial value, and \(\boldsymbol\epsilon\) is the innovation noise.

Parameters:
  • num_steps – Number of steps.

  • transition_matrix – State space transition matrix \(\mathbf{A}\).

  • covariance_matrix – Covariance of the innovation noise \(\boldsymbol\epsilon\).

  • precision_matrix – Precision matrix of the innovation noise \(\boldsymbol\epsilon\).

  • scale_tril – Scale matrix of the innovation noise \(\boldsymbol\epsilon\).

  • initial_value – Initial state vector \(\mathbf{z}_0\). If None, defaults to zero.

arg_constraints: dict[str, Any] = {'covariance_matrix': PositiveDefinite(), 'initial_value': RealVector(Real(), 1), 'precision_matrix': PositiveDefinite(), 'scale_tril': LowerCholesky(), 'transition_matrix': RealMatrix(Real(), 2)}
support = RealMatrix(Real(), 2)
pytree_data_fields: tuple[str, ...] = ('transition_matrix', '_initial_value', 'scale_tril')
pytree_aux_fields: tuple[str, ...] = ('num_steps',)
property initial_value: Array
property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.

covariance_matrix()[source]
precision_matrix()[source]

Gompertz

class Gompertz(concentration: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, rate: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1.0, *, validate_args: bool | None = None)[source]

Bases: Distribution

Gompertz Distribution.

The Gompertz distribution is a distribution with support on the positive real line that is closely related to the Gumbel distribution. This implementation follows the notation used in the Wikipedia entry for the Gompertz distribution. See https://en.wikipedia.org/wiki/Gompertz_distribution.

However, we call the parameter “eta” a concentration parameter and the parameter “b” a rate parameter (as opposed to scale parameter as in wikipedia description.)

The CDF, in terms of concentration (con) and rate, is

\[F(x) = 1 - \exp \left\{ - \text{con} * \left [ \exp\{x * rate \} - 1 \right ] \right\}\]
arg_constraints: dict[str, Any] = {'concentration': Positive(lower_bound=0.0), 'rate': Positive(lower_bound=0.0)}
support = Positive(lower_bound=0.0)
reparametrized_params: list[str] = ['concentration', 'rate']
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

cdf(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The cumulative distribution function of this distribution.

Parameters:

value – samples from this distribution.

Returns:

output of the cumulative distribution function evaluated at value.

icdf(q: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The inverse cumulative distribution function of this distribution.

Parameters:

q – quantile values, should belong to [0, 1].

Returns:

the samples whose cdf values equals to q.

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

Gumbel

class Gumbel(loc: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 0.0, scale: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1.0, *, validate_args: bool | None = None)[source]

Bases: Distribution

The Gumbel (maximum) distribution, a continuous real-valued distribution parameterized by location \(\mu\) and scale \(\beta > 0\). It is the limiting distribution of the maximum of a large number of i.i.d. samples from an exponential-tailed distribution.

The Probability Density Function (PDF) is:

\[f(x \mid \mu, \beta) = \frac{1}{\beta} \exp\!\left( -\frac{x - \mu}{\beta} - \exp\!\left(-\frac{x - \mu}{\beta}\right) \right), \quad x \in \mathbb{R}\]

where \(\mu \in \mathbb{R}\) is the location (loc) and \(\beta > 0\) is the scale (scale).

Parameters:
  • loc – Location parameter \(\mu \in \mathbb{R}\). Defaults to 0.0.

  • scale – Scale parameter \(\beta > 0\). Defaults to 1.0.

  • validate_args – If True, enforce domain constraints during initialization.

arg_constraints: dict[str, Any] = {'loc': Real(), 'scale': Positive(lower_bound=0.0)}
support = Real()
reparametrized_params: list[str] = ['loc', 'scale']
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Draw samples from the Gumbel distribution via the location-scale transform \(X = \mu + \beta Z\), where \(Z \sim \mathrm{Gumbel}(0, 1)\) is drawn from gumbel().

Parameters:
  • key – A JAX PRNG key.

  • sample_shape – Sample dimensions to prepend to the batch shape.

Returns:

Real-valued samples from the Gumbel distribution.

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluate the log probability density function at value.

Letting \(z = (x - \mu)/\beta\),

\[\ln f(x \mid \mu, \beta) = -z - e^{-z} - \ln \beta\]
Parameters:

value – Real-valued point \(x\) at which to evaluate the log PDF.

Returns:

Log probability density evaluated under the Gumbel distribution.

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the Gumbel distribution:

\[\mathbb{E}[X] = \mu + \beta \gamma\]

where \(\gamma \approx 0.5772\ldots\) is the Euler-Mascheroni constant, available at, euler_gamma.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the Gumbel distribution:

\[\mathrm{Var}(X) = \frac{\pi^2}{6} \beta^2\]
cdf(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Cumulative Distribution Function (CDF) of the Gumbel distribution:

\[F(x \mid \mu, \beta) = \exp\!\left(-\exp\!\left(-\frac{x - \mu}{\beta}\right)\right)\]
Parameters:

value – Real-valued point \(x\) at which to evaluate the CDF.

Returns:

CDF values in \([0, 1]\).

icdf(q: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Inverse CDF (quantile function) of the Gumbel distribution:

\[F^{-1}(q \mid \mu, \beta) = \mu - \beta \ln(-\ln q), \quad q \in (0, 1)\]
Parameters:

q – Quantile values in \((0, 1)\).

Returns:

Real-valued quantiles of the Gumbel distribution at q.

HalfCauchy

class HalfCauchy(scale: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1.0, *, validate_args: bool | None = None)[source]

Bases: Distribution

reparametrized_params: list[str] = ['scale']
support = Positive(lower_bound=0.0)
arg_constraints: dict[str, Any] = {'scale': Positive(lower_bound=0.0)}
pytree_data_fields: tuple[str, ...] = ('_cauchy', 'scale')
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

cdf(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The cumulative distribution function of this distribution.

Parameters:

value – samples from this distribution.

Returns:

output of the cumulative distribution function evaluated at value.

icdf(q: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The inverse cumulative distribution function of this distribution.

Parameters:

q – quantile values, should belong to [0, 1].

Returns:

the samples whose cdf values equals to q.

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.

HalfNormal

class HalfNormal(scale: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1.0, *, validate_args: bool | None = None)[source]

Bases: Distribution

reparametrized_params: list[str] = ['scale']
support = Positive(lower_bound=0.0)
arg_constraints: dict[str, Any] = {'scale': Positive(lower_bound=0.0)}
pytree_data_fields: tuple[str, ...] = ('_normal', 'scale')
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

cdf(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The cumulative distribution function of this distribution.

Parameters:

value – samples from this distribution.

Returns:

output of the cumulative distribution function evaluated at value.

icdf(q: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The inverse cumulative distribution function of this distribution.

Parameters:

q – quantile values, should belong to [0, 1].

Returns:

the samples whose cdf values equals to q.

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.

HurdleGamma

class HurdleGamma(gate: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, concentration: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, rate: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1.0, *, validate_args: bool | None = None)[source]

Bases: HurdleProbs

A hurdle Gamma distribution: a two-part model in which a structural zero occurs with probability \(g\) and, conditional on a positive outcome, the magnitude is drawn from \(\mathrm{Gamma}(\alpha, \lambda)\). The hurdle and the magnitude (given a positive value) are conditionally independent; see HurdleProbs for the full mechanism and assumptions.

Because \(P(X = 0) = 0\) under a Gamma density, no truncation factor is needed and the PDF is

\[P(X = 0) = g, \qquad f(x) = (1 - g) \, \frac{\lambda^{\alpha} x^{\alpha - 1} e^{-\lambda x}}{\Gamma(\alpha)} \;\text{for } x > 0.\]
Parameters:
  • gate (ArrayLike) – probability of a structural zero, \(g \in [0, 1]\).

  • concentration (ArrayLike) – shape parameter \(\alpha > 0\) of the Gamma.

  • rate (ArrayLike) – rate parameter \(\lambda > 0\) of the Gamma.

References:

  1. Cragg, J. G. (1971). Some Statistical Models for Limited Dependent Variables with Application to the Demand for Durable Goods. Econometrica, 39(5), 829-844.

  2. Mullahy, J. (1986). Specification and testing of some modified count data models. Journal of Econometrics, 33(3), 341-365.

arg_constraints: dict[str, Any] = {'concentration': Positive(lower_bound=0.0), 'gate': UnitInterval(lower_bound=0.0, upper_bound=1.0), 'rate': Positive(lower_bound=0.0)}
support = Nonnegative(lower_bound=0.0)
pytree_data_fields: tuple[str, ...] = ('concentration', 'rate')

HurdleLogNormal

class HurdleLogNormal(gate: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, loc: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 0.0, scale: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1.0, *, validate_args: bool | None = None)[source]

Bases: HurdleProbs

A hurdle Log-Normal distribution: a two-part model in which a structural zero occurs with probability \(g\) and, conditional on a positive outcome, the magnitude is drawn from \(\mathrm{LogNormal}(\mu, \sigma)\). The hurdle and the magnitude (given a positive value) are conditionally independent; see HurdleProbs for the full mechanism and assumptions.

Because \(P(X = 0) = 0\) under a Log-Normal density, no truncation factor is needed and the PDF is

\[P(X = 0) = g, \qquad f(x) = (1 - g) \, \frac{1}{x \sigma \sqrt{2 \pi}} \exp\!\left( -\frac{(\ln x - \mu)^2}{2 \sigma^2} \right) \;\text{for } x > 0.\]
Parameters:
  • gate (ArrayLike) – probability of a structural zero, \(g \in [0, 1]\).

  • loc (ArrayLike) – location parameter \(\mu \in \mathbb{R}\) (mean of \(\ln X\) given \(X > 0\)).

  • scale (ArrayLike) – scale parameter \(\sigma > 0\) (std-dev of \(\ln X\) given \(X > 0\)).

References:

  1. Cragg, J. G. (1971). Some Statistical Models for Limited Dependent Variables with Application to the Demand for Durable Goods. Econometrica, 39(5), 829-844.

  2. Mullahy, J. (1986). Specification and testing of some modified count data models. Journal of Econometrics, 33(3), 341-365.

arg_constraints: dict[str, Any] = {'gate': UnitInterval(lower_bound=0.0, upper_bound=1.0), 'loc': Real(), 'scale': Positive(lower_bound=0.0)}
support = Nonnegative(lower_bound=0.0)
pytree_data_fields: tuple[str, ...] = ('loc', 'scale')

InverseGamma

class InverseGamma(concentration: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, rate: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1.0, *, validate_args: bool | None = None)[source]

Bases: TransformedDistribution

Note

We keep the same notation rate as in Pyro but it plays the role of scale parameter of InverseGamma in literatures (e.g. wikipedia: https://en.wikipedia.org/wiki/Inverse-gamma_distribution)

arg_constraints: dict[str, Any] = {'concentration': Positive(lower_bound=0.0), 'rate': Positive(lower_bound=0.0)}
reparametrized_params: list[str] = ['concentration', 'rate']
support = Positive(lower_bound=0.0)
property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.

entropy() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns the entropy of the distribution.

InverseWishart

class InverseWishart(concentration: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, scale_matrix: Array | None = None, rate_matrix: Array | None = None, scale_tril: Array | None = None, *, validate_args: bool | None = None)[source]

Bases: TransformedDistribution

Inverse Wishart distribution for covariance matrices.

The Inverse Wishart distribution is the conjugate prior for the covariance matrix of a multivariate normal distribution. If \(\mathbf{X} \sim W^{-1}(\mathbf{\Psi}, \nu)\), then \(\mathbf{X}^{-1} \sim W(\mathbf{\Psi}^{-1}, \nu)\) (Wishart distribution).

\[p(\mathbf{X} \mid \mathbf{\Psi}, \nu) = \frac{|\mathbf{\Psi}|^{\nu/2}}{2^{\nu p/2} \Gamma_p(\nu/2)} |\mathbf{X}|^{-(\nu + p + 1)/2} \exp\left( -\frac{1}{2} \mathrm{tr}(\mathbf{\Psi} \mathbf{X}^{-1}) \right)\]

where \(p\) is the dimension of the matrix, \(\nu > p - 1\) is the degrees of freedom, and \(\mathbf{\Psi}\) is the positive definite scale matrix.

Parameters:
  • concentration – Degrees of freedom parameter (often denoted \(\nu\)). Must be greater than p - 1 where p is the dimension of the scale matrix.

  • scale_matrix – Positive definite scale matrix \(\mathbf{\Psi}\), analogous to the inverse rate of a Gamma distribution.

  • rate_matrix – Inverse of the scale matrix, analogous to the rate of a Gamma distribution.

  • scale_tril – Cholesky decomposition of the scale matrix.

Properties

  • Mean: \(\frac{\mathbf{\Psi}}{\nu - p - 1}\) for \(\nu > p + 1\)

  • Mode: \(\frac{\mathbf{\Psi}}{\nu + p + 1}\)

References

[1] https://en.wikipedia.org/wiki/Inverse-Wishart_distribution

arg_constraints: dict[str, Any] = {'concentration': Dependent(), 'rate_matrix': PositiveDefinite(), 'scale_matrix': PositiveDefinite(), 'scale_tril': LowerCholesky()}
support = PositiveDefinite()
reparametrized_params: list[str] = ['scale_matrix', 'rate_matrix', 'scale_tril']
concentration()[source]
scale_matrix()[source]
rate_matrix()[source]
scale_tril()[source]
mean() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Mean of the distribution.

mode() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]
variance() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Variance of the distribution.

static infer_shapes(concentration=(), scale_matrix=None, rate_matrix=None, scale_tril=None)[source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.

  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.

Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

InverseWishartCholesky

class InverseWishartCholesky(concentration: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, scale_matrix: Array | None = None, rate_matrix: Array | None = None, scale_tril: Array | None = None, *, validate_args: bool | None = None)[source]

Bases: Distribution

Cholesky factor of an Inverse Wishart distribution for covariance matrices.

This distribution samples the Cholesky factor \(\mathbf{L}\) such that \(\mathbf{X} = \mathbf{L} \mathbf{L}^T \sim W^{-1}(\mathbf{\Psi}, \nu)\).

Parameters:
  • concentration – Degrees of freedom parameter (often denoted \(\nu\)). Must be greater than p - 1 where p is the dimension of the scale matrix.

  • scale_matrix – Positive definite scale matrix \(\mathbf{\Psi}\), analogous to the inverse rate of a Gamma distribution.

  • rate_matrix – Inverse of the scale matrix, analogous to the rate of a Gamma distribution.

  • scale_tril – Cholesky decomposition of the scale matrix.

References

[1] https://en.wikipedia.org/wiki/Inverse-Wishart_distribution

arg_constraints: dict[str, Any] = {'concentration': Dependent(), 'rate_matrix': PositiveDefinite(), 'scale_matrix': PositiveDefinite(), 'scale_tril': LowerCholesky()}
support = LowerCholesky()
reparametrized_params: list[str] = ['scale_matrix', 'rate_matrix', 'scale_tril']
log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

scale_matrix()[source]
rate_matrix()[source]
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

mean() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Mean of the distribution.

variance() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Variance of the distribution.

static infer_shapes(concentration: tuple[int, ...] = (), scale_matrix: tuple[int, ...] | None = None, rate_matrix: tuple[int, ...] | None = None, scale_tril: tuple[int, ...] | None = None)[source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.

  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.

Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

Kumaraswamy

class Kumaraswamy(concentration1: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, concentration0: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, validate_args: bool | None = None)[source]

Bases: Distribution

arg_constraints: dict[str, Any] = {'concentration0': Positive(lower_bound=0.0), 'concentration1': Positive(lower_bound=0.0)}
reparametrized_params: list[str] = ['concentration1', 'concentration0']
support = UnitInterval(lower_bound=0.0, upper_bound=1.0)
KL_KUMARASWAMY_BETA_TAYLOR_ORDER = 10
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.

Laplace

class Laplace(loc: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 0.0, scale: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1.0, *, validate_args: bool | None = None)[source]

Bases: Distribution

The Laplace (double-exponential) distribution, a continuous real-valued distribution parameterized by location \(\mu\) and scale \(b > 0\). It is the distribution of the difference of two i.i.d. exponential variates and has heavier tails than the Normal distribution.

The Probability Density Function (PDF) is:

\[f(x \mid \mu, b) = \frac{1}{2 b} \exp\!\left(-\frac{|x - \mu|}{b}\right), \quad x \in \mathbb{R}\]

where \(\mu \in \mathbb{R}\) is the location (loc) and \(b > 0\) is the scale (scale).

Parameters:
  • loc – Location parameter \(\mu \in \mathbb{R}\). Defaults to 0.0.

  • scale – Scale parameter \(b > 0\). Defaults to 1.0.

  • validate_args – If True, enforce domain constraints during initialization.

arg_constraints: dict[str, Any] = {'loc': Real(), 'scale': Positive(lower_bound=0.0)}
support = Real()
reparametrized_params: list[str] = ['loc', 'scale']
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Draw samples via the location-scale transform \(X = \mu + b Z\), where \(Z \sim \mathrm{Laplace}(0, 1)\) is drawn from laplace().

Parameters:
  • key – A JAX PRNG key.

  • sample_shape – Sample dimensions to prepend to the batch shape.

Returns:

Real-valued samples from the Laplace distribution.

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluate the log probability density function at value:

\[\ln f(x \mid \mu, b) = -\frac{|x - \mu|}{b} - \ln(2 b)\]
Parameters:

value – Real-valued point \(x\) at which to evaluate the log PDF.

Returns:

Log probability density evaluated under the Laplace distribution.

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the Laplace distribution:

\[\mathbb{E}[X] = \mu\]
property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the Laplace distribution:

\[\mathrm{Var}(X) = 2 b^2\]
cdf(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Cumulative Distribution Function (CDF) of the Laplace distribution. Letting \(z = (x - \mu)/b\),

\[F(x \mid \mu, b) = \frac{1}{2} - \frac{1}{2}\, \operatorname{sgn}(z)\,\left(e^{-|z|} - 1\right)\]

The implementation uses expm1() for numerical stability near \(z = 0\).

Parameters:

value – Real-valued point \(x\) at which to evaluate the CDF.

Returns:

CDF values in \([0, 1]\).

icdf(q: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Inverse CDF (quantile function) of the Laplace distribution:

\[F^{-1}(q \mid \mu, b) = \mu - b\,\mathrm{sgn}\left(q - \frac{1}{2}\right)\, \ln\!\left(1 - 2 \left| q - \frac{1}{2} \right| \right), \quad q \in [0, 1]\]
Parameters:

q – Quantile values in \([0, 1]\).

Returns:

Real-valued quantiles of the Laplace distribution at q.

entropy() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Differential entropy of the Laplace distribution:

\[H(X) = \ln(2 b) + 1\]

Levy

class Levy(loc: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, scale: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, validate_args: bool | None = None)[source]

Bases: Distribution

Lévy distribution is a special case of Lévy alpha-stable distribution. Its probability density function is given by,

\[f(x\mid \mu, c) = \sqrt{\frac{c}{2\pi(x-\mu)^{3}}} \exp\left(-\frac{c}{2(x-\mu)}\right), \qquad x > \mu\]

where \(\mu\) is the location parameter and \(c\) is the scale parameter.

Parameters:
  • loc – Location parameter.

  • scale – Scale parameter.

arg_constraints: dict[str, Any] = {'loc': Real(), 'scale': Positive(lower_bound=0.0)}
property support: Constraint

The support of this distribution. Subclasses can override this as a class attribute or as a property.

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Compute the log probability density function of the Lévy distribution.

\[\log f(x\mid \mu, c) = \frac{1}{2}\log\left(\frac{c}{2\pi}\right) - \frac{c}{2(x-\mu)} - \frac{3}{2}\log(x-\mu), \qquad x > \mu\]
Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

numpy.ndarray

sample(key: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

icdf(q: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The inverse cumulative distribution function of Lévy distribution is given by,

\[F^{-1}(q\mid \mu, c) = \mu + c\left(\Phi^{-1}(1-q/2)\right)^{-2}\]

where \(\Phi^{-1}\) is the inverse of the standard normal cumulative distribution function.

Parameters:

q – quantile values, should belong to [0, 1].

Returns:

the samples whose cdf values equals to q.

cdf(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The cumulative distribution function of Lévy distribution is given by,

\[F(x\mid \mu, c) = 2 - 2\Phi\left(\sqrt{\frac{c}{x-\mu}}\right)\]

where \(\Phi\) is the standard normal cumulative distribution function.

Parameters:

value – samples from Lévy distribution.

Returns:

output of the cumulative distribution function evaluated at value.

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.

entropy() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

If \(X \sim \text{Levy}(\mu, c)\), then the entropy of \(X\) is given by,

\[H(X) = \frac{1}{2}+\frac{3}{2}\gamma+\frac{1}{2}\ln{\left(16\pi c^2\right)}\]

LKJ

class LKJ(dimension: int, concentration: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1.0, sample_method: Literal['onion', 'cvine'] = 'onion', *, validate_args: bool | None = None)[source]

Bases: TransformedDistribution

LKJ distribution for correlation matrices. The distribution is controlled by concentration parameter \(\eta\) to make the probability of the correlation matrix \(M\) proportional to \(\det(M)^{\eta - 1}\). Because of that, when concentration == 1, we have a uniform distribution over correlation matrices.

When concentration > 1, the distribution favors samples with large large determinent. This is useful when we know a priori that the underlying variables are not correlated.

When concentration < 1, the distribution favors samples with small determinent. This is useful when we know a priori that some underlying variables are correlated.

Sample code for using LKJ in the context of multivariate normal sample:

def model(y):  # y has dimension N x d
    d = y.shape[1]
    N = y.shape[0]
    # Vector of variances for each of the d variables
    theta = numpyro.sample("theta", dist.HalfCauchy(jnp.ones(d)))

    concentration = jnp.ones(1)  # Implies a uniform distribution over correlation matrices
    corr_mat = numpyro.sample("corr_mat", dist.LKJ(d, concentration))
    sigma = jnp.sqrt(theta)
    # we can also use a faster formula `cov_mat = jnp.outer(sigma, sigma) * corr_mat`
    cov_mat = jnp.matmul(jnp.matmul(jnp.diag(sigma), corr_mat), jnp.diag(sigma))

    # Vector of expectations
    mu = jnp.zeros(d)

    with numpyro.plate("observations", N):
        obs = numpyro.sample("obs", dist.MultivariateNormal(mu, covariance_matrix=cov_mat), obs=y)
    return obs
Parameters:
  • dimension (int) – dimension of the matrices

  • concentration (ndarray) – concentration/shape parameter of the distribution (often referred to as eta)

  • sample_method (str) – Either “cvine” or “onion”. Both methods are proposed in [1] and offer the same distribution over correlation matrices. But they are different in how to generate samples. Defaults to “onion”.

References

[1] Generating random correlation matrices based on vines and extended onion method, Daniel Lewandowski, Dorota Kurowicka, Harry Joe

arg_constraints: dict[str, Any] = {'concentration': Positive(lower_bound=0.0)}
reparametrized_params: list[str] = ['concentration']
support = CorrMatrix()
pytree_aux_fields: tuple[str, ...] = ('dimension', 'sample_method')
property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

LKJCholesky

class LKJCholesky(dimension: int, concentration: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1.0, sample_method: Literal['onion', 'cvine'] = 'onion', *, validate_args: bool | None = None)[source]

Bases: Distribution

LKJ distribution for lower Cholesky factors of correlation matrices. The distribution is controlled by concentration parameter \(\eta\) to make the probability of the correlation matrix \(M\) generated from a Cholesky factor proportional to \(\det(M)^{\eta - 1}\). Because of that, when concentration == 1, we have a uniform distribution over Cholesky factors of correlation matrices.

When concentration > 1, the distribution favors samples with large diagonal entries (hence large determinent). This is useful when we know a priori that the underlying variables are not correlated.

When concentration < 1, the distribution favors samples with small diagonal entries (hence small determinent). This is useful when we know a priori that some underlying variables are correlated.

Sample code for using LKJCholesky in the context of multivariate normal sample:

def model(y):  # y has dimension N x d
    d = y.shape[1]
    N = y.shape[0]
    # Vector of variances for each of the d variables
    theta = numpyro.sample("theta", dist.HalfCauchy(jnp.ones(d)))
    # Lower cholesky factor of a correlation matrix
    concentration = jnp.ones(1)  # Implies a uniform distribution over correlation matrices
    L_omega = numpyro.sample("L_omega", dist.LKJCholesky(d, concentration))
    # Lower cholesky factor of the covariance matrix
    sigma = jnp.sqrt(theta)
    # we can also use a faster formula `L_Omega = sigma[..., None] * L_omega`
    L_Omega = jnp.matmul(jnp.diag(sigma), L_omega)

    # Vector of expectations
    mu = jnp.zeros(d)

    with numpyro.plate("observations", N):
        obs = numpyro.sample("obs", dist.MultivariateNormal(mu, scale_tril=L_Omega), obs=y)
    return obs
Parameters:
  • dimension (int) – dimension of the matrices

  • concentration (ndarray) – concentration/shape parameter of the distribution (often referred to as eta)

  • sample_method (str) – Either “cvine” or “onion”. Both methods are proposed in [1] and offer the same distribution over correlation matrices. But they are different in how to generate samples. Defaults to “onion”.

References

[1] Generating random correlation matrices based on vines and extended onion method, Daniel Lewandowski, Dorota Kurowicka, Harry Joe

arg_constraints: dict[str, Any] = {'concentration': Positive(lower_bound=0.0)}
reparametrized_params: list[str] = ['concentration']
support = CorrCholesky()
pytree_data_fields: tuple[str, ...] = ('_beta', 'concentration')
pytree_aux_fields: tuple[str, ...] = ('dimension', 'sample_method')
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

LogNormal

class LogNormal(loc: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 0.0, scale: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1.0, *, validate_args: bool | None = None)[source]

Bases: TransformedDistribution

arg_constraints: dict[str, Any] = {'loc': Real(), 'scale': Positive(lower_bound=0.0)}
support = Positive(lower_bound=0.0)
reparametrized_params: list[str] = ['loc', 'scale']
property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.

entropy() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns the entropy of the distribution.

LogUniform

class LogUniform(low: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, high: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, validate_args: bool | None = None)[source]

Bases: TransformedDistribution

arg_constraints: dict[str, Any] = {'high': Positive(lower_bound=0.0), 'low': Positive(lower_bound=0.0)}
reparametrized_params: list[str] = ['low', 'high']
pytree_data_fields: tuple[str, ...] = ('low', 'high', '_support')
property support: Constraint

The support of this distribution. Subclasses can override this as a class attribute or as a property.

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.

entropy() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns the entropy of the distribution.

Logistic

class Logistic(loc: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 0.0, scale: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1.0, *, validate_args: bool | None = None)[source]

Bases: Distribution

The Logistic distribution, a continuous real-valued distribution parameterized by location \(\mu\) and scale \(s > 0\). Its CDF is the standard logistic (sigmoid) function shifted and scaled to \(\mu\), \(s\), which makes it the natural latent distribution underlying logistic regression.

The Probability Density Function (PDF) is:

\[f(x \mid \mu, s) = \frac{ \exp\!\left(-\displaystyle\frac{x - \mu}{s}\right) }{ s \left(1 + \exp\!\left(-\displaystyle\frac{x - \mu}{s}\right)\right)^{2} }, \quad x \in \mathbb{R}\]

where \(\mu \in \mathbb{R}\) is the location (loc) and \(s > 0\) is the scale (scale).

Parameters:
  • loc – Location parameter \(\mu \in \mathbb{R}\). Defaults to 0.0.

  • scale – Scale parameter \(s > 0\). Defaults to 1.0.

  • validate_args – If True, enforce domain constraints during initialization.

arg_constraints: dict[str, Any] = {'loc': Real(), 'scale': Positive(lower_bound=0.0)}
support = Real()
reparametrized_params: list[str] = ['loc', 'scale']
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Draw samples via the location-scale transform \(X = \mu + s Z\), where \(Z \sim \mathrm{Logistic}(0, 1)\) is drawn from logistic().

Parameters:
  • key – A JAX PRNG key.

  • sample_shape – Sample dimensions to prepend to the batch shape.

Returns:

Real-valued samples from the Logistic distribution.

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluate the log probability density function at value.

Letting \(u = (\mu - x)/s\), the log PDF is

\[\ln f(x \mid \mu, s) = u - \ln s - 2 \ln(1 + e^{u})\]

The implementation uses softplus() for \(\ln(1 + e^{u})\), which is numerically stable for both large positive and large negative values of \(u\).

Parameters:

value – Real-valued point \(x\) at which to evaluate the log PDF.

Returns:

Log probability density evaluated under the Logistic distribution.

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the Logistic distribution:

\[\mathbb{E}[X] = \mu\]
property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the Logistic distribution:

\[\mathrm{Var}(X) = \frac{\pi^2 s^2}{3}\]
cdf(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Cumulative Distribution Function (CDF) of the Logistic distribution. Letting \(z = (x - \mu)/s\),

\[F(x \mid \mu, s) = \sigma(z) = \frac{1}{1 + e^{-z}}\]

where \(\sigma\) is the logistic sigmoid, computed via expit().

Parameters:

value – Real-valued point \(x\) at which to evaluate the CDF.

Returns:

CDF values in \([0, 1]\).

icdf(q: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Inverse CDF (quantile function) of the Logistic distribution:

\[F^{-1}(q \mid \mu, s) = \mu + s\,\operatorname{logit}(q), \quad q \in [0, 1]\]

where \(\operatorname{logit}(q) = \ln(q / (1 - q))\).

Parameters:

q – Quantile values in \([0, 1]\).

Returns:

Real-valued quantiles of the Logistic distribution at q.

entropy() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Differential entropy of the Logistic distribution:

\[H(X) = \ln(s) + 2\]

LowRankMultivariateNormal

class LowRankMultivariateNormal(loc: Array, cov_factor: Array, cov_diag: Array, *, validate_args: bool | None = None)[source]

Bases: Distribution

arg_constraints: dict[str, Any] = {'cov_diag': IndependentConstraint(Positive(lower_bound=0.0), 1), 'cov_factor': IndependentConstraint(Real(), 2), 'loc': RealVector(Real(), 1)}
support = RealVector(Real(), 1)
reparametrized_params: list[str] = ['loc', 'cov_factor', 'cov_diag']
pytree_data_fields: tuple[str, ...] = ('loc', 'cov_factor', 'cov_diag', '_capacitance_tril')
property mean: Array

Mean of the distribution.

variance() Array[source]

Variance of the distribution.

scale_tril() Array[source]
covariance_matrix() Array[source]
precision_matrix() Array[source]
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

entropy() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns the entropy of the distribution.

static infer_shapes(loc, cov_factor, cov_diag)[source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.

  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.

Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

MatrixNormal

class MatrixNormal(loc: Array, scale_tril_row: Array, scale_tril_column: Array, *, validate_args: bool | None = None)[source]

Bases: Distribution

Matrix variate normal distribution as described in [1] but with a lower_triangular parametrization, i.e. \(U=scale_tril_row @ scale_tril_row^{T}\) and \(V=scale_tril_column @ scale_tril_column^{T}\). The distribution is related to the multivariate normal distribution in the following way. If \(X ~ MN(loc,U,V)\) then \(vec(X) ~ MVN(vec(loc), kron(V,U) )\).

Parameters:
  • loc (array_like) – Location of the distribution.

  • scale_tril_row (array_like) – Lower cholesky of rows covariance matrix.

  • scale_tril_column (array_like) – Lower cholesky of columns covariance matrix.

References

[1] https://en.wikipedia.org/wiki/Matrix_normal_distribution

arg_constraints: dict[str, Any] = {'loc': RealVector(Real(), 1), 'scale_tril_column': LowerCholesky(), 'scale_tril_row': LowerCholesky()}
support = RealMatrix(Real(), 2)
reparametrized_params: list[str] = ['loc', 'scale_tril_row', 'scale_tril_column']
property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(values)[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

MultivariateNormal

class MultivariateNormal(loc: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 0.0, covariance_matrix: Array | None = None, precision_matrix: Array | None = None, scale_tril: Array | None = None, *, validate_args: bool | None = None)[source]

Bases: Distribution

arg_constraints: dict[str, Any] = {'covariance_matrix': PositiveDefinite(), 'loc': RealVector(Real(), 1), 'precision_matrix': PositiveDefinite(), 'scale_tril': LowerCholesky()}
support = RealVector(Real(), 1)
reparametrized_params: list[str] = ['loc', 'covariance_matrix', 'precision_matrix', 'scale_tril']
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

covariance_matrix()[source]
precision_matrix()[source]
property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.

static infer_shapes(loc=(), covariance_matrix=None, precision_matrix=None, scale_tril=None)[source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.

  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.

Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

entropy() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns the entropy of the distribution.

MultivariateStudentT

class MultivariateStudentT(df: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, loc: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 0.0, scale_tril: Array | None = None, *, validate_args: bool | None = None)[source]

Bases: Distribution

arg_constraints: dict[str, Any] = {'df': Positive(lower_bound=0.0), 'loc': RealVector(Real(), 1), 'scale_tril': LowerCholesky()}
support = RealVector(Real(), 1)
reparametrized_params: list[str] = ['df', 'loc', 'scale_tril']
pytree_data_fields: tuple[str, ...] = ('df', 'loc', 'scale_tril', '_chi2')
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

covariance_matrix()[source]
precision_matrix() Array[source]
property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.

static infer_shapes(df, loc, scale_tril)[source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.

  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.

Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

Normal

class Normal(loc: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 0.0, scale: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1.0, *, validate_args: bool | None = None)[source]

Bases: Distribution

Normal (Gaussian) distribution parameterized by mean (loc) and standard deviation (scale).

The probability density function (PDF) is defined as:

\[f(x; \mu, \sigma) = \frac{1}{\sigma \sqrt{2\pi}} \exp\!\left( -\frac{(x - \mu)^2}{2\sigma^2} \right)\]

where \(x \in \mathbb{R}\), \(\mu \in \mathbb{R}\) is the mean, and \(\sigma > 0\) is the standard deviation.

Parameters:
  • loc (ArrayLike) – Mean of the distribution (\(\mu\)).

  • scale (ArrayLike) – Standard deviation of the distribution (\(\sigma\)).

  • validate_args (bool, optional) – Whether to validate input constraints, defaults to None.

arg_constraints: dict[str, Any] = {'loc': Real(), 'scale': Positive(lower_bound=0.0)}
support = Real()
reparametrized_params: list[str] = ['loc', 'scale']
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Generates samples via the reparameterization trick: \(X = \mu + \sigma \epsilon\), where \(\epsilon \sim \mathcal{N}(0,1)\).

Parameters:
  • key (jax.Array) – JAX PRNGKey for reproducibility.

  • sample_shape (tuple[int, ...]) – The shape of the samples to be generated.

Returns:

Samples from the Normal distribution of shape sample_shape + batch_shape.

Return type:

ArrayLike

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Calculates the log of the probability density function.

\[\log f(x; \mu, \sigma) = -\frac{(x - \mu)^2}{2\sigma^2} - \log(\sigma \sqrt{2\pi})\]
Parameters:

value (ArrayLike) – Values at which to evaluate the log density.

Returns:

Log probability density.

Return type:

ArrayLike

cdf(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Cumulative distribution function.

\[F(x; \mu, \sigma) = \Phi\!\left(\frac{x-\mu}{\sigma}\right)\]

where, \(\Phi\) is the cumulative distribution function of standard normal distribution. Implementation uses jax.scipy.special.ndtr() for \(\Phi\).

Parameters:

value (ArrayLike) – Value to evaluate.

log_cdf(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Log of the cumulative distribution function. Implementation calls jax.scipy.stats.norm.logcdf().

Parameters:

value (ArrayLike) – Value to evaluate.

icdf(q: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Inverse cumulative distribution function (Quantile function).

\[F^{-1}(q; \mu, \sigma) = \mu + \sigma\,\Phi^{-1}(q)\]

where, \(\mathrm{\Phi^{-1}}\) is inverse cumulative distribution function of standard normal distribution. Implementation uses jax.scipy.special.ndtri() for \(\Phi^{-1}\).

Parameters:

q (ArrayLike) – Probability value in \([0,1]\).

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Calculates the analytical mean.

\[E[X] = \mu\]
property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Calculates the analytical variance.

\[\mathrm{Var}(X) = \sigma^2\]
entropy() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Entropy of the Normal distribution.

\[H(X) = \frac{1}{2} \ln(2\pi e \sigma^2)\]

Pareto

class Pareto(scale: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, alpha: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, validate_args: bool | None = None)[source]

Bases: TransformedDistribution

arg_constraints: dict[str, Any] = {'alpha': Positive(lower_bound=0.0), 'scale': Positive(lower_bound=0.0)}
reparametrized_params: list[str] = ['scale', 'alpha']
property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.

property support: Constraint

The support of this distribution. Subclasses can override this as a class attribute or as a property.

entropy() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns the entropy of the distribution.

RelaxedBernoulli

RelaxedBernoulli(temperature, probs=None, logits=None, *, validate_args: bool | None = None)[source]

RelaxedBernoulliLogits

class RelaxedBernoulliLogits(temperature: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, logits: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, validate_args: bool | None = None)[source]

Bases: TransformedDistribution

arg_constraints: dict[str, Any] = {'logits': Real(), 'temperature': Positive(lower_bound=0.0)}
support = UnitInterval(lower_bound=0.0, upper_bound=1.0)

SoftLaplace

class SoftLaplace(loc: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, scale: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, validate_args: bool | None = None)[source]

Bases: Distribution

Smooth distribution with Laplace-like tail behavior.

This distribution corresponds to the log-convex density:

z = (value - loc) / scale
log_prob = log(2 / pi) - log(scale) - logaddexp(z, -z)

Like the Laplace density, this density has the heaviest possible tails (asymptotically) while still being log-convex. Unlike the Laplace distribution, this distribution is infinitely differentiable everywhere, and is thus suitable for HMC and Laplace approximation.

Parameters:
  • loc – Location parameter.

  • scale – Scale parameter.

arg_constraints: dict[str, Any] = {'loc': Real(), 'scale': Positive(lower_bound=0.0)}
support = Real()
reparametrized_params: list[str] = ['loc', 'scale']
log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

cdf(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The cumulative distribution function of this distribution.

Parameters:

value – samples from this distribution.

Returns:

output of the cumulative distribution function evaluated at value.

icdf(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The inverse cumulative distribution function of this distribution.

Parameters:

q – quantile values, should belong to [0, 1].

Returns:

the samples whose cdf values equals to q.

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.

StudentT

class StudentT(df: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, loc: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 0.0, scale: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1.0, *, validate_args: bool | None = None)[source]

Bases: Distribution

arg_constraints: dict[str, Any] = {'df': Positive(lower_bound=0.0), 'loc': Real(), 'scale': Positive(lower_bound=0.0)}
support = Real()
reparametrized_params: list[str] = ['df', 'loc', 'scale']
pytree_data_fields: tuple[str, ...] = ('df', 'loc', 'scale', '_chi2')
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.

cdf(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The cumulative distribution function of this distribution.

Parameters:

value – samples from this distribution.

Returns:

output of the cumulative distribution function evaluated at value.

icdf(q: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The inverse cumulative distribution function of this distribution.

Parameters:

q – quantile values, should belong to [0, 1].

Returns:

the samples whose cdf values equals to q.

entropy() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns the entropy of the distribution.

Uniform

class Uniform(low: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 0.0, high: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1.0, *, validate_args: bool | None = None)[source]

Bases: Distribution

arg_constraints: dict[str, Any] = {'high': Dependent(), 'low': Dependent()}
reparametrized_params: list[str] = ['low', 'high']
pytree_data_fields: tuple[str, ...] = ('low', 'high', '_support')
property support: Constraint

The support of this distribution. Subclasses can override this as a class attribute or as a property.

sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

cdf(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The cumulative distribution function of this distribution.

Parameters:

value – samples from this distribution.

Returns:

output of the cumulative distribution function evaluated at value.

icdf(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The inverse cumulative distribution function of this distribution.

Parameters:

q – quantile values, should belong to [0, 1].

Returns:

the samples whose cdf values equals to q.

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.

static infer_shapes(low: tuple[int, ...] = (), high: tuple[int, ...] = ()) tuple[tuple[int, ...], tuple[int, ...]][source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.

  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.

Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

entropy() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns the entropy of the distribution.

Weibull

class Weibull(scale: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, concentration: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, validate_args: bool | None = None)[source]

Bases: Distribution

arg_constraints: dict[str, Any] = {'concentration': Positive(lower_bound=0.0), 'scale': Positive(lower_bound=0.0)}
support = Positive(lower_bound=0.0)
reparametrized_params: list[str] = ['scale', 'concentration']
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

cdf(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The cumulative distribution function of this distribution.

Parameters:

value – samples from this distribution.

Returns:

output of the cumulative distribution function evaluated at value.

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.

entropy() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns the entropy of the distribution.

Wishart

class Wishart(concentration: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, scale_matrix: Array | None = None, rate_matrix: Array | None = None, scale_tril: Array | None = None, *, validate_args: bool | None = None)[source]

Bases: TransformedDistribution

Wishart distribution for covariance matrices.

Parameters:
  • concentration – Positive concentration parameter analogous to the concentration of a Gamma distribution. The concentration must be larger than the dimensionality of the scale matrix.

  • scale_matrix – Scale matrix analogous to the inverse rate of a Gamma distribution.

  • rate_matrix – Rate matrix anaologous to the rate of a Gamma distribution.

  • scale_tril – Cholesky decomposition of the scale_matrix.

arg_constraints: dict[str, Any] = {'concentration': Dependent(), 'rate_matrix': PositiveDefinite(), 'scale_matrix': PositiveDefinite(), 'scale_tril': LowerCholesky()}
support = PositiveDefinite()
reparametrized_params: list[str] = ['scale_matrix', 'rate_matrix', 'scale_tril']
concentration()[source]
scale_matrix()[source]
rate_matrix()[source]
scale_tril()[source]
mean() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Mean of the distribution.

variance() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Variance of the distribution.

static infer_shapes(concentration=(), scale_matrix=None, rate_matrix=None, scale_tril=None)[source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.

  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.

Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

entropy() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns the entropy of the distribution.

WishartCholesky

class WishartCholesky(concentration: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, scale_matrix: Array | None = None, rate_matrix: Array | None = None, scale_tril: Array | None = None, *, validate_args: bool | None = None)[source]

Bases: Distribution

Cholesky factor of a Wishart distribution for covariance matrices.

Parameters:
  • concentration – Positive concentration parameter analogous to the concentration of a Gamma distribution. The concentration must be larger than the dimensionality of the scale matrix.

  • scale_matrix – Scale matrix analogous to the inverse rate of a Gamma distribution.

  • rate_matrix – Rate matrix anaologous to the rate of a Gamma distribution.

  • scale_tril – Cholesky decomposition of the scale_matrix.

arg_constraints: dict[str, Any] = {'concentration': Dependent(), 'rate_matrix': PositiveDefinite(), 'scale_matrix': PositiveDefinite(), 'scale_tril': LowerCholesky()}
support = LowerCholesky()
reparametrized_params: list[str] = ['scale_matrix', 'rate_matrix', 'scale_tril']
log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

scale_matrix()[source]
rate_matrix()[source]
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

mean() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Mean of the distribution.

variance() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Variance of the distribution.

static infer_shapes(concentration: tuple[int, ...] = (), scale_matrix: tuple[int, ...] | None = None, rate_matrix: tuple[int, ...] | None = None, scale_tril: tuple[int, ...] | None = None)[source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.

  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.

Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

ZeroSumNormal

class ZeroSumNormal(scale: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, event_shape: tuple[int, ...], *, validate_args: bool | None = None)[source]

Bases: TransformedDistribution

Zero Sum Normal distribution adapted from PyMC [1] as described in [2,3]. This is a Normal distribution where one or more axes are constrained to sum to zero (the last axis by default).

\[\begin{split}\begin{align*} ZSN(\sigma) = N(0, \sigma^2 (I - \tfrac{1}{n}J)) \\ \text{where} \ ~ J_{ij} = 1 \ ~ \text{and} \\ n = \text{number of zero-sum axes} \end{align*}\end{split}\]
Parameters:
  • scale (array_like) – Standard deviation of the underlying normal distribution before the zerosum constraint is enforced.

  • event_shape (tuple) – The event shape of the distribution, the axes of which get constrained to sum to zero.

Example:

>>> from numpy.testing import assert_allclose
>>> from jax import random
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.infer import MCMC, NUTS

>>> N = 1000
>>> n_categories = 20
>>> rng_key = random.key(0)
>>> key1, key2, key3 = random.split(rng_key, 3)
>>> category_ind = random.choice(key1, jnp.arange(n_categories), shape=(N,))
>>> beta = random.normal(key2, shape=(n_categories,))
>>> beta -= beta.mean(-1)
>>> y = 5 + beta[category_ind] + random.normal(key3, shape=(N,))

>>> def model(category_ind, y): # category_ind is an indexed categorical variable with 20 categories
...     N = len(category_ind)
...     alpha = numpyro.sample("alpha", dist.Normal(0, 2.5))
...     beta = numpyro.sample("beta", dist.ZeroSumNormal(1, event_shape=(n_categories,)))
...     sigma =  numpyro.sample("sigma", dist.Exponential(1))
...     with numpyro.plate("observations", N):
...         mu = alpha + beta[category_ind]
...         obs = numpyro.sample("obs", dist.Normal(mu, sigma), obs=y)
...     return obs

>>> nuts_kernel = NUTS(model=model, target_accept_prob=0.9)
>>> mcmc = MCMC(
...     sampler=nuts_kernel,
...     num_samples=1_000, num_warmup=1_000, num_chains=4
... )
>>> mcmc.run(random.key(0), category_ind=category_ind, y=y)
>>> posterior_samples = mcmc.get_samples()
>>> # Confirm everything along last axis sums to zero
>>> assert_allclose(posterior_samples['beta'].sum(-1), 0, atol=1e-3)

References [1] https://github.com/pymc-devs/pymc/blob/6252d2e58dc211c913ee2e652a4058d271d48bbd/pymc/distributions/multivariate.py#L2637 [2] https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.ZeroSumNormal.html [3] https://learnbayesstats.com/episode/74-optimizing-nuts-developing-zerosumnormal-distribution-adrian-seyboldt/

arg_constraints: dict[str, Any] = {'scale': Positive(lower_bound=0.0)}
reparametrized_params: list[str] = ['scale']
property support: Constraint

The support of this distribution. Subclasses can override this as a class attribute or as a property.

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.

Discrete Distributions

Bernoulli

Bernoulli(probs: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, logits: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, *, validate_args: bool | None = None) BernoulliProbs | BernoulliLogits[source]

Factory function to create a Bernoulli distribution instance from either probability or log-odds parameterization.

Parameters:
  • probs – The success probability parameter in the unit interval \([0, 1]\), defaults to None

  • logits – The log-odds parameter, defaults to None

  • validate_args – Optional toggle to enforce domain constraints during graph construction. Default is None.

Returns:

The created Bernoulli distribution instance.

BernoulliLogits

class BernoulliLogits(logits: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, validate_args: bool | None = None)[source]

Bases: Distribution

A Bernoulli discrete random variable parameterized by log-odds (logits).

The Probability Mass Function (PMF) of the Bernoulli distribution is:

\[P(X = k | \alpha) = \sigma(\alpha)^k (1 - \sigma(\alpha))^{1-k}, \quad k \in \{0, 1\}\]

Where \(\alpha = \text{logits}\) is the log-odds parameter and \(\sigma(\alpha) = 1/(1 + \exp{(-\alpha)})\) is the sigmoid function.

Parameters:
  • logits – Log-odds parameter spanning the full real line \(\alpha \in \mathbb{R}\).

  • validate_args – If True, enforce domain constraints during initialization.

arg_constraints: dict[str, Any] = {'logits': Real()}
support = Boolean()

The support of the Bernoulli distribution is the set of binary outcomes \(\{0, 1\}\).

has_enumerate_support: bool = True
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Draw samples from the Bernoulli distribution.

The method first converts logits to probabilities via the sigmoid function (accessed via the lazy property probs), then invokes bernoulli() for sampling.

Parameters:
  • key – A JAX random number generator key (PRNG state).

  • sample_shape – Desired sample dimensions to prepend to the batch shape.

Returns:

Binary-valued samples (0 or 1) drawn from the Bernoulli distribution.

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluate the log probability mass function at specified binary configurations.

The log probability mass function leverages the numerically-stable binary_cross_entropy_with_logits() primitive, which computes the Bernoulli negative log-likelihood directly in log-odds space:

\[\ln P(X = k | \alpha) = -\mathrm{BCEWithLogits}(\alpha, k) = k \ln(\sigma(\alpha)) + (1-k) \ln(1 - \sigma(\alpha))\]

This formulation avoids explicit exponential evaluation for large \(|\alpha|\), protecting against overflow (\(e^\alpha \to \infty\) for \(\alpha \gg 0\)) and underflow (\(e^{-\alpha} \to 0\) for \(\alpha \ll -0\)).

Parameters:

value – Binary observation(s) to score (\(k \in \{0, 1\}\)).

Returns:

Log probability scores evaluated under the Bernoulli PMF.

probs() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The success probability parameter of the Bernoulli distribution is given by the sigmoid of the log-odds parameter:

\[p = \sigma(\alpha) = \frac{1}{1 + e^{-\alpha}}\]
property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

The mean of the Bernoulli distribution is given by the sigmoid of the log-odds parameter:

\[E[X] = \sigma(\alpha) = \frac{1}{1 + e^{-\alpha}}\]
property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

The variance of the Bernoulli distribution is given by:

\[\mathrm{Var}[X] = \sigma(\alpha) (1 - \sigma(\alpha))\]
enumerate_support(expand: bool = True) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns an array with shape len(support) x batch_shape containing all values in the support.

entropy() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The entropy of the Bernoulli distribution is given by:

\[H[X] = -p \ln p - (1-p) \ln (1-p)\]

where \(p = \sigma(\alpha)\) is the mean of the distribution.

The implementation is of following form to maintain numerical stability across the full range of log-odds values:

\[H[X] = \frac{(1 + e^{-\alpha}) \ln(1 + e^{-\alpha}) + e^{-\alpha} \alpha}{1 + e^{-\alpha}}\]

BernoulliProbs

class BernoulliProbs(probs: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, validate_args: bool | None = None)[source]

Bases: Distribution

A Bernoulli discrete random variable parameterizing the probability of a binary outcome.

The Probability Mass Function (PMF) of the Bernoulli distribution is defined as:

\[P(X = k | p) = p^k (1 - p)^{1-k}, \quad k \in \{0, 1\}\]

Where, \(p\) represents the success probability parameter (probs), \(k\) represents the observed binary outcome (value). The support domain is \(k \in \{0, 1\}\).

Parameters:
  • probs – Success probability in the interval \([0, 1]\).

  • validate_args – If True, enforce domain constraints during initialization.

arg_constraints: dict[str, Any] = {'probs': UnitInterval(lower_bound=0.0, upper_bound=1.0)}
support = Boolean()

The support of the Bernoulli distribution is the set of binary outcomes \(\{0, 1\}\).

has_enumerate_support: bool = True
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Draw samples from the Bernoulli distribution.

This method invokes bernoulli() directly, which generates binary samples from the Bernoulli parametrization. Samples are mapped across the specified batch dimensions and sample dimensions via shape broadcasting.

Parameters:
  • key – A JAX random number generator key (PRNG state).

  • sample_shape – Desired sample dimensions to prepend to the batch shape.

Returns:

Binary-valued samples (0 or 1) drawn from the Bernoulli distribution.

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluate the log probability mass function at specified binary configurations.

\[\ln P(X=k | p) = k\ln(p) + (1-k)\ln(1-p)\]

The log probability mass function is evaluated using numerically-stable log-space operations. Rather than computing \(\ln(p)\) and \(\ln(1-p)\) directly from clamped probabilities, this implementation employs the primitives xlogy() and xlog1py(), which handle edge cases gracefully:

  • When \(p = 0\) or \(p = 1\), the log-probability computation is protected from logarithmic singularities via masking.

  • The clamped probability values prevent numerical underflow in extreme configurations.

Parameters:

value – Binary observation(s) to score (\(k \in \{0, 1\}\)).

Returns:

Log probability scores evaluated under the Bernoulli PMF.

logits() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The log-odds (logits) parameter of the Bernoulli distribution is given by the logit transformation of the success probability:

\[\alpha = \text{logit}(p) = \ln\left(\frac{p}{1-p}\right)\]
property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

The mean of the Bernoulli distribution is given by the success probability parameter:

\[E[X] = p\]
Returns:

The mean of the Bernoulli distribution, which is equal to the success probability probs.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

The variance of the Bernoulli distribution is given by:

\[\mathrm{Var}[X] = p (1 - p)\]
Returns:

The variance of the Bernoulli distribution, which is the product of the success probability and its complement.

enumerate_support(expand: bool = True) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns an array with shape len(support) x batch_shape containing all values in the support.

entropy() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The entropy of the Bernoulli distribution is given by:

\[H[X] = -p \ln p - (1-p) \ln (1-p)\]

BetaBinomial

class BetaBinomial(concentration1: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, concentration0: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, total_count: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1, *, validate_args: bool | None = None)[source]

Bases: Distribution

Compound distribution comprising of a beta-binomial pair. The probability of success (probs for the Binomial distribution) is unknown and randomly drawn from a Beta distribution prior to a certain number of Bernoulli trials given by total_count.

Parameters:
  • concentration1 (numpy.ndarray) – 1st concentration parameter (alpha) for the Beta distribution.

  • concentration0 (numpy.ndarray) – 2nd concentration parameter (beta) for the Beta distribution.

  • total_count (numpy.ndarray) – number of Bernoulli trials.

arg_constraints: dict[str, Any] = {'concentration0': Positive(lower_bound=0.0), 'concentration1': Positive(lower_bound=0.0), 'total_count': IntegerNonnegative(lower_bound=0)}
has_enumerate_support: bool = True
enumerate_support(expand: bool = True) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Returns an array with shape len(support) x batch_shape containing all values in the support.

pytree_data_fields: tuple[str, ...] = ('concentration1', 'concentration0', 'total_count', '_beta')
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.

property support: Constraint

The support of this distribution. Subclasses can override this as a class attribute or as a property.

BetaNegativeBinomial

class BetaNegativeBinomial(concentration1: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, concentration0: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, n: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, validate_args: bool | None = None)[source]

Bases: Distribution

Compound distribution comprising of a beta-negative-binomial pair. The probs parameter for the NegativeBinomialProbs distribution is unknown and randomly drawn from a Beta distribution prior to the negative binomial counting process.

The Beta Negative Binomial is a heavy-tailed discrete distribution useful for modeling overdispersed count data. It arises as the marginal distribution when integrating out the success probability in a negative binomial model with a beta prior.

Parameters:
  • concentration1 (numpy.ndarray) – 1st concentration parameter (alpha) for the Beta distribution.

  • concentration0 (numpy.ndarray) – 2nd concentration parameter (beta) for the Beta distribution.

  • n (numpy.ndarray) – positive number of successes parameter for the negative binomial distribution.

References

[1] https://en.wikipedia.org/wiki/Beta_negative_binomial_distribution [2] https://mc-stan.org/docs/functions-reference/unbounded_discrete_distributions.html#beta-neg-binomial

arg_constraints: dict[str, Any] = {'concentration0': Positive(lower_bound=0.0), 'concentration1': Positive(lower_bound=0.0), 'n': Positive(lower_bound=0.0)}
support = IntegerNonnegative(lower_bound=0)
pytree_data_fields: tuple[str, ...] = ('concentration1', 'concentration0', 'n', '_beta')
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

If \(X \sim \mathrm{BetaNegativeBinomial}(\alpha, \beta, n)\), then the sampling procedure is:

\[\begin{split}\begin{align*} p &\sim \mathrm{Beta}(\alpha, \beta) \\ X \mid p &\sim \mathrm{NegativeBinomial}(n, p) \end{align*}\end{split}\]

It uses Beta to generate samples from the Beta distribution and NegativeBinomialProbs to generate samples from the Negative Binomial distribution.

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

If \(X \sim \mathrm{BetaNegativeBinomial}(\alpha, \beta, n)\), then the log probability mass function is:

\[P(X = k) = \binom{n + k - 1}{k} \frac{B(\alpha + k, \beta + n)}{B(\alpha, \beta)}\]

To ensure differentiability, the binomial coefficient is computed using gamma functions.

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

If \(X \sim \mathrm{BetaNegativeBinomial}(\alpha, \beta, n)\) and \(\beta > 1\), then the mean is:

\[\mathbb{E}[X] = \frac{n\alpha}{\beta - 1},\]

otherwise, the its undefined.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

If \(X \sim \mathrm{BetaNegativeBinomial}(\alpha, \beta, n)\) and \(\beta > 2\), then the variance is:

\[\mathrm{Var}[X] = \frac{n\alpha (n + \beta - 1)(\alpha + \beta - 1)}{(\beta - 1)^2 \cdot (\beta - 2)},\]

otherwise, the its undefined.

Binomial

Binomial(total_count: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1, probs: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, logits: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, *, validate_args: bool | None = None) BinomialProbs | BinomialLogits[source]

Factory function to create a Binomial distribution instance from either probability or log-odds parameterization.

Parameters:
  • total_count – Number of trials (non-negative integer), defaults to 1

  • probs – The success probability parameter in the unit interval \([0, 1]\), defaults to None

  • logits – The log-odds parameter, defaults to None

  • validate_args – Optional toggle to enforce simplex constraint during graph construction. Default is None

Returns:

A Binomial distribution instance corresponding to the specified parameterization.

BinomialLogits

class BinomialLogits(logits: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, total_count: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1, *, validate_args: bool | None = None)[source]

Bases: Distribution

A Binomial discrete random variable parameterized by log-odds (logits).

The Probability Mass Function (PMF) of the Binomial distribution is:

\[P(X = k | n, \alpha) = \binom{n}{k} \sigma(\alpha)^k (1 - \sigma(\alpha))^{n-k}\]

Where \(\alpha = \text{logits}\) and \(\sigma(\alpha) = 1/(1 + \exp(-\alpha))\).

Parameters:
  • logits – Log-odds parameter spanning \(\mathbb{R}\).

  • total_count – Number of trials (non-negative integer).

  • validate_args – If True, enforce domain constraints during initialization.

arg_constraints: dict[str, Any] = {'logits': Real(), 'total_count': IntegerNonnegative(lower_bound=0)}
has_enumerate_support: bool = True
enumerate_support(expand: bool = True) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Returns an array with shape len(support) x batch_shape containing all values in the support.

sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Draw samples from the Binomial distribution.

The method first converts logits to probabilities via the sigmoid function (via the lazy property probs), then uses the internal binomial() utility for sampling. This maintains numerical stability across extreme log-odds values.

Parameters:
  • key – A JAX random number generator key (PRNG state).

  • sample_shape – Desired sample dimensions to prepend to the batch shape.

Returns:

Non-negative integer samples representing success counts.

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluate the log probability mass function at specified count configurations.

The log probability mass function is computed entirely in log-space using a numerically-stable formulation that avoids sigmoid underflow/overflow:

\[\ln P(X = k | n, \alpha) = \ln \binom{n}{k} + (k - n) \alpha - n \ln(1 + \sigma(-|\alpha|))\]

The binomial coefficient in log-space is computed using the log-gamma function:

\[\ln \binom{n}{k} = \ln\Gamma(n + 1) - \ln\Gamma(k + 1) - \ln\Gamma(n - k + 1)\]

This approach using gammaln() avoids computing factorials explicitly.

Parameters:

value – Count observation(s) in the range \([0, n]\).

Returns:

Log probability scores evaluated under the Binomial PMF.

probs() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The success probability per trial of the Binomial distribution is given by the sigmoid of the log-odds parameter:

\[p = \sigma(\alpha) = \frac{1}{1 + e^{-\alpha}}\]
property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

The mean of the Binomial distribution is given by:

\[E[X] = n \sigma(\alpha)\]
property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

The variance of the Binomial distribution is given by:

\[\mathrm{Var}[X] = n \sigma(\alpha) (1 - \sigma(\alpha))\]
property support: Constraint

The support of the Binomial distribution is the set of integer counts from 0 to the total count.

BinomialProbs

class BinomialProbs(probs: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, total_count: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1, *, validate_args: bool | None = None)[source]

Bases: Distribution

A Binomial discrete random variable parameterizing the count of successes in repeated trials.

The Probability Mass Function (PMF) of the Binomial distribution is defined as:

\[P(X = k | n, p) = \binom{n}{k} p^k (1 - p)^{n-k}, \quad k \in \{0, 1, \dots, n\}\]

Where, \(n\) is the number of trials (total_count), \(p\) is the success probability per trial (probs), \(k\) is the observed count of successes (value).

Parameters:
  • probs – Success probability per trial in \([0, 1]\).

  • total_count – Number of trials (non-negative integer).

  • validate_args – If True, enforce domain constraints during initialization.

arg_constraints: dict[str, Any] = {'probs': UnitInterval(lower_bound=0.0, upper_bound=1.0), 'total_count': IntegerNonnegative(lower_bound=0)}
has_enumerate_support: bool = True
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Draw samples from the Binomial distribution.

This method uses the internal binomial() utility function to generate count samples.

Parameters:
  • key – A JAX random number generator key (PRNG state).

  • sample_shape – Desired sample dimensions to prepend to the batch shape.

Returns:

Non-negative integer samples representing success counts.

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluate the log probability mass function at specified count configurations.

The log probability mass function is fully evaluated in log-space to prevent factorial overflow and underflow:

\[\ln P(X = k | n, p) = \ln \binom{n}{k} + k \ln p + (n-k) \log(1-p)\]

The binomial coefficient in log-space is computed using the log-gamma function:

\[\ln \binom{n}{k} = \ln\Gamma(n + 1) - \ln\Gamma(k + 1) - \ln\Gamma(n - k + 1)\]

This approach using gammaln() avoids computing factorials explicitly. The probability terms are evaluated using xlogy() and xlog1py() to handle boundary cases gracefully (\(p = 0\), \(p = 1\), etc.).

Parameters:

value – Count observation(s) in the range \([0, n]\).

Returns:

Log probability scores evaluated under the Binomial PMF.

logits() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The log-odds (logits) parameter of the Binomial distribution is given by the logit transformation of the success probability:

\[\alpha = \text{logit}(p) = \ln\left(\frac{p}{1-p}\right)\]
property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

The mean of the Binomial distribution is given by:

\[E[X] = n p\]
property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

The variance of the Binomial distribution is given by:

\[\mathrm{Var}[X] = n p (1 - p)\]
property support: Constraint

The support of the Binomial distribution is the set of integer counts from 0 to the total count.

enumerate_support(expand: bool = True) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns an array with shape len(support) x batch_shape containing all values in the support.

Categorical

Categorical(probs=None, logits=None, *, validate_args: bool | None = None)[source]

CategoricalLogits

class CategoricalLogits(logits: Array, *, validate_args: bool | None = None)[source]

Bases: Distribution

A Categorical discrete random variable over \(K\) mutually exclusive outcomes, parameterized by unnormalized log-probabilities (logits).

The Probability Mass Function (PMF) of the Categorical distribution is defined, via the softmax transformation of the logits, as:

\[P(X = k \mid \boldsymbol{\alpha}) = \frac{\exp(\alpha_k)}{\sum_{j=0}^{K-1} \exp(\alpha_j)}, \quad k \in \{0, 1, \dots, K-1\}\]

Where, \(\boldsymbol{\alpha} = (\alpha_0, \alpha_1, \dots, \alpha_{K-1})\) is the real-valued logits vector (logits), \(K\) is the number of categories (the size of the trailing dimension of logits), and \(k\) is the observed category index (value). The support domain is \(k \in \{0, 1, \dots, K-1\}\).

Parameters:
  • logits – Real-valued logits vector; the trailing dimension indexes the \(K\) categories. Logits are unnormalized and converted to probabilities via the softmax function.

  • validate_args – If True, enforce domain constraints during initialization.

arg_constraints: dict[str, Any] = {'logits': RealVector(Real(), 1)}
has_enumerate_support: bool = True
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Draw samples from the Categorical distribution.

This method invokes categorical() directly, which samples in logit-space using the Gumbel-max trick and therefore avoids materializing the softmax-normalized probabilities.

Parameters:
  • key – A JAX random number generator key (PRNG state).

  • sample_shape – Desired sample dimensions to prepend to the batch shape.

Returns:

Integer-valued samples in \(\{0, 1, \dots, K-1\}\) drawn from the Categorical distribution.

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluate the log probability mass function at specified category indices.

\[\ln P(X = k \mid \boldsymbol{\alpha}) = \alpha_k - \ln\!\sum_{j=0}^{K-1} \exp(\alpha_j)\]

The normalizing log-partition is computed via logsumexp(), which uses the standard max-subtraction trick to guarantee numerical stability in the presence of large or widely-spread logit magnitudes. After normalization, the relevant log-probability is gathered with take_along_axis().

Parameters:

value – Category index/indices to score (\(k \in \{0, 1, \dots, K-1\}\)).

Returns:

Log probability scores evaluated under the Categorical PMF.

probs() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The probability vector of the Categorical distribution is given by the softmax of the logits:

\[p_k = \frac{\exp(\alpha_k)}{\sum_{j=0}^{K-1} \exp(\alpha_j)}, \quad k \in \{0, 1, \dots, K-1\}\]
property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

The mean of a Categorical distribution over arbitrary unordered categories is not well-defined. This property therefore returns NaN.

Returns:

An array of NaNs with shape equal to batch_shape.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

The variance of a Categorical distribution over arbitrary unordered categories is not well-defined. This property therefore returns NaN.

Returns:

An array of NaNs with shape equal to batch_shape.

property support: Constraint

The support of the Categorical distribution is the set of integers \(\{0, 1, \dots, K-1\}\), where \(K\) is the number of categories inferred from the trailing dimension of logits.

enumerate_support(expand: bool = True) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Enumerate all values in the support of the Categorical distribution.

Parameters:

expand – Whether to broadcast the enumerated values across the batch shape. Default is True.

Returns:

An array of integer category indices \(\{0, 1, \dots, K-1\}\), optionally broadcast across the batch dimensions.

entropy() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The entropy of the Categorical distribution is given by:

\[H[X] = -\sum_{k=0}^{K-1} p_k \ln p_k = \ln\!\sum_{j=0}^{K-1} \exp(\alpha_j) - \sum_{k=0}^{K-1} p_k\, \alpha_k\]

where \(p_k = \mathrm{softmax}(\boldsymbol{\alpha})_k\). The implementation uses logsumexp() for the log-partition term, ensuring numerical stability for large or widely-spread logits.

Returns:

The entropy of the Categorical distribution.

CategoricalProbs

class CategoricalProbs(probs: Array, *, validate_args: bool | None = None)[source]

Bases: Distribution

A Categorical discrete random variable over \(K\) mutually exclusive outcomes, parameterized by a probability vector on the simplex.

The Probability Mass Function (PMF) of the Categorical distribution is defined as:

\[P(X = k \mid \mathbf{p}) = p_k, \quad k \in \{0, 1, \dots, K-1\}\]

where the probability vector \(\mathbf{p} = (p_0, p_1, \dots, p_{K-1})\) satisfies \(p_k \ge 0\) and \(\sum_{k=0}^{K-1} p_k = 1\).

Where, \(\mathbf{p}\) represents the category probability vector (probs), \(K\) is the number of categories (the size of the trailing dimension of probs), and \(k\) is the observed category index (value). The support domain is \(k \in \{0, 1, \dots, K-1\}\).

Parameters:
  • probs – Category probability vector on the simplex; the trailing dimension indexes the \(K\) categories and must sum to one.

  • validate_args – If True, enforce domain constraints during initialization.

arg_constraints: dict[str, Any] = {'probs': Simplex()}
has_enumerate_support: bool = True
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Draw samples from the Categorical distribution.

This method delegates to categorical(), which internally relies on categorical() over the log-probabilities of probs.

Parameters:
  • key – A JAX random number generator key (PRNG state).

  • sample_shape – Desired sample dimensions to prepend to the batch shape.

Returns:

Integer-valued samples in \(\{0, 1, \dots, K-1\}\) drawn from the Categorical distribution.

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluate the log probability mass function at specified category indices.

\[\ln P(X = k \mid \mathbf{p}) = \ln p_k\]

The implementation gathers the log-probabilities from logits (which are already normalized log-probabilities computed from probs) at the positions indicated by value using take_along_axis().

Parameters:

value – Category index/indices to score (\(k \in \{0, 1, \dots, K-1\}\)).

Returns:

Log probability scores evaluated under the Categorical PMF.

logits() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The log-probability (logits) parameter of the Categorical distribution is the (already-normalized) log of the category probabilities:

\[\alpha_k = \ln p_k, \quad k \in \{0, 1, \dots, K-1\}\]
property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

The mean of a Categorical distribution over arbitrary unordered categories is not well-defined. This property therefore returns NaN.

Returns:

An array of NaNs with shape equal to batch_shape.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

The variance of a Categorical distribution over arbitrary unordered categories is not well-defined. This property therefore returns NaN.

Returns:

An array of NaNs with shape equal to batch_shape.

property support: Constraint

The support of the Categorical distribution is the set of integers \(\{0, 1, \dots, K-1\}\), where \(K\) is the number of categories inferred from the trailing dimension of probs.

enumerate_support(expand: bool = True) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Enumerate all values in the support of the Categorical distribution.

Parameters:

expand – Whether to broadcast the enumerated values across the batch shape. Default is True.

Returns:

An array of integer category indices \(\{0, 1, \dots, K-1\}\), optionally broadcast across the batch dimensions.

entropy() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The entropy of the Categorical distribution is given by:

\[H[X] = -\sum_{k=0}^{K-1} p_k \ln p_k\]
Returns:

The entropy of the Categorical distribution.

DirichletMultinomial

class DirichletMultinomial(concentration: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, total_count: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1, *, total_count_max: int | None = None, validate_args: bool | None = None)[source]

Bases: Distribution

Compound distribution comprising of a dirichlet-multinomial pair. The probability of classes (probs for the Multinomial distribution) is unknown and randomly drawn from a Dirichlet distribution prior to a certain number of Categorical trials given by total_count.

Parameters:
  • concentration (numpy.ndarray) – concentration parameter (alpha) for the Dirichlet distribution.

  • total_count (numpy.ndarray) – number of Categorical trials.

  • total_count_max (int) – the maximum number of trials, i.e. max(total_count)

arg_constraints: dict[str, Any] = {'concentration': IndependentConstraint(Positive(lower_bound=0.0), 1), 'total_count': IntegerNonnegative(lower_bound=0)}
pytree_data_fields: tuple[str, ...] = ('concentration', '_dirichlet')
pytree_aux_fields: tuple[str, ...] = ('total_count', 'total_count_max')
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.

property support: Constraint

The support of this distribution. Subclasses can override this as a class attribute or as a property.

static infer_shapes(concentration: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, total_count=()) tuple[tuple[int, ...], tuple[int, ...]][source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.

  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.

Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

DiscreteUniform

class DiscreteUniform(low: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 0, high: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1, *, validate_args: bool | None = None)[source]

Bases: Distribution

A discrete uniform random variable over the inclusive integer interval \(\{a, a+1, \dots, b\}\), where \(a\) (low) is the inclusive lower bound and \(b\) (high) is the inclusive upper bound of the support.

The Probability Mass Function (PMF) of the discrete uniform distribution is defined as:

\[P(X = k \mid a, b) = \frac{1}{b - a + 1}, \quad k \in \{a, a+1, \dots, b\}\]

Where \(k\) is the observed integer value (value).

Parameters:
  • low – Inclusive lower bound of the integer support. Default is 0.

  • high – Inclusive upper bound of the integer support. Must satisfy high >= low. Default is 1.

  • validate_args – If True, enforce domain constraints during initialization.

arg_constraints: dict[str, Any] = {'high': Dependent(), 'low': Dependent()}
has_enumerate_support: bool = True
pytree_data_fields: tuple[str, ...] = ('low', 'high', '_support')
property support: Constraint

The support of the discrete uniform distribution is the set of integers \(\{\text{low}, \text{low}+1, \dots, \text{high}\}\).

sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Draw samples from the discrete uniform distribution.

This method invokes randint() directly, which generates uniformly distributed integers in the half-open interval [low, high + 1), equivalent to the inclusive interval \(\{a, \dots, b\}\).

Parameters:
  • key – A JAX random number generator key (PRNG state).

  • sample_shape – Desired sample dimensions to prepend to the batch shape.

Returns:

Integer-valued samples drawn uniformly from \(\{a, \dots, b\}\).

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluate the log probability mass function at specified integer values.

\[\ln P(X = k \mid a, b) = -\ln(b - a + 1)\]

The log-PMF is constant over the support, so the implementation simply broadcasts the negative log of the support cardinality to the requested shape.

Parameters:

value – Integer observation(s) to score (\(k \in \{a, \dots, b\}\)).

Returns:

Log probability scores evaluated under the discrete uniform PMF.

cdf(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluate the cumulative distribution function (CDF) of the discrete uniform distribution.

\[F(x) = \frac{\lfloor x \rfloor + 1 - a}{b - a + 1}, \quad \text{clipped to } [0, 1]\]
Parameters:

value – Point(s) at which to evaluate the CDF.

Returns:

The CDF evaluated at value, clipped to the unit interval.

icdf(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluate the inverse cumulative distribution function (quantile function) of the discrete uniform distribution.

\[F^{-1}(u) = a + u\,(b - a + 1) - 1, \quad u \in [0, 1]\]
Parameters:

value – Quantile level(s) \(u \in [0, 1]\).

Returns:

The inverse CDF evaluated at value.

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

The mean of the discrete uniform distribution is the midpoint of the support:

\[E[X] = \frac{a + b}{2}\]
Returns:

The mean of the discrete uniform distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

The variance of the discrete uniform distribution is given by:

\[\mathrm{Var}[X] = \frac{(b - a + 1)^2 - 1}{12}\]
Returns:

The variance of the discrete uniform distribution.

enumerate_support(expand: bool = True) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Enumerate all values in the support of the discrete uniform distribution.

Both low and high must be concrete (non-JAX-tracer) values and homogeneous across the batch shape; otherwise a NotImplementedError is raised.

Parameters:

expand – Whether to broadcast the enumerated values across the batch shape. Default is True.

Returns:

An array of integer values \(\{a, a+1, \dots, b\}\), optionally broadcast across the batch dimensions.

entropy() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The entropy of the discrete uniform distribution is given by:

\[H[X] = \ln(\text{high} - \text{low} + 1)\]
Returns:

The entropy of the discrete uniform distribution.

GammaPoisson

class GammaPoisson(concentration: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, rate: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1.0, *, validate_args: bool | None = None)[source]

Bases: Distribution

Compound distribution comprising of a gamma-poisson pair, also referred to as a gamma-poisson mixture. The rate parameter for the Poisson distribution is unknown and randomly drawn from a Gamma distribution.

Parameters:
  • concentration (numpy.ndarray) – shape parameter (alpha) of the Gamma distribution.

  • rate (numpy.ndarray) – rate parameter (rate) for the Gamma distribution.

arg_constraints: dict[str, Any] = {'concentration': Positive(lower_bound=0.0), 'rate': Positive(lower_bound=0.0)}
support = IntegerNonnegative(lower_bound=0)
pytree_data_fields: tuple[str, ...] = ('concentration', 'rate', '_gamma')
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

If \(X \sim \mathrm{GammaPoisson}(\alpha, \lambda)\), then the sampling procedure is:

\[\begin{split}\begin{align*} \theta &\sim \mathrm{Gamma}(\alpha, \lambda) \\ X \mid \theta &\sim \mathrm{Poisson}(\theta) \end{align*}\end{split}\]

It uses Gamma to generate samples from the Gamma distribution and Poisson to generate samples from the Poisson distribution.

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

If \(X \sim \mathrm{GammaPoisson}(\alpha, \lambda)\), then the probability mass function is:

\[p_{X}(k) = \frac{\lambda^\alpha}{(\alpha + k)(1+\lambda)^{\alpha + k}\mathrm{B}(\alpha, k + 1)}\]
property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

If \(X \sim \mathrm{GammaPoisson}(\alpha, \lambda)\), then the mean is:

\[\mathbb{E}[X] = \frac{\alpha}{\lambda}\]
property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

If \(X \sim \mathrm{GammaPoisson}(\alpha, \lambda)\), then the variance is:

\[\mathrm{Var}[X] = \frac{\alpha}{\lambda^2}(1 + \lambda)\]
cdf(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

If \(X \sim \mathrm{GammaPoisson}(\alpha, \lambda)\), then the cumulative distribution function is:

\[F_{X}(x) = \frac{1}{\mathrm{B}(\alpha, x + 1)} \int_{0}^{\frac{\lambda}{1 + \lambda}} t^{\alpha - 1} (1 - t)^{x} dt\]

which is the regularized incomplete beta function. This implementation uses betainc().

Geometric

Geometric(probs: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, logits: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, *, validate_args: bool | None = None) GeometricProbs | GeometricLogits[source]

GeometricLogits

class GeometricLogits(logits: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, validate_args: bool | None = None)[source]

Bases: Distribution

arg_constraints: dict[str, Any] = {'logits': Real()}
support = IntegerNonnegative(lower_bound=0)
probs() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.

entropy() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns the entropy of the distribution.

GeometricProbs

class GeometricProbs(probs: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, validate_args: bool | None = None)[source]

Bases: Distribution

arg_constraints: dict[str, Any] = {'probs': UnitInterval(lower_bound=0.0, upper_bound=1.0)}
support = IntegerNonnegative(lower_bound=0)
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

logits() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]
property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.

entropy() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns the entropy of the distribution.

HurdleDistribution

HurdleDistribution(base_dist: Distribution, *, gate: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, gate_logits: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, validate_args: bool | None = None) HurdleProbs | HurdleLogits[source]

Generic hurdle distribution.

A hurdle model is a two-part model: a Bernoulli “hurdle” selects between an exact zero (with probability gate) and a positive draw from the (zero-truncated, for discrete bases) base distribution. Returns a HurdleProbs if gate is supplied, or a HurdleLogits if gate_logits is supplied. Exactly one of the two must be provided. See HurdleProbs for the full mechanism, assumptions, and PMF/PDF.

Parameters:
  • base_dist (Distribution) – the base distribution.

  • gate (ArrayLike) – probability of a structural zero.

  • gate_logits (ArrayLike) – log-odds of a structural zero.

References:

  1. Cragg, J. G. (1971). Some Statistical Models for Limited Dependent Variables with Application to the Demand for Durable Goods. Econometrica, 39(5), 829-844.

  2. Mullahy, J. (1986). Specification and testing of some modified count data models. Journal of Econometrics, 33(3), 341-365.

HurdleLogits

class HurdleLogits(base_dist: Distribution, gate_logits: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, validate_args: bool | None = None)[source]

Bases: HurdleProbs

Hurdle distribution parameterized by gate_logits (the log-odds of the structural zero) instead of a probability.

Like HurdleProbs, this is a two-part model where a Bernoulli “hurdle” - here parameterized in logit space - selects between an exact zero and a positive draw from the (zero-truncated, for discrete bases) base distribution. See HurdleProbs for the full mechanism, assumptions, and underlying PMF/PDF.

Parameters:
  • base_dist (Distribution) – the base distribution.

  • gate_logits (ArrayLike) – log-odds of a structural zero, \(\mathrm{logit}(g) = \log\frac{g}{1 - g}\).

References:

  1. Cragg, J. G. (1971). Some Statistical Models for Limited Dependent Variables with Application to the Demand for Durable Goods. Econometrica, 39(5), 829-844.

  2. Mullahy, J. (1986). Specification and testing of some modified count data models. Journal of Econometrics, 33(3), 341-365.

arg_constraints: dict[str, Any] = {'gate_logits': Real()}
pytree_data_fields: tuple[str, ...] = ('base_dist', 'gate_logits')

HurdleNegativeBinomial2

HurdleNegativeBinomial2(gate: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, concentration: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, validate_args: bool | None = None) HurdleProbs[source]

A hurdle Negative Binomial distribution (NB2 / mean-dispersion parameterization): a two-part model in which structural zeros are produced by a Bernoulli “hurdle” with probability \(g\) and positive counts follow a zero-truncated \(\mathrm{NegativeBinomial2}(\mu, \alpha)\). The hurdle and the magnitude (given a positive count) are conditionally independent; see HurdleProbs for the full mechanism and assumptions. Compared to a Hurdle Poisson, NB2 accommodates count data that is over-dispersed (variance greater than the mean).

The probability mass function is

\[P(X = 0) = g, \qquad P(X = k) = (1 - g) \, \frac{\mathrm{NB2}(k\mid\mu, \alpha)} {1 - \mathrm{NB2}(0\mid\mu, \alpha)} \;\text{for } k \geq 1,\]

where \(\mathrm{NB2}(\cdot\mid\mu, \alpha)\) is the PMF of a Negative Binomial distribution with mean \(\mu\) and concentration (dispersion) \(\alpha\).

Parameters:
  • gate (ArrayLike) – probability of a structural zero, \(g \in [0, 1]\).

  • mean (ArrayLike) – mean \(\mu > 0\) of the underlying NegativeBinomial2.

  • concentration (ArrayLike) – concentration \(\alpha > 0\).

References:

  1. Mullahy, J. (1986). Specification and testing of some modified count data models. Journal of Econometrics, 33(3), 341-365.

  2. Cragg, J. G. (1971). Some Statistical Models for Limited Dependent Variables with Application to the Demand for Durable Goods. Econometrica, 39(5), 829-844.

HurdlePoisson

class HurdlePoisson(gate: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, rate: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1.0, *, validate_args: bool | None = None)[source]

Bases: HurdleProbs

A hurdle Poisson distribution: a two-part model in which structural zeros are produced by a Bernoulli “hurdle” with probability \(g\) and positive counts follow a zero-truncated \(\mathrm{Poisson}(\lambda)\). The hurdle and the magnitude (given a positive count) are conditionally independent; see HurdleProbs for the full mechanism and assumptions.

The probability mass function is

\[P(X = 0) = g, \qquad P(X = k) = (1 - g) \, \frac{\lambda^{k} e^{-\lambda} / k!}{1 - e^{-\lambda}} \;\text{for } k \geq 1.\]
Parameters:
  • gate (ArrayLike) – probability of a structural zero, \(g \in [0, 1]\).

  • rate (ArrayLike) – rate \(\lambda > 0\) of the underlying Poisson.

References:

  1. Mullahy, J. (1986). Specification and testing of some modified count data models. Journal of Econometrics, 33(3), 341-365.

  2. Cragg, J. G. (1971). Some Statistical Models for Limited Dependent Variables with Application to the Demand for Durable Goods. Econometrica, 39(5), 829-844.

arg_constraints: dict[str, Any] = {'gate': UnitInterval(lower_bound=0.0, upper_bound=1.0), 'rate': Positive(lower_bound=0.0)}
support = IntegerNonnegative(lower_bound=0)
pytree_data_fields: tuple[str, ...] = ('rate',)

HurdleProbs

class HurdleProbs(base_dist: Distribution, gate: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, validate_args: bool | None = None)[source]

Bases: Distribution

Generic hurdle distribution parameterized by a probability \(g\) (gate) of the structural zero and an arbitrary base distribution.

Hurdle mechanism. A hurdle model is a two-part model. A Bernoulli “hurdle” decides whether the outcome is zero (with probability \(g\), the gate) or strictly positive (with probability \(1 - g\)). Conditional on the outcome being positive, the magnitude is drawn from the base distribution - zero-truncated in the discrete case so the base distribution cannot itself produce a zero. With \(B\) denoting the base PMF/PDF:

\[P(X = 0) = g, \qquad P(X = k) = (1 - g) \, \frac{B(k)}{1 - B(0)} \;\text{for } k \geq 1 \;\text{(discrete base)}\]

For a continuous base on \(\mathbb{R}_{>0}\) the truncation factor \(1 - B(0)\) equals 1 and the formula simplifies to a point mass at 0 with weight \(g\) mixed with \((1 - g) \, b(x)\) on \(x > 0\).

Assumptions.

  1. All zeros are structural - they originate exclusively from the hurdle process. This contrasts with zero-inflated models, which mix structural zeros with sampling zeros from the base distribution.

  2. The hurdle decision (zero vs. positive) and the magnitude (given positive) are conditionally independent given the parameters.

  3. For a discrete base, \(P(\text{base} = 0) < 1\) so the truncation factor \(1 - B(0)\) is well-defined. For a continuous base supported on \(\mathbb{R}_{>0}\), \(P(\text{base} = 0) = 0\) and no truncation is needed.

Note

gate is the probability of a structural zero. This matches the convention used by ZeroInflatedDistribution, and corresponds to 1 - psi in PyMC’s hurdle distributions.

Parameters:
  • base_dist (Distribution) – the base distribution.

  • gate (ArrayLike) – probability of a structural zero, in \([0, 1]\).

References:

  1. Cragg, J. G. (1971). Some Statistical Models for Limited Dependent Variables with Application to the Demand for Durable Goods. Econometrica, 39(5), 829-844.

  2. Mullahy, J. (1986). Specification and testing of some modified count data models. Journal of Econometrics, 33(3), 341-365.

arg_constraints: dict[str, Any] = {'gate': UnitInterval(lower_bound=0.0, upper_bound=1.0)}
pytree_data_fields: tuple[str, ...] = ('base_dist', 'gate')
pytree_aux_fields: tuple[str, ...] = ('_is_discrete',)
property support: Constraint

The support of this distribution. Subclasses can override this as a class attribute or as a property.

sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

mean() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Mean of the distribution.

variance() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Variance of the distribution.

Multinomial

Multinomial(total_count=1, probs: Array = None, logits: Array = None, *, total_count_max: int | None = None, validate_args: bool | None = None) MultinomialProbs | MultinomialLogits[source]

Multinomial distribution.

Parameters:
  • total_count – number of trials. If this is a JAX array, it is required to specify total_count_max.

  • probs – event probabilities

  • logits – event log probabilities

  • total_count_max (int) – the maximum number of trials, i.e. max(total_count)

MultinomialLogits

class MultinomialLogits(logits: Array, total_count: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1, *, total_count_max: int | None = None, validate_args: bool | None = None)[source]

Bases: Distribution

arg_constraints: dict[str, Any] = {'logits': RealVector(Real(), 1), 'total_count': IntegerNonnegative(lower_bound=0)}
pytree_data_fields: tuple[str, ...] = ('logits',)
pytree_aux_fields: tuple[str, ...] = ('total_count', 'total_count_max')
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

probs() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]
property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.

property support: Constraint

The support of this distribution. Subclasses can override this as a class attribute or as a property.

static infer_shapes(logits: Array, total_count: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) tuple[tuple[int, ...], tuple[int, ...]][source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.

  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.

Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

MultinomialProbs

class MultinomialProbs(probs: Array, total_count: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1, *, total_count_max: int | None = None, validate_args: bool | None = None)[source]

Bases: Distribution

arg_constraints: dict[str, Any] = {'probs': Simplex(), 'total_count': IntegerNonnegative(lower_bound=0)}
pytree_data_fields: tuple[str, ...] = ('probs',)
pytree_aux_fields: tuple[str, ...] = ('total_count', 'total_count_max')
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

logits() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]
property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.

property support: Constraint

The support of this distribution. Subclasses can override this as a class attribute or as a property.

static infer_shapes(probs: Array, total_count: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) tuple[tuple[int, ...], tuple[int, ...]][source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.

  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.

Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

OrderedLogistic

class OrderedLogistic(predictor: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, cutpoints: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, validate_args: bool | None = None)[source]

Bases: CategoricalProbs

A categorical distribution with ordered outcomes.

References:

  1. Stan Functions Reference, v2.20 section 12.6, Stan Development Team

Parameters:
  • predictor (numpy.ndarray) – prediction in real domain; typically this is output of a linear model.

  • cutpoints (numpy.ndarray) – positions in real domain to separate categories.

  • probs – Category probability vector on the simplex; the trailing dimension indexes the \(K\) categories and must sum to one.

  • validate_args – If True, enforce domain constraints during initialization.

arg_constraints: dict[str, Any] = {'cutpoints': OrderedVector(), 'predictor': Real()}
static infer_shapes(predictor, cutpoints)[source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.

  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.

Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

entropy() Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The entropy of the Categorical distribution is given by:

\[H[X] = -\sum_{k=0}^{K-1} p_k \ln p_k\]
Returns:

The entropy of the Categorical distribution.

NegativeBinomial

NegativeBinomial(total_count: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, probs: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, logits: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, *, validate_args: bool | None = None) GammaPoisson[source]

Factory function for Negative Binomial distribution.

Parameters:
  • total_count (int) – Number of successful trials.

  • probs (Optional[ArrayLike]) – Probability of success for each trial, by default None

  • logits (Optional[ArrayLike]) – Log-odds of success for each trial, by default None

  • validate_args (Optional[bool]) – Whether to validate the parameters, by default None

Returns:

An instance of NegativeBinomialProbs or NegativeBinomialLogits depending on the provided parameters.

Return type:

GammaPoisson

Raises:

ValueError – If neither probs nor logits is specified.

NegativeBinomialLogits

class NegativeBinomialLogits(total_count: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, logits: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, validate_args: bool | None = None)[source]

Bases: GammaPoisson

Negative Binomial distribution parameterized by total_count (\(r\)) and logits (\(\displaystyle\mathrm{logits}(p)=\log \frac{p}{1-p}\)). It is implemented as a \(\mathrm{GammaPoisson}(n, \exp(-\mathrm{logits}(p)))\) distribution.

Parameters:
  • total_count – Number of successful trials.

  • logits – Log-odds of success for each trial (\(\ln \frac{p}{1-p}\)).

arg_constraints: dict[str, Any] = {'logits': Real(), 'total_count': Positive(lower_bound=0.0)}
support = IntegerNonnegative(lower_bound=0)
log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

If \(X \sim \mathrm{NegativeBinomial}(r, \mathrm{logits}(p))\), then the log probability mass function is:

\[\ln P(X = k) = -r \ln(1+\exp(\mathrm{logits}(p))) - k \ln(1+\exp(-\mathrm{logits}(p))) - \ln\Gamma(1 + k) - \ln\Gamma(\alpha) + \ln\Gamma(k + \alpha)\]

NegativeBinomialProbs

class NegativeBinomialProbs(total_count: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, probs: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, validate_args: bool | None = None)[source]

Bases: GammaPoisson

Negative Binomial distribution parameterized by total_count (\(r\)) and probs (\(p\)). It is implemented as a \(\displaystyle\mathrm{GammaPoisson}(n, \frac{1}{p} - 1)\) distribution.

Parameters:
  • total_count – Number of successful trials (\(r\)).

  • probs – Probability of success for each trial (\(p\)).

arg_constraints: dict[str, Any] = {'probs': UnitInterval(lower_bound=0.0, upper_bound=1.0), 'total_count': Positive(lower_bound=0.0)}
support = IntegerNonnegative(lower_bound=0)

NegativeBinomial2

class NegativeBinomial2(mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, concentration: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, validate_args: bool | None = None)[source]

Bases: GammaPoisson

If \(X \sim \mathrm{NegativeBinomial2}(\mu, \alpha)\), then \(X \sim \mathrm{GammaPoisson}(\alpha, \frac{\alpha}{\mu})\).

Parameters:
arg_constraints: dict[str, Any] = {'concentration': Positive(lower_bound=0.0), 'mean': Positive(lower_bound=0.0)}
support = IntegerNonnegative(lower_bound=0)
pytree_data_fields: tuple[str, ...] = ('concentration',)

Poisson

class Poisson(rate: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, is_sparse: bool = False, validate_args: bool | None = None)[source]

Bases: Distribution

Creates a Poisson distribution parameterized by rate, the rate parameter.

Samples are nonnegative integers, with a pmf given by

\[\mathrm{rate}^k \frac{e^{-\mathrm{rate}}}{k!}\]
Parameters:
  • rate (numpy.ndarray) – The rate parameter

  • is_sparse (bool) – Whether to assume value is mostly zero when computing log_prob(), which can speed up computation when data is sparse.

arg_constraints: dict[str, Any] = {'rate': GreaterThanEq(lower_bound=0.0)}
support = IntegerNonnegative(lower_bound=0)
pytree_aux_fields: tuple[str, ...] = ('is_sparse',)
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.

cdf(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The cumulative distribution function of this distribution.

Parameters:

value – samples from this distribution.

Returns:

output of the cumulative distribution function evaluated at value.

ZeroInflatedDistribution

ZeroInflatedDistribution(base_dist: Distribution, *, gate: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, gate_logits: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, validate_args: bool | None = None) ZeroInflatedProbs | ZeroInflatedLogits[source]

Generic Zero Inflated distribution.

Parameters:
  • base_dist (Distribution) – the base distribution.

  • gate (numpy.ndarray) – probability of extra zeros given via a Bernoulli distribution.

  • gate_logits (numpy.ndarray) – logits of extra zeros given via a Bernoulli distribution.

ZeroInflatedPoisson

class ZeroInflatedPoisson(gate: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, rate: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1.0, *, validate_args: bool | None = None)[source]

Bases: ZeroInflatedProbs

A Zero Inflated Poisson distribution.

Parameters:
arg_constraints: dict[str, Any] = {'gate': UnitInterval(lower_bound=0.0, upper_bound=1.0), 'rate': Positive(lower_bound=0.0)}
support = IntegerNonnegative(lower_bound=0)
pytree_data_fields: tuple[str, ...] = ('rate',)

ZeroInflatedNegativeBinomial2

ZeroInflatedNegativeBinomial2(mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, concentration: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, gate: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, gate_logits: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, validate_args: bool | None = None)[source]

Mixture Distributions

Mixture

Mixture(mixing_distribution: CategoricalProbs | CategoricalLogits, component_distributions: list[Distribution] | Distribution, *, validate_args: bool | None = None)[source]

A marginalized finite mixture of component distributions

The returned distribution will be either a:

  1. MixtureGeneral, when component_distributions is a list, or

  2. MixtureSameFamily, when component_distributions is a single distribution.

and more details can be found in the documentation for each of these classes.

Parameters:
  • mixing_distribution – A Categorical specifying the weights for each mixture components. The size of this distribution specifies the number of components in the mixture, mixture_size.

  • component_distributions – Either a list of component distributions or a single vectorized distribution. When a list is provided, the number of elements must equal mixture_size. Otherwise, the last batch dimension of the distribution must equal mixture_size.

Returns:

The mixture distribution.

MixtureSameFamily

class MixtureSameFamily(mixing_distribution: CategoricalProbs | CategoricalLogits, component_distribution: Distribution, *, validate_args: bool | None = None)[source]

Bases: _MixtureBase

A finite mixture of component distributions from the same family

This mixture only supports a mixture of component distributions that are all of the same family. The different components are specified along the last batch dimension of the input component_distribution. If you need a mixture of distributions from different families, use the more general implementation in MixtureGeneral.

Parameters:
  • mixing_distribution – A Categorical specifying the weights for each mixture components. The size of this distribution specifies the number of components in the mixture, mixture_size.

  • component_distribution – A single vectorized Distribution, whose last batch dimension equals mixture_size as specified by mixing_distribution.

Example

>>> import jax
>>> import jax.numpy as jnp
>>> import numpyro.distributions as dist
>>> mixing_dist = dist.Categorical(probs=jnp.ones(3) / 3.)
>>> component_dist = dist.Normal(loc=jnp.zeros(3), scale=jnp.ones(3))
>>> mixture = dist.MixtureSameFamily(mixing_dist, component_dist)
>>> mixture.sample(jax.random.key(42)).shape
()
pytree_data_fields: tuple[str, ...] = ('_mixing_distribution', '_component_distribution')
pytree_aux_fields: tuple[str, ...] = ('_mixture_size',)
property component_distribution: Distribution

Return the vectorized distribution of components being mixed.

Returns:

Component distribution

Return type:

Distribution

property support: Constraint

The support of this distribution. Subclasses can override this as a class attribute or as a property.

property is_discrete: bool
property component_mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray
property component_variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray
component_cdf(samples: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]
component_sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]
component_log_probs(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

MixtureGeneral

class MixtureGeneral(mixing_distribution: CategoricalProbs | CategoricalLogits, component_distributions: list[Distribution], *, support: Constraint | None = None, validate_args: bool | None = None)[source]

Bases: _MixtureBase

A finite mixture of component distributions from different families

If all of the component distributions are from the same family, the more specific implementation in MixtureSameFamily will be somewhat more efficient.

Parameters:
  • mixing_distribution – A Categorical specifying the weights for each mixture components. The size of this distribution specifies the number of components in the mixture, mixture_size.

  • component_distributions – A list of mixture_size Distribution objects.

  • support – A Constraint object specifying the support of the mixture distribution. If not provided, the support will be inferred from the component distributions.

Example

>>> import jax
>>> import jax.numpy as jnp
>>> import numpyro.distributions as dist
>>> mixing_dist = dist.Categorical(probs=jnp.ones(3) / 3.)
>>> component_dists = [
...     dist.Normal(loc=0.0, scale=1.0),
...     dist.Normal(loc=-0.5, scale=0.3),
...     dist.Normal(loc=0.6, scale=1.2),
... ]
>>> mixture = dist.MixtureGeneral(mixing_dist, component_dists)
>>> mixture.sample(jax.random.key(42)).shape
()
>>> import jax
>>> import jax.numpy as jnp
>>> import numpyro.distributions as dist
>>> mixing_dist = dist.Categorical(probs=jnp.ones(2) / 2.)
>>> component_dists = [
...     dist.Normal(loc=0.0, scale=1.0),
...     dist.HalfNormal(scale=0.3),
... ]
>>> mixture = dist.MixtureGeneral(mixing_dist, component_dists, support=dist.constraints.real)
>>> mixture.sample(jax.random.key(42)).shape
()
pytree_data_fields: tuple[str, ...] = ('_mixing_distribution', '_component_distributions', '_support')
pytree_aux_fields: tuple[str, ...] = ('_mixture_size',)
property component_distributions: list[Distribution]

The list of component distributions in the mixture

Returns:

The list of component distributions

Return type:

list[Distribution]

property support: Constraint

The support of this distribution. Subclasses can override this as a class attribute or as a property.

property is_discrete: bool
property component_mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray
property component_variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray
component_cdf(samples: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array[source]
component_sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]
component_log_probs(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Directional Distributions

ProjectedNormal

class ProjectedNormal(concentration: Array, *, validate_args: bool | None = None)[source]

Bases: Distribution

Projected isotropic normal distribution of arbitrary dimension.

This distribution over directional data is qualitatively similar to the von Mises and von Mises-Fisher distributions, but permits tractable variational inference via reparametrized gradients.

To use this distribution with autoguides and HMC, use handlers.reparam with a ProjectedNormalReparam reparametrizer in the model, e.g.:

@handlers.reparam(config={"direction": ProjectedNormalReparam()})
def model():
    direction = numpyro.sample("direction",
                               ProjectedNormal(zeros(3)))
    ...

Note

This implements log_prob() only for dimensions {2,3}.

[1] D. Hernandez-Stumpfhauser, F.J. Breidt, M.J. van der Woerd (2017)

“The General Projected Normal Distribution of Arbitrary Dimension: Modeling and Bayesian Inference” https://projecteuclid.org/euclid.ba/1453211962

arg_constraints: dict[str, Any] = {'concentration': RealVector(Real(), 1)}
reparametrized_params: list[str] = ['concentration']
support = Sphere()
property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Note this is the mean in the sense of a centroid in the submanifold that minimizes expected squared geodesic distance.

property mode
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

static infer_shapes(concentration)[source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.

  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.

Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

SineBivariateVonMises

class SineBivariateVonMises(phi_loc: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, psi_loc: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, phi_concentration: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, psi_concentration: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, correlation: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, weighted_correlation: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, *, validate_args: bool | None = None)[source]

Bases: Distribution

Unimodal distribution of two dependent angles on the 2-torus (\(S^1 \otimes S^1\)) given by

\[C^{-1}\exp(\kappa_1\cos(x_1-\mu_1) + \kappa_2\cos(x_2 -\mu_2) + \rho\sin(x_1 - \mu_1)\sin(x_2 - \mu_2))\]

and

\[C = (2\pi)^2 \sum_{i=0} {2i \choose i} \left(\frac{\rho^2}{4\kappa_1\kappa_2}\right)^i I_i(\kappa_1)I_i(\kappa_2),\]

where \(I_i(\cdot)\) is the modified bessel function of first kind, mu’s are the locations of the distribution, kappa’s are the concentration and rho gives the correlation between angles \(x_1\) and \(x_2\). This distribution is helpful for modeling coupled angles such as torsion angles in peptide chains.

To infer parameters, use NUTS or HMC with priors that avoid parameterizations where the distribution becomes bimodal; see note below.

Note

Sample efficiency drops as

\[\frac{\rho^2}{\kappa_1\kappa_2} \rightarrow 1\]

because the distribution becomes increasingly bimodal. To avoid inefficient sampling use the weighted_correlation parameter with a skew away from one (e.g., TransformedDistribution(Beta(5,5), AffineTransform(loc=-1, scale=2))). The weighted_correlation should be in [-1,1].

Note

The correlation and weighted_correlation params are mutually exclusive.

Note

In the context of SVI, this distribution can be used as a likelihood but not for latent variables.

Note

Normalization remains accurate for concentrations up to 10,000. Unlike Pyro, there is no assertion to verify this during initialization, as JIT-compilation would invalidate such a check.

** References: **
  1. Probabilistic model for two dependent circular variables Singh, H., Hnizdo, V., and Demchuck, E. (2002)

Parameters:
  • phi_loc (np.ndarray) – location of first angle

  • psi_loc (np.ndarray) – location of second angle

  • phi_concentration (np.ndarray) – concentration of first angle

  • psi_concentration (np.ndarray) – concentration of second angle

  • correlation (np.ndarray) – correlation between the two angles

  • weighted_correlation (np.ndarray) – set correlation to weighted_corr * sqrt(phi_conc*psi_conc) to avoid bimodality (see note). The weighted_correlation should be in [0,1].

arg_constraints: dict[str, Any] = {'correlation': Real(), 'phi_concentration': Positive(lower_bound=0.0), 'phi_loc': Circular(lower_bound=-3.141592653589793, upper_bound=3.141592653589793), 'psi_concentration': Positive(lower_bound=0.0), 'psi_loc': Circular(lower_bound=-3.141592653589793, upper_bound=3.141592653589793)}
support = IndependentConstraint(Circular(lower_bound=-3.141592653589793, upper_bound=3.141592653589793), 1)
max_sample_iter = 1000
norm_const() Array[source]
log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]
** References: **
  1. A New Unified Approach for the Simulation of a Wide Class of Directional Distributions John T. Kent, Asaad M. Ganeiber & Kanti V. Mardia (2018)

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Computes circular mean of distribution. Note: same as location when mapped to support [-pi, pi]

SineSkewed

class SineSkewed(base_dist: Distribution, skewness: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, validate_args: bool | None = None)[source]

Bases: Distribution

Sine-skewing [1] is a procedure for producing a distribution that breaks pointwise symmetry on a torus distribution. The new distribution is called the Sine Skewed X distribution, where X is the name of the (symmetric) base distribution. Torus distributions are distributions with support on products of circles (i.e., \(\otimes S^1\) where \(S^1 = [-pi,pi)\)). So, a 0-torus is a point, the 1-torus is a circle, and the 2-torus is commonly associated with the donut shape.

The sine skewed X distribution is parameterized by a weight parameter for each dimension of the event of X. For example with a von Mises distribution over a circle (1-torus), the sine skewed von Mises distribution has one skew parameter. The skewness parameters can be inferred using HMC or NUTS. For example, the following will produce a prior over skewness for the 2-torus,:

@numpyro.handlers.reparam(config={'phi_loc': CircularReparam(), 'psi_loc': CircularReparam()})
def model(obs):
    # Sine priors
    phi_loc = numpyro.sample('phi_loc', VonMises(pi, 2.))
    psi_loc = numpyro.sample('psi_loc', VonMises(-pi / 2, 2.))
    phi_conc = numpyro.sample('phi_conc', Beta(1., 1.))
    psi_conc = numpyro.sample('psi_conc', Beta(1., 1.))
    corr_scale = numpyro.sample('corr_scale', Beta(2., 5.))

    # Skewing prior
    ball_trans = L1BallTransform()
    skewness = numpyro.sample('skew_phi', Normal(0, 0.5).expand((2,)))
    skewness = ball_trans(skewness)  # constraint sum |skewness_i| <= 1

    with numpyro.plate('obs_plate'):
        sine = SineBivariateVonMises(phi_loc=phi_loc, psi_loc=psi_loc,
                                     phi_concentration=70 * phi_conc,
                                     psi_concentration=70 * psi_conc,
                                     weighted_correlation=corr_scale)
        return numpyro.sample('phi_psi', SineSkewed(sine, skewness), obs=obs)

To ensure the skewing does not alter the normalization constant of the (sine bivariate von Mises) base distribution the skewness parameters are constraint. The constraint requires the sum of the absolute values of skewness to be less than or equal to one. We can use the L1BallTransform to achieve this.

In the context of SVI, this distribution can freely be used as a likelihood, but use as latent variables it will lead to slow inference for 2 and higher dim toruses. This is because the base_dist cannot be reparameterized.

Note

An event in the base distribution must be on a d-torus, so the event_shape must be (d,).

Note

For the skewness parameter, it must hold that the sum of the absolute value of its weights for an event must be less than or equal to one. See eq. 2.1 in [1].

** References: **
  1. Sine-skewed toroidal distributions and their application in protein bioinformatics

    Ameijeiras-Alonso, J., Ley, C. (2019)

Parameters:
  • base_dist (numpyro.distributions.Distribution) – base density on a d-dimensional torus. Supported base distributions include: 1D VonMises, SineBivariateVonMises, 1D ProjectedNormal, and Uniform (-pi, pi).

  • skewness (jax.numpy.array) – skewness of the distribution.

arg_constraints: dict[str, Any] = {'skewness': L1Ball()}
pytree_data_fields: tuple[str, ...] = ('base_dist', 'skewness')
support = IndependentConstraint(Circular(lower_bound=-3.141592653589793, upper_bound=3.141592653589793), 1)
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the base distribution

VonMises

class VonMises(loc: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, concentration: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, validate_args: bool | None = None)[source]

Bases: Distribution

The von Mises distribution, also known as the circular normal distribution.

This distribution is supported by a circular constraint from -pi to +pi. By default, the circular support behaves like constraints.interval(-math.pi, math.pi). To avoid issues at the boundaries of this interval during sampling, you should reparameterize this distribution using handlers.reparam with a CircularReparam reparametrizer in the model, e.g.:

@handlers.reparam(config={"direction": CircularReparam()})
def model():
    direction = numpyro.sample("direction", VonMises(0.0, 4.0))
    ...

von Mises distribution for sampling directions.

Parameters:
  • loc – center of distribution

  • concentration – concentration of distribution

arg_constraints: dict[str, Any] = {'concentration': Positive(lower_bound=0.0), 'loc': Real()}
reparametrized_params: list[str] = ['loc']
support = Circular(lower_bound=-3.141592653589793, upper_bound=3.141592653589793)
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Generate sample from von Mises distribution

Parameters:
  • key – random number generator key

  • sample_shape – shape of samples

Returns:

samples from von Mises

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Computes circular mean of distribution. NOTE: same as location when mapped to support [-pi, pi]

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Computes circular variance of distribution

Truncated Distributions

DoublyTruncatedPowerLaw

class DoublyTruncatedPowerLaw(alpha: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, low: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, high: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, validate_args: bool | None = None)[source]

Bases: Distribution

Power law distribution with \(\alpha\) index, and lower and upper bounds. We can define the power law distribution as,

\[f(x; \alpha, a, b) = \frac{x^{\alpha}}{Z(\alpha, a, b)},\]

where, \(a\) and \(b\) are the lower and upper bounds respectively, and \(Z(\alpha, a, b)\) is the normalization constant. It is defined as,

\[\begin{split}Z(\alpha, a, b) = \begin{cases} \log(b) - \log(a) & \text{if } \alpha = -1, \\ \frac{b^{1 + \alpha} - a^{1 + \alpha}}{1 + \alpha} & \text{otherwise}. \end{cases}\end{split}\]
Parameters:
  • alpha – index of the power law distribution

  • low – lower bound of the distribution

  • high – upper bound of the distribution

arg_constraints: dict[str, Any] = {'alpha': Real(), 'high': GreaterThan(lower_bound=0), 'low': GreaterThanEq(lower_bound=0)}
reparametrized_params: list[str] = ['alpha', 'low', 'high']
pytree_aux_fields: tuple[str, ...] = ('_support',)
pytree_data_fields: tuple[str, ...] = ('alpha', 'low', 'high')
property support: Constraint

The support of this distribution. Subclasses can override this as a class attribute or as a property.

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Logarithmic probability distribution:

Z inequal minus one:

\[\frac{(\alpha + 1)x^\alpha}{b^{\alpha + 1} - a^{\alpha + 1}}\]

Z equal minus one:

\[\frac{x^\alpha}{\log(b) - \log(a)}\]

Derivations are calculated by Wolfram Alpha via the Jacobian matrix accordingly.

cdf(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Cumulated probability distribution: Z inequal minus one:

\[\frac{x^{\alpha + 1} - a^{\alpha + 1}}{b^{\alpha + 1} - a^{\alpha + 1}}\]

Z equal minus one:

\[\frac{\log(x) - \log(a)}{\log(b) - \log(a)}\]

Derivations are calculated by Wolfram Alpha via the Jacobian matrix accordingly.

icdf(q: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Inverse cumulated probability distribution: Z inequal minus one:

\[a \left(\frac{b}{a}\right)^{q}\]

Z equal minus one:

\[\left(a^{1 + \alpha} + q (b^{1 + \alpha} - a^{1 + \alpha})\right)^{\frac{1}{1 + \alpha}}\]

Derivations are calculated by Wolfram Alpha via the Jacobian matrix accordingly.

sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

LeftTruncatedDistribution

class LeftTruncatedDistribution(base_dist: Cauchy | Laplace | Logistic | Normal | SoftLaplace | StudentT, low: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 0.0, *, validate_args: bool | None = None)[source]

Bases: Distribution

arg_constraints: dict[str, Any] = {'low': Real()}
reparametrized_params: list[str] = ['low']
supported_types = (<class 'numpyro.distributions.continuous.Cauchy'>, <class 'numpyro.distributions.continuous.Laplace'>, <class 'numpyro.distributions.continuous.Logistic'>, <class 'numpyro.distributions.continuous.Normal'>, <class 'numpyro.distributions.continuous.SoftLaplace'>, <class 'numpyro.distributions.continuous.StudentT'>)
pytree_data_fields: tuple[str, ...] = ('base_dist', 'low', '_support')
base_dist: Cauchy | Laplace | Logistic | Normal | SoftLaplace | StudentT
property support: Constraint

The support of this distribution. Subclasses can override this as a class attribute or as a property.

sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

icdf(q: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The inverse cumulative distribution function of this distribution.

Parameters:

q – quantile values, should belong to [0, 1].

Returns:

the samples whose cdf values equals to q.

cdf(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The cumulative distribution function of this distribution.

Parameters:

value – samples from this distribution.

Returns:

output of the cumulative distribution function evaluated at value.

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property var: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

LowerTruncatedPowerLaw

class LowerTruncatedPowerLaw(alpha: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, low: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, validate_args: bool | None = None)[source]

Bases: Distribution

Lower truncated power law distribution with \(\alpha\) index. We can define the power law distribution as,

\[f(x; \alpha, a) = (-\alpha-1)a^{-\alpha - 1}x^{-\alpha}, \qquad x \geq a, \qquad \alpha < -1,\]

where, \(a\) is the lower bound. The cdf of the distribution is given by,

\[F(x; \alpha, a) = 1 - \left(\frac{x}{a}\right)^{1+\alpha}.\]

The k-th moment of the distribution is given by,

\[\begin{split}E[X^k] = \begin{cases} \frac{-\alpha-1}{-\alpha-1-k}a^k & \text{if } k < -\alpha-1, \\ \infty & \text{otherwise}. \end{cases}\end{split}\]
Parameters:
  • alpha – index of the power law distribution

  • low – lower bound of the distribution

arg_constraints: dict[str, Any] = {'alpha': LessThan(upper_bound=-1.0), 'low': GreaterThan(lower_bound=0.0)}
reparametrized_params: list[str] = ['alpha', 'low']
pytree_aux_fields: tuple[str, ...] = ('_support',)
property support: Constraint

The support of this distribution. Subclasses can override this as a class attribute or as a property.

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

cdf(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The cumulative distribution function of this distribution.

Parameters:

value – samples from this distribution.

Returns:

output of the cumulative distribution function evaluated at value.

icdf(q: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The inverse cumulative distribution function of this distribution.

Parameters:

q – quantile values, should belong to [0, 1].

Returns:

the samples whose cdf values equals to q.

sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

RightTruncatedDistribution

class RightTruncatedDistribution(base_dist: Cauchy | Laplace | Logistic | Normal | SoftLaplace | StudentT, high: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 0.0, *, validate_args: bool | None = None)[source]

Bases: Distribution

arg_constraints: dict[str, Any] = {'high': Real()}
reparametrized_params: list[str] = ['high']
supported_types = (<class 'numpyro.distributions.continuous.Cauchy'>, <class 'numpyro.distributions.continuous.Laplace'>, <class 'numpyro.distributions.continuous.Logistic'>, <class 'numpyro.distributions.continuous.Normal'>, <class 'numpyro.distributions.continuous.SoftLaplace'>, <class 'numpyro.distributions.continuous.StudentT'>)
pytree_data_fields: tuple[str, ...] = ('base_dist', 'high', '_support')
base_dist: Cauchy | Laplace | Logistic | Normal | SoftLaplace | StudentT
property support: Constraint

The support of this distribution. Subclasses can override this as a class attribute or as a property.

sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

icdf(q: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The inverse cumulative distribution function of this distribution.

Parameters:

q – quantile values, should belong to [0, 1].

Returns:

the samples whose cdf values equals to q.

cdf(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The cumulative distribution function of this distribution.

Parameters:

value – samples from this distribution.

Returns:

output of the cumulative distribution function evaluated at value.

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property var: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

TruncatedCauchy

class TruncatedCauchy(loc: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 0.0, scale: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1.0, *, low: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, high: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, validate_args: bool | None = None)[source]

Bases:

TruncatedDistribution

TruncatedDistribution(base_dist: Cauchy | Laplace | Logistic | Normal | SoftLaplace | StudentT, low: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, high: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, *, validate_args: bool | None = None)[source]

A function to generate a truncated distribution.

Parameters:
  • base_dist – The base distribution to be truncated. This should be a univariate distribution. Currently, only the following distributions are supported: Cauchy, Laplace, Logistic, Normal, and StudentT.

  • low – the value which is used to truncate the base distribution from below. Setting this parameter to None to not truncate from below.

  • high – the value which is used to truncate the base distribution from above. Setting this parameter to None to not truncate from above.

TruncatedNormal

class TruncatedNormal(loc: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 0.0, scale: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1.0, *, low: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, high: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | None = None, validate_args: bool | None = None)[source]

Bases:

TruncatedPolyaGamma

class TruncatedPolyaGamma(batch_shape: tuple[int, ...] = (), *, validate_args: bool | None = None)[source]

Bases: Distribution

truncation_point = 2.5
num_log_prob_terms = 7
num_gamma_variates = 8
arg_constraints: dict[str, Any] = {}
support = Interval(lower_bound=0.0, upper_bound=2.5)
sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

TwoSidedTruncatedDistribution

class TwoSidedTruncatedDistribution(base_dist: Cauchy | Laplace | Logistic | Normal | SoftLaplace | StudentT, low: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 0.0, high: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 1.0, *, validate_args: bool | None = None)[source]

Bases: Distribution

arg_constraints: dict[str, Any] = {'high': Dependent(), 'low': Dependent()}
reparametrized_params: list[str] = ['low', 'high']
supported_types = (<class 'numpyro.distributions.continuous.Cauchy'>, <class 'numpyro.distributions.continuous.Laplace'>, <class 'numpyro.distributions.continuous.Logistic'>, <class 'numpyro.distributions.continuous.Normal'>, <class 'numpyro.distributions.continuous.SoftLaplace'>, <class 'numpyro.distributions.continuous.StudentT'>)
pytree_data_fields: tuple[str, ...] = ('base_dist', 'low', 'high', '_support')
base_dist: Cauchy | Laplace | Logistic | Normal | SoftLaplace | StudentT
property support: Constraint

The support of this distribution. Subclasses can override this as a class attribute or as a property.

sample(key: Array, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

icdf(q: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The inverse cumulative distribution function of this distribution.

Parameters:

q – quantile values, should belong to [0, 1].

Returns:

the samples whose cdf values equals to q.

cdf(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

The cumulative distribution function of this distribution.

Parameters:

value – samples from this distribution.

Returns:

output of the cumulative distribution function evaluated at value.

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property var: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Censored Distributions

LeftCensoredDistribution

class LeftCensoredDistribution(base_dist: Distribution, censored: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = False, *, validate_args: bool = False)[source]

Bases: Distribution

Distribution wrapper for left-censored outcomes.

This distribution augments a base distribution with left-censoring, so that the likelihood contribution depends on the censoring indicator.

Parameters:
  • base_dist (numpyro.distributions.Distribution) – Parametric distribution for the uncensored values (e.g., Exponential, Weibull, LogNormal, Normal, etc.). This distribution must implement a cdf method.

  • censored (array-like of {0,1}) – Censoring indicator per observation: 0 → value is observed exactly 1 → observation is left-censored at the reported value (true value occurred on or before the reported value)

Note

The log_prob(value) method expects value to be the observed upper bound for each observation. The contribution to the log-likelihood is:

log f(value) if censored == 0 log F(value) if censored == 1

where f is the density and F the cumulative distribution function of base_dist.

This is commonly used in survival analysis, where event times are positive, but the approach is more general and can be applied to any distribution with a cumulative distribution function, regardless of support.

In R’s survival package notation, this corresponds to Surv(time, event, type = 'left').

Example:

Surv(time = c(2, 4, 6), event = c(0, 1, 0), type=’left’)

means:

subject 1 had an event exactly at t=2 subject 2 had an event before or at t=4 (left-censored) subject 3 had an event exactly at t=6

Example:

>>> from jax import numpy as jnp
>>> from numpyro import distributions as dist
>>> base = dist.LogNormal(0., 1.)
>>> surv_dist = dist.LeftCensoredDistribution(base, censored=jnp.array([0, 1, 1]))
>>> loglik = surv_dist.log_prob(jnp.array([2., 4., 6.]))
arg_constraints: dict[str, Any] = {'censored': Boolean()}
pytree_data_fields: tuple[str, ...] = ('base_dist', 'censored', '_support')
sample(key: Array | None, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

property support: Constraint

The support of this distribution. Subclasses can override this as a class attribute or as a property.

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

RightCensoredDistribution

class RightCensoredDistribution(base_dist: Distribution, censored: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = False, *, validate_args: bool = False)[source]

Bases: Distribution

Distribution wrapper for right-censored outcomes.

This distribution augments a base distribution with right-censoring, so that the likelihood contribution depends on the censoring indicator.

Parameters:
  • base_dist (numpyro.distributions.Distribution) – Parametric distribution for the uncensored values (e.g., Exponential, Weibull, LogNormal, Normal, etc.). This distribution must implement a cdf method.

  • censored (array-like of {0,1}) – Censoring indicator per observation: 0 → value is observed exactly 1 → observation is right-censored at the reported value (true value occurred on or after the reported value)

Note

The log_prob(value) method expects value to be the observed lower bound for each observation. The contribution to the log-likelihood is:

log f(value) if censored == 0 log (1 - F(value)) if censored == 1

where f is the density and F the cumulative distribution function of base_dist.

This is commonly used in survival analysis, where event times are positive, but the approach is more general and can be applied to any distribution with a cumulative distribution function, regardless of support.

In R’s survival package notation, this corresponds to Surv(time, event, type = 'right').

Example:

Surv(time = c(5, 8, 10), event = c(1, 0, 1))

means:

subject 1 had an event at t=5 subject 2 was censored at t=8 subject 3 had an event at t=10

Example:

>>> from jax import numpy as jnp
>>> from numpyro import distributions as dist
>>> base = dist.Exponential(rate=0.1)
>>> surv_dist = dist.RightCensoredDistribution(base, censored=jnp.array([0, 1, 0]))
>>> loglik = surv_dist.log_prob(jnp.array([5., 8., 10.]))
arg_constraints: dict[str, Any] = {'censored': Boolean()}
pytree_data_fields: tuple[str, ...] = ('base_dist', 'censored', '_support')
sample(key: Array | None, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

property support: Constraint

The support of this distribution. Subclasses can override this as a class attribute or as a property.

log_prob(value: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

IntervalCensoredDistribution

class IntervalCensoredDistribution(base_dist: Distribution, left_censored: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, right_censored: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, *, validate_args: bool = False)[source]

Bases: Distribution

Distribution wrapper for interval-censored outcomes.

This distribution augments a base distribution with interval censoring, so that the likelihood contribution depends on whether the observation is exactly observed, left-censored, right-censored, interval-censored, or doubly-censored (i.e., known to lie outside the observed interval).

Parameters:
  • base_dist (numpyro.distributions.Distribution) – Parametric distribution for the uncensored values (e.g., Exponential, Weibull, LogNormal, Normal, etc.). This distribution must implement a cdf method.

  • left_censored (array-like of {0,1}) – Indicator per observation: 1 → observation is left-censored at the reported upper bound 0 → not left-censored

  • right_censored (array-like of {0,1}) – Indicator per observation: 1 → observation is right-censored at the reported lower bound 0 → not right-censored

Note

The log_prob(value) method expects value to be a two-dimensional array of shape (batch_size, 2), where each row is (lower, upper). The contribution to the log-likelihood is determined as follows:

log F(upper) if left_censored == 1 and right_censored == 0 log (1 - F(lower)) if right_censored == 1 and left_censored == 0 log (F(upper) - F(lower)) if both == 0 (interval-censored) log (1 - (F(upper) - F(lower))) if both == 1 (doubly-censored) log f(value) if lower ≈ upper (point interval)

where f is the density and F the cumulative distribution function of base_dist.

This is commonly used in survival analysis, where event times are positive, but the approach is general and can be applied to any distribution with a cumulative distribution function, regardless of support.

In R’s survival package notation, this corresponds to Surv(l, r, type = 'interval2').

Example:

Surv(l = c(2, 4, 6), r = c(5, Inf, 9), type = ‘interval2’)

means:

subject 1 had an event in (2, 5] subject 2 was right-censored at 4 subject 3 had an event in (6, 9]

Example:

>>> from jax import numpy as jnp
>>> from numpyro import distributions as dist
>>> base = dist.Weibull(concentration=2.0, scale=3.0)
>>> left_censored = jnp.array([0, 0, 0])
>>> right_censored = jnp.array([0, 1, 0])
>>> surv_dist = dist.IntervalCensoredDistribution(base, left_censored, right_censored)
>>> values = jnp.array([
...     [2.0, 5.0],
...     [4.0, jnp.inf],
...     [6.0, 9.0],
... ])
>>> loglik = surv_dist.log_prob(values)
arg_constraints: dict[str, Any] = {'left_censored': Boolean(), 'right_censored': Boolean()}
pytree_data_fields: tuple[str, ...] = ('base_dist', 'left_censored', 'right_censored', '_support')
sample(key: Array | None, sample_shape: tuple[int, ...] = ()) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.key) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

property support: Constraint

The support of this distribution. Subclasses can override this as a class attribute or as a property.

log_prob(value)[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

TensorFlow Distributions

Thin wrappers around TensorFlow Probability (TFP) distributions. For details on the TFP distribution interface, see its Distribution docs.

BijectorConstraint

class BijectorConstraint(bijector)[source]

A constraint which is codomain of a TensorFlow bijector.

Parameters:

bijector (Bijector) – a TensorFlow bijector

BijectorTransform

class BijectorTransform(bijector)[source]

A wrapper for TensorFlow bijectors to make them compatible with NumPyro’s transforms.

Parameters:

bijector (Bijector) – a TensorFlow bijector

TFPDistribution

class TFPDistribution(batch_shape: tuple[int, ...] = (), event_shape: tuple[int, ...] = (), *, validate_args: bool | None = None)[source]

A thin wrapper for TensorFlow Probability (TFP) distributions. The constructor has the same signature as the corresponding TFP distribution.

This class can be used to convert a TFP distribution to a NumPyro-compatible one as follows:

d = TFPDistribution[tfd.Normal](0, 1)

Note that typical use cases do not require explicitly invoking this wrapper, since NumPyro wraps TFP distributions automatically under the hood in model code, e.g.:

from tensorflow_probability.substrates.jax import distributions as tfd

def model():
    numpyro.sample("x", tfd.Normal(0, 1))

Constraints

Constraint

class Constraint[source]

Bases: Generic[NumLikeT]

Abstract base class for constraints.

A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.

is_discrete: bool = False
event_dim: int = 0
check(value: NumLikeT) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Returns a byte tensor of sample_shape + batch_shape indicating whether each event in value satisfies this constraint.

feasible_like(prototype: NumLikeT) NumLikeT[source]

Get a feasible value which has the same shape as dtype as prototype.

eq(other: object, static: bool = False) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]
classmethod tree_unflatten(aux_data, params)[source]

boolean

boolean = Boolean()

circular

circular = Circular(lower_bound=-3.141592653589793, upper_bound=3.141592653589793)

corr_cholesky

corr_cholesky = CorrCholesky()

corr_matrix

corr_matrix = CorrMatrix()

dependent

dependent: _Dependent = Dependent()

Placeholder for variables whose support depends on other variables. These variables obey no simple coordinate-wise constraints.

Parameters:
  • is_discrete (bool) – Optional value of .is_discrete in case this can be computed statically. If not provided, access to the .is_discrete attribute will raise a NotImplementedError.

  • event_dim (int) – Optional value of .event_dim in case this can be computed statically. If not provided, access to the .event_dim attribute will raise a NotImplementedError.

greater_than

greater_than(lower_bound: ndarray | Array | number | int | float | complex) None

Abstract base class for constraints.

A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.

integer_interval

integer_interval(lower_bound: ndarray | Array | number | int | float | complex, upper_bound: ndarray | Array | number | int | float | complex) None

Abstract base class for constraints.

A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.

integer_greater_than

integer_greater_than(lower_bound: ndarray | Array | number | int | float | complex) None

Abstract base class for constraints.

A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.

interval

interval(lower_bound: ndarray | Array | number | int | float | complex, upper_bound: ndarray | Array | number | int | float | complex) None

Abstract base class for constraints.

A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.

l1_ball

l1_ball(x: ndarray | Array | number | int | float | complex) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Constrain to the L1 ball of any dimension.

less_than

less_than(upper_bound: ndarray | Array | number | int | float | complex) None

Abstract base class for constraints.

A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.

lower_cholesky

lower_cholesky = LowerCholesky()

multinomial

multinomial(upper_bound: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) None

Abstract base class for constraints.

A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.

nonnegative_integer

nonnegative_integer = IntegerNonnegative(lower_bound=0)

ordered_vector

ordered_vector = OrderedVector()

positive

positive = Positive(lower_bound=0.0)

positive_definite

positive_definite = PositiveDefinite()

positive_integer

positive_integer = IntegerPositive(lower_bound=1)

positive_ordered_vector

positive_ordered_vector = PositiveOrderedVector()

Constrains to a positive real-valued tensor where the elements are monotonically increasing along the event_shape dimension.

real

real = Real()

real_vector

real_vector = RealVector(Real(), 1)

scaled_unit_lower_cholesky

scaled_unit_lower_cholesky = ScaledUnitLowerCholesky()

softplus_positive

softplus_positive = SoftplusPositive(lower_bound=0.0)

softplus_lower_cholesky

softplus_lower_cholesky = SoftplusLowerCholesky()

simplex

simplex = Simplex()

sphere

sphere = Sphere()

Constrain to the Euclidean sphere of any dimension.

unit_interval

unit_interval = UnitInterval(lower_bound=0.0, upper_bound=1.0)

zero_sum

zero_sum = <class 'numpyro.distributions.constraints._ZeroSum'>

Abstract base class for constraints.

A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.

Transforms

biject_to

biject_to(constraint)

Transform

class Transform[source]

Bases: Generic[NumLikeT]

property domain: Constraint
property codomain: Constraint
property inv: Transform
log_abs_det_jacobian(x: NumLikeT, y: NumLikeT, intermediates: Any | None = None) ndarray | Array | number | int | float | complex[source]
call_with_intermediates(x: NumLikeT) Tuple[ndarray | Array | number | int | float | complex, Any | None][source]
forward_shape(shape: tuple[int, ...]) tuple[int, ...][source]

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape: tuple[int, ...]) tuple[int, ...][source]

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

property sign: ndarray | Array | number | int | float | complex

Sign of the derivative of the transform if it is bijective.

classmethod tree_unflatten(aux_data, params)[source]
eq(other: object, static: bool = False) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

AbsTransform

class AbsTransform[source]

Bases: ParameterFreeTransform[ndarray | Array | number | int | float | complex]

domain = Real()
codomain = Positive(lower_bound=0.0)

AffineTransform

class AffineTransform(loc: ndarray | Array | number | int | float | complex, scale: ndarray | Array | number | int | float | complex, domain: Constraint = Real())[source]

Bases: Transform[ndarray | Array | number | int | float | complex]

Note

When scale is a JAX tracer, we always assume that scale > 0 when calculating codomain.

property domain: Constraint
property codomain: Constraint
property sign: ndarray | Array | number | int | float | complex

Sign of the derivative of the transform if it is bijective.

log_abs_det_jacobian(x: ndarray | Array | number | int | float | complex, y: ndarray | Array | number | int | float | complex, intermediates: Any | None = None) ndarray | Array | number | int | float | complex[source]
forward_shape(shape: tuple[int, ...]) tuple[int, ...][source]

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape: tuple[int, ...]) tuple[int, ...][source]

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

tree_flatten()[source]
eq(other: object, static: bool = False) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

CholeskyTransform

class CholeskyTransform[source]

Bases: ParameterFreeTransform[ndarray | Array]

Transform via the mapping \(y = cholesky(x)\), where x is a positive definite matrix.

domain = PositiveDefinite()
codomain = LowerCholesky()
log_abs_det_jacobian(x: ndarray | Array, y: ndarray | Array, intermediates: Any | None = None) ndarray | Array | number | int | float | complex[source]

ComplexTransform

class ComplexTransform[source]

Bases: ParameterFreeTransform[ndarray | Array]

Transforms a pair of real numbers to a complex number.

domain = RealVector(Real(), 1)
codomain = Complex()
log_abs_det_jacobian(x: ndarray | Array, y: ndarray | Array, intermediates: Any | None = None) ndarray | Array[source]
forward_shape(shape: tuple[int, ...]) tuple[int, ...][source]

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape: tuple[int, ...]) tuple[int, ...][source]

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

ComposeTransform

class ComposeTransform(parts: Sequence[Transform])[source]

Bases: Transform[ndarray | Array | number | int | float | complex]

property domain: Constraint
property codomain: Constraint
property sign: ndarray | Array | number | int | float | complex

Sign of the derivative of the transform if it is bijective.

log_abs_det_jacobian(x: ndarray | Array | number | int | float | complex, y: ndarray | Array | number | int | float | complex, intermediates: Any | None = None) ndarray | Array | number | int | float | complex[source]
call_with_intermediates(x: ndarray | Array | number | int | float | complex) Tuple[ndarray | Array | number | int | float | complex, Any | None][source]
forward_shape(shape: tuple[int, ...]) tuple[int, ...][source]

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape: tuple[int, ...]) tuple[int, ...][source]

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

tree_flatten()[source]
eq(other: object, static: bool = False) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

CorrCholeskyTransform

class CorrCholeskyTransform[source]

Bases: ParameterFreeTransform[ndarray | Array]

Transforms an unconstrained real vector \(x\) with length \(D*(D-1)/2\) into the Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower triangular matrix with positive diagonals and unit Euclidean norm for each row. The transform is processed as follows:

  1. First we convert \(x\) into a lower triangular matrix with the following order:

\[\begin{split}\begin{bmatrix} 1 & 0 & 0 & 0 \\ x_0 & 1 & 0 & 0 \\ x_1 & x_2 & 1 & 0 \\ x_3 & x_4 & x_5 & 1 \end{bmatrix}\end{split}\]

2. For each row \(X_i\) of the lower triangular part, we apply a signed version of class StickBreakingTransform to transform \(X_i\) into a unit Euclidean length vector using the following steps:

  1. Scales into the interval \((-1, 1)\) domain: \(r_i = \tanh(X_i)\).

  2. Transforms into an unsigned domain: \(z_i = r_i^2\).

  3. Applies \(s_i = StickBreakingTransform(z_i)\).

  4. Transforms back into signed domain: \(y_i = (sign(r_i), 1) * \sqrt{s_i}\).

domain = RealVector(Real(), 1)
codomain = CorrCholesky()
log_abs_det_jacobian(x: ndarray | Array, y: ndarray | Array, intermediates: Any | None = None) ndarray | Array | number | int | float | complex[source]
forward_shape(shape: tuple[int, ...]) tuple[int, ...][source]

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape: tuple[int, ...]) tuple[int, ...][source]

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

CorrMatrixCholeskyTransform

class CorrMatrixCholeskyTransform[source]

Bases: CholeskyTransform

Transform via the mapping \(y = cholesky(x)\), where x is a correlation matrix.

domain = CorrMatrix()
codomain = CorrCholesky()
log_abs_det_jacobian(x: ndarray | Array, y: ndarray | Array, intermediates: Any | None = None) ndarray | Array | number | int | float | complex[source]

ExpTransform

class ExpTransform(domain: Constraint = Real())[source]

Bases: Transform[ndarray | Array | number | int | float | complex]

sign = 1
property domain: Constraint
property codomain: Constraint
log_abs_det_jacobian(x: ndarray | Array | number | int | float | complex, y: ndarray | Array | number | int | float | complex, intermediates: Any | None = None) ndarray | Array | number | int | float | complex[source]
tree_flatten()[source]
eq(other: object, static: bool = False) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

IdentityTransform

class IdentityTransform[source]

Bases: ParameterFreeTransform[ndarray | Array | number | int | float | complex]

sign = 1
log_abs_det_jacobian(x: ndarray | Array | number | int | float | complex, y: ndarray | Array | number | int | float | complex, intermediates: Any | None = None) ndarray | Array | number | int | float | complex[source]

L1BallTransform

class L1BallTransform[source]

Bases: ParameterFreeTransform[ndarray | Array]

Transforms an unconstrained real vector \(x\) into the unit L1 ball.

domain = RealVector(Real(), 1)
codomain = L1Ball()
log_abs_det_jacobian(x: ndarray | Array, y: ndarray | Array, intermediates: Any | None = None) ndarray | Array | number | int | float | complex[source]

LowerCholeskyAffine

class LowerCholeskyAffine(loc: ndarray | Array, scale_tril: ndarray | Array)[source]

Bases: Transform[ndarray | Array]

Transform via the mapping \(y = loc + scale\_tril\ @\ x\).

Parameters:
  • loc – a real vector.

  • scale_tril – a lower triangular matrix with positive diagonal.

Example

>>> import jax.numpy as jnp
>>> from numpyro.distributions.transforms import LowerCholeskyAffine
>>> base = jnp.ones(2)
>>> loc = jnp.zeros(2)
>>> scale_tril = jnp.array([[0.3, 0.0], [1.0, 0.5]])
>>> affine = LowerCholeskyAffine(loc=loc, scale_tril=scale_tril)
>>> affine(base)
Array([0.3, 1.5], dtype=float32)
domain = RealVector(Real(), 1)
codomain = RealVector(Real(), 1)
log_abs_det_jacobian(x: ndarray | Array, y: ndarray | Array, intermediates: Any | None = None) ndarray | Array | number | int | float | complex[source]
forward_shape(shape: tuple[int, ...]) tuple[int, ...][source]

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape: tuple[int, ...]) tuple[int, ...][source]

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

tree_flatten()[source]
eq(other: object, static: bool = False) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

LowerCholeskyTransform

class LowerCholeskyTransform[source]

Bases: ParameterFreeTransform[ndarray | Array]

Transform a real vector to a lower triangular cholesky factor, where the strictly lower triangular submatrix is unconstrained and the diagonal is parameterized with an exponential transform.

domain = RealVector(Real(), 1)
codomain = LowerCholesky()
log_abs_det_jacobian(x: ndarray | Array, y: ndarray | Array, intermediates: Any | None = None) ndarray | Array | number | int | float | complex[source]
forward_shape(shape: tuple[int, ...]) tuple[int, ...][source]

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape: tuple[int, ...]) tuple[int, ...][source]

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

OrderedTransform

class OrderedTransform[source]

Bases: ParameterFreeTransform[ndarray | Array]

Transform a real vector to an ordered vector.

References:

  1. Stan Reference Manual v2.20, section 10.6, Stan Development Team

Example

>>> import jax.numpy as jnp
>>> from numpyro.distributions.transforms import OrderedTransform
>>> base = jnp.ones(3)
>>> transform = OrderedTransform()
>>> assert jnp.allclose(transform(base), jnp.array([1., 3.7182817, 6.4365635]), rtol=1e-3, atol=1e-3)
domain = RealVector(Real(), 1)
codomain = OrderedVector()
log_abs_det_jacobian(x: ndarray | Array, y: ndarray | Array, intermediates: Any | None = None) ndarray | Array | number | int | float | complex[source]

PackRealFastFourierCoefficientsTransform

class PackRealFastFourierCoefficientsTransform(transform_shape: tuple[int, ...] | None = None)[source]

Bases: Transform[ndarray | Array]

Transform a real vector to complex coefficients of a real fast Fourier transform.

Parameters:

transform_shape – Shape of the real vector, defaults to the input size.

domain = RealVector(Real(), 1)
codomain = IndependentConstraint(Complex(), 1)
shape: tuple[int, ...] | None
tree_flatten()[source]
forward_shape(shape: tuple[int, ...]) tuple[int, ...][source]

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape: tuple[int, ...]) tuple[int, ...][source]

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

log_abs_det_jacobian(x: ndarray | Array, y: ndarray | Array, intermediates: Any | None = None) Array[source]
eq(other: object, static: bool = False) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

PermuteTransform

class PermuteTransform(permutation: Array)[source]

Bases: Transform[ndarray | Array]

domain = RealVector(Real(), 1)
codomain = RealVector(Real(), 1)
log_abs_det_jacobian(x: ndarray | Array, y: ndarray | Array, intermediates: Any | None = None) ndarray | Array | number | int | float | complex[source]
tree_flatten()[source]
eq(other: object, static: bool = False) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

PowerTransform

class PowerTransform(exponent: ndarray | Array | number | int | float | complex)[source]

Bases: Transform[ndarray | Array | number | int | float | complex]

domain = Positive(lower_bound=0.0)
codomain = Positive(lower_bound=0.0)
log_abs_det_jacobian(x: ndarray | Array | number | int | float | complex, y: ndarray | Array | number | int | float | complex, intermediates: Any | None = None) ndarray | Array | number | int | float | complex[source]
forward_shape(shape: tuple[int, ...]) tuple[int, ...][source]

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape: tuple[int, ...]) tuple[int, ...][source]

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

tree_flatten()[source]
eq(other: object, static: bool = False) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]
property sign: ndarray | Array | number | int | float | complex

Sign of the derivative of the transform if it is bijective.

RealFastFourierTransform

class RealFastFourierTransform(transform_shape: tuple[int, ...] | None = None, transform_ndims: int = 1)[source]

Bases: Transform[ndarray | Array]

N-dimensional discrete fast Fourier transform for real input.

Parameters:
  • transform_shape – Length of each transformed axis to use from the input, defaults to the input size.

  • transform_ndims – Number of trailing dimensions to transform.

forward_shape(shape: tuple[int, ...]) tuple[int, ...][source]

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape: tuple[int, ...]) tuple[int, ...][source]

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

log_abs_det_jacobian(x: ndarray | Array, y: ndarray | Array, intermediates: Any | None = None) ndarray | Array | number | int | float | complex[source]
tree_flatten()[source]
property domain: Constraint
property codomain: Constraint
eq(other: object, static: bool = False) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

RecursiveLinearTransform

class RecursiveLinearTransform(transition_matrix: ndarray | Array, initial_value: ndarray | Array | None = None)[source]

Bases: Transform[ndarray | Array]

Apply a linear transformation recursively such that \(y_t = A y_{t - 1} + x_t\) for \(t > 0\), where \(x_t\) and \(y_t\) are vectors and \(A\) is a square transition matrix. The series is initialized by \(y_0 = 0\).

Parameters:

transition_matrix – Square transition matrix \(A\) for successive states or a batch of transition matrices.

Example:

>>> from jax import random
>>> from jax import numpy as jnp
>>> import numpyro
>>> from numpyro import distributions as dist
>>>
>>> def cauchy_random_walk():
...     return numpyro.sample(
...         "x",
...         dist.TransformedDistribution(
...             dist.Cauchy(0, 1).expand([10, 1]).to_event(1),
...             dist.transforms.RecursiveLinearTransform(jnp.eye(1)),
...         ),
...     )
>>>
>>> numpyro.handlers.seed(cauchy_random_walk, 0)().shape
(10, 1)
>>>
>>> def rocket_trajectory():
...     scale = numpyro.sample(
...         "scale",
...         dist.HalfCauchy(1).expand([2]).to_event(1),
...     )
...     transition_matrix = jnp.array([[1, 1], [0, 1]])
...     return numpyro.sample(
...         "x",
...         dist.TransformedDistribution(
...             dist.Normal(0, scale).expand([10, 2]).to_event(1),
...             dist.transforms.RecursiveLinearTransform(transition_matrix),
...         ),
...     )
>>>
>>> numpyro.handlers.seed(rocket_trajectory, 0)().shape
(10, 2)
domain = RealMatrix(Real(), 2)
codomain = RealMatrix(Real(), 2)
log_abs_det_jacobian(x: ndarray | Array, y: ndarray | Array, intermediates: Any | None = None) ndarray | Array | number | int | float | complex[source]
tree_flatten()[source]
eq(other: object, static: bool = False) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

ScaledUnitLowerCholeskyTransform

class ScaledUnitLowerCholeskyTransform[source]

Bases: LowerCholeskyTransform

Like LowerCholeskyTransform this Transform transforms a real vector to a lower triangular cholesky factor. However it does so via a decomposition

\(y = loc + unit\_scale\_tril\ @\ scale\_diag\ @\ x\).

where \(unit\_scale\_tril\) has ones along the diagonal and \(scale\_diag\) is a diagonal matrix with all positive entries that is parameterized with a softplus transform.

domain = RealVector(Real(), 1)
codomain = ScaledUnitLowerCholesky()
log_abs_det_jacobian(x: ndarray | Array, y: ndarray | Array, intermediates: Any | None = None) ndarray | Array | number | int | float | complex[source]

SigmoidTransform

class SigmoidTransform[source]

Bases: ParameterFreeTransform[ndarray | Array | number | int | float | complex]

codomain = UnitInterval(lower_bound=0.0, upper_bound=1.0)
sign = 1
log_abs_det_jacobian(x: ndarray | Array | number | int | float | complex, y: ndarray | Array | number | int | float | complex, intermediates: Any | None = None) ndarray | Array | number | int | float | complex[source]

SimplexToOrderedTransform

class SimplexToOrderedTransform(anchor_point: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 0.0)[source]

Bases: Transform[ndarray | Array]

Transform a simplex into an ordered vector (via difference in Logistic CDF between cutpoints) Used in [1] to induce a prior on latent cutpoints via transforming ordered category probabilities.

Parameters:

anchor_point – Anchor point is a nuisance parameter to improve the identifiability of the transform. For simplicity, we assume it is a scalar value, but it is broadcastable x.shape[:-1]. For more details please refer to Section 2.2 in [1]

References:

  1. Ordinal Regression Case Study, section 2.2, M. Betancourt, https://betanalpha.github.io/assets/case_studies/ordinal_regression.html

Example

>>> import jax.numpy as jnp
>>> from numpyro.distributions.transforms import SimplexToOrderedTransform
>>> base = jnp.array([0.3, 0.1, 0.4, 0.2])
>>> transform = SimplexToOrderedTransform()
>>> assert jnp.allclose(transform(base), jnp.array([-0.8472978, -0.40546507, 1.3862944]), rtol=1e-3, atol=1e-3)
domain = Simplex()
codomain = OrderedVector()
log_abs_det_jacobian(x: ndarray | Array, y: ndarray | Array, intermediates: Any | None = None) ndarray | Array | number | int | float | complex[source]
tree_flatten()[source]
eq(other: object, static: bool = False) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]
forward_shape(shape: tuple[int, ...]) tuple[int, ...][source]

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape: tuple[int, ...]) tuple[int, ...][source]

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

SoftplusLowerCholeskyTransform

class SoftplusLowerCholeskyTransform[source]

Bases: ParameterFreeTransform[ndarray | Array]

Transform from unconstrained vector to lower-triangular matrices with nonnegative diagonal entries. This is useful for parameterizing positive definite matrices in terms of their Cholesky factorization.

domain = RealVector(Real(), 1)
codomain = SoftplusLowerCholesky()
log_abs_det_jacobian(x: ndarray | Array, y: ndarray | Array, intermediates: Any | None = None) ndarray | Array | number | int | float | complex[source]
forward_shape(shape: tuple[int, ...]) tuple[int, ...][source]

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape: tuple[int, ...]) tuple[int, ...][source]

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

SoftplusTransform

class SoftplusTransform[source]

Bases: ParameterFreeTransform[ndarray | Array | number | int | float | complex]

Transform from unconstrained space to positive domain via softplus \(y = \log(1 + \exp(x))\). The inverse is computed as \(x = \log(\exp(y) - 1)\).

domain = Real()
codomain = SoftplusPositive(lower_bound=0.0)
sign = 1
log_abs_det_jacobian(x: ndarray | Array | number | int | float | complex, y: ndarray | Array | number | int | float | complex, intermediates: Any | None = None) ndarray | Array | number | int | float | complex[source]

StickBreakingTransform

class StickBreakingTransform[source]

Bases: ParameterFreeTransform[ndarray | Array]

domain = RealVector(Real(), 1)
codomain = Simplex()
log_abs_det_jacobian(x: ndarray | Array, y: ndarray | Array, intermediates: Any | None = None) ndarray | Array | number | int | float | complex[source]
forward_shape(shape: tuple[int, ...]) tuple[int, ...][source]

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape: tuple[int, ...]) tuple[int, ...][source]

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

ZeroSumTransform

class ZeroSumTransform(transform_ndims: int = 1)[source]

Bases: Transform[ndarray | Array]

A transform that constrains an array to sum to zero, adapted from PyMC [1] as described in [2,3]

Parameters:

transform_ndims – Number of trailing dimensions to transform.

References [1] https://github.com/pymc-devs/pymc/blob/244fb97b01ad0f3dadf5c3837b65839e2a59a0e8/pymc/distributions/transforms.py#L266 [2] https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.ZeroSumNormal.html [3] https://learnbayesstats.com/episode/74-optimizing-nuts-developing-zerosumnormal-distribution-adrian-seyboldt/

property domain: Constraint
property codomain: Constraint
extend_axis_rev(array: ndarray | Array, axis: int) ndarray | Array[source]
extend_axis(array: ndarray | Array, axis: int) ndarray | Array[source]
log_abs_det_jacobian(x: ndarray | Array, y: ndarray | Array, intermediates: Any | None = None) Array[source]
forward_shape(shape: tuple[int, ...]) tuple[int, ...][source]

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape: tuple[int, ...]) tuple[int, ...][source]

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

tree_flatten()[source]
eq(other: object, static: bool = False) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Flows

InverseAutoregressiveTransform

class InverseAutoregressiveTransform(autoregressive_nn, log_scale_min_clip: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = -5.0, log_scale_max_clip: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray = 3.0)[source]

Bases: Transform

An implementation of Inverse Autoregressive Flow, using Eq (10) from Kingma et al., 2016,

\(\mathbf{y} = \mu_t + \sigma_t\odot\mathbf{x}\)

where \(\mathbf{x}\) are the inputs, \(\mathbf{y}\) are the outputs, \(\mu_t,\sigma_t\) are calculated from an autoregressive network on \(\mathbf{x}\), and \(\sigma_t>0\).

References

  1. Improving Variational Inference with Inverse Autoregressive Flow [arXiv:1606.04934], Diederik P. Kingma, Tim Salimans, Rafal Jozefowicz, Xi Chen, Ilya Sutskever, Max Welling

Parameters:

autoregressive_nn – an autoregressive neural network whose forward call returns a real-valued mean and log scale as a tuple

domain = RealVector(Real(), 1)
codomain = RealVector(Real(), 1)
call_with_intermediates(x: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]
log_abs_det_jacobian(x: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, y: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, intermediates=None) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Calculates the elementwise determinant of the log jacobian.

Parameters:
tree_flatten()[source]
eq(other: Transform, static: bool = False) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

BlockNeuralAutoregressiveTransform

class BlockNeuralAutoregressiveTransform(bn_arn)[source]

Bases: Transform

An implementation of Block Neural Autoregressive flow.

References

  1. Block Neural Autoregressive Flow, Nicola De Cao, Ivan Titov, Wilker Aziz

domain = RealVector(Real(), 1)
codomain = RealVector(Real(), 1)
call_with_intermediates(x: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]
log_abs_det_jacobian(x: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, y: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, intermediates=None) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Calculates the elementwise determinant of the log jacobian.

Parameters:
tree_flatten()[source]
eq(other: Transform, static: bool = False) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Utilities

log1mexp

log1mexp(x: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Numerically stable calculation of the quantity \(\log(1 - \exp(x))\), following the algorithm of Mächler 2012.

Returns -jnp.inf when x == 0 and jnp.nan when x > 0.

Parameters:

x – A number or array of numbers.

Returns:

The value of \(\log(1 - \exp(x))\).

logdiffexp

logdiffexp(a: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, b: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray[source]

Numerically stable calculation of the quantity \(\log(\exp(a) - \exp(b))\), provided \(+\infty > a \ge b\), following the algorithm of Mächler 2012.

Returns -jnp.inf when a == b, including when a == b == -jnp.inf, since this corresponds to jnp.log(0). Returns jnp.nan when a < b or a == jnp.inf.

Parameters:
  • a – A number or array of numbers.

  • b – A number or array of numbers.

Returns:

The value of \(\log(\exp(a) - \exp(b))\).