Base Distribution

Distribution

class Distribution(batch_shape=(), event_shape=(), validate_args=None)[source]

Bases: object

Base class for probability distributions in NumPyro. The design largely follows from torch.distributions.

Parameters:
  • batch_shape – The batch shape for the distribution. This designates independent (possibly non-identical) dimensions of a sample from the distribution. This is fixed for a distribution instance and is inferred from the shape of the distribution parameters.
  • event_shape – The event shape for the distribution. This designates the dependent dimensions of a sample from the distribution. These are collapsed when we evaluate the log probability density of a batch of samples using .log_prob.
  • validate_args – Whether to enable validation of distribution parameters and arguments to .log_prob method.

As an example:

>>> d = dist.Dirichlet(np.ones((2, 3, 4)))
>>> d.batch_shape
(2, 3)
>>> d.event_shape
(4,)
arg_constraints = {}
support = None
reparametrized_params = []
batch_shape

Returns the shape over which the distribution parameters are batched.

Returns:batch shape of the distribution.
Return type:tuple
event_shape

Returns the shape of a single sample from the distribution without batching.

Returns:event shape of the distribution.
Return type:tuple
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng key to be used for the distribution.
  • size – the sample shape for the distribution.
Returns:

a numpy.ndarray of shape sample_shape + batch_shape + event_shape

log_prob(value)[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:value – A batch of samples from the distribution.
Returns:a numpy.ndarray with shape value.shape[:-self.event_shape]
mean

Mean of the distribution.

variance

Variance of the distribution.

TransformedDistribution

class TransformedDistribution(base_distribution, transforms, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

Returns a distribution instance obtained as a result of applying a sequence of transforms to a base distribution. For an example, see LogNormal and HalfNormal.

Parameters:
  • base_distribution – the base distribution over which to apply transforms.
  • transforms – a single transform or a list of transforms.
  • validate_args – Whether to enable validation of distribution parameters and arguments to .log_prob method.
arg_constraints = {}
support
is_reparametrized
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(value)[source]

See numpyro.distributions.distribution.Distribution.log_prob()

mean
variance

Continuous Distributions

Beta

class Beta(concentration1, concentration0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'concentration0': <numpyro.distributions.constraints._GreaterThan object>, 'concentration1': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._Interval object>
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(value)[source]

See numpyro.distributions.distribution.Distribution.log_prob()

mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

Cauchy

class Cauchy(loc=0.0, scale=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._Real object>
reparametrized_params = ['loc', 'scale']
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(value)[source]

See numpyro.distributions.distribution.Distribution.log_prob()

mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

Chi2

class Chi2(df, validate_args=None)[source]

Bases: numpyro.distributions.continuous.Gamma

arg_constraints = {'df': <numpyro.distributions.constraints._GreaterThan object>}

Dirichlet

class Dirichlet(concentration, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'concentration': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._Simplex object>
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(value)[source]

See numpyro.distributions.distribution.Distribution.log_prob()

mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

Exponential

class Exponential(rate=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

reparametrized_params = ['rate']
arg_constraints = {'rate': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._GreaterThan object>
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(value)[source]

See numpyro.distributions.distribution.Distribution.log_prob()

mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

Gamma

class Gamma(concentration, rate=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'concentration': <numpyro.distributions.constraints._GreaterThan object>, 'rate': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._GreaterThan object>
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(value)[source]

See numpyro.distributions.distribution.Distribution.log_prob()

mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

GaussianRandomWalk

class GaussianRandomWalk(scale=1.0, num_steps=1, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'num_steps': <numpyro.distributions.constraints._IntegerGreaterThan object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._Real object>
reparametrized_params = ['scale']
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(value)[source]

See numpyro.distributions.distribution.Distribution.log_prob()

mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

HalfCauchy

class HalfCauchy(scale=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.TransformedDistribution

reparametrized_params = ['scale']
arg_constraints = {'scale': <numpyro.distributions.constraints._GreaterThan object>}
log_prob(value)[source]

See numpyro.distributions.distribution.Distribution.log_prob()

mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

HalfNormal

class HalfNormal(scale=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.TransformedDistribution

reparametrized_params = ['scale']
arg_constraints = {'scale': <numpyro.distributions.constraints._GreaterThan object>}
log_prob(value)[source]

See numpyro.distributions.distribution.Distribution.log_prob()

mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

LKJCholesky

class LKJCholesky(dimension, concentration=1.0, sample_method='onion', validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

LKJ distribution for lower Cholesky factors of correlation matrices. The distribution is controlled by concentration parameter \(\eta\) to make the probability of the correlation matrix \(M\) generated from a Cholesky factor propotional to \(\det(M)^{\eta - 1}\). Because of that, when concentration == 1, we have a uniform distribution over Cholesky factors of correlation matrices.

When concentration > 1, the distribution favors samples with large diagonal entries (hence large determinent). This is useful when we know a priori that the underlying variables are not correlated.

When concentration < 1, the distribution favors samples with small diagonal entries (hence small determinent). This is useful when we know a priori that some underlying variables are correlated.

Parameters:
  • dimension (int) – dimension of the matrices
  • concentration (ndarray) – concentration/shape parameter of the distribution (often referred to as eta)
  • sample_method (str) – Either “cvine” or “onion”. Both methods are proposed in [1] and offer the same distribution over correlation matrices. But they are different in how to generate samples. Defaults to “onion”.

References

[1] Generating random correlation matrices based on vines and extended onion method, Daniel Lewandowski, Dorota Kurowicka, Harry Joe

arg_constraints = {'concentration': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._CorrCholesky object>
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(value)[source]

See numpyro.distributions.distribution.Distribution.log_prob()

LogNormal

class LogNormal(loc=0.0, scale=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.TransformedDistribution

arg_constraints = {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
reparametrized_params = ['loc', 'scale']
mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

Normal

class Normal(loc=0.0, scale=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._Real object>
reparametrized_params = ['loc', 'scale']
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(value)[source]

See numpyro.distributions.distribution.Distribution.log_prob()

mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

Pareto

class Pareto(alpha, scale=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.TransformedDistribution

arg_constraints = {'alpha': <numpyro.distributions.constraints._GreaterThan object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

support

StudentT

class StudentT(df, loc=0.0, scale=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'df': <numpyro.distributions.constraints._GreaterThan object>, 'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._Real object>
reparametrized_params = ['loc', 'scale']
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(value)[source]

See numpyro.distributions.distribution.Distribution.log_prob()

mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

TruncatedCauchy

class TruncatedCauchy(low=0.0, loc=0.0, scale=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'loc': <numpyro.distributions.constraints._Real object>, 'low': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
reparametrized_params = ['low', 'loc', 'scale']
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(value)[source]

See numpyro.distributions.distribution.Distribution.log_prob()

mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

support

TruncatedNormal

class TruncatedNormal(low=0.0, loc=0.0, scale=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'loc': <numpyro.distributions.constraints._Real object>, 'low': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
reparametrized_params = ['low', 'loc', 'scale']
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(value)[source]

See numpyro.distributions.distribution.Distribution.log_prob()

mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

support

Uniform

class Uniform(low=0.0, high=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'high': <numpyro.distributions.constraints._Dependent object>, 'low': <numpyro.distributions.constraints._Dependent object>}
reparametrized_params = ['low', 'high']
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(value)[source]

See numpyro.distributions.distribution.Distribution.log_prob()

mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

support

Discrete Distributions

Bernoulli

Bernoulli(probs=None, logits=None, validate_args=None)[source]

BernoulliLogits

class BernoulliLogits(logits=None, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'logits': <numpyro.distributions.constraints._Real object>}
support = <numpyro.distributions.constraints._Boolean object>
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(value)[source]

See numpyro.distributions.distribution.Distribution.log_prob()

probs[source]
mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

BernoulliProbs

class BernoulliProbs(probs, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'probs': <numpyro.distributions.constraints._Interval object>}
support = <numpyro.distributions.constraints._Boolean object>
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(value)[source]

See numpyro.distributions.distribution.Distribution.log_prob()

mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

Binomial

Binomial(total_count=1, probs=None, logits=None, validate_args=None)[source]

BinomialLogits

class BinomialLogits(logits, total_count=1, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'logits': <numpyro.distributions.constraints._Real object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(value)[source]

See numpyro.distributions.distribution.Distribution.log_prob()

probs[source]
mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

support

BinomialProbs

class BinomialProbs(probs, total_count=1, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'probs': <numpyro.distributions.constraints._Interval object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(value)[source]

See numpyro.distributions.distribution.Distribution.log_prob()

mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

support

Categorical

Categorical(probs=None, logits=None, validate_args=None)[source]

CategoricalLogits

class CategoricalLogits(logits, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'logits': <numpyro.distributions.constraints._Real object>}
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(value)[source]

See numpyro.distributions.distribution.Distribution.log_prob()

probs[source]
mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

support

CategoricalProbs

class CategoricalProbs(probs, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'probs': <numpyro.distributions.constraints._Simplex object>}
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(value)[source]

See numpyro.distributions.distribution.Distribution.log_prob()

mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

support

Multinomial

Multinomial(total_count=1, probs=None, logits=None, validate_args=None)[source]

MultinomialLogits

class MultinomialLogits(logits, total_count=1, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'logits': <numpyro.distributions.constraints._Real object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(value)[source]

See numpyro.distributions.distribution.Distribution.log_prob()

probs[source]
mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

support

MultinomialProbs

class MultinomialProbs(probs, total_count=1, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'probs': <numpyro.distributions.constraints._Simplex object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(value)[source]

See numpyro.distributions.distribution.Distribution.log_prob()

mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

support

Poisson

class Poisson(rate, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'rate': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._IntegerGreaterThan object>
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(value)[source]

See numpyro.distributions.distribution.Distribution.log_prob()

mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()