Base Distribution

Distribution

class Distribution(batch_shape=(), event_shape=(), validate_args=None)[source]

Bases: object

Base class for probability distributions in NumPyro. The design largely follows from torch.distributions.

Parameters:
  • batch_shape – The batch shape for the distribution. This designates independent (possibly non-identical) dimensions of a sample from the distribution. This is fixed for a distribution instance and is inferred from the shape of the distribution parameters.
  • event_shape – The event shape for the distribution. This designates the dependent dimensions of a sample from the distribution. These are collapsed when we evaluate the log probability density of a batch of samples using .log_prob.
  • validate_args – Whether to enable validation of distribution parameters and arguments to .log_prob method.

As an example:

>>> import jax.numpy as jnp
>>> import numpyro.distributions as dist
>>> d = dist.Dirichlet(jnp.ones((2, 3, 4)))
>>> d.batch_shape
(2, 3)
>>> d.event_shape
(4,)
arg_constraints = {}
support = None
has_enumerate_support = False
is_discrete = False
reparametrized_params = []
tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]
static set_default_validate_args(value)[source]
batch_shape

Returns the shape over which the distribution parameters are batched.

Returns:batch shape of the distribution.
Return type:tuple
event_shape

Returns the shape of a single sample from the distribution without batching.

Returns:event shape of the distribution.
Return type:tuple
event_dim
Returns:Number of dimensions of individual events.
Return type:int
has_rsample
rsample(key, sample_shape=())[source]
shape(sample_shape=())[source]

The tensor shape of samples from this distribution.

Samples are of shape:

d.shape(sample_shape) == sample_shape + d.batch_shape + d.event_shape
Parameters:sample_shape (tuple) – the size of the iid batch to be drawn from the distribution.
Returns:shape of samples.
Return type:tuple
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

sample_with_intermediates(key, sample_shape=())[source]

Same as sample except that any intermediate computations are returned (useful for TransformedDistribution).

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value)[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:value – A batch of samples from the distribution.
Returns:an array with shape value.shape[:-self.event_shape]
Return type:numpy.ndarray
mean

Mean of the distribution.

variance

Variance of the distribution.

to_event(reinterpreted_batch_ndims=None)[source]

Interpret the rightmost reinterpreted_batch_ndims batch dimensions as dependent event dimensions.

Parameters:reinterpreted_batch_ndims – Number of rightmost batch dims to interpret as event dims.
Returns:An instance of Independent distribution.
Return type:numpyro.distributions.distribution.Independent
enumerate_support(expand=True)[source]

Returns an array with shape len(support) x batch_shape containing all values in the support.

expand(batch_shape)[source]

Returns a new ExpandedDistribution instance with batch dimensions expanded to batch_shape.

Parameters:batch_shape (tuple) – batch shape to expand to.
Returns:an instance of ExpandedDistribution.
Return type:ExpandedDistribution
expand_by(sample_shape)[source]

Expands a distribution by adding sample_shape to the left side of its batch_shape. To expand internal dims of self.batch_shape from 1 to something larger, use expand() instead.

Parameters:sample_shape (tuple) – The size of the iid batch to be drawn from the distribution.
Returns:An expanded version of this distribution.
Return type:ExpandedDistribution
mask(mask)[source]

Masks a distribution by a boolean or boolean-valued array that is broadcastable to the distributions Distribution.batch_shape .

Parameters:mask (bool or jnp.ndarray) – A boolean or boolean valued array (True includes a site, False excludes a site).
Returns:A masked copy of this distribution.
Return type:MaskedDistribution

Example:

>>> from jax import random
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.distributions import constraints
>>> from numpyro.infer import SVI, Trace_ELBO

>>> def model(data, m):
...     f = numpyro.sample("latent_fairness", dist.Beta(1, 1))
...     with numpyro.plate("N", data.shape[0]):
...         # only take into account the values selected by the mask
...         masked_dist = dist.Bernoulli(f).mask(m)
...         numpyro.sample("obs", masked_dist, obs=data)


>>> def guide(data, m):
...     alpha_q = numpyro.param("alpha_q", 5., constraint=constraints.positive)
...     beta_q = numpyro.param("beta_q", 5., constraint=constraints.positive)
...     numpyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q))


>>> data = jnp.concatenate([jnp.ones(5), jnp.zeros(5)])
>>> # select values equal to one
>>> masked_array = jnp.where(data == 1, True, False)
>>> optimizer = numpyro.optim.Adam(step_size=0.05)
>>> svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
>>> svi_result = svi.run(random.PRNGKey(0), 300, data, masked_array)
>>> params = svi_result.params
>>> # inferred_mean is closer to 1
>>> inferred_mean = params["alpha_q"] / (params["alpha_q"] + params["beta_q"])
classmethod infer_shapes(*args, **kwargs)[source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

cdf(value)[source]

The cummulative distribution function of this distribution.

Parameters:value – samples from this distribution.
Returns:output of the cummulative distribution function evaluated at value.
icdf(q)[source]

The inverse cumulative distribution function of this distribution.

Parameters:q – quantile values, should belong to [0, 1].
Returns:the samples whose cdf values equals to q.

ExpandedDistribution

class ExpandedDistribution(base_dist, batch_shape=())[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {}
has_enumerate_support

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

is_discrete

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

has_rsample
rsample(key, sample_shape=())[source]
support
sample_with_intermediates(key, sample_shape=())[source]

Same as sample except that any intermediate computations are returned (useful for TransformedDistribution).

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value)[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:value – A batch of samples from the distribution.
Returns:an array with shape value.shape[:-self.event_shape]
Return type:numpy.ndarray
enumerate_support(expand=True)[source]

Returns an array with shape len(support) x batch_shape containing all values in the support.

mean

Mean of the distribution.

variance

Variance of the distribution.

tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]

ImproperUniform

class ImproperUniform(support, batch_shape, event_shape, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

A helper distribution with zero log_prob() over the support domain.

Note

sample method is not implemented for this distribution. In autoguide and mcmc, initial parameters for improper sites are derived from init_to_uniform or init_to_value strategies.

Usage:

>>> from numpyro import sample
>>> from numpyro.distributions import ImproperUniform, Normal, constraints
>>>
>>> def model():
...     # ordered vector with length 10
...     x = sample('x', ImproperUniform(constraints.ordered_vector, (), event_shape=(10,)))
...
...     # real matrix with shape (3, 4)
...     y = sample('y', ImproperUniform(constraints.real, (), event_shape=(3, 4)))
...
...     # a shape-(6, 8) batch of length-5 vectors greater than 3
...     z = sample('z', ImproperUniform(constraints.greater_than(3), (6, 8), event_shape=(5,)))

If you want to set improper prior over all values greater than a, where a is another random variable, you might use

>>> def model():
...     a = sample('a', Normal(0, 1))
...     x = sample('x', ImproperUniform(constraints.greater_than(a), (), event_shape=()))

or if you want to reparameterize it

>>> from numpyro.distributions import TransformedDistribution, transforms
>>> from numpyro.handlers import reparam
>>> from numpyro.infer.reparam import TransformReparam
>>>
>>> def model():
...     a = sample('a', Normal(0, 1))
...     with reparam(config={'x': TransformReparam()}):
...         x = sample('x',
...                    TransformedDistribution(ImproperUniform(constraints.positive, (), ()),
...                                            transforms.AffineTransform(a, 1)))
Parameters:
  • support (Constraint) – the support of this distribution.
  • batch_shape (tuple) – batch shape of this distribution. It is usually safe to set batch_shape=().
  • event_shape (tuple) – event shape of this distribution.
arg_constraints = {}
support = <numpyro.distributions.constraints._Dependent object>
log_prob(*args, **kwargs)
tree_flatten()[source]

Independent

class Independent(base_dist, reinterpreted_batch_ndims, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

Reinterprets batch dimensions of a distribution as event dims by shifting the batch-event dim boundary further to the left.

From a practical standpoint, this is useful when changing the result of log_prob(). For example, a univariate Normal distribution can be interpreted as a multivariate Normal with diagonal covariance:

>>> import numpyro.distributions as dist
>>> normal = dist.Normal(jnp.zeros(3), jnp.ones(3))
>>> [normal.batch_shape, normal.event_shape]
[(3,), ()]
>>> diag_normal = dist.Independent(normal, 1)
>>> [diag_normal.batch_shape, diag_normal.event_shape]
[(), (3,)]
Parameters:
  • base_distribution (numpyro.distribution.Distribution) – a distribution instance.
  • reinterpreted_batch_ndims (int) – the number of batch dims to reinterpret as event dims.
arg_constraints = {}
support
has_enumerate_support

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

is_discrete

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

reparameterized_params
mean

Mean of the distribution.

variance

Variance of the distribution.

has_rsample
rsample(key, sample_shape=())[source]
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value)[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:value – A batch of samples from the distribution.
Returns:an array with shape value.shape[:-self.event_shape]
Return type:numpy.ndarray
expand(batch_shape)[source]

Returns a new ExpandedDistribution instance with batch dimensions expanded to batch_shape.

Parameters:batch_shape (tuple) – batch shape to expand to.
Returns:an instance of ExpandedDistribution.
Return type:ExpandedDistribution
tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]

MaskedDistribution

class MaskedDistribution(base_dist, mask)[source]

Bases: numpyro.distributions.distribution.Distribution

Masks a distribution by a boolean array that is broadcastable to the distribution’s Distribution.batch_shape. In the special case mask is False, computation of log_prob() , is skipped, and constant zero values are returned instead.

Parameters:mask (jnp.ndarray or bool) – A boolean or boolean-valued array.
arg_constraints = {}
has_enumerate_support

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

is_discrete

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

has_rsample
rsample(key, sample_shape=())[source]
support
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value)[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:value – A batch of samples from the distribution.
Returns:an array with shape value.shape[:-self.event_shape]
Return type:numpy.ndarray
enumerate_support(expand=True)[source]

Returns an array with shape len(support) x batch_shape containing all values in the support.

mean

Mean of the distribution.

variance

Variance of the distribution.

tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]

TransformedDistribution

class TransformedDistribution(base_distribution, transforms, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

Returns a distribution instance obtained as a result of applying a sequence of transforms to a base distribution. For an example, see LogNormal and HalfNormal.

Parameters:
  • base_distribution – the base distribution over which to apply transforms.
  • transforms – a single transform or a list of transforms.
  • validate_args – Whether to enable validation of distribution parameters and arguments to .log_prob method.
arg_constraints = {}
has_rsample
rsample(key, sample_shape=())[source]
support
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

sample_with_intermediates(key, sample_shape=())[source]

Same as sample except that any intermediate computations are returned (useful for TransformedDistribution).

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

tree_flatten()[source]

Delta

class Delta(v=0.0, log_density=0.0, event_dim=0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'log_density': <numpyro.distributions.constraints._Real object>, 'v': <numpyro.distributions.constraints._Dependent object>}
reparameterized_params = ['v', 'log_density']
is_discrete = True
support
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]

Unit

class Unit(log_factor, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

Trivial nonnormalized distribution representing the unit type.

The unit type has a single value with no data, i.e. value.size == 0.

This is used for numpyro.factor() statements.

arg_constraints = {'log_factor': <numpyro.distributions.constraints._Real object>}
support = <numpyro.distributions.constraints._Real object>
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value)[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:value – A batch of samples from the distribution.
Returns:an array with shape value.shape[:-self.event_shape]
Return type:numpy.ndarray

Continuous Distributions

Beta

class Beta(concentration1, concentration0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'concentration0': <numpyro.distributions.constraints._GreaterThan object>, 'concentration1': <numpyro.distributions.constraints._GreaterThan object>}
reparametrized_params = ['concentration1', 'concentration0']
support = <numpyro.distributions.constraints._Interval object>
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

cdf(value)[source]

The cummulative distribution function of this distribution.

Parameters:value – samples from this distribution.
Returns:output of the cummulative distribution function evaluated at value.

Cauchy

class Cauchy(loc=0.0, scale=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._Real object>
reparametrized_params = ['loc', 'scale']
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

cdf(value)[source]

The cummulative distribution function of this distribution.

Parameters:value – samples from this distribution.
Returns:output of the cummulative distribution function evaluated at value.
icdf(q)[source]

The inverse cumulative distribution function of this distribution.

Parameters:q – quantile values, should belong to [0, 1].
Returns:the samples whose cdf values equals to q.

Chi2

class Chi2(df, validate_args=None)[source]

Bases: numpyro.distributions.continuous.Gamma

arg_constraints = {'df': <numpyro.distributions.constraints._GreaterThan object>}
reparametrized_params = ['df']

Dirichlet

class Dirichlet(concentration, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'concentration': <numpyro.distributions.constraints._IndependentConstraint object>}
reparametrized_params = ['concentration']
support = <numpyro.distributions.constraints._Simplex object>
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

static infer_shapes(concentration)[source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

Exponential

class Exponential(rate=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

reparametrized_params = ['rate']
arg_constraints = {'rate': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._GreaterThan object>
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

Gamma

class Gamma(concentration, rate=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'concentration': <numpyro.distributions.constraints._GreaterThan object>, 'rate': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._GreaterThan object>
reparametrized_params = ['concentration', 'rate']
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

Gumbel

class Gumbel(loc=0.0, scale=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._Real object>
reparametrized_params = ['loc', 'scale']
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

GaussianRandomWalk

class GaussianRandomWalk(scale=1.0, num_steps=1, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'scale': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._IndependentConstraint object>
reparametrized_params = ['scale']
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]

HalfCauchy

class HalfCauchy(scale=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

reparametrized_params = ['scale']
support = <numpyro.distributions.constraints._GreaterThan object>
arg_constraints = {'scale': <numpyro.distributions.constraints._GreaterThan object>}
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

HalfNormal

class HalfNormal(scale=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

reparametrized_params = ['scale']
support = <numpyro.distributions.constraints._GreaterThan object>
arg_constraints = {'scale': <numpyro.distributions.constraints._GreaterThan object>}
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

InverseGamma

class InverseGamma(concentration, rate=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.TransformedDistribution

Note

We keep the same notation rate as in Pyro but it plays the role of scale parameter of InverseGamma in literatures (e.g. wikipedia: https://en.wikipedia.org/wiki/Inverse-gamma_distribution)

arg_constraints = {'concentration': <numpyro.distributions.constraints._GreaterThan object>, 'rate': <numpyro.distributions.constraints._GreaterThan object>}
reparametrized_params = ['concentration', 'rate']
support = <numpyro.distributions.constraints._GreaterThan object>
mean

Mean of the distribution.

variance

Variance of the distribution.

tree_flatten()[source]

Laplace

class Laplace(loc=0.0, scale=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._Real object>
reparametrized_params = ['loc', 'scale']
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

cdf(value)[source]

The cummulative distribution function of this distribution.

Parameters:value – samples from this distribution.
Returns:output of the cummulative distribution function evaluated at value.
icdf(q)[source]

The inverse cumulative distribution function of this distribution.

Parameters:q – quantile values, should belong to [0, 1].
Returns:the samples whose cdf values equals to q.

LKJ

class LKJ(dimension, concentration=1.0, sample_method='onion', validate_args=None)[source]

Bases: numpyro.distributions.distribution.TransformedDistribution

LKJ distribution for correlation matrices. The distribution is controlled by concentration parameter \(\eta\) to make the probability of the correlation matrix \(M\) propotional to \(\det(M)^{\eta - 1}\). Because of that, when concentration == 1, we have a uniform distribution over correlation matrices.

When concentration > 1, the distribution favors samples with large large determinent. This is useful when we know a priori that the underlying variables are not correlated.

When concentration < 1, the distribution favors samples with small determinent. This is useful when we know a priori that some underlying variables are correlated.

Parameters:
  • dimension (int) – dimension of the matrices
  • concentration (ndarray) – concentration/shape parameter of the distribution (often referred to as eta)
  • sample_method (str) – Either “cvine” or “onion”. Both methods are proposed in [1] and offer the same distribution over correlation matrices. But they are different in how to generate samples. Defaults to “onion”.

References

[1] Generating random correlation matrices based on vines and extended onion method, Daniel Lewandowski, Dorota Kurowicka, Harry Joe

arg_constraints = {'concentration': <numpyro.distributions.constraints._GreaterThan object>}
reparametrized_params = ['concentration']
support = <numpyro.distributions.constraints._CorrMatrix object>
mean

Mean of the distribution.

tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]

LKJCholesky

class LKJCholesky(dimension, concentration=1.0, sample_method='onion', validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

LKJ distribution for lower Cholesky factors of correlation matrices. The distribution is controlled by concentration parameter \(\eta\) to make the probability of the correlation matrix \(M\) generated from a Cholesky factor propotional to \(\det(M)^{\eta - 1}\). Because of that, when concentration == 1, we have a uniform distribution over Cholesky factors of correlation matrices.

When concentration > 1, the distribution favors samples with large diagonal entries (hence large determinent). This is useful when we know a priori that the underlying variables are not correlated.

When concentration < 1, the distribution favors samples with small diagonal entries (hence small determinent). This is useful when we know a priori that some underlying variables are correlated.

Parameters:
  • dimension (int) – dimension of the matrices
  • concentration (ndarray) – concentration/shape parameter of the distribution (often referred to as eta)
  • sample_method (str) – Either “cvine” or “onion”. Both methods are proposed in [1] and offer the same distribution over correlation matrices. But they are different in how to generate samples. Defaults to “onion”.

References

[1] Generating random correlation matrices based on vines and extended onion method, Daniel Lewandowski, Dorota Kurowicka, Harry Joe

arg_constraints = {'concentration': <numpyro.distributions.constraints._GreaterThan object>}
reparametrized_params = ['concentration']
support = <numpyro.distributions.constraints._CorrCholesky object>
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]

LogNormal

class LogNormal(loc=0.0, scale=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.TransformedDistribution

arg_constraints = {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._GreaterThan object>
reparametrized_params = ['loc', 'scale']
mean

Mean of the distribution.

variance

Variance of the distribution.

tree_flatten()[source]

Logistic

class Logistic(loc=0.0, scale=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._Real object>
reparametrized_params = ['loc', 'scale']
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

cdf(value)[source]

The cummulative distribution function of this distribution.

Parameters:value – samples from this distribution.
Returns:output of the cummulative distribution function evaluated at value.
icdf(q)[source]

The inverse cumulative distribution function of this distribution.

Parameters:q – quantile values, should belong to [0, 1].
Returns:the samples whose cdf values equals to q.

MultivariateNormal

class MultivariateNormal(loc=0.0, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'covariance_matrix': <numpyro.distributions.constraints._PositiveDefinite object>, 'loc': <numpyro.distributions.constraints._IndependentConstraint object>, 'precision_matrix': <numpyro.distributions.constraints._PositiveDefinite object>, 'scale_tril': <numpyro.distributions.constraints._LowerCholesky object>}
support = <numpyro.distributions.constraints._IndependentConstraint object>
reparametrized_params = ['loc', 'covariance_matrix', 'precision_matrix', 'scale_tril']
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
covariance_matrix[source]
precision_matrix[source]
mean

Mean of the distribution.

variance

Variance of the distribution.

tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]
static infer_shapes(loc=(), covariance_matrix=None, precision_matrix=None, scale_tril=None)[source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

LowRankMultivariateNormal

class LowRankMultivariateNormal(loc, cov_factor, cov_diag, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'cov_diag': <numpyro.distributions.constraints._IndependentConstraint object>, 'cov_factor': <numpyro.distributions.constraints._IndependentConstraint object>, 'loc': <numpyro.distributions.constraints._IndependentConstraint object>}
support = <numpyro.distributions.constraints._IndependentConstraint object>
reparametrized_params = ['loc', 'cov_factor', 'cov_diag']
mean

Mean of the distribution.

variance[source]
scale_tril[source]
covariance_matrix[source]
precision_matrix[source]
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
entropy()[source]
static infer_shapes(loc, cov_factor, cov_diag)[source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

Normal

class Normal(loc=0.0, scale=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._Real object>
reparametrized_params = ['loc', 'scale']
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
cdf(value)[source]

The cummulative distribution function of this distribution.

Parameters:value – samples from this distribution.
Returns:output of the cummulative distribution function evaluated at value.
icdf(q)[source]

The inverse cumulative distribution function of this distribution.

Parameters:q – quantile values, should belong to [0, 1].
Returns:the samples whose cdf values equals to q.
mean

Mean of the distribution.

variance

Variance of the distribution.

Pareto

class Pareto(scale, alpha, validate_args=None)[source]

Bases: numpyro.distributions.distribution.TransformedDistribution

arg_constraints = {'alpha': <numpyro.distributions.constraints._GreaterThan object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
reparametrized_params = ['scale', 'alpha']
mean

Mean of the distribution.

variance

Variance of the distribution.

support
tree_flatten()[source]

StudentT

class StudentT(df, loc=0.0, scale=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'df': <numpyro.distributions.constraints._GreaterThan object>, 'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._Real object>
reparametrized_params = ['df', 'loc', 'scale']
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

cdf(value)[source]

The cummulative distribution function of this distribution.

Parameters:value – samples from this distribution.
Returns:output of the cummulative distribution function evaluated at value.
icdf(q)[source]

The inverse cumulative distribution function of this distribution.

Parameters:q – quantile values, should belong to [0, 1].
Returns:the samples whose cdf values equals to q.

TruncatedCauchy

class TruncatedCauchy(low=0.0, loc=0.0, scale=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.TransformedDistribution

arg_constraints = {'loc': <numpyro.distributions.constraints._Real object>, 'low': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
reparametrized_params = ['low', 'loc', 'scale']
support
mean

Mean of the distribution.

variance

Variance of the distribution.

tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]

TruncatedDistribution

class TruncatedDistribution[source]

A function to generate a truncated distribution.

Parameters:
  • base_dist – The base distribution to be truncated. This should be a univariate distribution. Currently, only the following distributions are supported: Cauchy, Laplace, Logistic, Normal, and StudentT.
  • low – the value which is used to truncate the base distribution from below. Setting this parameter to None to not truncate from below.
  • high – the value which is used to truncate the base distribution from above. Setting this parameter to None to not truncate from above.

TruncatedNormal

class TruncatedNormal(low=0.0, loc=0.0, scale=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.TransformedDistribution

arg_constraints = {'loc': <numpyro.distributions.constraints._Real object>, 'low': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
reparametrized_params = ['low', 'loc', 'scale']
support
mean

Mean of the distribution.

variance

Variance of the distribution.

tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]

TruncatedPolyaGamma

class TruncatedPolyaGamma(batch_shape=(), validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

truncation_point = 2.5
num_log_prob_terms = 7
num_gamma_variates = 8
arg_constraints = {}
support = <numpyro.distributions.constraints._Interval object>
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]

Uniform

class Uniform(low=0.0, high=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.TransformedDistribution

arg_constraints = {'high': <numpyro.distributions.constraints._Dependent object>, 'low': <numpyro.distributions.constraints._Dependent object>}
reparametrized_params = ['low', 'high']
support
mean

Mean of the distribution.

variance

Variance of the distribution.

tree_flatten()[source]
classmethod tree_unflatten(aux_data, params)[source]
static infer_shapes(low=(), high=())[source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

Discrete Distributions

Bernoulli

Bernoulli(probs=None, logits=None, validate_args=None)[source]

BernoulliLogits

class BernoulliLogits(logits=None, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'logits': <numpyro.distributions.constraints._Real object>}
support = <numpyro.distributions.constraints._Boolean object>
has_enumerate_support = True
is_discrete = True
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
probs[source]
mean

Mean of the distribution.

variance

Variance of the distribution.

enumerate_support(expand=True)[source]

Returns an array with shape len(support) x batch_shape containing all values in the support.

BernoulliProbs

class BernoulliProbs(probs, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'probs': <numpyro.distributions.constraints._Interval object>}
support = <numpyro.distributions.constraints._Boolean object>
has_enumerate_support = True
is_discrete = True
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
logits[source]
mean

Mean of the distribution.

variance

Variance of the distribution.

enumerate_support(expand=True)[source]

Returns an array with shape len(support) x batch_shape containing all values in the support.

BetaBinomial

class BetaBinomial(concentration1, concentration0, total_count=1, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

Compound distribution comprising of a beta-binomial pair. The probability of success (probs for the Binomial distribution) is unknown and randomly drawn from a Beta distribution prior to a certain number of Bernoulli trials given by total_count.

Parameters:
  • concentration1 (numpy.ndarray) – 1st concentration parameter (alpha) for the Beta distribution.
  • concentration0 (numpy.ndarray) – 2nd concentration parameter (beta) for the Beta distribution.
  • total_count (numpy.ndarray) – number of Bernoulli trials.
arg_constraints = {'concentration0': <numpyro.distributions.constraints._GreaterThan object>, 'concentration1': <numpyro.distributions.constraints._GreaterThan object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}
has_enumerate_support = True
is_discrete = True
enumerate_support(expand=True)

Returns an array with shape len(support) x batch_shape containing all values in the support.

sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

support

Binomial

Binomial(total_count=1, probs=None, logits=None, validate_args=None)[source]

BinomialLogits

class BinomialLogits(logits, total_count=1, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'logits': <numpyro.distributions.constraints._Real object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}
has_enumerate_support = True
is_discrete = True
enumerate_support(expand=True)

Returns an array with shape len(support) x batch_shape containing all values in the support.

sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
probs[source]
mean

Mean of the distribution.

variance

Variance of the distribution.

support

BinomialProbs

class BinomialProbs(probs, total_count=1, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'probs': <numpyro.distributions.constraints._Interval object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}
has_enumerate_support = True
is_discrete = True
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
logits[source]
mean

Mean of the distribution.

variance

Variance of the distribution.

support
enumerate_support(expand=True)[source]

Returns an array with shape len(support) x batch_shape containing all values in the support.

Categorical

Categorical(probs=None, logits=None, validate_args=None)[source]

CategoricalLogits

class CategoricalLogits(logits, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'logits': <numpyro.distributions.constraints._IndependentConstraint object>}
has_enumerate_support = True
is_discrete = True
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
probs[source]
mean

Mean of the distribution.

variance

Variance of the distribution.

support
enumerate_support(expand=True)[source]

Returns an array with shape len(support) x batch_shape containing all values in the support.

CategoricalProbs

class CategoricalProbs(probs, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'probs': <numpyro.distributions.constraints._Simplex object>}
has_enumerate_support = True
is_discrete = True
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
logits[source]
mean

Mean of the distribution.

variance

Variance of the distribution.

support
enumerate_support(expand=True)[source]

Returns an array with shape len(support) x batch_shape containing all values in the support.

DirichletMultinomial

class DirichletMultinomial(concentration, total_count=1, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

Compound distribution comprising of a dirichlet-multinomial pair. The probability of classes (probs for the Multinomial distribution) is unknown and randomly drawn from a Dirichlet distribution prior to a certain number of Categorical trials given by total_count.

Parameters:
  • concentration (numpy.ndarray) – concentration parameter (alpha) for the Dirichlet distribution.
  • total_count (numpy.ndarray) – number of Categorical trials.
arg_constraints = {'concentration': <numpyro.distributions.constraints._IndependentConstraint object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}
is_discrete = True
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

support
static infer_shapes(concentration, total_count=())[source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

GammaPoisson

class GammaPoisson(concentration, rate=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

Compound distribution comprising of a gamma-poisson pair, also referred to as a gamma-poisson mixture. The rate parameter for the Poisson distribution is unknown and randomly drawn from a Gamma distribution.

Parameters:
  • concentration (numpy.ndarray) – shape parameter (alpha) of the Gamma distribution.
  • rate (numpy.ndarray) – rate parameter (beta) for the Gamma distribution.
arg_constraints = {'concentration': <numpyro.distributions.constraints._GreaterThan object>, 'rate': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._IntegerGreaterThan object>
is_discrete = True
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

Geometric

Geometric(probs=None, logits=None, validate_args=None)[source]

GeometricLogits

class GeometricLogits(logits, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'logits': <numpyro.distributions.constraints._Real object>}
support = <numpyro.distributions.constraints._IntegerGreaterThan object>
is_discrete = True
probs[source]
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

GeometricProbs

class GeometricProbs(probs, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'probs': <numpyro.distributions.constraints._Interval object>}
support = <numpyro.distributions.constraints._IntegerGreaterThan object>
is_discrete = True
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
logits[source]
mean

Mean of the distribution.

variance

Variance of the distribution.

Multinomial

Multinomial(total_count=1, probs=None, logits=None, validate_args=None)[source]

MultinomialLogits

class MultinomialLogits(logits, total_count=1, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'logits': <numpyro.distributions.constraints._IndependentConstraint object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}
is_discrete = True
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
probs[source]
mean

Mean of the distribution.

variance

Variance of the distribution.

support
static infer_shapes(logits, total_count)[source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

MultinomialProbs

class MultinomialProbs(probs, total_count=1, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'probs': <numpyro.distributions.constraints._Simplex object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}
is_discrete = True
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
logits[source]
mean

Mean of the distribution.

variance

Variance of the distribution.

support
static infer_shapes(probs, total_count)[source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

OrderedLogistic

class OrderedLogistic(predictor, cutpoints, validate_args=None)[source]

Bases: numpyro.distributions.discrete.CategoricalProbs

A categorical distribution with ordered outcomes.

References:

  1. Stan Functions Reference, v2.20 section 12.6, Stan Development Team
Parameters:
  • predictor (numpy.ndarray) – prediction in real domain; typically this is output of a linear model.
  • cutpoints (numpy.ndarray) – positions in real domain to separate categories.
arg_constraints = {'cutpoints': <numpyro.distributions.constraints._OrderedVector object>, 'predictor': <numpyro.distributions.constraints._Real object>}
static infer_shapes(predictor, cutpoints)[source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

Poisson

class Poisson(rate, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'rate': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._IntegerGreaterThan object>
is_discrete = True
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean

Mean of the distribution.

variance

Variance of the distribution.

PRNGIdentity

class PRNGIdentity[source]

Bases: numpyro.distributions.distribution.Distribution

Distribution over PRNGKey(). This can be used to draw a batch of PRNGKey() using the seed handler. Only sample method is supported.

is_discrete = True
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

ZeroInflatedPoisson

class ZeroInflatedPoisson(gate, rate=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

A Zero Inflated Poisson distribution.

Parameters:
arg_constraints = {'gate': <numpyro.distributions.constraints._Interval object>, 'rate': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._IntegerGreaterThan object>
is_discrete = True
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(*args, **kwargs)
mean[source]
variance[source]

Directional Distributions

ProjectedNormal

class ProjectedNormal(concentration, *, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

Projected isotropic normal distribution of arbitrary dimension.

This distribution over directional data is qualitatively similar to the von Mises and von Mises-Fisher distributions, but permits tractable variational inference via reparametrized gradients.

To use this distribution with autoguides and HMC, use handlers.reparam with a ProjectedNormalReparam reparametrizer in the model, e.g.:

@handlers.reparam(config={"direction": ProjectedNormalReparam()})
def model():
    direction = numpyro.sample("direction",
                               ProjectedNormal(zeros(3)))
    ...

Note

This implements log_prob() only for dimensions {2,3}.

[1] D. Hernandez-Stumpfhauser, F.J. Breidt, M.J. van der Woerd (2017)
“The General Projected Normal Distribution of Arbitrary Dimension: Modeling and Bayesian Inference” https://projecteuclid.org/euclid.ba/1453211962
arg_constraints = {'concentration': <numpyro.distributions.constraints._IndependentConstraint object>}
reparametrized_params = ['concentration']
support = <numpyro.distributions.constraints._Sphere object>
mean

Note this is the mean in the sense of a centroid in the submanifold that minimizes expected squared geodesic distance.

mode
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

log_prob(value)[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:value – A batch of samples from the distribution.
Returns:an array with shape value.shape[:-self.event_shape]
Return type:numpy.ndarray
static infer_shapes(concentration)[source]

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
  • **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

VonMises

class VonMises(loc, concentration, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'concentration': <numpyro.distributions.constraints._GreaterThan object>, 'loc': <numpyro.distributions.constraints._Real object>}
reparametrized_params = ['loc']
support = <numpyro.distributions.constraints._Interval object>
sample(key, sample_shape=())[source]

Generate sample from von Mises distribution

Parameters:
  • key – random number generator key
  • sample_shape – shape of samples
Returns:

samples from von Mises

log_prob(*args, **kwargs)
mean

Computes circular mean of distribution. NOTE: same as location when mapped to support [-pi, pi]

variance

Computes circular variance of distribution

TensorFlow Distributions

Thin wrappers around TensorFlow Probability (TFP) distributions. For details on the TFP distribution interface, see its Distribution docs.

BijectorConstraint

class BijectorConstraint(bijector)[source]

A constraint which is codomain of a TensorFlow bijector.

Parameters:bijector (Bijector) – a TensorFlow bijector

BijectorTransform

class BijectorTransform(bijector)[source]

A wrapper for TensorFlow bijectors to make them compatible with NumPyro’s transforms.

Parameters:bijector (Bijector) – a TensorFlow bijector

TFPDistributionMixin

class TFPDistributionMixin(batch_shape=(), event_shape=(), validate_args=None)[source]

A mixin layer to make TensorFlow Probability (TFP) distribution compatible with NumPyro internal.

Autoregressive

class Autoregressive(distribution_fn, sample0=None, num_steps=None, validate_args=False, allow_nan_stats=True, name='Autoregressive')

Wraps tensorflow_probability.substrates.jax.distributions.autoregressive.Autoregressive with TFPDistributionMixin.

BatchBroadcast

class BatchBroadcast(distribution, with_shape=None, *, to_shape=None, validate_args=False, name=None)

Wraps tensorflow_probability.substrates.jax.distributions.batch_broadcast.BatchBroadcast with TFPDistributionMixin.

BatchConcat

class BatchConcat(distributions, axis, validate_args=False, allow_nan_stats=True, name='BatchConcat')

Wraps tensorflow_probability.substrates.jax.distributions.batch_concat.BatchConcat with TFPDistributionMixin.

BatchReshape

class BatchReshape(distribution, batch_shape, validate_args=False, allow_nan_stats=True, name=None)

Wraps tensorflow_probability.substrates.jax.distributions.batch_reshape.BatchReshape with TFPDistributionMixin.

Bates

class Bates(total_count, low=0.0, high=1.0, validate_args=False, allow_nan_stats=True, name='Bates')

Wraps tensorflow_probability.substrates.jax.distributions.bates.Bates with TFPDistributionMixin.

Bernoulli

class Bernoulli(logits=None, probs=None, dtype=<class 'jax._src.numpy.lax_numpy.int32'>, validate_args=False, allow_nan_stats=True, name='Bernoulli')

Wraps tensorflow_probability.substrates.jax.distributions.bernoulli.Bernoulli with TFPDistributionMixin.

Beta

class Beta(concentration1, concentration0, validate_args=False, allow_nan_stats=True, name='Beta')

Wraps tensorflow_probability.substrates.jax.distributions.beta.Beta with TFPDistributionMixin.

BetaBinomial

class BetaBinomial(total_count, concentration1, concentration0, validate_args=False, allow_nan_stats=True, name='BetaBinomial')

Wraps tensorflow_probability.substrates.jax.distributions.beta_binomial.BetaBinomial with TFPDistributionMixin.

BetaQuotient

class BetaQuotient(concentration1_numerator, concentration0_numerator, concentration1_denominator, concentration0_denominator, validate_args=False, allow_nan_stats=True, name='BetaQuotient')

Wraps tensorflow_probability.substrates.jax.distributions.beta_quotient.BetaQuotient with TFPDistributionMixin.

Binomial

class Binomial(total_count, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name=None)

Wraps tensorflow_probability.substrates.jax.distributions.binomial.Binomial with TFPDistributionMixin.

Blockwise

class Blockwise(distributions, dtype_override=None, validate_args=False, allow_nan_stats=False, name='Blockwise')

Wraps tensorflow_probability.substrates.jax.distributions.blockwise.Blockwise with TFPDistributionMixin.

Categorical

class Categorical(logits=None, probs=None, dtype=<class 'jax._src.numpy.lax_numpy.int32'>, validate_args=False, allow_nan_stats=True, name='Categorical')

Wraps tensorflow_probability.substrates.jax.distributions.categorical.Categorical with TFPDistributionMixin.

Cauchy

class Cauchy(loc, scale, validate_args=False, allow_nan_stats=True, name='Cauchy')

Wraps tensorflow_probability.substrates.jax.distributions.cauchy.Cauchy with TFPDistributionMixin.

Chi

class Chi(df, validate_args=False, allow_nan_stats=True, name='Chi')

Wraps tensorflow_probability.substrates.jax.distributions.chi.Chi with TFPDistributionMixin.

Chi2

class Chi2(df, validate_args=False, allow_nan_stats=True, name='Chi2')

Wraps tensorflow_probability.substrates.jax.distributions.chi2.Chi2 with TFPDistributionMixin.

CholeskyLKJ

class CholeskyLKJ(dimension, concentration, validate_args=False, allow_nan_stats=True, name='CholeskyLKJ')

Wraps tensorflow_probability.substrates.jax.distributions.cholesky_lkj.CholeskyLKJ with TFPDistributionMixin.

ContinuousBernoulli

class ContinuousBernoulli(logits=None, probs=None, lims=(0.499, 0.501), dtype=<class 'jax._src.numpy.lax_numpy.float32'>, validate_args=False, allow_nan_stats=True, name='ContinuousBernoulli')

Wraps tensorflow_probability.substrates.jax.distributions.continuous_bernoulli.ContinuousBernoulli with TFPDistributionMixin.

DeterminantalPointProcess

class DeterminantalPointProcess(eigenvalues, eigenvectors, validate_args=False, allow_nan_stats=False, name='DeterminantalPointProcess')

Wraps tensorflow_probability.substrates.jax.distributions.dpp.DeterminantalPointProcess with TFPDistributionMixin.

Deterministic

class Deterministic(loc, atol=None, rtol=None, validate_args=False, allow_nan_stats=True, name='Deterministic')

Wraps tensorflow_probability.substrates.jax.distributions.deterministic.Deterministic with TFPDistributionMixin.

Dirichlet

class Dirichlet(concentration, validate_args=False, allow_nan_stats=True, name='Dirichlet')

Wraps tensorflow_probability.substrates.jax.distributions.dirichlet.Dirichlet with TFPDistributionMixin.

DirichletMultinomial

class DirichletMultinomial(total_count, concentration, validate_args=False, allow_nan_stats=True, name='DirichletMultinomial')

Wraps tensorflow_probability.substrates.jax.distributions.dirichlet_multinomial.DirichletMultinomial with TFPDistributionMixin.

DoublesidedMaxwell

class DoublesidedMaxwell(loc, scale, validate_args=False, allow_nan_stats=True, name='doublesided_maxwell')

Wraps tensorflow_probability.substrates.jax.distributions.doublesided_maxwell.DoublesidedMaxwell with TFPDistributionMixin.

Empirical

class Empirical(samples, event_ndims=0, validate_args=False, allow_nan_stats=True, name='Empirical')

Wraps tensorflow_probability.substrates.jax.distributions.empirical.Empirical with TFPDistributionMixin.

ExpGamma

class ExpGamma(concentration, rate=None, log_rate=None, validate_args=False, allow_nan_stats=True, name='ExpGamma')

Wraps tensorflow_probability.substrates.jax.distributions.exp_gamma.ExpGamma with TFPDistributionMixin.

ExpInverseGamma

class ExpInverseGamma(concentration, scale=None, log_scale=None, validate_args=False, allow_nan_stats=True, name='ExpInverseGamma')

Wraps tensorflow_probability.substrates.jax.distributions.exp_gamma.ExpInverseGamma with TFPDistributionMixin.

ExpRelaxedOneHotCategorical

class ExpRelaxedOneHotCategorical(temperature, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name='ExpRelaxedOneHotCategorical')

Wraps tensorflow_probability.substrates.jax.distributions.relaxed_onehot_categorical.ExpRelaxedOneHotCategorical with TFPDistributionMixin.

Exponential

class Exponential(rate, force_probs_to_zero_outside_support=False, validate_args=False, allow_nan_stats=True, name='Exponential')

Wraps tensorflow_probability.substrates.jax.distributions.exponential.Exponential with TFPDistributionMixin.

ExponentiallyModifiedGaussian

class ExponentiallyModifiedGaussian(loc, scale, rate, validate_args=False, allow_nan_stats=True, name='ExponentiallyModifiedGaussian')

Wraps tensorflow_probability.substrates.jax.distributions.exponentially_modified_gaussian.ExponentiallyModifiedGaussian with TFPDistributionMixin.

FiniteDiscrete

class FiniteDiscrete(outcomes, logits=None, probs=None, rtol=None, atol=None, validate_args=False, allow_nan_stats=True, name='FiniteDiscrete')

Wraps tensorflow_probability.substrates.jax.distributions.finite_discrete.FiniteDiscrete with TFPDistributionMixin.

Gamma

class Gamma(concentration, rate=None, log_rate=None, validate_args=False, allow_nan_stats=True, name='Gamma')

Wraps tensorflow_probability.substrates.jax.distributions.gamma.Gamma with TFPDistributionMixin.

GammaGamma

class GammaGamma(concentration, mixing_concentration, mixing_rate, validate_args=False, allow_nan_stats=True, name='GammaGamma')

Wraps tensorflow_probability.substrates.jax.distributions.gamma_gamma.GammaGamma with TFPDistributionMixin.

GaussianProcess

class GaussianProcess(kernel, index_points=None, mean_fn=None, observation_noise_variance=0.0, marginal_fn=None, jitter=1e-06, validate_args=False, allow_nan_stats=False, name='GaussianProcess')

Wraps tensorflow_probability.substrates.jax.distributions.gaussian_process.GaussianProcess with TFPDistributionMixin.

GaussianProcessRegressionModel

class GaussianProcessRegressionModel(kernel, index_points=None, observation_index_points=None, observations=None, observation_noise_variance=0.0, predictive_noise_variance=None, mean_fn=None, jitter=1e-06, validate_args=False, allow_nan_stats=False, name='GaussianProcessRegressionModel')

Wraps tensorflow_probability.substrates.jax.distributions.gaussian_process_regression_model.GaussianProcessRegressionModel with TFPDistributionMixin.

GeneralizedExtremeValue

class GeneralizedExtremeValue(loc, scale, concentration, validate_args=False, allow_nan_stats=True, name='GeneralizedExtremeValue')

Wraps tensorflow_probability.substrates.jax.distributions.gev.GeneralizedExtremeValue with TFPDistributionMixin.

GeneralizedNormal

class GeneralizedNormal(loc, scale, power, validate_args=False, allow_nan_stats=True, name='GeneralizedNormal')

Wraps tensorflow_probability.substrates.jax.distributions.generalized_normal.GeneralizedNormal with TFPDistributionMixin.

GeneralizedPareto

class GeneralizedPareto(loc, scale, concentration, validate_args=False, allow_nan_stats=True, name=None)

Wraps tensorflow_probability.substrates.jax.distributions.generalized_pareto.GeneralizedPareto with TFPDistributionMixin.

Geometric

class Geometric(logits=None, probs=None, force_probs_to_zero_outside_support=False, validate_args=False, allow_nan_stats=True, name='Geometric')

Wraps tensorflow_probability.substrates.jax.distributions.geometric.Geometric with TFPDistributionMixin.

Gumbel

class Gumbel(loc, scale, validate_args=False, allow_nan_stats=True, name='Gumbel')

Wraps tensorflow_probability.substrates.jax.distributions.gumbel.Gumbel with TFPDistributionMixin.

HalfCauchy

class HalfCauchy(loc, scale, validate_args=False, allow_nan_stats=True, name='HalfCauchy')

Wraps tensorflow_probability.substrates.jax.distributions.half_cauchy.HalfCauchy with TFPDistributionMixin.

HalfNormal

class HalfNormal(scale, validate_args=False, allow_nan_stats=True, name='HalfNormal')

Wraps tensorflow_probability.substrates.jax.distributions.half_normal.HalfNormal with TFPDistributionMixin.

HalfStudentT

class HalfStudentT(df, loc, scale, validate_args=False, allow_nan_stats=True, name='HalfStudentT')

Wraps tensorflow_probability.substrates.jax.distributions.half_student_t.HalfStudentT with TFPDistributionMixin.

HiddenMarkovModel

class HiddenMarkovModel(initial_distribution, transition_distribution, observation_distribution, num_steps, validate_args=False, allow_nan_stats=True, time_varying_transition_distribution=False, time_varying_observation_distribution=False, name='HiddenMarkovModel')

Wraps tensorflow_probability.substrates.jax.distributions.hidden_markov_model.HiddenMarkovModel with TFPDistributionMixin.

Horseshoe

class Horseshoe(scale, validate_args=False, allow_nan_stats=True, name='Horseshoe')

Wraps tensorflow_probability.substrates.jax.distributions.horseshoe.Horseshoe with TFPDistributionMixin.

Independent

class Independent(distribution, reinterpreted_batch_ndims=None, validate_args=False, experimental_use_kahan_sum=False, name=None)

Wraps tensorflow_probability.substrates.jax.distributions.independent.Independent with TFPDistributionMixin.

InverseGamma

class InverseGamma(concentration, scale=None, validate_args=False, allow_nan_stats=True, name='InverseGamma')[source]

Wraps tensorflow_probability.substrates.jax.distributions.inverse_gamma.InverseGamma with TFPDistributionMixin.

InverseGaussian

class InverseGaussian(loc, concentration, validate_args=False, allow_nan_stats=True, name='InverseGaussian')

Wraps tensorflow_probability.substrates.jax.distributions.inverse_gaussian.InverseGaussian with TFPDistributionMixin.

JohnsonSU

class JohnsonSU(skewness, tailweight, loc, scale, validate_args=False, allow_nan_stats=True, name=None)

Wraps tensorflow_probability.substrates.jax.distributions.johnson_su.JohnsonSU with TFPDistributionMixin.

JointDistribution

class JointDistribution(dtype, reparameterization_type, validate_args, allow_nan_stats, parameters=None, graph_parents=None, name=None)

Wraps tensorflow_probability.substrates.jax.distributions.joint_distribution.JointDistribution with TFPDistributionMixin.

JointDistributionCoroutine

class JointDistributionCoroutine(model, sample_dtype=None, validate_args=False, name=None)

Wraps tensorflow_probability.substrates.jax.distributions.joint_distribution_coroutine.JointDistributionCoroutine with TFPDistributionMixin.

JointDistributionCoroutineAutoBatched

class JointDistributionCoroutineAutoBatched(model, sample_dtype=None, batch_ndims=0, use_vectorized_map=True, validate_args=False, experimental_use_kahan_sum=False, name=None)

Wraps tensorflow_probability.substrates.jax.distributions.joint_distribution_auto_batched.JointDistributionCoroutineAutoBatched with TFPDistributionMixin.

JointDistributionNamed

class JointDistributionNamed(model, validate_args=False, name=None)

Wraps tensorflow_probability.substrates.jax.distributions.joint_distribution_named.JointDistributionNamed with TFPDistributionMixin.

JointDistributionNamedAutoBatched

class JointDistributionNamedAutoBatched(model, batch_ndims=0, use_vectorized_map=True, validate_args=False, experimental_use_kahan_sum=False, name=None)

Wraps tensorflow_probability.substrates.jax.distributions.joint_distribution_auto_batched.JointDistributionNamedAutoBatched with TFPDistributionMixin.

JointDistributionSequential

class JointDistributionSequential(model, validate_args=False, name=None)

Wraps tensorflow_probability.substrates.jax.distributions.joint_distribution_sequential.JointDistributionSequential with TFPDistributionMixin.

JointDistributionSequentialAutoBatched

class JointDistributionSequentialAutoBatched(model, batch_ndims=0, use_vectorized_map=True, validate_args=False, experimental_use_kahan_sum=False, name=None)

Wraps tensorflow_probability.substrates.jax.distributions.joint_distribution_auto_batched.JointDistributionSequentialAutoBatched with TFPDistributionMixin.

Kumaraswamy

class Kumaraswamy(concentration1=1.0, concentration0=1.0, validate_args=False, allow_nan_stats=True, name='Kumaraswamy')

Wraps tensorflow_probability.substrates.jax.distributions.kumaraswamy.Kumaraswamy with TFPDistributionMixin.

LKJ

class LKJ(dimension, concentration, input_output_cholesky=False, validate_args=False, allow_nan_stats=True, name='LKJ')

Wraps tensorflow_probability.substrates.jax.distributions.lkj.LKJ with TFPDistributionMixin.

Laplace

class Laplace(loc, scale, validate_args=False, allow_nan_stats=True, name='Laplace')

Wraps tensorflow_probability.substrates.jax.distributions.laplace.Laplace with TFPDistributionMixin.

LinearGaussianStateSpaceModel

class LinearGaussianStateSpaceModel(num_timesteps, transition_matrix, transition_noise, observation_matrix, observation_noise, initial_state_prior, initial_step=0, experimental_parallelize=False, validate_args=False, allow_nan_stats=True, name='LinearGaussianStateSpaceModel')

Wraps tensorflow_probability.substrates.jax.distributions.linear_gaussian_ssm.LinearGaussianStateSpaceModel with TFPDistributionMixin.

LogLogistic

class LogLogistic(loc, scale, validate_args=False, allow_nan_stats=True, name='LogLogistic')

Wraps tensorflow_probability.substrates.jax.distributions.loglogistic.LogLogistic with TFPDistributionMixin.

LogNormal

class LogNormal(loc, scale, validate_args=False, allow_nan_stats=True, name='LogNormal')

Wraps tensorflow_probability.substrates.jax.distributions.lognormal.LogNormal with TFPDistributionMixin.

Logistic

class Logistic(loc, scale, validate_args=False, allow_nan_stats=True, name='Logistic')

Wraps tensorflow_probability.substrates.jax.distributions.logistic.Logistic with TFPDistributionMixin.

LogitNormal

class LogitNormal(loc, scale, num_probit_terms_approx=2, validate_args=False, allow_nan_stats=True, name='LogitNormal')

Wraps tensorflow_probability.substrates.jax.distributions.logitnormal.LogitNormal with TFPDistributionMixin.

MatrixNormalLinearOperator

class MatrixNormalLinearOperator(loc, scale_row, scale_column, validate_args=False, allow_nan_stats=True, name='MatrixNormalLinearOperator')

Wraps tensorflow_probability.substrates.jax.distributions.matrix_normal_linear_operator.MatrixNormalLinearOperator with TFPDistributionMixin.

MatrixTLinearOperator

class MatrixTLinearOperator(df, loc, scale_row, scale_column, validate_args=False, allow_nan_stats=True, name='MatrixTLinearOperator')

Wraps tensorflow_probability.substrates.jax.distributions.matrix_t_linear_operator.MatrixTLinearOperator with TFPDistributionMixin.

MixtureSameFamily

class MixtureSameFamily(mixture_distribution, components_distribution, reparameterize=False, validate_args=False, allow_nan_stats=True, name='MixtureSameFamily')

Wraps tensorflow_probability.substrates.jax.distributions.mixture_same_family.MixtureSameFamily with TFPDistributionMixin.

Moyal

class Moyal(loc, scale, validate_args=False, allow_nan_stats=True, name='Moyal')

Wraps tensorflow_probability.substrates.jax.distributions.moyal.Moyal with TFPDistributionMixin.

Multinomial

class Multinomial(total_count, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name='Multinomial')

Wraps tensorflow_probability.substrates.jax.distributions.multinomial.Multinomial with TFPDistributionMixin.

MultivariateNormalDiag

class MultivariateNormalDiag(loc=None, scale_diag=None, scale_identity_multiplier=None, validate_args=False, allow_nan_stats=True, experimental_use_kahan_sum=False, name='MultivariateNormalDiag')

Wraps tensorflow_probability.substrates.jax.distributions.mvn_diag.MultivariateNormalDiag with TFPDistributionMixin.

MultivariateNormalDiagPlusLowRank

class MultivariateNormalDiagPlusLowRank(loc=None, scale_diag=None, scale_perturb_factor=None, scale_perturb_diag=None, validate_args=False, allow_nan_stats=True, name='MultivariateNormalDiagPlusLowRank')

Wraps tensorflow_probability.substrates.jax.distributions.mvn_diag_plus_low_rank.MultivariateNormalDiagPlusLowRank with TFPDistributionMixin.

MultivariateNormalFullCovariance

class MultivariateNormalFullCovariance(loc=None, covariance_matrix=None, validate_args=False, allow_nan_stats=True, name='MultivariateNormalFullCovariance')

Wraps tensorflow_probability.substrates.jax.distributions.mvn_full_covariance.MultivariateNormalFullCovariance with TFPDistributionMixin.

MultivariateNormalLinearOperator

class MultivariateNormalLinearOperator(loc=None, scale=None, validate_args=False, allow_nan_stats=True, experimental_use_kahan_sum=False, name='MultivariateNormalLinearOperator')

Wraps tensorflow_probability.substrates.jax.distributions.mvn_linear_operator.MultivariateNormalLinearOperator with TFPDistributionMixin.

MultivariateNormalTriL

class MultivariateNormalTriL(loc=None, scale_tril=None, validate_args=False, allow_nan_stats=True, experimental_use_kahan_sum=False, name='MultivariateNormalTriL')

Wraps tensorflow_probability.substrates.jax.distributions.mvn_tril.MultivariateNormalTriL with TFPDistributionMixin.

MultivariateStudentTLinearOperator

class MultivariateStudentTLinearOperator(df, loc, scale, validate_args=False, allow_nan_stats=True, name='MultivariateStudentTLinearOperator')

Wraps tensorflow_probability.substrates.jax.distributions.multivariate_student_t.MultivariateStudentTLinearOperator with TFPDistributionMixin.

NegativeBinomial

class NegativeBinomial(total_count, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name='NegativeBinomial')

Wraps tensorflow_probability.substrates.jax.distributions.negative_binomial.NegativeBinomial with TFPDistributionMixin.

Normal

class Normal(loc, scale, validate_args=False, allow_nan_stats=True, name='Normal')

Wraps tensorflow_probability.substrates.jax.distributions.normal.Normal with TFPDistributionMixin.

NormalInverseGaussian

class NormalInverseGaussian(loc, scale, tailweight, skewness, validate_args=False, allow_nan_stats=True, name='NormalInverseGaussian')

Wraps tensorflow_probability.substrates.jax.distributions.normal_inverse_gaussian.NormalInverseGaussian with TFPDistributionMixin.

OneHotCategorical

class OneHotCategorical(logits=None, probs=None, dtype=<class 'jax._src.numpy.lax_numpy.int32'>, validate_args=False, allow_nan_stats=True, name='OneHotCategorical')[source]

Wraps tensorflow_probability.substrates.jax.distributions.onehot_categorical.OneHotCategorical with TFPDistributionMixin.

OrderedLogistic

class OrderedLogistic(cutpoints, loc, dtype=<class 'jax._src.numpy.lax_numpy.int32'>, validate_args=False, allow_nan_stats=True, name='OrderedLogistic')[source]

Wraps tensorflow_probability.substrates.jax.distributions.ordered_logistic.OrderedLogistic with TFPDistributionMixin.

PERT

class PERT(low, peak, high, temperature=4.0, validate_args=False, allow_nan_stats=False, name='PERT')

Wraps tensorflow_probability.substrates.jax.distributions.pert.PERT with TFPDistributionMixin.

Pareto

class Pareto(concentration, scale=1.0, validate_args=False, allow_nan_stats=True, name='Pareto')[source]

Wraps tensorflow_probability.substrates.jax.distributions.pareto.Pareto with TFPDistributionMixin.

PlackettLuce

class PlackettLuce(scores, dtype=<class 'jax._src.numpy.lax_numpy.int32'>, validate_args=False, allow_nan_stats=True, name='PlackettLuce')

Wraps tensorflow_probability.substrates.jax.distributions.plackett_luce.PlackettLuce with TFPDistributionMixin.

Poisson

class Poisson(rate=None, log_rate=None, force_probs_to_zero_outside_support=None, interpolate_nondiscrete=True, validate_args=False, allow_nan_stats=True, name='Poisson')

Wraps tensorflow_probability.substrates.jax.distributions.poisson.Poisson with TFPDistributionMixin.

PoissonLogNormalQuadratureCompound

class PoissonLogNormalQuadratureCompound(loc, scale, quadrature_size=8, quadrature_fn=<function quadrature_scheme_lognormal_quantiles>, validate_args=False, allow_nan_stats=True, name='PoissonLogNormalQuadratureCompound')

Wraps tensorflow_probability.substrates.jax.distributions.poisson_lognormal.PoissonLogNormalQuadratureCompound with TFPDistributionMixin.

PowerSpherical

class PowerSpherical(mean_direction, concentration, validate_args=False, allow_nan_stats=True, name='PowerSpherical')

Wraps tensorflow_probability.substrates.jax.distributions.power_spherical.PowerSpherical with TFPDistributionMixin.

ProbitBernoulli

class ProbitBernoulli(probits=None, probs=None, dtype=<class 'jax._src.numpy.lax_numpy.int32'>, validate_args=False, allow_nan_stats=True, name='ProbitBernoulli')

Wraps tensorflow_probability.substrates.jax.distributions.probit_bernoulli.ProbitBernoulli with TFPDistributionMixin.

QuantizedDistribution

class QuantizedDistribution(distribution, low=None, high=None, validate_args=False, name='QuantizedDistribution')

Wraps tensorflow_probability.substrates.jax.distributions.quantized_distribution.QuantizedDistribution with TFPDistributionMixin.

RelaxedBernoulli

class RelaxedBernoulli(temperature, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name='RelaxedBernoulli')

Wraps tensorflow_probability.substrates.jax.distributions.relaxed_bernoulli.RelaxedBernoulli with TFPDistributionMixin.

RelaxedOneHotCategorical

class RelaxedOneHotCategorical(temperature, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name='RelaxedOneHotCategorical')

Wraps tensorflow_probability.substrates.jax.distributions.relaxed_onehot_categorical.RelaxedOneHotCategorical with TFPDistributionMixin.

Sample

class Sample(distribution, sample_shape=(), validate_args=False, experimental_use_kahan_sum=False, name=None)

Wraps tensorflow_probability.substrates.jax.distributions.sample.Sample with TFPDistributionMixin.

SigmoidBeta

class SigmoidBeta(concentration1, concentration0, validate_args=False, allow_nan_stats=True, name='SigmoidBeta')

Wraps tensorflow_probability.substrates.jax.distributions.sigmoid_beta.SigmoidBeta with TFPDistributionMixin.

SinhArcsinh

class SinhArcsinh(loc, scale, skewness=None, tailweight=None, distribution=None, validate_args=False, allow_nan_stats=True, name='SinhArcsinh')

Wraps tensorflow_probability.substrates.jax.distributions.sinh_arcsinh.SinhArcsinh with TFPDistributionMixin.

Skellam

class Skellam(rate1=None, rate2=None, log_rate1=None, log_rate2=None, force_probs_to_zero_outside_support=False, validate_args=False, allow_nan_stats=True, name='Skellam')

Wraps tensorflow_probability.substrates.jax.distributions.skellam.Skellam with TFPDistributionMixin.

SphericalUniform

class SphericalUniform(dimension, batch_shape=(), dtype=<class 'jax._src.numpy.lax_numpy.float32'>, validate_args=False, allow_nan_stats=True, name='SphericalUniform')

Wraps tensorflow_probability.substrates.jax.distributions.spherical_uniform.SphericalUniform with TFPDistributionMixin.

StoppingRatioLogistic

class StoppingRatioLogistic(cutpoints, loc, dtype=<class 'jax._src.numpy.lax_numpy.int32'>, validate_args=False, allow_nan_stats=True, name='StoppingRatioLogistic')

Wraps tensorflow_probability.substrates.jax.distributions.stopping_ratio_logistic.StoppingRatioLogistic with TFPDistributionMixin.

StudentT

class StudentT(df, loc, scale, validate_args=False, allow_nan_stats=True, name='StudentT')

Wraps tensorflow_probability.substrates.jax.distributions.student_t.StudentT with TFPDistributionMixin.

StudentTProcess

class StudentTProcess(df, kernel, index_points=None, mean_fn=None, jitter=1e-06, validate_args=False, allow_nan_stats=False, name='StudentTProcess')

Wraps tensorflow_probability.substrates.jax.distributions.student_t_process.StudentTProcess with TFPDistributionMixin.

TransformedDistribution

class TransformedDistribution(distribution, bijector, kwargs_split_fn=<function _default_kwargs_split_fn>, validate_args=False, parameters=None, name=None)

Wraps tensorflow_probability.substrates.jax.distributions.transformed_distribution.TransformedDistribution with TFPDistributionMixin.

Triangular

class Triangular(low=0.0, high=1.0, peak=0.5, validate_args=False, allow_nan_stats=True, name='Triangular')

Wraps tensorflow_probability.substrates.jax.distributions.triangular.Triangular with TFPDistributionMixin.

TruncatedCauchy

class TruncatedCauchy(loc, scale, low, high, validate_args=False, allow_nan_stats=True, name='TruncatedCauchy')

Wraps tensorflow_probability.substrates.jax.distributions.truncated_cauchy.TruncatedCauchy with TFPDistributionMixin.

TruncatedNormal

class TruncatedNormal(loc, scale, low, high, validate_args=False, allow_nan_stats=True, name='TruncatedNormal')

Wraps tensorflow_probability.substrates.jax.distributions.truncated_normal.TruncatedNormal with TFPDistributionMixin.

Uniform

class Uniform(low=0.0, high=1.0, validate_args=False, allow_nan_stats=True, name='Uniform')

Wraps tensorflow_probability.substrates.jax.distributions.uniform.Uniform with TFPDistributionMixin.

VariationalGaussianProcess

class VariationalGaussianProcess(kernel, index_points, inducing_index_points, variational_inducing_observations_loc, variational_inducing_observations_scale, mean_fn=None, observation_noise_variance=None, predictive_noise_variance=None, jitter=1e-06, validate_args=False, allow_nan_stats=False, name='VariationalGaussianProcess')

Wraps tensorflow_probability.substrates.jax.distributions.variational_gaussian_process.VariationalGaussianProcess with TFPDistributionMixin.

VectorDeterministic

class VectorDeterministic(loc, atol=None, rtol=None, validate_args=False, allow_nan_stats=True, name='VectorDeterministic')

Wraps tensorflow_probability.substrates.jax.distributions.deterministic.VectorDeterministic with TFPDistributionMixin.

VectorExponentialDiag

class VectorExponentialDiag(loc=None, scale_diag=None, scale_identity_multiplier=None, validate_args=False, allow_nan_stats=True, name='VectorExponentialDiag')

Wraps tensorflow_probability.substrates.jax.distributions.vector_exponential_diag.VectorExponentialDiag with TFPDistributionMixin.

VonMises

class VonMises(loc, concentration, validate_args=False, allow_nan_stats=True, name='VonMises')

Wraps tensorflow_probability.substrates.jax.distributions.von_mises.VonMises with TFPDistributionMixin.

VonMisesFisher

class VonMisesFisher(mean_direction, concentration, validate_args=False, allow_nan_stats=True, name='VonMisesFisher')

Wraps tensorflow_probability.substrates.jax.distributions.von_mises_fisher.VonMisesFisher with TFPDistributionMixin.

Weibull

class Weibull(concentration, scale, validate_args=False, allow_nan_stats=True, name='Weibull')

Wraps tensorflow_probability.substrates.jax.distributions.weibull.Weibull with TFPDistributionMixin.

WishartLinearOperator

class WishartLinearOperator(df, scale, input_output_cholesky=False, validate_args=False, allow_nan_stats=True, name='WishartLinearOperator')

Wraps tensorflow_probability.substrates.jax.distributions.wishart.WishartLinearOperator with TFPDistributionMixin.

WishartTriL

class WishartTriL(df, scale_tril=None, input_output_cholesky=False, validate_args=False, allow_nan_stats=True, name='WishartTriL')

Wraps tensorflow_probability.substrates.jax.distributions.wishart.WishartTriL with TFPDistributionMixin.

Zipf

class Zipf(power, dtype=<class 'jax._src.numpy.lax_numpy.int32'>, force_probs_to_zero_outside_support=None, interpolate_nondiscrete=True, sample_maximum_iterations=100, validate_args=False, allow_nan_stats=False, name='Zipf')

Wraps tensorflow_probability.substrates.jax.distributions.zipf.Zipf with TFPDistributionMixin.

Constraints

Constraint

class Constraint[source]

Bases: object

Abstract base class for constraints.

A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.

event_dim = 0
check(value)[source]

Returns a byte tensor of sample_shape + batch_shape indicating whether each event in value satisfies this constraint.

feasible_like(prototype)[source]

Get a feasible value which has the same shape as dtype as prototype.

boolean

boolean = <numpyro.distributions.constraints._Boolean object>

corr_cholesky

corr_cholesky = <numpyro.distributions.constraints._CorrCholesky object>

corr_matrix

corr_matrix = <numpyro.distributions.constraints._CorrMatrix object>

dependent

dependent = <numpyro.distributions.constraints._Dependent object>

Placeholder for variables whose support depends on other variables. These variables obey no simple coordinate-wise constraints.

Parameters:
  • is_discrete (bool) – Optional value of .is_discrete in case this can be computed statically. If not provided, access to the .is_discrete attribute will raise a NotImplementedError.
  • event_dim (int) – Optional value of .event_dim in case this can be computed statically. If not provided, access to the .event_dim attribute will raise a NotImplementedError.

greater_than

greater_than(lower_bound)

Abstract base class for constraints.

A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.

integer_interval

integer_interval(lower_bound, upper_bound)

Abstract base class for constraints.

A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.

integer_greater_than

integer_greater_than(lower_bound)

Abstract base class for constraints.

A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.

interval

interval(lower_bound, upper_bound)

Abstract base class for constraints.

A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.

less_than

less_than(upper_bound)

Abstract base class for constraints.

A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.

lower_cholesky

lower_cholesky = <numpyro.distributions.constraints._LowerCholesky object>

multinomial

multinomial(upper_bound)

Abstract base class for constraints.

A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.

nonnegative_integer

nonnegative_integer = <numpyro.distributions.constraints._IntegerGreaterThan object>

ordered_vector

ordered_vector = <numpyro.distributions.constraints._OrderedVector object>

positive

positive = <numpyro.distributions.constraints._GreaterThan object>

positive_definite

positive_definite = <numpyro.distributions.constraints._PositiveDefinite object>

positive_integer

positive_integer = <numpyro.distributions.constraints._IntegerGreaterThan object>

positive_ordered_vector

positive_ordered_vector = <numpyro.distributions.constraints._PositiveOrderedVector object>

Constrains to a positive real-valued tensor where the elements are monotonically increasing along the event_shape dimension.

real

real = <numpyro.distributions.constraints._Real object>

real_vector

real_vector = <numpyro.distributions.constraints._IndependentConstraint object>

Wraps a constraint by aggregating over reinterpreted_batch_ndims-many dims in check(), so that an event is valid only if all its independent entries are valid.

simplex

simplex = <numpyro.distributions.constraints._Simplex object>

sphere

sphere = <numpyro.distributions.constraints._Sphere object>

Constrain to the Euclidean sphere of any dimension.

unit_interval

unit_interval = <numpyro.distributions.constraints._Interval object>

Transforms

biject_to

biject_to(constraint)

Transform

class Transform[source]

Bases: object

domain = <numpyro.distributions.constraints._Real object>
codomain = <numpyro.distributions.constraints._Real object>
event_dim
inv
log_abs_det_jacobian(x, y, intermediates=None)[source]
call_with_intermediates(x)[source]
forward_shape(shape)[source]

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

AbsTransform

class AbsTransform[source]

Bases: numpyro.distributions.transforms.Transform

domain = <numpyro.distributions.constraints._Real object>
codomain = <numpyro.distributions.constraints._GreaterThan object>

AffineTransform

class AffineTransform(loc, scale, domain=<numpyro.distributions.constraints._Real object>)[source]

Bases: numpyro.distributions.transforms.Transform

Note

When scale is a JAX tracer, we always assume that scale > 0 when calculating codomain.

codomain
log_abs_det_jacobian(x, y, intermediates=None)[source]
forward_shape(shape)[source]

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

CholeskyTransform

class CholeskyTransform[source]

Bases: numpyro.distributions.transforms.Transform

Transform via the mapping \(y = cholesky(x)\), where x is a positive definite matrix.

domain = <numpyro.distributions.constraints._PositiveDefinite object>
codomain = <numpyro.distributions.constraints._LowerCholesky object>
log_abs_det_jacobian(x, y, intermediates=None)[source]

ComposeTransform

class ComposeTransform(parts)[source]

Bases: numpyro.distributions.transforms.Transform

domain
codomain
log_abs_det_jacobian(x, y, intermediates=None)[source]
call_with_intermediates(x)[source]
forward_shape(shape)[source]

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

CorrCholeskyTransform

class CorrCholeskyTransform[source]

Bases: numpyro.distributions.transforms.Transform

Transforms a uncontrained real vector \(x\) with length \(D*(D-1)/2\) into the Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower triangular matrix with positive diagonals and unit Euclidean norm for each row. The transform is processed as follows:

  1. First we convert \(x\) into a lower triangular matrix with the following order:
\[\begin{split}\begin{bmatrix} 1 & 0 & 0 & 0 \\ x_0 & 1 & 0 & 0 \\ x_1 & x_2 & 1 & 0 \\ x_3 & x_4 & x_5 & 1 \end{bmatrix}\end{split}\]

2. For each row \(X_i\) of the lower triangular part, we apply a signed version of class StickBreakingTransform to transform \(X_i\) into a unit Euclidean length vector using the following steps:

  1. Scales into the interval \((-1, 1)\) domain: \(r_i = \tanh(X_i)\).
  2. Transforms into an unsigned domain: \(z_i = r_i^2\).
  3. Applies \(s_i = StickBreakingTransform(z_i)\).
  4. Transforms back into signed domain: \(y_i = (sign(r_i), 1) * \sqrt{s_i}\).
domain = <numpyro.distributions.constraints._IndependentConstraint object>
codomain = <numpyro.distributions.constraints._CorrCholesky object>
log_abs_det_jacobian(x, y, intermediates=None)[source]
forward_shape(shape)[source]

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

CorrMatrixCholeskyTransform

class CorrMatrixCholeskyTransform[source]

Bases: numpyro.distributions.transforms.CholeskyTransform

Transform via the mapping \(y = cholesky(x)\), where x is a correlation matrix.

domain = <numpyro.distributions.constraints._CorrMatrix object>
codomain = <numpyro.distributions.constraints._CorrCholesky object>
log_abs_det_jacobian(x, y, intermediates=None)[source]

ExpTransform

class ExpTransform(domain=<numpyro.distributions.constraints._Real object>)[source]

Bases: numpyro.distributions.transforms.Transform

codomain
log_abs_det_jacobian(x, y, intermediates=None)[source]

SoftplusTransform

class SoftplusTransform[source]

Bases: numpyro.distributions.transforms.Transform

Transform from unconstrained space to positive domain via softplus \(y = \log(1 + \exp(x))\). The inverse is computed as \(x = \log(\exp(y) - 1)\).

domain = <numpyro.distributions.constraints._Real object>
codomain = <numpyro.distributions.constraints._SoftplusPositive object>
log_abs_det_jacobian(x, y, intermediates=None)[source]

IdentityTransform

class IdentityTransform[source]

Bases: numpyro.distributions.transforms.Transform

log_abs_det_jacobian(x, y, intermediates=None)[source]

InvCholeskyTransform

class InvCholeskyTransform(domain=<numpyro.distributions.constraints._LowerCholesky object>)[source]

Bases: numpyro.distributions.transforms.Transform

Transform via the mapping \(y = x @ x.T\), where x is a lower triangular matrix with positive diagonal.

codomain
log_abs_det_jacobian(x, y, intermediates=None)[source]

LowerCholeskyAffine

class LowerCholeskyAffine(loc, scale_tril)[source]

Bases: numpyro.distributions.transforms.Transform

Transform via the mapping \(y = loc + scale\_tril\ @\ x\).

Parameters:
  • loc – a real vector.
  • scale_tril – a lower triangular matrix with positive diagonal.
domain = <numpyro.distributions.constraints._IndependentConstraint object>
codomain = <numpyro.distributions.constraints._IndependentConstraint object>
log_abs_det_jacobian(x, y, intermediates=None)[source]
forward_shape(shape)[source]

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

LowerCholeskyTransform

class LowerCholeskyTransform[source]

Bases: numpyro.distributions.transforms.Transform

domain = <numpyro.distributions.constraints._IndependentConstraint object>
codomain = <numpyro.distributions.constraints._LowerCholesky object>
log_abs_det_jacobian(x, y, intermediates=None)[source]
forward_shape(shape)[source]

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

OrderedTransform

class OrderedTransform[source]

Bases: numpyro.distributions.transforms.Transform

Transform a real vector to an ordered vector.

References:

  1. Stan Reference Manual v2.20, section 10.6, Stan Development Team
domain = <numpyro.distributions.constraints._IndependentConstraint object>
codomain = <numpyro.distributions.constraints._OrderedVector object>
log_abs_det_jacobian(x, y, intermediates=None)[source]

PermuteTransform

class PermuteTransform(permutation)[source]

Bases: numpyro.distributions.transforms.Transform

domain = <numpyro.distributions.constraints._IndependentConstraint object>
codomain = <numpyro.distributions.constraints._IndependentConstraint object>
log_abs_det_jacobian(x, y, intermediates=None)[source]

PowerTransform

class PowerTransform(exponent)[source]

Bases: numpyro.distributions.transforms.Transform

domain = <numpyro.distributions.constraints._GreaterThan object>
codomain = <numpyro.distributions.constraints._GreaterThan object>
log_abs_det_jacobian(x, y, intermediates=None)[source]
forward_shape(shape)[source]

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

SigmoidTransform

class SigmoidTransform[source]

Bases: numpyro.distributions.transforms.Transform

codomain = <numpyro.distributions.constraints._Interval object>
log_abs_det_jacobian(x, y, intermediates=None)[source]

SoftplusLowerCholeskyTransform

class SoftplusLowerCholeskyTransform[source]

Bases: numpyro.distributions.transforms.Transform

Transform from unconstrained vector to lower-triangular matrices with nonnegative diagonal entries. This is useful for parameterizing positive definite matrices in terms of their Cholesky factorization.

domain = <numpyro.distributions.constraints._IndependentConstraint object>
codomain = <numpyro.distributions.constraints._SoftplusLowerCholesky object>
log_abs_det_jacobian(x, y, intermediates=None)[source]
forward_shape(shape)[source]

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

SoftplusTransform

class SoftplusTransform[source]

Bases: numpyro.distributions.transforms.Transform

Transform from unconstrained space to positive domain via softplus \(y = \log(1 + \exp(x))\). The inverse is computed as \(x = \log(\exp(y) - 1)\).

domain = <numpyro.distributions.constraints._Real object>
codomain = <numpyro.distributions.constraints._SoftplusPositive object>
log_abs_det_jacobian(x, y, intermediates=None)[source]

StickBreakingTransform

class StickBreakingTransform[source]

Bases: numpyro.distributions.transforms.Transform

domain = <numpyro.distributions.constraints._IndependentConstraint object>
codomain = <numpyro.distributions.constraints._Simplex object>
log_abs_det_jacobian(x, y, intermediates=None)[source]
forward_shape(shape)[source]

Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

inverse_shape(shape)[source]

Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.

Flows

InverseAutoregressiveTransform

class InverseAutoregressiveTransform(autoregressive_nn, log_scale_min_clip=-5.0, log_scale_max_clip=3.0)[source]

Bases: numpyro.distributions.transforms.Transform

An implementation of Inverse Autoregressive Flow, using Eq (10) from Kingma et al., 2016,

\(\mathbf{y} = \mu_t + \sigma_t\odot\mathbf{x}\)

where \(\mathbf{x}\) are the inputs, \(\mathbf{y}\) are the outputs, \(\mu_t,\sigma_t\) are calculated from an autoregressive network on \(\mathbf{x}\), and \(\sigma_t>0\).

References

  1. Improving Variational Inference with Inverse Autoregressive Flow [arXiv:1606.04934], Diederik P. Kingma, Tim Salimans, Rafal Jozefowicz, Xi Chen, Ilya Sutskever, Max Welling
domain = <numpyro.distributions.constraints._IndependentConstraint object>
codomain = <numpyro.distributions.constraints._IndependentConstraint object>
call_with_intermediates(x)[source]
log_abs_det_jacobian(x, y, intermediates=None)[source]

Calculates the elementwise determinant of the log jacobian.

Parameters:

BlockNeuralAutoregressiveTransform

class BlockNeuralAutoregressiveTransform(bn_arn)[source]

Bases: numpyro.distributions.transforms.Transform

An implementation of Block Neural Autoregressive flow.

References

  1. Block Neural Autoregressive Flow, Nicola De Cao, Ivan Titov, Wilker Aziz
domain = <numpyro.distributions.constraints._IndependentConstraint object>
codomain = <numpyro.distributions.constraints._IndependentConstraint object>
call_with_intermediates(x)[source]
log_abs_det_jacobian(x, y, intermediates=None)[source]

Calculates the elementwise determinant of the log jacobian.

Parameters: