Distributions¶
Base Distribution¶
Distribution¶

class
Distribution
(batch_shape=(), event_shape=(), *, validate_args=None)[source]¶ Bases:
object
Base class for probability distributions in NumPyro. The design largely follows from
torch.distributions
.Parameters:  batch_shape – The batch shape for the distribution. This designates independent (possibly nonidentical) dimensions of a sample from the distribution. This is fixed for a distribution instance and is inferred from the shape of the distribution parameters.
 event_shape – The event shape for the distribution. This designates the dependent dimensions of a sample from the distribution. These are collapsed when we evaluate the log probability density of a batch of samples using .log_prob.
 validate_args – Whether to enable validation of distribution parameters and arguments to .log_prob method.
As an example:
>>> import jax.numpy as jnp >>> import numpyro.distributions as dist >>> d = dist.Dirichlet(jnp.ones((2, 3, 4))) >>> d.batch_shape (2, 3) >>> d.event_shape (4,)

arg_constraints
= {}¶

support
= None¶

has_enumerate_support
= False¶

reparametrized_params
= []¶

batch_shape
¶ Returns the shape over which the distribution parameters are batched.
Returns: batch shape of the distribution. Return type: tuple

event_shape
¶ Returns the shape of a single sample from the distribution without batching.
Returns: event shape of the distribution. Return type: tuple

has_rsample
¶

shape
(sample_shape=())[source]¶ The tensor shape of samples from this distribution.
Samples are of shape:
d.shape(sample_shape) == sample_shape + d.batch_shape + d.event_shape
Parameters: sample_shape (tuple) – the size of the iid batch to be drawn from the distribution. Returns: shape of samples. Return type: tuple

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

sample_with_intermediates
(key, sample_shape=())[source]¶ Same as
sample
except that any intermediate computations are returned (useful for TransformedDistribution).Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(value)[source]¶ Evaluates the log probability density for a batch of samples given by value.
Parameters: value – A batch of samples from the distribution. Returns: an array with shape value.shape[:self.event_shape] Return type: numpy.ndarray

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

to_event
(reinterpreted_batch_ndims=None)[source]¶ Interpret the rightmost reinterpreted_batch_ndims batch dimensions as dependent event dimensions.
Parameters: reinterpreted_batch_ndims – Number of rightmost batch dims to interpret as event dims. Returns: An instance of Independent distribution. Return type: numpyro.distributions.distribution.Independent

enumerate_support
(expand=True)[source]¶ Returns an array with shape len(support) x batch_shape containing all values in the support.

expand
(batch_shape)[source]¶ Returns a new
ExpandedDistribution
instance with batch dimensions expanded to batch_shape.Parameters: batch_shape (tuple) – batch shape to expand to. Returns: an instance of ExpandedDistribution. Return type: ExpandedDistribution

expand_by
(sample_shape)[source]¶ Expands a distribution by adding
sample_shape
to the left side of itsbatch_shape
. To expand internal dims ofself.batch_shape
from 1 to something larger, useexpand()
instead.Parameters: sample_shape (tuple) – The size of the iid batch to be drawn from the distribution. Returns: An expanded version of this distribution. Return type: ExpandedDistribution

mask
(mask)[source]¶ Masks a distribution by a boolean or booleanvalued array that is broadcastable to the distributions
Distribution.batch_shape
.Parameters: mask (bool or jnp.ndarray) – A boolean or boolean valued array (True includes a site, False excludes a site). Returns: A masked copy of this distribution. Return type: MaskedDistribution
Example:
>>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.distributions import constraints >>> from numpyro.infer import SVI, Trace_ELBO >>> def model(data, m): ... f = numpyro.sample("latent_fairness", dist.Beta(1, 1)) ... with numpyro.plate("N", data.shape[0]): ... # only take into account the values selected by the mask ... masked_dist = dist.Bernoulli(f).mask(m) ... numpyro.sample("obs", masked_dist, obs=data) >>> def guide(data, m): ... alpha_q = numpyro.param("alpha_q", 5., constraint=constraints.positive) ... beta_q = numpyro.param("beta_q", 5., constraint=constraints.positive) ... numpyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q)) >>> data = jnp.concatenate([jnp.ones(5), jnp.zeros(5)]) >>> # select values equal to one >>> masked_array = jnp.where(data == 1, True, False) >>> optimizer = numpyro.optim.Adam(step_size=0.05) >>> svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) >>> svi_result = svi.run(random.PRNGKey(0), 300, data, masked_array) >>> params = svi_result.params >>> # inferred_mean is closer to 1 >>> inferred_mean = params["alpha_q"] / (params["alpha_q"] + params["beta_q"])

classmethod
infer_shapes
(*args, **kwargs)[source]¶ Infers
batch_shape
andevent_shape
given shapes of args to__init__()
.Note
This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.
Parameters:  *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
 **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns: A pair
(batch_shape, event_shape)
of the shapes of a distribution that would be created with input args of the given shapes.Return type:

cdf
(value)[source]¶ The cummulative distribution function of this distribution.
Parameters: value – samples from this distribution. Returns: output of the cummulative distribution function evaluated at value.

icdf
(q)[source]¶ The inverse cumulative distribution function of this distribution.
Parameters: q – quantile values, should belong to [0, 1]. Returns: the samples whose cdf values equals to q.

is_discrete
¶
ExpandedDistribution¶

class
ExpandedDistribution
(base_dist, batch_shape=())[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {}¶

has_enumerate_support
¶ bool(x) > bool
Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

has_rsample
¶

support
¶

sample_with_intermediates
(key, sample_shape=())[source]¶ Same as
sample
except that any intermediate computations are returned (useful for TransformedDistribution).Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(value, intermediates=None)[source]¶ Evaluates the log probability density for a batch of samples given by value.
Parameters: value – A batch of samples from the distribution. Returns: an array with shape value.shape[:self.event_shape] Return type: numpy.ndarray

enumerate_support
(expand=True)[source]¶ Returns an array with shape len(support) x batch_shape containing all values in the support.

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

FoldedDistribution¶

class
FoldedDistribution
(base_dist, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.TransformedDistribution
Equivalent to
TransformedDistribution(base_dist, AbsTransform())
, but additionally supportslog_prob()
.Parameters: base_dist (Distribution) – A univariate distribution to reflect. 
support
= Positive(lower_bound=0.0)¶

log_prob
(*args, **kwargs)¶

ImproperUniform¶

class
ImproperUniform
(support, batch_shape, event_shape, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
A helper distribution with zero
log_prob()
over the support domain.Note
sample method is not implemented for this distribution. In autoguide and mcmc, initial parameters for improper sites are derived from init_to_uniform or init_to_value strategies.
Usage:
>>> from numpyro import sample >>> from numpyro.distributions import ImproperUniform, Normal, constraints >>> >>> def model(): ... # ordered vector with length 10 ... x = sample('x', ImproperUniform(constraints.ordered_vector, (), event_shape=(10,))) ... ... # real matrix with shape (3, 4) ... y = sample('y', ImproperUniform(constraints.real, (), event_shape=(3, 4))) ... ... # a shape(6, 8) batch of length5 vectors greater than 3 ... z = sample('z', ImproperUniform(constraints.greater_than(3), (6, 8), event_shape=(5,)))
If you want to set improper prior over all values greater than a, where a is another random variable, you might use
>>> def model(): ... a = sample('a', Normal(0, 1)) ... x = sample('x', ImproperUniform(constraints.greater_than(a), (), event_shape=()))
or if you want to reparameterize it
>>> from numpyro.distributions import TransformedDistribution, transforms >>> from numpyro.handlers import reparam >>> from numpyro.infer.reparam import TransformReparam >>> >>> def model(): ... a = sample('a', Normal(0, 1)) ... with reparam(config={'x': TransformReparam()}): ... x = sample('x', ... TransformedDistribution(ImproperUniform(constraints.positive, (), ()), ... transforms.AffineTransform(a, 1)))
Parameters:  support (Constraint) – the support of this distribution.
 batch_shape (tuple) – batch shape of this distribution. It is usually safe to set batch_shape=().
 event_shape (tuple) – event shape of this distribution.

arg_constraints
= {}¶

support
= Dependent()¶

log_prob
(*args, **kwargs)¶
Independent¶

class
Independent
(base_dist, reinterpreted_batch_ndims, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Reinterprets batch dimensions of a distribution as event dims by shifting the batchevent dim boundary further to the left.
From a practical standpoint, this is useful when changing the result of
log_prob()
. For example, a univariate Normal distribution can be interpreted as a multivariate Normal with diagonal covariance:>>> import numpyro.distributions as dist >>> normal = dist.Normal(jnp.zeros(3), jnp.ones(3)) >>> [normal.batch_shape, normal.event_shape] [(3,), ()] >>> diag_normal = dist.Independent(normal, 1) >>> [diag_normal.batch_shape, diag_normal.event_shape] [(), (3,)]
Parameters:  base_distribution (numpyro.distribution.Distribution) – a distribution instance.
 reinterpreted_batch_ndims (int) – the number of batch dims to reinterpret as event dims.

arg_constraints
= {}¶

support
¶

has_enumerate_support
¶ bool(x) > bool
Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

reparametrized_params
¶ Builtin mutable sequence.
If no argument is given, the constructor creates a new empty list. The argument must be an iterable if specified.

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

has_rsample
¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(value)[source]¶ Evaluates the log probability density for a batch of samples given by value.
Parameters: value – A batch of samples from the distribution. Returns: an array with shape value.shape[:self.event_shape] Return type: numpy.ndarray

expand
(batch_shape)[source]¶ Returns a new
ExpandedDistribution
instance with batch dimensions expanded to batch_shape.Parameters: batch_shape (tuple) – batch shape to expand to. Returns: an instance of ExpandedDistribution. Return type: ExpandedDistribution
MaskedDistribution¶

class
MaskedDistribution
(base_dist, mask)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Masks a distribution by a boolean array that is broadcastable to the distribution’s
Distribution.batch_shape
. In the special casemask is False
, computation oflog_prob()
, is skipped, and constant zero values are returned instead.Parameters: mask (jnp.ndarray or bool) – A boolean or booleanvalued array. 
arg_constraints
= {}¶

has_enumerate_support
¶ bool(x) > bool
Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

has_rsample
¶

support
¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(value)[source]¶ Evaluates the log probability density for a batch of samples given by value.
Parameters: value – A batch of samples from the distribution. Returns: an array with shape value.shape[:self.event_shape] Return type: numpy.ndarray

enumerate_support
(expand=True)[source]¶ Returns an array with shape len(support) x batch_shape containing all values in the support.

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

TransformedDistribution¶

class
TransformedDistribution
(base_distribution, transforms, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Returns a distribution instance obtained as a result of applying a sequence of transforms to a base distribution. For an example, see
LogNormal
andHalfNormal
.Parameters:  base_distribution – the base distribution over which to apply transforms.
 transforms – a single transform or a list of transforms.
 validate_args – Whether to enable validation of distribution parameters and arguments to .log_prob method.

arg_constraints
= {}¶

has_rsample
¶

support
¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

sample_with_intermediates
(key, sample_shape=())[source]¶ Same as
sample
except that any intermediate computations are returned (useful for TransformedDistribution).Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.
Delta¶

class
Delta
(v=0.0, log_density=0.0, event_dim=0, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'log_density': Real(), 'v': Dependent()}¶

reparametrized_params
= ['v', 'log_density']¶

support
¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

Unit¶

class
Unit
(log_factor, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Trivial nonnormalized distribution representing the unit type.
The unit type has a single value with no data, i.e.
value.size == 0
.This is used for
numpyro.factor()
statements.
arg_constraints
= {'log_factor': Real()}¶

support
= Real()¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(value)[source]¶ Evaluates the log probability density for a batch of samples given by value.
Parameters: value – A batch of samples from the distribution. Returns: an array with shape value.shape[:self.event_shape] Return type: numpy.ndarray

Continuous Distributions¶
AsymmetricLaplace¶

class
AsymmetricLaplace
(loc=0.0, scale=1.0, asymmetry=1.0, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'asymmetry': Positive(lower_bound=0.0), 'loc': Real(), 'scale': Positive(lower_bound=0.0)}¶

reparametrized_params
= ['loc', 'scale', 'asymmetry']¶

support
= Real()¶

log_prob
(value)[source]¶ Evaluates the log probability density for a batch of samples given by value.
Parameters: value – A batch of samples from the distribution. Returns: an array with shape value.shape[:self.event_shape] Return type: numpy.ndarray

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

AsymmetricLaplaceQuantile¶

class
AsymmetricLaplaceQuantile
(loc=0.0, scale=1.0, quantile=0.5, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
An alternative parameterization of AsymmetricLaplace commonly applied in Bayesian quantile regression.
Instead of the asymmetry parameter employed by AsymmetricLaplace, to define the balance between left versus righthand sides of the distribution, this class utilizes a quantile parameter, which describes the proportion of probability density that falls to the lefthand side of the distribution.
The scale parameter is also interpreted slightly differently than in AsymmetricLaplace. When loc=0 and scale=1, AsymmetricLaplace(0,1,1) is equivalent to Laplace(0,1), while AsymmetricLaplaceQuantile(0,1,0.5) is equivalent to Laplace(0,2).

arg_constraints
= {'loc': Real(), 'quantile': OpenInterval(lower_bound=0.0, upper_bound=1.0), 'scale': Positive(lower_bound=0.0)}¶

reparametrized_params
= ['loc', 'scale', 'quantile']¶

support
= Real()¶

log_prob
(value)[source]¶ Evaluates the log probability density for a batch of samples given by value.
Parameters: value – A batch of samples from the distribution. Returns: an array with shape value.shape[:self.event_shape] Return type: numpy.ndarray

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

Beta¶

class
Beta
(concentration1, concentration0, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'concentration0': Positive(lower_bound=0.0), 'concentration1': Positive(lower_bound=0.0)}¶

reparametrized_params
= ['concentration1', 'concentration0']¶

support
= UnitInterval(lower_bound=0.0, upper_bound=1.0)¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

BetaProportion¶

class
BetaProportion
(mean, concentration, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.continuous.Beta
The BetaProportion distribution is a reparameterization of the conventional Beta distribution in terms of a the variate mean and a precision parameter.
 Reference:
 Beta regression for modelling rates and proportion, Ferrari Silvia, and
 Francisco CribariNeto. Journal of Applied Statistics 31.7 (2004): 799815.

arg_constraints
= {'concentration': Positive(lower_bound=0.0), 'mean': OpenInterval(lower_bound=0.0, upper_bound=1.0)}¶

reparametrized_params
= ['mean', 'concentration']¶

support
= UnitInterval(lower_bound=0.0, upper_bound=1.0)¶
CAR¶

class
CAR
(loc, correlation, conditional_precision, adj_matrix, *, is_sparse=False, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
The Conditional Autoregressive (CAR) distribution is a special case of the multivariate normal in which the precision matrix is structured according to the adjacency matrix of sites. The amount of autocorrelation between sites is controlled by
correlation
. The distribution is a popular prior for areal spatial data.Parameters:  or ndarray loc (float) – mean of the multivariate normal
 correlation (float) – autoregression parameter. For most cases, the value should lie between 0 (sites are independent, collapses to an iid multivariate normal) and 1 (perfect autocorrelation between sites), but the specification allows for negative correlations.
 conditional_precision (float) – positive precision for the multivariate normal
 or scipy.sparse.csr_matrix adj_matrix (ndarray) – symmetric adjacency matrix where 1
indicates adjacency between sites and 0 otherwise.
jax.numpy.ndarray
adj_matrix
is supported but is not recommended overnumpy.ndarray
orscipy.sparse.spmatrix
.  is_sparse (bool) – whether to use a sparse form of
adj_matrix
in calculations (must be True ifadj_matrix
is ascipy.sparse.spmatrix
)

arg_constraints
= {'adj_matrix': Dependent(), 'conditional_precision': Positive(lower_bound=0.0), 'correlation': OpenInterval(lower_bound=1, upper_bound=1), 'loc': RealVector(Real(), 1)}¶

support
= RealVector(Real(), 1)¶

reparametrized_params
= ['loc', 'correlation', 'conditional_precision', 'adj_matrix']¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

static
infer_shapes
(loc, correlation, conditional_precision, adj_matrix)[source]¶ Infers
batch_shape
andevent_shape
given shapes of args to__init__()
.Note
This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.
Parameters:  *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
 **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns: A pair
(batch_shape, event_shape)
of the shapes of a distribution that would be created with input args of the given shapes.Return type:
Cauchy¶

class
Cauchy
(loc=0.0, scale=1.0, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'loc': Real(), 'scale': Positive(lower_bound=0.0)}¶

support
= Real()¶

reparametrized_params
= ['loc', 'scale']¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

Chi2¶

class
Chi2
(df, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.continuous.Gamma

arg_constraints
= {'df': Positive(lower_bound=0.0)}¶

reparametrized_params
= ['df']¶

Dirichlet¶

class
Dirichlet
(concentration, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'concentration': IndependentConstraint(Positive(lower_bound=0.0), 1)}¶

reparametrized_params
= ['concentration']¶

support
= Simplex()¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

static
infer_shapes
(concentration)[source]¶ Infers
batch_shape
andevent_shape
given shapes of args to__init__()
.Note
This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.
Parameters:  *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
 **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns: A pair
(batch_shape, event_shape)
of the shapes of a distribution that would be created with input args of the given shapes.Return type:

EulerMaruyama¶

class
EulerMaruyama
(t, sde_fn, init_dist, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Euler–Maruyama method is a method for the approximate numerical solution of a stochastic differential equation (SDE)
Parameters:  t (ndarray) – discretized time
 sde_fn (callable) – function returning the drift and diffusion coefficients of SDE
 init_dist (Distribution) – Distribution for initial values.
References
[1] https://en.wikipedia.org/wiki/EulerMaruyama_method

arg_constraints
= {'t': OrderedVector()}¶

support
¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶
Exponential¶

class
Exponential
(rate=1.0, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

reparametrized_params
= ['rate']¶

arg_constraints
= {'rate': Positive(lower_bound=0.0)}¶

support
= Positive(lower_bound=0.0)¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

Gamma¶

class
Gamma
(concentration, rate=1.0, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'concentration': Positive(lower_bound=0.0), 'rate': Positive(lower_bound=0.0)}¶

support
= Positive(lower_bound=0.0)¶

reparametrized_params
= ['concentration', 'rate']¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

GaussianCopula¶

class
GaussianCopula
(marginal_dist, correlation_matrix=None, correlation_cholesky=None, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
A distribution that links the batch_shape[:1] of marginal distribution marginal_dist with a multivariate Gaussian copula modelling the correlation between the axes.
Parameters:  marginal_dist (Distribution) – Distribution whose last batch axis is to be coupled.
 correlation_matrix (array_like) – Correlation matrix of coupling multivariate normal distribution.
 correlation_cholesky (array_like) – Correlation Cholesky factor of coupling multivariate normal distribution.

arg_constraints
= {'correlation_cholesky': CorrCholesky(), 'correlation_matrix': CorrMatrix()}¶

reparametrized_params
= ['correlation_matrix', 'correlation_cholesky']¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

support
¶
GaussianCopulaBeta¶

class
GaussianCopulaBeta
(concentration1, concentration0, correlation_matrix=None, correlation_cholesky=None, *, validate_args=False)[source]¶ Bases:
numpyro.distributions.copula.GaussianCopula

arg_constraints
= {'concentration0': Positive(lower_bound=0.0), 'concentration1': Positive(lower_bound=0.0), 'correlation_cholesky': CorrCholesky(), 'correlation_matrix': CorrMatrix()}¶

support
= IndependentConstraint(UnitInterval(lower_bound=0.0, upper_bound=1.0), 1)¶

GaussianRandomWalk¶

class
GaussianRandomWalk
(scale=1.0, num_steps=1, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'scale': Positive(lower_bound=0.0)}¶

support
= RealVector(Real(), 1)¶

reparametrized_params
= ['scale']¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

Gompertz¶

class
Gompertz
(concentration, rate=1.0, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Gompertz Distribution.
The Gompertz distribution is a distribution with support on the positive real line that is closely related to the Gumbel distribution. This implementation follows the notation used in the Wikipedia entry for the Gompertz distribution. See https://en.wikipedia.org/wiki/Gompertz_distribution.
However, we call the parameter “eta” a concentration parameter and the parameter “b” a rate parameter (as opposed to scale parameter as in wikipedia description.)
The CDF, in terms of concentration (con) and rate, is
\[F(x) = 1  \exp \left\{  \text{con} * \left [ \exp\{x * rate \}  1 \right ] \right\}\]
arg_constraints
= {'concentration': Positive(lower_bound=0.0), 'rate': Positive(lower_bound=0.0)}¶

support
= Positive(lower_bound=0.0)¶

reparametrized_params
= ['concentration', 'rate']¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

cdf
(value)[source]¶ The cummulative distribution function of this distribution.
Parameters: value – samples from this distribution. Returns: output of the cummulative distribution function evaluated at value.

icdf
(q)[source]¶ The inverse cumulative distribution function of this distribution.
Parameters: q – quantile values, should belong to [0, 1]. Returns: the samples whose cdf values equals to q.

mean
¶ Mean of the distribution.

Gumbel¶

class
Gumbel
(loc=0.0, scale=1.0, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'loc': Real(), 'scale': Positive(lower_bound=0.0)}¶

support
= Real()¶

reparametrized_params
= ['loc', 'scale']¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

HalfCauchy¶

class
HalfCauchy
(scale=1.0, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

reparametrized_params
= ['scale']¶

support
= Positive(lower_bound=0.0)¶

arg_constraints
= {'scale': Positive(lower_bound=0.0)}¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

cdf
(value)[source]¶ The cummulative distribution function of this distribution.
Parameters: value – samples from this distribution. Returns: output of the cummulative distribution function evaluated at value.

icdf
(q)[source]¶ The inverse cumulative distribution function of this distribution.
Parameters: q – quantile values, should belong to [0, 1]. Returns: the samples whose cdf values equals to q.

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

HalfNormal¶

class
HalfNormal
(scale=1.0, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

reparametrized_params
= ['scale']¶

support
= Positive(lower_bound=0.0)¶

arg_constraints
= {'scale': Positive(lower_bound=0.0)}¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

cdf
(value)[source]¶ The cummulative distribution function of this distribution.
Parameters: value – samples from this distribution. Returns: output of the cummulative distribution function evaluated at value.

icdf
(q)[source]¶ The inverse cumulative distribution function of this distribution.
Parameters: q – quantile values, should belong to [0, 1]. Returns: the samples whose cdf values equals to q.

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

InverseGamma¶

class
InverseGamma
(concentration, rate=1.0, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.TransformedDistribution
Note
We keep the same notation rate as in Pyro but it plays the role of scale parameter of InverseGamma in literatures (e.g. wikipedia: https://en.wikipedia.org/wiki/Inversegamma_distribution)

arg_constraints
= {'concentration': Positive(lower_bound=0.0), 'rate': Positive(lower_bound=0.0)}¶

reparametrized_params
= ['concentration', 'rate']¶

support
= Positive(lower_bound=0.0)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

Kumaraswamy¶

class
Kumaraswamy
(concentration1, concentration0, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.TransformedDistribution

arg_constraints
= {'concentration0': Positive(lower_bound=0.0), 'concentration1': Positive(lower_bound=0.0)}¶

reparametrized_params
= ['concentration1', 'concentration0']¶

support
= UnitInterval(lower_bound=0.0, upper_bound=1.0)¶

KL_KUMARASWAMY_BETA_TAYLOR_ORDER
= 10¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

Laplace¶

class
Laplace
(loc=0.0, scale=1.0, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'loc': Real(), 'scale': Positive(lower_bound=0.0)}¶

support
= Real()¶

reparametrized_params
= ['loc', 'scale']¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

LKJ¶

class
LKJ
(dimension, concentration=1.0, sample_method='onion', *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.TransformedDistribution
LKJ distribution for correlation matrices. The distribution is controlled by
concentration
parameter \(\eta\) to make the probability of the correlation matrix \(M\) proportional to \(\det(M)^{\eta  1}\). Because of that, whenconcentration == 1
, we have a uniform distribution over correlation matrices.When
concentration > 1
, the distribution favors samples with large large determinent. This is useful when we know a priori that the underlying variables are not correlated.When
concentration < 1
, the distribution favors samples with small determinent. This is useful when we know a priori that some underlying variables are correlated.Sample code for using LKJ in the context of multivariate normal sample:
def model(y): # y has dimension N x d d = y.shape[1] N = y.shape[0] # Vector of variances for each of the d variables theta = numpyro.sample("theta", dist.HalfCauchy(jnp.ones(d))) concentration = jnp.ones(1) # Implies a uniform distribution over correlation matrices corr_mat = numpyro.sample("corr_mat", dist.LKJ(d, concentration)) sigma = jnp.sqrt(theta) # we can also use a faster formula `cov_mat = jnp.outer(theta, theta) * corr_mat` cov_mat = jnp.matmul(jnp.matmul(jnp.diag(sigma), corr_mat), jnp.diag(sigma)) # Vector of expectations mu = jnp.zeros(d) with numpyro.plate("observations", N): obs = numpyro.sample("obs", dist.MultivariateNormal(mu, covariance_matrix=cov_mat), obs=y) return obs
Parameters:  dimension (int) – dimension of the matrices
 concentration (ndarray) – concentration/shape parameter of the distribution (often referred to as eta)
 sample_method (str) – Either “cvine” or “onion”. Both methods are proposed in [1] and offer the same distribution over correlation matrices. But they are different in how to generate samples. Defaults to “onion”.
References
[1] Generating random correlation matrices based on vines and extended onion method, Daniel Lewandowski, Dorota Kurowicka, Harry Joe

arg_constraints
= {'concentration': Positive(lower_bound=0.0)}¶

reparametrized_params
= ['concentration']¶

support
= CorrMatrix()¶

mean
¶ Mean of the distribution.
LKJCholesky¶

class
LKJCholesky
(dimension, concentration=1.0, sample_method='onion', *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
LKJ distribution for lower Cholesky factors of correlation matrices. The distribution is controlled by
concentration
parameter \(\eta\) to make the probability of the correlation matrix \(M\) generated from a Cholesky factor propotional to \(\det(M)^{\eta  1}\). Because of that, whenconcentration == 1
, we have a uniform distribution over Cholesky factors of correlation matrices.When
concentration > 1
, the distribution favors samples with large diagonal entries (hence large determinent). This is useful when we know a priori that the underlying variables are not correlated.When
concentration < 1
, the distribution favors samples with small diagonal entries (hence small determinent). This is useful when we know a priori that some underlying variables are correlated.Sample code for using LKJCholesky in the context of multivariate normal sample:
def model(y): # y has dimension N x d d = y.shape[1] N = y.shape[0] # Vector of variances for each of the d variables theta = numpyro.sample("theta", dist.HalfCauchy(jnp.ones(d))) # Lower cholesky factor of a correlation matrix concentration = jnp.ones(1) # Implies a uniform distribution over correlation matrices L_omega = numpyro.sample("L_omega", dist.LKJCholesky(d, concentration)) # Lower cholesky factor of the covariance matrix sigma = jnp.sqrt(theta) # we can also use a faster formula `L_Omega = sigma[..., None] * L_omega` L_Omega = jnp.matmul(jnp.diag(sigma), L_omega) # Vector of expectations mu = jnp.zeros(d) with numpyro.plate("observations", N): obs = numpyro.sample("obs", dist.MultivariateNormal(mu, scale_tril=L_Omega), obs=y) return obs
Parameters:  dimension (int) – dimension of the matrices
 concentration (ndarray) – concentration/shape parameter of the distribution (often referred to as eta)
 sample_method (str) – Either “cvine” or “onion”. Both methods are proposed in [1] and offer the same distribution over correlation matrices. But they are different in how to generate samples. Defaults to “onion”.
References
[1] Generating random correlation matrices based on vines and extended onion method, Daniel Lewandowski, Dorota Kurowicka, Harry Joe

arg_constraints
= {'concentration': Positive(lower_bound=0.0)}¶

reparametrized_params
= ['concentration']¶

support
= CorrCholesky()¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶
LogNormal¶

class
LogNormal
(loc=0.0, scale=1.0, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.TransformedDistribution

arg_constraints
= {'loc': Real(), 'scale': Positive(lower_bound=0.0)}¶

support
= Positive(lower_bound=0.0)¶

reparametrized_params
= ['loc', 'scale']¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

LogUniform¶

class
LogUniform
(low, high, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.TransformedDistribution

arg_constraints
= {'high': Positive(lower_bound=0.0), 'low': Positive(lower_bound=0.0)}¶

reparametrized_params
= ['low', 'high']¶

support
¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

Logistic¶

class
Logistic
(loc=0.0, scale=1.0, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'loc': Real(), 'scale': Positive(lower_bound=0.0)}¶

support
= Real()¶

reparametrized_params
= ['loc', 'scale']¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

LowRankMultivariateNormal¶

class
LowRankMultivariateNormal
(loc, cov_factor, cov_diag, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'cov_diag': IndependentConstraint(Positive(lower_bound=0.0), 1), 'cov_factor': IndependentConstraint(Real(), 2), 'loc': RealVector(Real(), 1)}¶

support
= RealVector(Real(), 1)¶

reparametrized_params
= ['loc', 'cov_factor', 'cov_diag']¶

mean
¶ Mean of the distribution.

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

static
infer_shapes
(loc, cov_factor, cov_diag)[source]¶ Infers
batch_shape
andevent_shape
given shapes of args to__init__()
.Note
This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.
Parameters:  *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
 **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns: A pair
(batch_shape, event_shape)
of the shapes of a distribution that would be created with input args of the given shapes.Return type:

MatrixNormal¶

class
MatrixNormal
(loc, scale_tril_row, scale_tril_column, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Matrix variate normal distribution as described in [1] but with a lower_triangular parametrization, i.e. \(U=scale_tril_row @ scale_tril_row^{T}\) and \(V=scale_tril_column @ scale_tril_column^{T}\). The distribution is related to the multivariate normal distribution in the following way. If \(X ~ MN(loc,U,V)\) then \(vec(X) ~ MVN(vec(loc), kron(V,U) )\).
Parameters:  loc (array_like) – Location of the distribution.
 scale_tril_row (array_like) – Lower cholesky of rows correlation matrix.
 scale_tril_column (array_like) – Lower cholesky of columns correlation matrix.
References
[1] https://en.wikipedia.org/wiki/Matrix_normal_distribution

arg_constraints
= {'loc': RealVector(Real(), 1), 'scale_tril_column': LowerCholesky(), 'scale_tril_row': LowerCholesky()}¶

support
= RealMatrix(Real(), 2)¶

reparametrized_params
= ['loc', 'scale_tril_row', 'scale_tril_column']¶

mean
¶ Mean of the distribution.

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶
MultivariateNormal¶

class
MultivariateNormal
(loc=0.0, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'covariance_matrix': PositiveDefinite(), 'loc': RealVector(Real(), 1), 'precision_matrix': PositiveDefinite(), 'scale_tril': LowerCholesky()}¶

support
= RealVector(Real(), 1)¶

reparametrized_params
= ['loc', 'covariance_matrix', 'precision_matrix', 'scale_tril']¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

static
infer_shapes
(loc=(), covariance_matrix=None, precision_matrix=None, scale_tril=None)[source]¶ Infers
batch_shape
andevent_shape
given shapes of args to__init__()
.Note
This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.
Parameters:  *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
 **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns: A pair
(batch_shape, event_shape)
of the shapes of a distribution that would be created with input args of the given shapes.Return type:

MultivariateStudentT¶

class
MultivariateStudentT
(df, loc=0.0, scale_tril=None, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'df': Positive(lower_bound=0.0), 'loc': RealVector(Real(), 1), 'scale_tril': LowerCholesky()}¶

support
= RealVector(Real(), 1)¶

reparametrized_params
= ['df', 'loc', 'scale_tril']¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

static
infer_shapes
(df, loc, scale_tril)[source]¶ Infers
batch_shape
andevent_shape
given shapes of args to__init__()
.Note
This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.
Parameters:  *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
 **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns: A pair
(batch_shape, event_shape)
of the shapes of a distribution that would be created with input args of the given shapes.Return type:

Normal¶

class
Normal
(loc=0.0, scale=1.0, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'loc': Real(), 'scale': Positive(lower_bound=0.0)}¶

support
= Real()¶

reparametrized_params
= ['loc', 'scale']¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

cdf
(value)[source]¶ The cummulative distribution function of this distribution.
Parameters: value – samples from this distribution. Returns: output of the cummulative distribution function evaluated at value.

icdf
(q)[source]¶ The inverse cumulative distribution function of this distribution.
Parameters: q – quantile values, should belong to [0, 1]. Returns: the samples whose cdf values equals to q.

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

Pareto¶

class
Pareto
(scale, alpha, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.TransformedDistribution

arg_constraints
= {'alpha': Positive(lower_bound=0.0), 'scale': Positive(lower_bound=0.0)}¶

reparametrized_params
= ['scale', 'alpha']¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

support
¶

cdf
(value)[source]¶ The cummulative distribution function of this distribution.
Parameters: value – samples from this distribution. Returns: output of the cummulative distribution function evaluated at value.

RelaxedBernoulli¶
RelaxedBernoulliLogits¶

class
RelaxedBernoulliLogits
(temperature, logits, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.TransformedDistribution

arg_constraints
= {'logits': Real(), 'temperature': Positive(lower_bound=0.0)}¶

support
= UnitInterval(lower_bound=0.0, upper_bound=1.0)¶

SoftLaplace¶

class
SoftLaplace
(loc, scale, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Smooth distribution with Laplacelike tail behavior.
This distribution corresponds to the logconvex density:
z = (value  loc) / scale log_prob = log(2 / pi)  log(scale)  logaddexp(z, z)
Like the Laplace density, this density has the heaviest possible tails (asymptotically) while still being logconvex. Unlike the Laplace distribution, this distribution is infinitely differentiable everywhere, and is thus suitable for HMC and Laplace approximation.
Parameters:  loc – Location parameter.
 scale – Scale parameter.

arg_constraints
= {'loc': Real(), 'scale': Positive(lower_bound=0.0)}¶

support
= Real()¶

reparametrized_params
= ['loc', 'scale']¶

log_prob
(*args, **kwargs)¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

cdf
(value)[source]¶ The cummulative distribution function of this distribution.
Parameters: value – samples from this distribution. Returns: output of the cummulative distribution function evaluated at value.

icdf
(value)[source]¶ The inverse cumulative distribution function of this distribution.
Parameters: q – quantile values, should belong to [0, 1]. Returns: the samples whose cdf values equals to q.

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.
StudentT¶

class
StudentT
(df, loc=0.0, scale=1.0, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'df': Positive(lower_bound=0.0), 'loc': Real(), 'scale': Positive(lower_bound=0.0)}¶

support
= Real()¶

reparametrized_params
= ['df', 'loc', 'scale']¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

Uniform¶

class
Uniform
(low=0.0, high=1.0, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'high': Dependent(), 'low': Dependent()}¶

reparametrized_params
= ['low', 'high']¶

support
¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

cdf
(value)[source]¶ The cummulative distribution function of this distribution.
Parameters: value – samples from this distribution. Returns: output of the cummulative distribution function evaluated at value.

icdf
(value)[source]¶ The inverse cumulative distribution function of this distribution.
Parameters: q – quantile values, should belong to [0, 1]. Returns: the samples whose cdf values equals to q.

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

static
infer_shapes
(low=(), high=())[source]¶ Infers
batch_shape
andevent_shape
given shapes of args to__init__()
.Note
This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.
Parameters:  *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
 **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns: A pair
(batch_shape, event_shape)
of the shapes of a distribution that would be created with input args of the given shapes.Return type:

Weibull¶

class
Weibull
(scale, concentration, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'concentration': Positive(lower_bound=0.0), 'scale': Positive(lower_bound=0.0)}¶

support
= Positive(lower_bound=0.0)¶

reparametrized_params
= ['scale', 'concentration']¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

cdf
(value)[source]¶ The cummulative distribution function of this distribution.
Parameters: value – samples from this distribution. Returns: output of the cummulative distribution function evaluated at value.

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

Discrete Distributions¶
BernoulliLogits¶

class
BernoulliLogits
(logits=None, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'logits': Real()}¶

support
= Boolean()¶

has_enumerate_support
= True¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

BernoulliProbs¶

class
BernoulliProbs
(probs, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'probs': UnitInterval(lower_bound=0.0, upper_bound=1.0)}¶

support
= Boolean()¶

has_enumerate_support
= True¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

BetaBinomial¶

class
BetaBinomial
(concentration1, concentration0, total_count=1, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Compound distribution comprising of a betabinomial pair. The probability of success (
probs
for theBinomial
distribution) is unknown and randomly drawn from aBeta
distribution prior to a certain number of Bernoulli trials given bytotal_count
.Parameters:  concentration1 (numpy.ndarray) – 1st concentration parameter (alpha) for the Beta distribution.
 concentration0 (numpy.ndarray) – 2nd concentration parameter (beta) for the Beta distribution.
 total_count (numpy.ndarray) – number of Bernoulli trials.

arg_constraints
= {'concentration0': Positive(lower_bound=0.0), 'concentration1': Positive(lower_bound=0.0), 'total_count': IntegerNonnegative(lower_bound=0)}¶

has_enumerate_support
= True¶

enumerate_support
(expand=True)¶ Returns an array with shape len(support) x batch_shape containing all values in the support.

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

support
¶
BinomialLogits¶

class
BinomialLogits
(logits, total_count=1, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'logits': Real(), 'total_count': IntegerNonnegative(lower_bound=0)}¶

has_enumerate_support
= True¶

enumerate_support
(expand=True)¶ Returns an array with shape len(support) x batch_shape containing all values in the support.

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

support
¶

BinomialProbs¶

class
BinomialProbs
(probs, total_count=1, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'probs': UnitInterval(lower_bound=0.0, upper_bound=1.0), 'total_count': IntegerNonnegative(lower_bound=0)}¶

has_enumerate_support
= True¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

support
¶

CategoricalLogits¶

class
CategoricalLogits
(logits, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'logits': RealVector(Real(), 1)}¶

has_enumerate_support
= True¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

support
¶

CategoricalProbs¶

class
CategoricalProbs
(probs, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'probs': Simplex()}¶

has_enumerate_support
= True¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

support
¶

DirichletMultinomial¶

class
DirichletMultinomial
(concentration, total_count=1, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Compound distribution comprising of a dirichletmultinomial pair. The probability of classes (
probs
for theMultinomial
distribution) is unknown and randomly drawn from aDirichlet
distribution prior to a certain number of Categorical trials given bytotal_count
.Parameters:  concentration (numpy.ndarray) – concentration parameter (alpha) for the Dirichlet distribution.
 total_count (numpy.ndarray) – number of Categorical trials.

arg_constraints
= {'concentration': IndependentConstraint(Positive(lower_bound=0.0), 1), 'total_count': IntegerNonnegative(lower_bound=0)}¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

support
¶

static
infer_shapes
(concentration, total_count=())[source]¶ Infers
batch_shape
andevent_shape
given shapes of args to__init__()
.Note
This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.
Parameters:  *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
 **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns: A pair
(batch_shape, event_shape)
of the shapes of a distribution that would be created with input args of the given shapes.Return type:
DiscreteUniform¶

class
DiscreteUniform
(low=0, high=1, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'high': Dependent(), 'low': Dependent()}¶

has_enumerate_support
= True¶

support
¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

cdf
(value)[source]¶ The cummulative distribution function of this distribution.
Parameters: value – samples from this distribution. Returns: output of the cummulative distribution function evaluated at value.

icdf
(value)[source]¶ The inverse cumulative distribution function of this distribution.
Parameters: q – quantile values, should belong to [0, 1]. Returns: the samples whose cdf values equals to q.

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

GammaPoisson¶

class
GammaPoisson
(concentration, rate=1.0, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Compound distribution comprising of a gammapoisson pair, also referred to as a gammapoisson mixture. The
rate
parameter for thePoisson
distribution is unknown and randomly drawn from aGamma
distribution.Parameters:  concentration (numpy.ndarray) – shape parameter (alpha) of the Gamma distribution.
 rate (numpy.ndarray) – rate parameter (beta) for the Gamma distribution.

arg_constraints
= {'concentration': Positive(lower_bound=0.0), 'rate': Positive(lower_bound=0.0)}¶

support
= IntegerNonnegative(lower_bound=0)¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.
GeometricLogits¶

class
GeometricLogits
(logits, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'logits': Real()}¶

support
= IntegerNonnegative(lower_bound=0)¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

GeometricProbs¶

class
GeometricProbs
(probs, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'probs': UnitInterval(lower_bound=0.0, upper_bound=1.0)}¶

support
= IntegerNonnegative(lower_bound=0)¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

Multinomial¶

Multinomial
(total_count=1, probs=None, logits=None, *, total_count_max=None, validate_args=None)[source]¶ Multinomial distribution.
Parameters:  total_count – number of trials. If this is a JAX array, it is required to specify total_count_max.
 probs – event probabilities
 logits – event log probabilities
 total_count_max (int) – the maximum number of trials, i.e. max(total_count)
MultinomialLogits¶

class
MultinomialLogits
(logits, total_count=1, *, total_count_max=None, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'logits': RealVector(Real(), 1), 'total_count': IntegerNonnegative(lower_bound=0)}¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

support
¶

static
infer_shapes
(logits, total_count)[source]¶ Infers
batch_shape
andevent_shape
given shapes of args to__init__()
.Note
This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.
Parameters:  *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
 **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns: A pair
(batch_shape, event_shape)
of the shapes of a distribution that would be created with input args of the given shapes.Return type:

MultinomialProbs¶

class
MultinomialProbs
(probs, total_count=1, *, total_count_max=None, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'probs': Simplex(), 'total_count': IntegerNonnegative(lower_bound=0)}¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.

support
¶

static
infer_shapes
(probs, total_count)[source]¶ Infers
batch_shape
andevent_shape
given shapes of args to__init__()
.Note
This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.
Parameters:  *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
 **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns: A pair
(batch_shape, event_shape)
of the shapes of a distribution that would be created with input args of the given shapes.Return type:

OrderedLogistic¶

class
OrderedLogistic
(predictor, cutpoints, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.discrete.CategoricalProbs
A categorical distribution with ordered outcomes.
References:
 Stan Functions Reference, v2.20 section 12.6, Stan Development Team
Parameters:  predictor (numpy.ndarray) – prediction in real domain; typically this is output of a linear model.
 cutpoints (numpy.ndarray) – positions in real domain to separate categories.

arg_constraints
= {'cutpoints': OrderedVector(), 'predictor': Real()}¶

static
infer_shapes
(predictor, cutpoints)[source]¶ Infers
batch_shape
andevent_shape
given shapes of args to__init__()
.Note
This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.
Parameters:  *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
 **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns: A pair
(batch_shape, event_shape)
of the shapes of a distribution that would be created with input args of the given shapes.Return type:
NegativeBinomial¶
NegativeBinomialLogits¶
NegativeBinomialProbs¶

class
NegativeBinomialProbs
(total_count, probs, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.conjugate.GammaPoisson

arg_constraints
= {'probs': UnitInterval(lower_bound=0.0, upper_bound=1.0), 'total_count': Positive(lower_bound=0.0)}¶

support
= IntegerNonnegative(lower_bound=0)¶

NegativeBinomial2¶

class
NegativeBinomial2
(mean, concentration, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.conjugate.GammaPoisson
Another parameterization of GammaPoisson with rate is replaced by mean.

arg_constraints
= {'concentration': Positive(lower_bound=0.0), 'mean': Positive(lower_bound=0.0)}¶

support
= IntegerNonnegative(lower_bound=0)¶

Poisson¶

class
Poisson
(rate, *, is_sparse=False, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Creates a Poisson distribution parameterized by rate, the rate parameter.
Samples are nonnegative integers, with a pmf given by
\[\mathrm{rate}^k \frac{e^{\mathrm{rate}}}{k!}\]Parameters:  rate (numpy.ndarray) – The rate parameter
 is_sparse (bool) – Whether to assume value is mostly zero when computing
log_prob()
, which can speed up computation when data is sparse.

arg_constraints
= {'rate': Positive(lower_bound=0.0)}¶

support
= IntegerNonnegative(lower_bound=0)¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

variance
¶ Variance of the distribution.
ZeroInflatedDistribution¶

ZeroInflatedDistribution
(base_dist, *, gate=None, gate_logits=None, validate_args=None)[source]¶ Generic Zero Inflated distribution.
Parameters:  base_dist (Distribution) – the base distribution.
 gate (numpy.ndarray) – probability of extra zeros given via a Bernoulli distribution.
 gate_logits (numpy.ndarray) – logits of extra zeros given via a Bernoulli distribution.
ZeroInflatedPoisson¶

class
ZeroInflatedPoisson
(gate, rate=1.0, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.discrete.ZeroInflatedProbs
A Zero Inflated Poisson distribution.
Parameters:  gate (numpy.ndarray) – probability of extra zeros.
 rate (numpy.ndarray) – rate of Poisson distribution.

arg_constraints
= {'gate': UnitInterval(lower_bound=0.0, upper_bound=1.0), 'rate': Positive(lower_bound=0.0)}¶

support
= IntegerNonnegative(lower_bound=0)¶
Mixture Distributions¶
Mixture¶

Mixture
(mixing_distribution, component_distributions, *, validate_args=None)[source]¶ A marginalized finite mixture of component distributions
The returned distribution will be either a:
MixtureGeneral
, whencomponent_distributions
is a list, orMixtureSameFamily
, whencomponent_distributions
is a single distribution.
and more details can be found in the documentation for each of these classes.
Parameters:  mixing_distribution – A
Categorical
specifying the weights for each mixture components. The size of this distribution specifies the number of components in the mixture,mixture_size
.  component_distributions – Either a list of component distributions or
a single vectorized distribution. When a list is provided, the number of
elements must equal
mixture_size
. Otherwise, the last batch dimension of the distribution must equalmixture_size
.
Returns: The mixture distribution.
MixtureSameFamily¶

class
MixtureSameFamily
(mixing_distribution, component_distribution, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.mixtures._MixtureBase
A finite mixture of component distributions from the same family
This mixture only supports a mixture of component distributions that are all of the same family. The different components are specified along the last batch dimension of the input
component_distribution
. If you need a mixture of distributions from different families, use the more general implementation inMixtureGeneral
.Parameters:  mixing_distribution – A
Categorical
specifying the weights for each mixture components. The size of this distribution specifies the number of components in the mixture,mixture_size
.  component_distribution – A single vectorized
Distribution
, whose last batch dimension equalsmixture_size
as specified bymixing_distribution
.
Example
>>> import jax >>> import jax.numpy as jnp >>> import numpyro.distributions as dist >>> mixing_dist = dist.Categorical(probs=jnp.ones(3) / 3.) >>> component_dist = dist.Normal(loc=jnp.zeros(3), scale=jnp.ones(3)) >>> mixture = dist.MixtureSameFamily(mixing_dist, component_dist) >>> mixture.sample(jax.random.PRNGKey(42)).shape ()

component_distribution
¶ Return the vectorized distribution of components being mixed.
Returns: Component distribution Return type: Distribution

support
¶

is_discrete
¶

component_mean
¶

component_variance
¶
 mixing_distribution – A
MixtureGeneral¶

class
MixtureGeneral
(mixing_distribution, component_distributions, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.mixtures._MixtureBase
A finite mixture of component distributions from different families
If all of the component distributions are from the same family, the more specific implementation in
MixtureSameFamily
will be somewhat more efficient.Parameters:  mixing_distribution – A
Categorical
specifying the weights for each mixture components. The size of this distribution specifies the number of components in the mixture,mixture_size
.  component_distributions – A list of
mixture_size
Distribution
objects.
Example
>>> import jax >>> import jax.numpy as jnp >>> import numpyro.distributions as dist >>> mixing_dist = dist.Categorical(probs=jnp.ones(3) / 3.) >>> component_dists = [ ... dist.Normal(loc=0.0, scale=1.0), ... dist.Normal(loc=0.5, scale=0.3), ... dist.Normal(loc=0.6, scale=1.2), ... ] >>> mixture = dist.MixtureGeneral(mixing_dist, component_dists) >>> mixture.sample(jax.random.PRNGKey(42)).shape ()

component_distributions
¶ The list of component distributions in the mixture
Returns: The list of component distributions Return type: List[Distribution]

support
¶

is_discrete
¶

component_mean
¶

component_variance
¶
 mixing_distribution – A
Directional Distributions¶
ProjectedNormal¶

class
ProjectedNormal
(concentration, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Projected isotropic normal distribution of arbitrary dimension.
This distribution over directional data is qualitatively similar to the von Mises and von MisesFisher distributions, but permits tractable variational inference via reparametrized gradients.
To use this distribution with autoguides and HMC, use
handlers.reparam
with aProjectedNormalReparam
reparametrizer in the model, e.g.:@handlers.reparam(config={"direction": ProjectedNormalReparam()}) def model(): direction = numpyro.sample("direction", ProjectedNormal(zeros(3))) ...
Note
This implements
log_prob()
only for dimensions {2,3}. [1] D. HernandezStumpfhauser, F.J. Breidt, M.J. van der Woerd (2017)
 “The General Projected Normal Distribution of Arbitrary Dimension: Modeling and Bayesian Inference” https://projecteuclid.org/euclid.ba/1453211962

arg_constraints
= {'concentration': RealVector(Real(), 1)}¶

reparametrized_params
= ['concentration']¶

support
= Sphere()¶

mean
¶ Note this is the mean in the sense of a centroid in the submanifold that minimizes expected squared geodesic distance.

mode
¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(value)[source]¶ Evaluates the log probability density for a batch of samples given by value.
Parameters: value – A batch of samples from the distribution. Returns: an array with shape value.shape[:self.event_shape] Return type: numpy.ndarray

static
infer_shapes
(concentration)[source]¶ Infers
batch_shape
andevent_shape
given shapes of args to__init__()
.Note
This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.
Parameters:  *args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
 **kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
Returns: A pair
(batch_shape, event_shape)
of the shapes of a distribution that would be created with input args of the given shapes.Return type:
SineBivariateVonMises¶

class
SineBivariateVonMises
(phi_loc, psi_loc, phi_concentration, psi_concentration, correlation=None, weighted_correlation=None, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Unimodal distribution of two dependent angles on the 2torus (\(S^1 \otimes S^1\)) given by
\[C^{1}\exp(\kappa_1\cos(x_1\mu_1) + \kappa_2\cos(x_2 \mu_2) + \rho\sin(x_1  \mu_1)\sin(x_2  \mu_2))\]and
\[C = (2\pi)^2 \sum_{i=0} {2i \choose i} \left(\frac{\rho^2}{4\kappa_1\kappa_2}\right)^i I_i(\kappa_1)I_i(\kappa_2),\]where \(I_i(\cdot)\) is the modified bessel function of first kind, mu’s are the locations of the distribution, kappa’s are the concentration and rho gives the correlation between angles \(x_1\) and \(x_2\). This distribution is helpful for modeling coupled angles such as torsion angles in peptide chains.
To infer parameters, use
NUTS
orHMC
with priors that avoid parameterizations where the distribution becomes bimodal; see note below.Note
Sample efficiency drops as
\[\frac{\rho}{\kappa_1\kappa_2} \rightarrow 1\]because the distribution becomes increasingly bimodal. To avoid bimodality use the weighted_correlation parameter with a skew away from one (e.g., Beta(1,3)). The weighted_correlation should be in [0,1].
Note
The correlation and weighted_correlation params are mutually exclusive.
Note
In the context of
SVI
, this distribution can be used as a likelihood but not for latent variables. ** References: **
 Probabilistic model for two dependent circular variables Singh, H., Hnizdo, V., and Demchuck, E. (2002)
Parameters:  phi_loc (np.ndarray) – location of first angle
 psi_loc (np.ndarray) – location of second angle
 phi_concentration (np.ndarray) – concentration of first angle
 psi_concentration (np.ndarray) – concentration of second angle
 correlation (np.ndarray) – correlation between the two angles
 weighted_correlation (np.ndarray) – set correlation to weighted_corr * sqrt(phi_conc*psi_conc) to avoid bimodality (see note). The weighted_correlation should be in [0,1].

arg_constraints
= {'correlation': Real(), 'phi_concentration': Positive(lower_bound=0.0), 'phi_loc': Circular(lower_bound=3.141592653589793, upper_bound=3.141592653589793), 'psi_concentration': Positive(lower_bound=0.0), 'psi_loc': Circular(lower_bound=3.141592653589793, upper_bound=3.141592653589793)}¶

support
= IndependentConstraint(Circular(lower_bound=3.141592653589793, upper_bound=3.141592653589793), 1)¶

max_sample_iter
= 1000¶

log_prob
(*args, **kwargs)¶

sample
(key, sample_shape=())[source]¶  ** References: **
 A New Unified Approach for the Simulation of a Wide Class of Directional Distributions John T. Kent, Asaad M. Ganeiber & Kanti V. Mardia (2018)

mean
¶ Computes circular mean of distribution. Note: same as location when mapped to support [pi, pi]
SineSkewed¶

class
SineSkewed
(base_dist: numpyro.distributions.distribution.Distribution, skewness, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
Sineskewing [1] is a procedure for producing a distribution that breaks pointwise symmetry on a torus distribution. The new distribution is called the Sine Skewed X distribution, where X is the name of the (symmetric) base distribution. Torus distributions are distributions with support on products of circles (i.e., \(\otimes S^1\) where \(S^1 = [pi,pi)\)). So, a 0torus is a point, the 1torus is a circle, and the 2torus is commonly associated with the donut shape.
The sine skewed X distribution is parameterized by a weight parameter for each dimension of the event of X. For example with a von Mises distribution over a circle (1torus), the sine skewed von Mises distribution has one skew parameter. The skewness parameters can be inferred using
HMC
orNUTS
. For example, the following will produce a prior over skewness for the 2torus,:@numpyro.handlers.reparam(config={'phi_loc': CircularReparam(), 'psi_loc': CircularReparam()}) def model(obs): # Sine priors phi_loc = numpyro.sample('phi_loc', VonMises(pi, 2.)) psi_loc = numpyro.sample('psi_loc', VonMises(pi / 2, 2.)) phi_conc = numpyro.sample('phi_conc', Beta(1., 1.)) psi_conc = numpyro.sample('psi_conc', Beta(1., 1.)) corr_scale = numpyro.sample('corr_scale', Beta(2., 5.)) # Skewing prior ball_trans = L1BallTransform() skewness = numpyro.sample('skew_phi', Normal(0, 0.5).expand((2,))) skewness = ball_trans(skewness) # constraint sum skewness_i <= 1 with numpyro.plate('obs_plate'): sine = SineBivariateVonMises(phi_loc=phi_loc, psi_loc=psi_loc, phi_concentration=70 * phi_conc, psi_concentration=70 * psi_conc, weighted_correlation=corr_scale) return numpyro.sample('phi_psi', SineSkewed(sine, skewness), obs=obs)
To ensure the skewing does not alter the normalization constant of the (sine bivariate von Mises) base distribution the skewness parameters are constraint. The constraint requires the sum of the absolute values of skewness to be less than or equal to one. We can use the
L1BallTransform
to achieve this.In the context of
SVI
, this distribution can freely be used as a likelihood, but use as latent variables it will lead to slow inference for 2 and higher dim toruses. This is because the base_dist cannot be reparameterized.Note
An event in the base distribution must be on a dtorus, so the event_shape must be (d,).
Note
For the skewness parameter, it must hold that the sum of the absolute value of its weights for an event must be less than or equal to one. See eq. 2.1 in [1].
 ** References: **
 Sineskewed toroidal distributions and their application in protein bioinformatics
 AmeijeirasAlonso, J., Ley, C. (2019)
Parameters:  base_dist (numpyro.distributions.Distribution) – base density on a ddimensional torus. Supported base
distributions include: 1D
VonMises
,SineBivariateVonMises
, 1DProjectedNormal
, andUniform
(pi, pi).  skewness (jax.numpy.array) – skewness of the distribution.

arg_constraints
= {'skewness': L1Ball()}¶

support
= IndependentConstraint(Circular(lower_bound=3.141592653589793, upper_bound=3.141592653589793), 1)¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(value)[source]¶ Evaluates the log probability density for a batch of samples given by value.
Parameters: value – A batch of samples from the distribution. Returns: an array with shape value.shape[:self.event_shape] Return type: numpy.ndarray

mean
¶ Mean of the base distribution
VonMises¶

class
VonMises
(loc, concentration, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution
The von Mises distribution, also known as the circular normal distribution.
This distribution is supported by a circular constraint from pi to +pi. By default, the circular support behaves like
constraints.interval(math.pi, math.pi)
. To avoid issues at the boundaries of this interval during sampling, you should reparameterize this distribution usinghandlers.reparam
with aCircularReparam
reparametrizer in the model, e.g.:@handlers.reparam(config={"direction": CircularReparam()}) def model(): direction = numpyro.sample("direction", VonMises(0.0, 4.0)) ...

arg_constraints
= {'concentration': Positive(lower_bound=0.0), 'loc': Real()}¶

reparametrized_params
= ['loc']¶

support
= Circular(lower_bound=3.141592653589793, upper_bound=3.141592653589793)¶

sample
(key, sample_shape=())[source]¶ Generate sample from von Mises distribution
Parameters:  key – random number generator key
 sample_shape – shape of samples
Returns: samples from von Mises

log_prob
(*args, **kwargs)¶

mean
¶ Computes circular mean of distribution. NOTE: same as location when mapped to support [pi, pi]

variance
¶ Computes circular variance of distribution

Truncated Distributions¶
LeftTruncatedDistribution¶

class
LeftTruncatedDistribution
(base_dist, low=0.0, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'low': Real()}¶

reparametrized_params
= ['low']¶

supported_types
= (<class 'numpyro.distributions.continuous.Cauchy'>, <class 'numpyro.distributions.continuous.Laplace'>, <class 'numpyro.distributions.continuous.Logistic'>, <class 'numpyro.distributions.continuous.Normal'>, <class 'numpyro.distributions.continuous.SoftLaplace'>, <class 'numpyro.distributions.continuous.StudentT'>)¶

support
¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

var
¶

RightTruncatedDistribution¶

class
RightTruncatedDistribution
(base_dist, high=0.0, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'high': Real()}¶

reparametrized_params
= ['high']¶

supported_types
= (<class 'numpyro.distributions.continuous.Cauchy'>, <class 'numpyro.distributions.continuous.Laplace'>, <class 'numpyro.distributions.continuous.Logistic'>, <class 'numpyro.distributions.continuous.Normal'>, <class 'numpyro.distributions.continuous.SoftLaplace'>, <class 'numpyro.distributions.continuous.StudentT'>)¶

support
¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

var
¶

TruncatedDistribution¶

TruncatedDistribution
(base_dist, low=None, high=None, *, validate_args=None)[source]¶ A function to generate a truncated distribution.
Parameters:  base_dist – The base distribution to be truncated. This should be a univariate distribution. Currently, only the following distributions are supported: Cauchy, Laplace, Logistic, Normal, and StudentT.
 low – the value which is used to truncate the base distribution from below. Setting this parameter to None to not truncate from below.
 high – the value which is used to truncate the base distribution from above. Setting this parameter to None to not truncate from above.
TruncatedPolyaGamma¶

class
TruncatedPolyaGamma
(batch_shape=(), *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

truncation_point
= 2.5¶

num_log_prob_terms
= 7¶

num_gamma_variates
= 8¶

arg_constraints
= {}¶

support
= Interval(lower_bound=0.0, upper_bound=2.5)¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

TwoSidedTruncatedDistribution¶

class
TwoSidedTruncatedDistribution
(base_dist, low=0.0, high=1.0, *, validate_args=None)[source]¶ Bases:
numpyro.distributions.distribution.Distribution

arg_constraints
= {'high': Dependent(), 'low': Dependent()}¶

reparametrized_params
= ['low', 'high']¶

supported_types
= (<class 'numpyro.distributions.continuous.Cauchy'>, <class 'numpyro.distributions.continuous.Laplace'>, <class 'numpyro.distributions.continuous.Logistic'>, <class 'numpyro.distributions.continuous.Normal'>, <class 'numpyro.distributions.continuous.SoftLaplace'>, <class 'numpyro.distributions.continuous.StudentT'>)¶

support
¶

sample
(key, sample_shape=())[source]¶ Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is nonempty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
Parameters:  key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
 sample_shape (tuple) – the sample shape for the distribution.
Returns: an array of shape sample_shape + batch_shape + event_shape
Return type:

log_prob
(*args, **kwargs)¶

mean
¶ Mean of the distribution.

var
¶

TensorFlow Distributions¶
Thin wrappers around TensorFlow Probability (TFP) distributions. For details on the TFP distribution interface, see its Distribution docs.
BijectorConstraint¶
BijectorTransform¶
TFPDistribution¶

class
TFPDistribution
(batch_shape=(), event_shape=(), *, validate_args=None)[source]¶ A thin wrapper for TensorFlow Probability (TFP) distributions. The constructor has the same signature as the corresponding TFP distribution.
This class can be used to convert a TFP distribution to a NumPyrocompatible one as follows:
d = TFPDistribution[tfd.Normal](0, 1)
Note that typical use cases do not require explicitly invoking this wrapper, since NumPyro wraps TFP distributions automatically under the hood in model code, e.g.:
from tensorflow_probability.substrates.jax import distributions as tfd def model(): numpyro.sample("x", tfd.Normal(0, 1))
Constraints¶
Constraint¶

class
Constraint
[source]¶ Bases:
object
Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.

is_discrete
= False¶

event_dim
= 0¶

check
(value)[source]¶ Returns a byte tensor of sample_shape + batch_shape indicating whether each event in value satisfies this constraint.

dependent¶

dependent
= Dependent()¶ Placeholder for variables whose support depends on other variables. These variables obey no simple coordinatewise constraints.
Parameters:  is_discrete (bool) – Optional value of
.is_discrete
in case this can be computed statically. If not provided, access to the.is_discrete
attribute will raise a NotImplementedError.  event_dim (int) – Optional value of
.event_dim
in case this can be computed statically. If not provided, access to the.event_dim
attribute will raise a NotImplementedError.
 is_discrete (bool) – Optional value of
greater_than¶

greater_than
(lower_bound)¶ Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.
integer_interval¶

integer_interval
(lower_bound, upper_bound)¶ Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.
integer_greater_than¶

integer_greater_than
(lower_bound)¶ Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.
interval¶

interval
(lower_bound, upper_bound)¶ Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.
less_than¶

less_than
(upper_bound)¶ Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.
multinomial¶

multinomial
(upper_bound)¶ Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.
Transforms¶
Transform¶

class
Transform
[source]¶ Bases:
object

domain
= Real()¶

codomain
= Real()¶

inv
¶

forward_shape
(shape)[source]¶ Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

AbsTransform¶
AffineTransform¶

class
AffineTransform
(loc, scale, domain=Real())[source]¶ Bases:
numpyro.distributions.transforms.Transform
Note
When scale is a JAX tracer, we always assume that scale > 0 when calculating codomain.

codomain
¶

forward_shape
(shape)[source]¶ Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

CholeskyTransform¶
ComposeTransform¶

class
ComposeTransform
(parts)[source]¶ Bases:
numpyro.distributions.transforms.Transform

domain
¶

codomain
¶

forward_shape
(shape)[source]¶ Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

CorrCholeskyTransform¶

class
CorrCholeskyTransform
[source]¶ Bases:
numpyro.distributions.transforms.ParameterFreeTransform
Transforms a uncontrained real vector \(x\) with length \(D*(D1)/2\) into the Cholesky factor of a Ddimension correlation matrix. This Cholesky factor is a lower triangular matrix with positive diagonals and unit Euclidean norm for each row. The transform is processed as follows:
 First we convert \(x\) into a lower triangular matrix with the following order:
\[\begin{split}\begin{bmatrix} 1 & 0 & 0 & 0 \\ x_0 & 1 & 0 & 0 \\ x_1 & x_2 & 1 & 0 \\ x_3 & x_4 & x_5 & 1 \end{bmatrix}\end{split}\]2. For each row \(X_i\) of the lower triangular part, we apply a signed version of class
StickBreakingTransform
to transform \(X_i\) into a unit Euclidean length vector using the following steps: Scales into the interval \((1, 1)\) domain: \(r_i = \tanh(X_i)\).
 Transforms into an unsigned domain: \(z_i = r_i^2\).
 Applies \(s_i = StickBreakingTransform(z_i)\).
 Transforms back into signed domain: \(y_i = (sign(r_i), 1) * \sqrt{s_i}\).

domain
= RealVector(Real(), 1)¶

codomain
= CorrCholesky()¶
CorrMatrixCholeskyTransform¶
ExpTransform¶
IdentityTransform¶
L1BallTransform¶
LowerCholeskyAffine¶

class
LowerCholeskyAffine
(loc, scale_tril)[source]¶ Bases:
numpyro.distributions.transforms.Transform
Transform via the mapping \(y = loc + scale\_tril\ @\ x\).
Parameters:  loc – a real vector.
 scale_tril – a lower triangular matrix with positive diagonal.
Example
>>> import jax.numpy as jnp >>> from numpyro.distributions.transforms import LowerCholeskyAffine >>> base = jnp.ones(2) >>> loc = jnp.zeros(2) >>> scale_tril = jnp.array([[0.3, 0.0], [1.0, 0.5]]) >>> affine = LowerCholeskyAffine(loc=loc, scale_tril=scale_tril) >>> affine(base) Array([0.3, 1.5], dtype=float32)

domain
= RealVector(Real(), 1)¶

codomain
= RealVector(Real(), 1)¶

forward_shape
(shape)[source]¶ Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.
LowerCholeskyTransform¶

class
LowerCholeskyTransform
[source]¶ Bases:
numpyro.distributions.transforms.ParameterFreeTransform
Transform a real vector to a lower triangular cholesky factor, where the strictly lower triangular submatrix is unconstrained and the diagonal is parameterized with an exponential transform.

domain
= RealVector(Real(), 1)¶

codomain
= LowerCholesky()¶

OrderedTransform¶

class
OrderedTransform
[source]¶ Bases:
numpyro.distributions.transforms.ParameterFreeTransform
Transform a real vector to an ordered vector.
References:
 Stan Reference Manual v2.20, section 10.6, Stan Development Team
Example
>>> import jax.numpy as jnp >>> from numpyro.distributions.transforms import OrderedTransform >>> base = jnp.ones(3) >>> transform = OrderedTransform() >>> assert jnp.allclose(transform(base), jnp.array([1., 3.7182817, 6.4365635]), rtol=1e3, atol=1e3)

domain
= RealVector(Real(), 1)¶

codomain
= OrderedVector()¶
PermuteTransform¶
PowerTransform¶

class
PowerTransform
(exponent)[source]¶ Bases:
numpyro.distributions.transforms.Transform

domain
= Positive(lower_bound=0.0)¶

codomain
= Positive(lower_bound=0.0)¶

forward_shape
(shape)[source]¶ Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.

ScaledUnitLowerCholeskyTransform¶

class
ScaledUnitLowerCholeskyTransform
[source]¶ Bases:
numpyro.distributions.transforms.LowerCholeskyTransform
Like LowerCholeskyTransform this Transform transforms a real vector to a lower triangular cholesky factor. However it does so via a decomposition
\(y = loc + unit\_scale\_tril\ @\ scale\_diag\ @\ x\).
where \(unit\_scale\_tril\) has ones along the diagonal and \(scale\_diag\) is a diagonal matrix with all positive entries that is parameterized with a softplus transform.

domain
= RealVector(Real(), 1)¶

codomain
= ScaledUnitLowerCholesky()¶

SigmoidTransform¶
SimplexToOrderedTransform¶

class
SimplexToOrderedTransform
(anchor_point=0.0)[source]¶ Bases:
numpyro.distributions.transforms.Transform
Transform a simplex into an ordered vector (via difference in Logistic CDF between cutpoints) Used in [1] to induce a prior on latent cutpoints via transforming ordered category probabilities.
Parameters: anchor_point – Anchor point is a nuisance parameter to improve the identifiability of the transform. For simplicity, we assume it is a scalar value, but it is broadcastable x.shape[:1]. For more details please refer to Section 2.2 in [1] References:
 Ordinal Regression Case Study, section 2.2, M. Betancourt, https://betanalpha.github.io/assets/case_studies/ordinal_regression.html
Example
>>> import jax.numpy as jnp >>> from numpyro.distributions.transforms import SimplexToOrderedTransform >>> base = jnp.array([0.3, 0.1, 0.4, 0.2]) >>> transform = SimplexToOrderedTransform() >>> assert jnp.allclose(transform(base), jnp.array([0.8472978, 0.40546507, 1.3862944]), rtol=1e3, atol=1e3)

domain
= Simplex()¶

codomain
= OrderedVector()¶
SoftplusLowerCholeskyTransform¶

class
SoftplusLowerCholeskyTransform
[source]¶ Bases:
numpyro.distributions.transforms.ParameterFreeTransform
Transform from unconstrained vector to lowertriangular matrices with nonnegative diagonal entries. This is useful for parameterizing positive definite matrices in terms of their Cholesky factorization.

domain
= RealVector(Real(), 1)¶

codomain
= SoftplusLowerCholesky()¶

SoftplusTransform¶

class
SoftplusTransform
[source]¶ Bases:
numpyro.distributions.transforms.ParameterFreeTransform
Transform from unconstrained space to positive domain via softplus \(y = \log(1 + \exp(x))\). The inverse is computed as \(x = \log(\exp(y)  1)\).

domain
= Real()¶

codomain
= SoftplusPositive(lower_bound=0.0)¶

StickBreakingTransform¶
Flows¶
InverseAutoregressiveTransform¶

class
InverseAutoregressiveTransform
(autoregressive_nn, log_scale_min_clip=5.0, log_scale_max_clip=3.0)[source]¶ Bases:
numpyro.distributions.transforms.Transform
An implementation of Inverse Autoregressive Flow, using Eq (10) from Kingma et al., 2016,
\(\mathbf{y} = \mu_t + \sigma_t\odot\mathbf{x}\)where \(\mathbf{x}\) are the inputs, \(\mathbf{y}\) are the outputs, \(\mu_t,\sigma_t\) are calculated from an autoregressive network on \(\mathbf{x}\), and \(\sigma_t>0\).
References
 Improving Variational Inference with Inverse Autoregressive Flow [arXiv:1606.04934], Diederik P. Kingma, Tim Salimans, Rafal Jozefowicz, Xi Chen, Ilya Sutskever, Max Welling

domain
= RealVector(Real(), 1)¶

codomain
= RealVector(Real(), 1)¶

log_abs_det_jacobian
(x, y, intermediates=None)[source]¶ Calculates the elementwise determinant of the log jacobian.
Parameters:  x (numpy.ndarray) – the input to the transform
 y (numpy.ndarray) – the output of the transform
BlockNeuralAutoregressiveTransform¶

class
BlockNeuralAutoregressiveTransform
(bn_arn)[source]¶ Bases:
numpyro.distributions.transforms.Transform
An implementation of Block Neural Autoregressive flow.
References
 Block Neural Autoregressive Flow, Nicola De Cao, Ivan Titov, Wilker Aziz

domain
= RealVector(Real(), 1)¶

codomain
= RealVector(Real(), 1)¶

log_abs_det_jacobian
(x, y, intermediates=None)[source]¶ Calculates the elementwise determinant of the log jacobian.
Parameters:  x (numpy.ndarray) – the input to the transform
 y (numpy.ndarray) – the output of the transform