Base Distribution

Distribution

class Distribution(batch_shape=(), event_shape=(), validate_args=None)[source]

Bases: object

Base class for probability distributions in NumPyro. The design largely follows from torch.distributions.

Parameters:
  • batch_shape – The batch shape for the distribution. This designates independent (possibly non-identical) dimensions of a sample from the distribution. This is fixed for a distribution instance and is inferred from the shape of the distribution parameters.
  • event_shape – The event shape for the distribution. This designates the dependent dimensions of a sample from the distribution. These are collapsed when we evaluate the log probability density of a batch of samples using .log_prob.
  • validate_args – Whether to enable validation of distribution parameters and arguments to .log_prob method.

As an example:

>>> import jax.numpy as np
>>> import numpyro.distributions as dist
>>> d = dist.Dirichlet(np.ones((2, 3, 4)))
>>> d.batch_shape
(2, 3)
>>> d.event_shape
(4,)
arg_constraints = {}
support = None
reparametrized_params = []
static set_default_validate_args(value)[source]
batch_shape

Returns the shape over which the distribution parameters are batched.

Returns:batch shape of the distribution.
Return type:tuple
event_shape

Returns the shape of a single sample from the distribution without batching.

Returns:event shape of the distribution.
Return type:tuple
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

sample_with_intermediates(key, sample_shape=())[source]

Same as sample except that any intermediate computations are returned (useful for TransformedDistribution).

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
  • sample_shape (tuple) – the sample shape for the distribution.
Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

numpy.ndarray

transform_with_intermediates(base_value)[source]
log_prob(value)[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:value – A batch of samples from the distribution.
Returns:an array with shape value.shape[:-self.event_shape]
Return type:numpy.ndarray
mean

Mean of the distribution.

variance

Variance of the distribution.

to_event(reinterpreted_batch_ndims=None)[source]

Interpret the rightmost reinterpreted_batch_ndims batch dimensions as dependent event dimensions.

Parameters:reinterpreted_batch_ndims – Number of rightmost batch dims to interpret as event dims.
Returns:An instance of Independent distribution.
Return type:Independent

Independent

class Independent(base_dist, reinterpreted_batch_ndims, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

Reinterprets batch dimensions of a distribution as event dims by shifting the batch-event dim boundary further to the left.

From a practical standpoint, this is useful when changing the result of log_prob(). For example, a univariate Normal distribution can be interpreted as a multivariate Normal with diagonal covariance:

>>> import numpyro.distributions as dist
>>> normal = dist.Normal(np.zeros(3), np.ones(3))
>>> [normal.batch_shape, normal.event_shape]
[(3,), ()]
>>> diag_normal = dist.Independent(normal, 1)
>>> [diag_normal.batch_shape, diag_normal.event_shape]
[(), (3,)]
Parameters:
  • base_distribution (numpyro.distribution.Distribution) – a distribution instance.
  • reinterpreted_batch_ndims (int) – the number of batch dims to reinterpret as event dims.
arg_constraints = {}
support
reparameterized_params
mean
variance
sample(key, sample_shape=())[source]
log_prob(value)[source]

TransformedDistribution

class TransformedDistribution(base_distribution, transforms, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

Returns a distribution instance obtained as a result of applying a sequence of transforms to a base distribution. For an example, see LogNormal and HalfNormal.

Parameters:
  • base_distribution – the base distribution over which to apply transforms.
  • transforms – a single transform or a list of transforms.
  • validate_args – Whether to enable validation of distribution parameters and arguments to .log_prob method.
arg_constraints = {}
support
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

sample_with_intermediates(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample_with_intermediates()

transform_with_intermediates(base_value)[source]
log_prob(*args, **kwargs)

See numpyro.distributions.distribution.Distribution.log_prob()

mean
variance

Unit

class Unit(log_factor, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

Trivial nonnormalized distribution representing the unit type.

The unit type has a single value with no data, i.e. value.size == 0.

This is used for numpyro.factor() statements.

arg_constraints = {'log_factor': <numpyro.distributions.constraints._Real object>}
support = <numpyro.distributions.constraints._Real object>
sample(key, sample_shape=())[source]
log_prob(value)[source]

Continuous Distributions

Beta

class Beta(concentration1, concentration0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'concentration0': <numpyro.distributions.constraints._GreaterThan object>, 'concentration1': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._Interval object>
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(*args, **kwargs)

See numpyro.distributions.distribution.Distribution.log_prob()

mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

Cauchy

class Cauchy(loc=0.0, scale=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._Real object>
reparametrized_params = ['loc', 'scale']
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(*args, **kwargs)

See numpyro.distributions.distribution.Distribution.log_prob()

mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

Chi2

class Chi2(df, validate_args=None)[source]

Bases: numpyro.distributions.continuous.Gamma

arg_constraints = {'df': <numpyro.distributions.constraints._GreaterThan object>}

Dirichlet

class Dirichlet(concentration, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'concentration': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._Simplex object>
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(*args, **kwargs)

See numpyro.distributions.distribution.Distribution.log_prob()

mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

Exponential

class Exponential(rate=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

reparametrized_params = ['rate']
arg_constraints = {'rate': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._GreaterThan object>
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(*args, **kwargs)

See numpyro.distributions.distribution.Distribution.log_prob()

mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

Gamma

class Gamma(concentration, rate=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'concentration': <numpyro.distributions.constraints._GreaterThan object>, 'rate': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._GreaterThan object>
reparametrized_params = ['rate']
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(*args, **kwargs)

See numpyro.distributions.distribution.Distribution.log_prob()

mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

GaussianRandomWalk

class GaussianRandomWalk(scale=1.0, num_steps=1, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'num_steps': <numpyro.distributions.constraints._IntegerGreaterThan object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._RealVector object>
reparametrized_params = ['scale']
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(*args, **kwargs)

See numpyro.distributions.distribution.Distribution.log_prob()

mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

HalfCauchy

class HalfCauchy(scale=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

reparametrized_params = ['scale']
support = <numpyro.distributions.constraints._GreaterThan object>
arg_constraints = {'scale': <numpyro.distributions.constraints._GreaterThan object>}
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(*args, **kwargs)

See numpyro.distributions.distribution.Distribution.log_prob()

mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

HalfNormal

class HalfNormal(scale=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

reparametrized_params = ['scale']
support = <numpyro.distributions.constraints._GreaterThan object>
arg_constraints = {'scale': <numpyro.distributions.constraints._GreaterThan object>}
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(*args, **kwargs)

See numpyro.distributions.distribution.Distribution.log_prob()

mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

InverseGamma

class InverseGamma(concentration, rate=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.TransformedDistribution

arg_constraints = {'concentration': <numpyro.distributions.constraints._GreaterThan object>, 'rate': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._GreaterThan object>
reparametrized_params = ['rate']
mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

LKJ

class LKJ(dimension, concentration=1.0, sample_method='onion', validate_args=None)[source]

Bases: numpyro.distributions.distribution.TransformedDistribution

LKJ distribution for correlation matrices. The distribution is controlled by concentration parameter \(\eta\) to make the probability of the correlation matrix \(M\) propotional to \(\det(M)^{\eta - 1}\). Because of that, when concentration == 1, we have a uniform distribution over correlation matrices.

When concentration > 1, the distribution favors samples with large large determinent. This is useful when we know a priori that the underlying variables are not correlated.

When concentration < 1, the distribution favors samples with small determinent. This is useful when we know a priori that some underlying variables are correlated.

Parameters:
  • dimension (int) – dimension of the matrices
  • concentration (ndarray) – concentration/shape parameter of the distribution (often referred to as eta)
  • sample_method (str) – Either “cvine” or “onion”. Both methods are proposed in [1] and offer the same distribution over correlation matrices. But they are different in how to generate samples. Defaults to “onion”.

References

[1] Generating random correlation matrices based on vines and extended onion method, Daniel Lewandowski, Dorota Kurowicka, Harry Joe

arg_constraints = {'concentration': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._CorrMatrix object>
mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

LKJCholesky

class LKJCholesky(dimension, concentration=1.0, sample_method='onion', validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

LKJ distribution for lower Cholesky factors of correlation matrices. The distribution is controlled by concentration parameter \(\eta\) to make the probability of the correlation matrix \(M\) generated from a Cholesky factor propotional to \(\det(M)^{\eta - 1}\). Because of that, when concentration == 1, we have a uniform distribution over Cholesky factors of correlation matrices.

When concentration > 1, the distribution favors samples with large diagonal entries (hence large determinent). This is useful when we know a priori that the underlying variables are not correlated.

When concentration < 1, the distribution favors samples with small diagonal entries (hence small determinent). This is useful when we know a priori that some underlying variables are correlated.

Parameters:
  • dimension (int) – dimension of the matrices
  • concentration (ndarray) – concentration/shape parameter of the distribution (often referred to as eta)
  • sample_method (str) – Either “cvine” or “onion”. Both methods are proposed in [1] and offer the same distribution over correlation matrices. But they are different in how to generate samples. Defaults to “onion”.

References

[1] Generating random correlation matrices based on vines and extended onion method, Daniel Lewandowski, Dorota Kurowicka, Harry Joe

arg_constraints = {'concentration': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._CorrCholesky object>
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(*args, **kwargs)

See numpyro.distributions.distribution.Distribution.log_prob()

LogNormal

class LogNormal(loc=0.0, scale=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.TransformedDistribution

arg_constraints = {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
reparametrized_params = ['loc', 'scale']
mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

MultivariateNormal

class MultivariateNormal(loc=0.0, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'covariance_matrix': <numpyro.distributions.constraints._PositiveDefinite object>, 'loc': <numpyro.distributions.constraints._RealVector object>, 'precision_matrix': <numpyro.distributions.constraints._PositiveDefinite object>, 'scale_tril': <numpyro.distributions.constraints._LowerCholesky object>}
support = <numpyro.distributions.constraints._RealVector object>
reparametrized_params = ['loc', 'covariance_matrix', 'precision_matrix', 'scale_tril']
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(*args, **kwargs)

See numpyro.distributions.distribution.Distribution.log_prob()

covariance_matrix[source]
precision_matrix[source]
mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

LowRankMultivariateNormal

class LowRankMultivariateNormal(loc, cov_factor, cov_diag, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'cov_diag': <numpyro.distributions.constraints._GreaterThan object>, 'cov_factor': <numpyro.distributions.constraints._Real object>, 'loc': <numpyro.distributions.constraints._RealVector object>}
support = <numpyro.distributions.constraints._RealVector object>
mean

See numpyro.distributions.distribution.Distribution.mean()

variance[source]

See numpyro.distributions.distribution.Distribution.variance()

scale_tril[source]
covariance_matrix[source]
precision_matrix[source]
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(*args, **kwargs)

See numpyro.distributions.distribution.Distribution.log_prob()

entropy()[source]

Normal

class Normal(loc=0.0, scale=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._Real object>
reparametrized_params = ['loc', 'scale']
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(*args, **kwargs)

See numpyro.distributions.distribution.Distribution.log_prob()

icdf(q)[source]
mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

Pareto

class Pareto(alpha, scale=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.TransformedDistribution

arg_constraints = {'alpha': <numpyro.distributions.constraints._GreaterThan object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

support

StudentT

class StudentT(df, loc=0.0, scale=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'df': <numpyro.distributions.constraints._GreaterThan object>, 'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._Real object>
reparametrized_params = ['loc', 'scale']
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(*args, **kwargs)

See numpyro.distributions.distribution.Distribution.log_prob()

mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

TruncatedCauchy

class TruncatedCauchy(low=0.0, loc=0.0, scale=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.TransformedDistribution

arg_constraints = {'loc': <numpyro.distributions.constraints._Real object>, 'low': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
reparametrized_params = ['low', 'loc', 'scale']
mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

TruncatedNormal

class TruncatedNormal(low=0.0, loc=0.0, scale=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.TransformedDistribution

arg_constraints = {'loc': <numpyro.distributions.constraints._Real object>, 'low': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}
reparametrized_params = ['low', 'loc', 'scale']
mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

Uniform

class Uniform(low=0.0, high=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.TransformedDistribution

arg_constraints = {'high': <numpyro.distributions.constraints._Dependent object>, 'low': <numpyro.distributions.constraints._Dependent object>}
reparametrized_params = ['low', 'high']
mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

Discrete Distributions

Bernoulli

Bernoulli(probs=None, logits=None, validate_args=None)[source]

BernoulliLogits

class BernoulliLogits(logits=None, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'logits': <numpyro.distributions.constraints._Real object>}
support = <numpyro.distributions.constraints._Boolean object>
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(*args, **kwargs)

See numpyro.distributions.distribution.Distribution.log_prob()

probs[source]
mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

BernoulliProbs

class BernoulliProbs(probs, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'probs': <numpyro.distributions.constraints._Interval object>}
support = <numpyro.distributions.constraints._Boolean object>
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(*args, **kwargs)

See numpyro.distributions.distribution.Distribution.log_prob()

mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

BetaBinomial

class BetaBinomial(concentration1, concentration0, total_count=1, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

Compound distribution comprising of a beta-binomial pair. The probability of success (probs for the Binomial distribution) is unknown and randomly drawn from a Beta distribution prior to a certain number of Bernoulli trials given by total_count.

Parameters:
  • concentration1 (numpy.ndarray) – 1st concentration parameter (alpha) for the Beta distribution.
  • concentration0 (numpy.ndarray) – 2nd concentration parameter (beta) for the Beta distribution.
  • total_count (numpy.ndarray) – number of Bernoulli trials.
arg_constraints = {'concentration0': <numpyro.distributions.constraints._GreaterThan object>, 'concentration1': <numpyro.distributions.constraints._GreaterThan object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}
sample(key, sample_shape=())[source]
log_prob(*args, **kwargs)
mean
variance
support

Binomial

Binomial(total_count=1, probs=None, logits=None, validate_args=None)[source]

BinomialLogits

class BinomialLogits(logits, total_count=1, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'logits': <numpyro.distributions.constraints._Real object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(*args, **kwargs)

See numpyro.distributions.distribution.Distribution.log_prob()

probs[source]
mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

support

BinomialProbs

class BinomialProbs(probs, total_count=1, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'probs': <numpyro.distributions.constraints._Interval object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(*args, **kwargs)

See numpyro.distributions.distribution.Distribution.log_prob()

mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

support

Categorical

Categorical(probs=None, logits=None, validate_args=None)[source]

CategoricalLogits

class CategoricalLogits(logits, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'logits': <numpyro.distributions.constraints._Real object>}
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(*args, **kwargs)

See numpyro.distributions.distribution.Distribution.log_prob()

probs[source]
mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

support

CategoricalProbs

class CategoricalProbs(probs, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'probs': <numpyro.distributions.constraints._Simplex object>}
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(*args, **kwargs)

See numpyro.distributions.distribution.Distribution.log_prob()

mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

support

Delta

class Delta(value=0.0, log_density=0.0, event_ndim=0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'log_density': <numpyro.distributions.constraints._Real object>, 'value': <numpyro.distributions.constraints._Real object>}
support = <numpyro.distributions.constraints._Real object>
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(*args, **kwargs)

See numpyro.distributions.distribution.Distribution.log_prob()

mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

GammaPoisson

class GammaPoisson(concentration, rate=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

Compound distribution comprising of a gamma-poisson pair, also referred to as a gamma-poisson mixture. The rate parameter for the Poisson distribution is unknown and randomly drawn from a Gamma distribution.

Parameters:
  • concentration (numpy.ndarray) – shape parameter (alpha) of the Gamma distribution.
  • rate (numpy.ndarray) – rate parameter (beta) for the Gamma distribution.
arg_constraints = {'concentration': <numpyro.distributions.constraints._GreaterThan object>, 'rate': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._IntegerGreaterThan object>
sample(key, sample_shape=())[source]
log_prob(*args, **kwargs)
mean
variance

Multinomial

Multinomial(total_count=1, probs=None, logits=None, validate_args=None)[source]

MultinomialLogits

class MultinomialLogits(logits, total_count=1, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'logits': <numpyro.distributions.constraints._Real object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(*args, **kwargs)

See numpyro.distributions.distribution.Distribution.log_prob()

probs[source]
mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

support

MultinomialProbs

class MultinomialProbs(probs, total_count=1, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'probs': <numpyro.distributions.constraints._Simplex object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(*args, **kwargs)

See numpyro.distributions.distribution.Distribution.log_prob()

mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

support

OrderedLogistic

class OrderedLogistic(predictor, cutpoints, validate_args=None)[source]

Bases: numpyro.distributions.discrete.CategoricalProbs

A categorical distribution with ordered outcomes.

References:

  1. Stan Functions Reference, v2.20 section 12.6, Stan Development Team
Parameters:
  • predictor (numpy.ndarray) – prediction in real domain; typically this is output of a linear model.
  • cutpoints (numpy.ndarray) – positions in real domain to separate categories.
arg_constraints = {'cutpoints': <numpyro.distributions.constraints._OrderedVector object>, 'predictor': <numpyro.distributions.constraints._Real object>}

Poisson

class Poisson(rate, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

arg_constraints = {'rate': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._IntegerGreaterThan object>
sample(key, sample_shape=())[source]

See numpyro.distributions.distribution.Distribution.sample()

log_prob(*args, **kwargs)

See numpyro.distributions.distribution.Distribution.log_prob()

mean

See numpyro.distributions.distribution.Distribution.mean()

variance

See numpyro.distributions.distribution.Distribution.variance()

PRNGIdentity

class PRNGIdentity[source]

Bases: numpyro.distributions.distribution.Distribution

Distribution over PRNGKey(). This can be used to draw a batch of PRNGKey() using the seed handler. Only sample method is supported.

sample(key, sample_shape=())[source]

ZeroInflatedPoisson

class ZeroInflatedPoisson(gate, rate=1.0, validate_args=None)[source]

Bases: numpyro.distributions.distribution.Distribution

A Zero Inflated Poisson distribution.

Parameters:
arg_constraints = {'gate': <numpyro.distributions.constraints._Interval object>, 'rate': <numpyro.distributions.constraints._GreaterThan object>}
support = <numpyro.distributions.constraints._IntegerGreaterThan object>
sample(key, sample_shape=())[source]
log_prob(*args, **kwargs)
mean[source]
variance[source]

Constraints

boolean

boolean = <numpyro.distributions.constraints._Boolean object>

corr_cholesky

corr_cholesky = <numpyro.distributions.constraints._CorrCholesky object>

corr_matrix

corr_matrix = <numpyro.distributions.constraints._CorrMatrix object>

dependent

dependent = <numpyro.distributions.constraints._Dependent object>

greater_than

greater_than(lower_bound)

integer_interval

integer_interval(lower_bound, upper_bound)

integer_greater_than

integer_greater_than(lower_bound)

interval

interval(lower_bound, upper_bound)

lower_cholesky

lower_cholesky = <numpyro.distributions.constraints._LowerCholesky object>

multinomial

multinomial(upper_bound)

nonnegative_integer

nonnegative_integer = <numpyro.distributions.constraints._IntegerGreaterThan object>

ordered_vector

ordered_vector = <numpyro.distributions.constraints._OrderedVector object>

positive

positive = <numpyro.distributions.constraints._GreaterThan object>

positive_definite

positive_definite = <numpyro.distributions.constraints._PositiveDefinite object>

positive_integer

positive_integer = <numpyro.distributions.constraints._IntegerGreaterThan object>

real

real = <numpyro.distributions.constraints._Real object>

real_vector

real_vector = <numpyro.distributions.constraints._RealVector object>

simplex

simplex = <numpyro.distributions.constraints._Simplex object>

unit_interval

unit_interval = <numpyro.distributions.constraints._Interval object>

Transforms

biject_to

biject_to(constraint)

Transform

class Transform[source]

Bases: object

domain = <numpyro.distributions.constraints._Real object>
codomain = <numpyro.distributions.constraints._Real object>
event_dim = 0
inv(y)[source]
log_abs_det_jacobian(x, y, intermediates=None)[source]
call_with_intermediates(x)[source]

AbsTransform

class AbsTransform[source]

Bases: numpyro.distributions.transforms.Transform

domain = <numpyro.distributions.constraints._Real object>
codomain = <numpyro.distributions.constraints._GreaterThan object>
inv(y)[source]

AffineTransform

class AffineTransform(loc, scale, domain=<numpyro.distributions.constraints._Real object>)[source]

Bases: numpyro.distributions.transforms.Transform

codomain
event_dim
inv(y)[source]
log_abs_det_jacobian(x, y, intermediates=None)[source]

ComposeTransform

class ComposeTransform(parts)[source]

Bases: numpyro.distributions.transforms.Transform

domain
codomain
event_dim
inv(y)[source]
log_abs_det_jacobian(x, y, intermediates=None)[source]
call_with_intermediates(x)[source]

CorrCholeskyTransform

class CorrCholeskyTransform[source]

Bases: numpyro.distributions.transforms.Transform

Transforms a uncontrained real vector \(x\) with length \(D*(D-1)/2\) into the Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower triangular matrix with positive diagonals and unit Euclidean norm for each row. The transform is processed as follows:

  1. First we convert \(x\) into a lower triangular matrix with the following order:
\[\begin{split}\begin{bmatrix} 1 & 0 & 0 & 0 \\ x_0 & 1 & 0 & 0 \\ x_1 & x_2 & 1 & 0 \\ x_3 & x_4 & x_5 & 1 \end{bmatrix}\end{split}\]

2. For each row \(X_i\) of the lower triangular part, we apply a signed version of class StickBreakingTransform to transform \(X_i\) into a unit Euclidean length vector using the following steps:

  1. Scales into the interval \((-1, 1)\) domain: \(r_i = \tanh(X_i)\).
  2. Transforms into an unsigned domain: \(z_i = r_i^2\).
  3. Applies \(s_i = StickBreakingTransform(z_i)\).
  4. Transforms back into signed domain: \(y_i = (sign(r_i), 1) * \sqrt{s_i}\).
domain = <numpyro.distributions.constraints._RealVector object>
codomain = <numpyro.distributions.constraints._CorrCholesky object>
event_dim = 2
inv(y)[source]
log_abs_det_jacobian(x, y, intermediates=None)[source]

ExpTransform

class ExpTransform(domain=<numpyro.distributions.constraints._Real object>)[source]

Bases: numpyro.distributions.transforms.Transform

codomain
inv(y)[source]
log_abs_det_jacobian(x, y, intermediates=None)[source]

IdentityTransform

class IdentityTransform(event_dim=0)[source]

Bases: numpyro.distributions.transforms.Transform

inv(y)[source]
log_abs_det_jacobian(x, y, intermediates=None)[source]

InvCholeskyTransform

class InvCholeskyTransform(domain=<numpyro.distributions.constraints._LowerCholesky object>)[source]

Bases: numpyro.distributions.transforms.Transform

Transform via the mapping \(y = x @ x.T\), where x is a lower triangular matrix with positive diagonal.

event_dim = 2
codomain
inv(y)[source]
log_abs_det_jacobian(x, y, intermediates=None)[source]

LowerCholeskyTransform

class LowerCholeskyTransform[source]

Bases: numpyro.distributions.transforms.Transform

domain = <numpyro.distributions.constraints._RealVector object>
codomain = <numpyro.distributions.constraints._LowerCholesky object>
event_dim = 2
inv(y)[source]
log_abs_det_jacobian(x, y, intermediates=None)[source]

MultivariateAffineTransform

class MultivariateAffineTransform(loc, scale_tril)[source]

Bases: numpyro.distributions.transforms.Transform

Transform via the mapping \(y = loc + scale\_tril\ @\ x\).

Parameters:
  • loc – a real vector.
  • scale_tril – a lower triangular matrix with positive diagonal.
domain = <numpyro.distributions.constraints._RealVector object>
codomain = <numpyro.distributions.constraints._RealVector object>
event_dim = 1
inv(y)[source]
log_abs_det_jacobian(x, y, intermediates=None)[source]

OrderedTransform

class OrderedTransform[source]

Bases: numpyro.distributions.transforms.Transform

Transform a real vector to an ordered vector.

References:

  1. Stan Reference Manual v2.20, section 10.6, Stan Development Team
domain = <numpyro.distributions.constraints._RealVector object>
codomain = <numpyro.distributions.constraints._OrderedVector object>
event_dim = 1
inv(y)[source]
log_abs_det_jacobian(x, y, intermediates=None)[source]

PermuteTransform

class PermuteTransform(permutation)[source]

Bases: numpyro.distributions.transforms.Transform

domain = <numpyro.distributions.constraints._RealVector object>
codomain = <numpyro.distributions.constraints._RealVector object>
event_dim = 1
inv(y)[source]
log_abs_det_jacobian(x, y, intermediates=None)[source]

PowerTransform

class PowerTransform(exponent)[source]

Bases: numpyro.distributions.transforms.Transform

domain = <numpyro.distributions.constraints._GreaterThan object>
codomain = <numpyro.distributions.constraints._GreaterThan object>
inv(y)[source]
log_abs_det_jacobian(x, y, intermediates=None)[source]

SigmoidTransform

class SigmoidTransform[source]

Bases: numpyro.distributions.transforms.Transform

codomain = <numpyro.distributions.constraints._Interval object>
inv(y)[source]
log_abs_det_jacobian(x, y, intermediates=None)[source]

StickBreakingTransform

class StickBreakingTransform[source]

Bases: numpyro.distributions.transforms.Transform

domain = <numpyro.distributions.constraints._RealVector object>
codomain = <numpyro.distributions.constraints._Simplex object>
event_dim = 1
inv(y)[source]
log_abs_det_jacobian(x, y, intermediates=None)[source]

Flows

InverseAutoregressiveTransform

class InverseAutoregressiveTransform(autoregressive_nn, log_scale_min_clip=-5.0, log_scale_max_clip=3.0)[source]

Bases: numpyro.distributions.transforms.Transform

An implementation of Inverse Autoregressive Flow, using Eq (10) from Kingma et al., 2016,

\(\mathbf{y} = \mu_t + \sigma_t\odot\mathbf{x}\)

where \(\mathbf{x}\) are the inputs, \(\mathbf{y}\) are the outputs, \(\mu_t,\sigma_t\) are calculated from an autoregressive network on \(\mathbf{x}\), and \(\sigma_t>0\).

References

  1. Improving Variational Inference with Inverse Autoregressive Flow [arXiv:1606.04934], Diederik P. Kingma, Tim Salimans, Rafal Jozefowicz, Xi Chen, Ilya Sutskever, Max Welling
domain = <numpyro.distributions.constraints._RealVector object>
codomain = <numpyro.distributions.constraints._RealVector object>
event_dim = 1
call_with_intermediates(x)[source]
inv(y)[source]
Parameters:y (numpy.ndarray) – the output of the transform to be inverted
log_abs_det_jacobian(x, y, intermediates=None)[source]

Calculates the elementwise determinant of the log jacobian.

Parameters: