Source code for numpyro.distributions.continuous

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

# The implementation largely follows the design in PyTorch's `torch.distributions`
#
# Copyright (c) 2016-     Facebook, Inc            (Adam Paszke)
# Copyright (c) 2014-     Facebook, Inc            (Soumith Chintala)
# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
# Copyright (c) 2012-2014 Deepmind Technologies    (Koray Kavukcuoglu)
# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
# Copyright (c) 2011-2013 NYU                      (Clement Farabet)
# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
# Copyright (c) 2006      Idiap Research Institute (Samy Bengio)
# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.


from typing import Callable, Literal, Optional

import numpy as np

import jax
from jax import Array, lax, vmap
from jax.experimental.sparse import BCOO
from jax.lax import scan
import jax.nn as nn
import jax.numpy as jnp
import jax.random as random
from jax.scipy.linalg import cho_solve, solve_triangular, toeplitz
from jax.scipy.special import (
    betaln,
    digamma,
    expi,
    expit,
    gammainc,
    gammaln,
    logit,
    multigammaln,
    ndtr,
    ndtri,
    xlog1py,
    xlogy,
)
from jax.scipy.stats import norm as jax_norm
from jax.typing import ArrayLike

from numpyro.distributions import constraints
from numpyro.distributions.discrete import HurdleProbs, _to_logits_bernoulli
from numpyro.distributions.distribution import (
    Distribution,
    TransformedDistribution,
)
from numpyro.distributions.transforms import (
    AffineTransform,
    CholeskyTransform,
    CorrMatrixCholeskyTransform,
    ExpTransform,
    PackRealFastFourierCoefficientsTransform,
    PowerTransform,
    RealFastFourierTransform,
    RecursiveLinearTransform,
    SigmoidTransform,
    ZeroSumTransform,
)
from numpyro.distributions.util import (
    _reshape,
    add_diag,
    assert_one_of,
    betainc,
    betaincinv,
    cholesky_of_inverse,
    gammaincinv,
    lazy_property,
    matrix_to_tril_vec,
    multidigamma,
    promote_shapes,
    signed_stick_breaking_tril,
    tri_logabsdet,
    validate_sample,
    vec_to_tril_matrix,
)
from numpyro.util import is_prng_key


[docs] class AsymmetricLaplace(Distribution): arg_constraints = { "loc": constraints.real, "scale": constraints.positive, "asymmetry": constraints.positive, } reparametrized_params = ["loc", "scale", "asymmetry"] support = constraints.real def __init__( self, loc: ArrayLike = 0.0, scale: ArrayLike = 1.0, asymmetry: ArrayLike = 1.0, *, validate_args: Optional[bool] = None, ) -> None: batch_shape = lax.broadcast_shapes( jnp.shape(loc), jnp.shape(scale), jnp.shape(asymmetry) ) self.loc, self.scale, self.asymmetry = promote_shapes( loc, scale, asymmetry, shape=batch_shape ) super(AsymmetricLaplace, self).__init__( batch_shape=batch_shape, validate_args=validate_args )
[docs] @lazy_property def left_scale(self): return self.scale * self.asymmetry
[docs] @lazy_property def right_scale(self): return self.scale / self.asymmetry
[docs] def log_prob(self, value: ArrayLike) -> ArrayLike: if self._validate_args: self._validate_sample(value) z = value - self.loc z = -jnp.abs(z) / jnp.where(z < 0, self.left_scale, self.right_scale) return z - jnp.log(self.left_scale + self.right_scale)
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: assert is_prng_key(key) shape = (2,) + sample_shape + self.batch_shape + self.event_shape u, v = random.exponential(key, shape=shape) return self.loc - self.left_scale * u + self.right_scale * v
@property def mean(self) -> ArrayLike: total_scale = self.left_scale + self.right_scale mean = self.loc + (self.right_scale**2 - self.left_scale**2) / total_scale return jnp.broadcast_to(mean, self.batch_shape) @property def variance(self) -> ArrayLike: left = self.left_scale right = self.right_scale total = left + right p = left / total q = right / total variance = p * left**2 + q * right**2 + p * q * total**2 return jnp.broadcast_to(variance, self.batch_shape)
[docs] def cdf(self, value: ArrayLike) -> ArrayLike: z = value - self.loc k = self.asymmetry return jnp.where( z >= 0, 1 - (1 / (1 + k**2)) * jnp.exp(-jnp.abs(z) / self.right_scale), k**2 / (1 + k**2) * jnp.exp(-jnp.abs(z) / self.left_scale), )
[docs] def icdf(self, value: ArrayLike) -> ArrayLike: k = self.asymmetry temp = k**2 / (1 + k**2) return jnp.where( value <= temp, self.loc + self.left_scale * jnp.log(value / temp), self.loc - self.right_scale * jnp.log((1 + k**2) * (1 - value)), )
[docs] class Beta(Distribution): r"""Beta distribution parameterized by concentration parameters alpha (:attr:`concentration1`) and beta (:attr:`concentration0`), on the unit interval :math:`[0,1]`. The probability density function (PDF) is defined as: .. math:: f(x; \alpha, \beta) = \frac{x^{\alpha - 1} (1 - x)^{\beta - 1}}{\text{B}(\alpha, \beta)} Where, :math:`x \in [0, 1]`, :math:`\alpha > 0`, :math:`\beta > 0`, and :math:`\text{B}(\alpha, \beta)` is the Beta function. :param concentration1: Alpha parameter (1st shape parameter). :type concentration1: ArrayLike :param concentration0: Beta parameter (2nd shape parameter). :type concentration0: ArrayLike :param validate_args: Whether to validate input constraints, defaults to None. :type validate_args: bool, optional """ arg_constraints = { "concentration1": constraints.positive, "concentration0": constraints.positive, } reparametrized_params = ["concentration1", "concentration0"] support = constraints.unit_interval pytree_data_fields = ("concentration0", "concentration1", "_dirichlet") def __init__( self, concentration1: ArrayLike, concentration0: ArrayLike, *, validate_args: Optional[bool] = None, ) -> None: self.concentration1, self.concentration0 = promote_shapes( concentration1, concentration0 ) batch_shape = lax.broadcast_shapes( jnp.shape(concentration1), jnp.shape(concentration0) ) concentration1 = jnp.broadcast_to(concentration1, batch_shape) concentration0 = jnp.broadcast_to(concentration0, batch_shape) super(Beta, self).__init__(batch_shape=batch_shape, validate_args=validate_args) self._dirichlet = Dirichlet( jnp.stack([concentration1, concentration0], axis=-1) )
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: r"""Generates samples from the distribution using the underlying Dirichlet implementation. Since a :math:`\mathrm{Beta}(\alpha, \beta)` distribution is equivalent to a 2-category :math:`\mathrm{Dirichlet}(\alpha, \beta)`, this method samples from the Dirichlet and slices the first component. :param key: JAX PRNGKey for reproducibility. :type key: jax.Array :param sample_shape: The shape of the samples to be generated. :type sample_shape: tuple[int, ...] :return: Samples from the Beta distribution of shape ``sample_shape + batch_shape``. :rtype: ArrayLike """ assert is_prng_key(key) return self._dirichlet.sample(key, sample_shape)[..., 0]
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: r"""Calculates the log of the probability density function. To avoid `NaN` gradients at the boundaries :math:`x=0` or :math:`x=1`, this implementation masks boundary values with a safe constant (0.5) during the differentiation path. The forward pass value is then corrected using :func:`~jax.lax.stop_gradient` to ensure numerical stability without sacrificing accuracy. :param value: Values at which to evaluate the log density. :type value: ArrayLike :return: Log probability density. :rtype: ArrayLike """ # Use double-where trick to avoid NaN gradients at boundary conditions # Reference: https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf is_boundary = (value == 0.0) | (value == 1.0) # Mask boundary values (0 or 1) to safe value (0.5) for gradient computation safe_value = jnp.where(is_boundary, 0.5, value) safe_complement = jnp.where(is_boundary, 0.5, 1.0 - value) # Compute log_prob with safe values (gradients flow through this path) safe_dirichlet_value = jnp.stack([safe_value, safe_complement], axis=-1) safe_log_prob = self._dirichlet.log_prob(safe_dirichlet_value) # At boundaries, compute correct forward value using xlogy (handles 0*log(0)=0) # Use stop_gradient so gradients come only from safe_log_prob correct_value = ( xlogy(self.concentration1 - 1.0, value) + xlogy(self.concentration0 - 1.0, 1.0 - value) - betaln(self.concentration1, self.concentration0) ) # Apply correction at boundaries, return safe value elsewhere return jnp.where( is_boundary, jax.lax.stop_gradient(correct_value), safe_log_prob )
@property def mean(self) -> ArrayLike: r"""Calculates the analytical mean. .. math:: E[X] = \frac{\alpha}{\alpha + \beta} """ return self.concentration1 / (self.concentration1 + self.concentration0) @property def variance(self) -> ArrayLike: r"""Calculates the analytical variance. .. math:: Var(X) = \frac{\alpha \beta}{(\alpha + \beta)^2 (\alpha + \beta + 1)} """ total = self.concentration1 + self.concentration0 return self.concentration1 * self.concentration0 / (total**2 * (total + 1))
[docs] def cdf(self, value: ArrayLike) -> ArrayLike: r"""Cumulative distribution function using the regularized incomplete beta function. .. math:: I_x(\alpha, \beta) = \frac{\text{B}(x; \alpha, \beta)}{\text{B}(\alpha, \beta)} :param value: Value to evaluate. :type value: ArrayLike """ return betainc(self.concentration1, self.concentration0, value)
[docs] def icdf(self, q: ArrayLike) -> ArrayLike: r"""Inverse cumulative distribution function (Quantile function). :param q: Probability value in :math:`[0,1]`. :type q: ArrayLike """ return betaincinv(self.concentration1, self.concentration0, q)
[docs] def entropy(self) -> ArrayLike: r"""Entropy of the Beta distribution. .. math:: H(X) = \ln \text{B}(\alpha, \beta) - (\alpha - 1)\psi(\alpha) - (\beta - 1)\psi(\beta) + (\alpha + \beta - 2)\psi(\alpha + \beta) where :math:`\psi` is the digamma function. """ total = self.concentration0 + self.concentration1 return ( betaln(self.concentration0, self.concentration1) - (self.concentration0 - 1) * digamma(self.concentration0) - (self.concentration1 - 1) * digamma(self.concentration1) + (total - 2) * digamma(total) )
[docs] class Cauchy(Distribution): r"""Cauchy distribution parameterized by location (:attr:`loc`) and scale (:attr:`scale`). The probability density function (PDF) is defined as: .. math:: f(x; x_0, \gamma) = \frac{1}{\pi \gamma \left[1 + \left(\frac{x - x_0}{\gamma}\right)^2\right]} where :math:`x \in \mathbb{R}`, :math:`x_0 \in \mathbb{R}` is the location, and :math:`\gamma > 0` is the scale. The Cauchy distribution has no finite mean or variance. :param loc: Location parameter (:math:`x_0`). :type loc: ArrayLike :param scale: Scale parameter (:math:`\gamma`). :type scale: ArrayLike :param validate_args: Whether to validate input constraints, defaults to None. :type validate_args: bool, optional """ arg_constraints = {"loc": constraints.real, "scale": constraints.positive} support = constraints.real reparametrized_params = ["loc", "scale"] def __init__( self, loc: ArrayLike = 0.0, scale: ArrayLike = 1.0, *, validate_args: Optional[bool] = None, ) -> None: self.loc, self.scale = promote_shapes(loc, scale) batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale)) super(Cauchy, self).__init__( batch_shape=batch_shape, validate_args=validate_args )
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: r"""Generates samples using the inverse CDF method via :func:`~jax.random.cauchy`. :param key: JAX PRNGKey for reproducibility. :type key: jax.Array :param sample_shape: The shape of the samples to be generated. :type sample_shape: tuple[int, ...] :return: Samples from the Cauchy distribution of shape ``sample_shape + batch_shape``. :rtype: ArrayLike """ assert is_prng_key(key) eps = random.cauchy(key, shape=sample_shape + self.batch_shape) return self.loc + eps * self.scale
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: r"""Calculates the log of the probability density function. .. math:: \log f(x; x_0, \gamma) = -\log(\pi) - \log(\gamma) - \log\!\left[1 + \left(\frac{x - x_0}{\gamma}\right)^2\right] :param value: Values at which to evaluate the log density. :type value: ArrayLike :return: Log probability density. :rtype: ArrayLike """ return ( -jnp.log(jnp.pi) - jnp.log(self.scale) - jnp.log1p(((value - self.loc) / self.scale) ** 2) )
@property def mean(self) -> ArrayLike: r"""The mean of the Cauchy distribution is undefined. Returns ``NaN`` for all batch elements. """ return jnp.full(self.batch_shape, jnp.nan) @property def variance(self) -> ArrayLike: r"""The variance of the Cauchy distribution is undefined. Returns ``NaN`` for all batch elements. """ return jnp.full(self.batch_shape, jnp.nan)
[docs] def cdf(self, value: ArrayLike) -> ArrayLike: r"""Cumulative distribution function. .. math:: F(x; x_0, \gamma) = \frac{1}{\pi}\arctan\!\left(\frac{x - x_0}{\gamma}\right) + \frac{1}{2} :param value: Value to evaluate. :type value: ArrayLike """ scaled = (value - self.loc) / self.scale return jnp.arctan(scaled) / jnp.pi + 0.5
[docs] def icdf(self, q: ArrayLike) -> ArrayLike: r"""Inverse cumulative distribution function (Quantile function). .. math:: F^{-1}(q; x_0, \gamma) = x_0 + \gamma \tan\!\left[\pi\!\left(q - \frac{1}{2}\right)\right] :param q: Probability value in :math:`[0,1]`. :type q: ArrayLike """ return self.loc + self.scale * jnp.tan(jnp.pi * (q - 0.5))
[docs] def entropy(self) -> ArrayLike: r"""Entropy of the Cauchy distribution. .. math:: H(X) = \log(4\pi\gamma) """ return jnp.broadcast_to(jnp.log(4 * np.pi * self.scale), self.batch_shape)
[docs] class Dirichlet(Distribution): arg_constraints = { "concentration": constraints.independent(constraints.positive, 1) } reparametrized_params = ["concentration"] support = constraints.simplex def __init__( self, concentration: Array, *, validate_args: Optional[bool] = None, ) -> None: if jnp.ndim(concentration) < 1: raise ValueError( "`concentration` parameter must be at least one-dimensional." ) self.concentration = concentration batch_shape, event_shape = concentration.shape[:-1], concentration.shape[-1:] super(Dirichlet, self).__init__( batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args, )
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: assert is_prng_key(key) shape = sample_shape + self.batch_shape samples = random.dirichlet(key, self.concentration, shape=shape) return jnp.clip( samples, jnp.finfo(samples.dtype).tiny, 1 - jnp.finfo(samples.dtype).eps )
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: normalize_term = jnp.sum(gammaln(self.concentration), axis=-1) - gammaln( jnp.sum(self.concentration, axis=-1) ) return ( jnp.sum(jnp.log(value) * (self.concentration - 1.0), axis=-1) - normalize_term )
@property def mean(self) -> ArrayLike: return self.concentration / jnp.sum(self.concentration, axis=-1, keepdims=True) @property def variance(self) -> ArrayLike: con0 = jnp.sum(self.concentration, axis=-1, keepdims=True) return self.concentration * (con0 - self.concentration) / (con0**2 * (con0 + 1))
[docs] @staticmethod def infer_shapes(concentration): batch_shape = concentration[:-1] event_shape = concentration[-1:] return batch_shape, event_shape
[docs] def entropy(self) -> ArrayLike: (n,) = self.event_shape total = self.concentration.sum(axis=-1) return ( gammaln(self.concentration).sum(axis=-1) - gammaln(total) + (total - n) * digamma(total) - ((self.concentration - 1) * digamma(self.concentration)).sum(axis=-1) )
[docs] class EulerMaruyama(Distribution): """ Euler–Maruyama method is a method for the approximate numerical solution of a stochastic differential equation (SDE) :param ndarray t: discretized time :param callable sde_fn: function returning the drift and diffusion coefficients of SDE :param Distribution init_dist: Distribution for initial values. **References** [1] https://en.wikipedia.org/wiki/Euler-Maruyama_method """ arg_constraints = {"t": constraints.ordered_vector} pytree_data_fields = ("t", "init_dist") pytree_aux_fields = ("sde_fn",) def __init__( self, t: Array, sde_fn: Callable[[Array, Array], tuple[Array, Array]], init_dist: Distribution, *, validate_args: Optional[bool] = None, ) -> None: self.t = t self.sde_fn = sde_fn self.init_dist = init_dist if not isinstance(init_dist, Distribution): raise TypeError("Init distribution is expected to be Distribution class.") batch_shape_t = jnp.shape(t)[:-1] batch_shape = lax.broadcast_shapes(batch_shape_t, init_dist.batch_shape) event_shape = (jnp.shape(t)[-1],) + init_dist.event_shape super(EulerMaruyama, self).__init__( batch_shape, event_shape, validate_args=validate_args ) @constraints.dependent_property(is_discrete=False) def support(self) -> constraints.Constraint: return constraints.independent(constraints.real, self.event_dim)
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: assert is_prng_key(key) batch_shape = sample_shape + self.batch_shape def step(y_curr, xs): noise_curr, t_curr, dt_curr = xs f, g = self.sde_fn(y_curr, t_curr) mu = y_curr + dt_curr * f sigma = jnp.sqrt(dt_curr) * g y_next = mu + sigma * noise_curr return y_next, y_next rng_noise, rng_init = random.split(key) noises = random.normal( rng_noise, shape=batch_shape + (self.event_shape[0] - 1,) + self.event_shape[1:], ) inits = self.init_dist.expand(batch_shape).sample(rng_init) def scan_fn(init, noise, tm1, dt): return scan(step, init, (noise, tm1, dt)) batch_dim = len(batch_shape) if batch_dim: inits = inits.reshape((-1,) + inits.shape[batch_dim:]) noises = noises.reshape((-1,) + noises.shape[batch_dim:]) t = jnp.broadcast_to(self.t, batch_shape + (self.event_shape[0],)) t = t.reshape((-1,) + t.shape[batch_dim:]) dt = jnp.diff(t, axis=-1) _, sde_out = vmap(scan_fn)(inits, noises, t[..., :-1], dt) sde_out = jnp.concatenate([inits[:, None], sde_out], axis=1) sde_out = jnp.reshape(sde_out, batch_shape + self.event_shape) else: dt = jnp.diff(self.t, axis=-1) _, sde_out = scan_fn(inits, noises, self.t[:-1], dt) sde_out = jnp.concatenate([inits[None], sde_out], axis=0) return sde_out
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: sample_shape = lax.broadcast_shapes( value.shape[: -self.event_dim], self.batch_shape ) value = jnp.broadcast_to(value, sample_shape + self.event_shape) if sample_shape: reshaped_value = value.reshape((-1,) + self.event_shape) xtm1, xt = reshaped_value[:, :-1], reshaped_value[:, 1:] value0 = reshaped_value[:, 0] t = jnp.broadcast_to(self.t, sample_shape + (self.event_shape[0],)) t = t.reshape((-1,) + (self.event_shape[0],)) f, g = vmap(vmap(self.sde_fn))(xtm1, t[:, :-1]) f = f.reshape(sample_shape + f.shape[1:]) g = g.reshape(sample_shape + g.shape[1:]) xtm1 = xtm1.reshape(sample_shape + xtm1.shape[1:]) xt = xt.reshape(sample_shape + xt.shape[1:]) value0 = value0.reshape(sample_shape + value0.shape[1:]) else: xtm1, xt = value[:-1], value[1:] value0 = value[0] f, g = vmap(self.sde_fn)(xtm1, self.t[:-1]) # add missing event dimensions batch_dim = len(sample_shape) f = f.reshape( f.shape[: batch_dim + 1] + (1,) * (xt.ndim - f.ndim) + f.shape[batch_dim + 1 :] ) g = g.reshape( g.shape[: batch_dim + 1] + (1,) * (xt.ndim - g.ndim) + g.shape[batch_dim + 1 :] ) dt = jnp.diff(self.t, axis=-1) dt = dt.reshape(dt.shape + (1,) * (self.event_dim - 1)) mu = xtm1 + dt * f sigma = jnp.sqrt(dt) * g sde_log_prob = Normal(mu, sigma).to_event(self.event_dim).log_prob(xt) init_log_prob = self.init_dist.log_prob(value0) return sde_log_prob + init_log_prob
[docs] class Exponential(Distribution): r"""Exponential distribution parameterized by rate (:attr:`rate`). The probability density function (PDF) is defined as: .. math:: f(x; \lambda) = \lambda e^{-\lambda x} where :math:`x \geq 0` and :math:`\lambda > 0` is the rate parameter. :param rate: Rate parameter (:math:`\lambda`), the inverse of the mean. :type rate: ArrayLike :param validate_args: Whether to validate input constraints, defaults to None. :type validate_args: bool, optional """ reparametrized_params = ["rate"] arg_constraints = {"rate": constraints.positive} support = constraints.positive def __init__( self, rate: ArrayLike = 1.0, *, validate_args: Optional[bool] = None, ) -> None: self.rate = rate super(Exponential, self).__init__( batch_shape=jnp.shape(rate), validate_args=validate_args )
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: r"""Generates samples by scaling standard exponential draws by the inverse rate: :math:`X = E / \lambda`, where :math:`E \sim \mathrm{Exp}(1)`. :param key: JAX PRNGKey for reproducibility. :type key: jax.Array :param sample_shape: The shape of the samples to be generated. :type sample_shape: tuple[int, ...] :return: Samples from the Exponential distribution of shape ``sample_shape + batch_shape``. :rtype: ArrayLike """ assert is_prng_key(key) return ( random.exponential(key, shape=sample_shape + self.batch_shape) / self.rate )
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: r"""Calculates the log of the probability density function. .. math:: \log f(x; \lambda) = \log \lambda - \lambda x :param value: Values at which to evaluate the log density. :type value: ArrayLike :return: Log probability density. :rtype: ArrayLike """ return jnp.log(self.rate) - self.rate * value
@property def mean(self) -> ArrayLike: r"""Calculates the analytical mean. .. math:: E[X] = \frac{1}{\lambda} """ return jnp.reciprocal(self.rate) @property def variance(self) -> ArrayLike: r"""Calculates the analytical variance. .. math:: \mathrm{Var}(X) = \frac{1}{\lambda^2} """ return jnp.reciprocal(self.rate**2)
[docs] def cdf(self, value: ArrayLike) -> ArrayLike: r"""Cumulative distribution function. .. math:: F(x; \lambda) = 1 - e^{-\lambda x} :param value: Value to evaluate. :type value: ArrayLike """ return -jnp.expm1(-self.rate * value)
[docs] def icdf(self, q: ArrayLike) -> ArrayLike: r"""Inverse cumulative distribution function (Quantile function). .. math:: F^{-1}(q; \lambda) = -\frac{\ln(1 - q)}{\lambda} :param q: Probability value in :math:`[0,1]`. :type q: ArrayLike """ return -jnp.log1p(-q) / self.rate
[docs] def entropy(self) -> ArrayLike: r"""Entropy of the Exponential distribution. .. math:: H(X) = 1 - \ln \lambda """ return 1 - jnp.log(self.rate)
[docs] class Gamma(Distribution): r"""Implementation of the `Gamma distribution <https://en.wikipedia.org/wiki/Gamma_distribution>`_, :math:`\mathrm{Gamma}(\alpha, \lambda)`, where, :math:`\alpha` is the concentration and :math:`\lambda` is the rate. :param ArrayLike concentration: concentration parameter :math:`\alpha` (also known as shape parameter). :param ArrayLike rate: rate parameter :math:`\lambda` (inverse scale parameter). """ arg_constraints = { "concentration": constraints.positive, "rate": constraints.positive, } support = constraints.positive reparametrized_params = ["concentration", "rate"] def __init__( self, concentration: ArrayLike, rate: ArrayLike = 1.0, *, validate_args: Optional[bool] = None, ) -> None: self.concentration, self.rate = promote_shapes(concentration, rate) batch_shape = lax.broadcast_shapes(jnp.shape(concentration), jnp.shape(rate)) super(Gamma, self).__init__( batch_shape=batch_shape, validate_args=validate_args )
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: r"""Method to generate samples :math:`X \sim \mathrm{Gamma}(\alpha, \lambda)`. It uses :func:`~jax.random.gamma` under the hood to generate samples. """ assert is_prng_key(key) shape = sample_shape + self.batch_shape + self.event_shape return random.gamma(key, self.concentration, shape=shape) / self.rate
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: r"""If :math:`X \sim \mathrm{Gamma}(\alpha, \lambda)`, then .. math:: f_X(x\mid \alpha, \lambda) = \frac{\lambda^{\alpha} x^{\alpha - 1} e^{-\lambda x}}{\Gamma(\alpha)}, \quad x > 0 It uses :func:`~jax.scipy.special.gammaln` to compute the logarithm of the gamma function. """ normalize_term = gammaln(self.concentration) - self.concentration * jnp.log( self.rate ) return ( (self.concentration - 1) * jnp.log(value) - self.rate * value - normalize_term )
@property def mean(self) -> ArrayLike: r"""If :math:`X \sim \mathrm{Gamma}(\alpha, \lambda)`, then .. math:: \mathbb{E}[X] = \frac{\alpha}{\lambda} """ return self.concentration / self.rate @property def variance(self) -> ArrayLike: r"""If :math:`X \sim \mathrm{Gamma}(\alpha, \lambda)`, then .. math:: \mathrm{Var}(X) = \frac{\alpha}{\lambda^2} """ return self.concentration / jnp.power(self.rate, 2)
[docs] def cdf(self, x): r"""If :math:`X \sim \mathrm{Gamma}(\alpha, \lambda)`, then .. math:: F_X(x \mid \alpha, \lambda) = \frac{1}{\Gamma(\alpha)} \gamma\left(\alpha, \lambda x\right) where, :math:`\gamma(\cdot,\cdot)` is the `lower incomplete gamma function <https://en.wikipedia.org/wiki/Incomplete_gamma_function>`_. This method uses regularized incomplete gamma function, which is implemented as :func:`~jax.scipy.special.gammainc`. """ return gammainc(self.concentration, self.rate * x)
[docs] def icdf(self, q: ArrayLike) -> ArrayLike: r"""If :math:`X \sim \mathrm{Gamma}(\alpha, \lambda)`, then .. math:: F^{-1}_X(q \mid \alpha, \lambda) = \frac{1}{\lambda} \gamma^{-1}\left(\alpha, q \Gamma(\alpha)\right) where, :math:`\gamma^{-1}(\cdot,\cdot)` is the inverse of the lower incomplete gamma function. This method uses regularized incomplete gamma inverse function, which is implemented as :func:`~numpyro.distributions.util.gammaincinv`. """ return gammaincinv(self.concentration, q) / self.rate
[docs] def entropy(self) -> ArrayLike: r"""If :math:`X \sim \mathrm{Gamma}(\alpha, \lambda)`, then .. math:: H[X] = \alpha - \ln(\lambda) + \ln\Gamma(\alpha) + (1 - \alpha) \psi(\alpha) where, :math:`\psi(\cdot)` is the `digamma function <https://en.wikipedia.org/wiki/Digamma_function>`_. This methods uses which is implemented as :func:`~jax.scipy.special.digamma`. """ return ( self.concentration - jnp.log(self.rate) + gammaln(self.concentration) + (1 - self.concentration) * digamma(self.concentration) )
[docs] class Chi2(Gamma): r"""A chi-square continuous random variable, parameterized by the degrees of freedom :math:`k`. The Probability Density Function (PDF) of the chi-square distribution with :math:`k` degrees of freedom is defined as: .. math:: f(x; k) = \frac{x^{k/2 - 1} e^{-x/2}}{2^{k/2}\,\Gamma(k/2)}, \quad x > 0 Where, :math:`k` represents the degrees of freedom (:attr:`df`), :math:`\Gamma(\cdot)` is the gamma function, and :math:`x` is the observed value. The support domain is :math:`x \in (0, \infty)`. The chi-square distribution is a special case of the Gamma distribution: .. math:: \chi^2(k) \equiv \mathrm{Gamma}(k/2,\; 1/2) so this class inherits sampling, log-probability, mean, variance, and entropy implementations from :class:`Gamma`. """ arg_constraints = {"df": constraints.positive} reparametrized_params = ["df"] def __init__(self, df: ArrayLike, *, validate_args: Optional[bool] = None) -> None: r""" :param df: Degrees of freedom parameter :math:`k > 0` (:attr:`df`). :param validate_args: If True, enforce domain constraints during initialization. """ self.df = df super(Chi2, self).__init__(0.5 * df, 0.5, validate_args=validate_args)
[docs] class GaussianStateSpace(TransformedDistribution): r""" Gaussian state space model. .. math:: \mathbf{z}_{t} &= \mathbf{A} \mathbf{z}_{t - 1} + \boldsymbol{\epsilon}_t\\ &= \mathbf{A}^t \mathbf{z}_0 + \sum_{k=1}^{t} \mathbf{A}^{t-k} \boldsymbol{\epsilon}_k, where :math:`\mathbf{z}_t` is the state vector at step :math:`t`, :math:`\mathbf{A}` is the transition matrix, :math:`\mathbf{z}_0` is the initial value, and :math:`\boldsymbol\epsilon` is the innovation noise. :param num_steps: Number of steps. :param transition_matrix: State space transition matrix :math:`\mathbf{A}`. :param covariance_matrix: Covariance of the innovation noise :math:`\boldsymbol\epsilon`. :param precision_matrix: Precision matrix of the innovation noise :math:`\boldsymbol\epsilon`. :param scale_tril: Scale matrix of the innovation noise :math:`\boldsymbol\epsilon`. :param initial_value: Initial state vector :math:`\mathbf{z}_0`. If ``None``, defaults to zero. """ arg_constraints = { "covariance_matrix": constraints.positive_definite, "precision_matrix": constraints.positive_definite, "scale_tril": constraints.lower_cholesky, "transition_matrix": constraints.real_matrix, "initial_value": constraints.real_vector, } support = constraints.real_matrix pytree_data_fields = ("transition_matrix", "_initial_value", "scale_tril") pytree_aux_fields = ("num_steps",) def __init__( self, num_steps: int, transition_matrix: Array, covariance_matrix: Optional[Array] = None, precision_matrix: Optional[Array] = None, scale_tril: Optional[Array] = None, initial_value: Optional[Array] = None, *, validate_args: Optional[bool] = None, ) -> None: assert isinstance(num_steps, int) and num_steps > 0, ( "`num_steps` argument should be a positive integer." ) self.num_steps = num_steps assert transition_matrix.ndim == 2, ( "`transition_matrix` argument should be a square matrix" ) self.transition_matrix = transition_matrix self._initial_value = initial_value # Expand the covariance/precision/scale matrices to the right number of steps. args = { "covariance_matrix": covariance_matrix, "precision_matrix": precision_matrix, "scale_tril": scale_tril, } args = { key: jnp.expand_dims(value, axis=-3).repeat(num_steps, axis=-3) for key, value in args.items() if value is not None } base_distribution = MultivariateNormal(**args) self.scale_tril = base_distribution.scale_tril[..., 0, :, :] base_distribution = base_distribution.to_event(1) # The base distribution must have at least the same batch shape as the initial # value. if initial_value is not None: batch_shape = initial_value.shape[:-1] base_distribution = base_distribution.expand(batch_shape) transform = RecursiveLinearTransform( transition_matrix, initial_value=initial_value ) super().__init__(base_distribution, transform, validate_args=validate_args) @property def initial_value(self) -> Array: if self._initial_value is None: return jnp.zeros(self.transition_matrix.shape[-1:]) return self._initial_value @property def mean(self) -> ArrayLike: # If there's no initial value, the mean is zero (base distribution mean). if self._initial_value is None: return self.base_dist.mean # Otherwise, compute A^t @ z_0 for each time step t. # z_t = A @ z_{t-1} for the deterministic part with z_0 = initial_value def propagate(z, _): z_next = jnp.einsum("...ij,...j->...i", self.transition_matrix, z) return z_next, z_next _, means = scan(propagate, self.initial_value, jnp.arange(self.num_steps)) # means has shape (num_steps, ..., state_dim) # We need to move num_steps to axis -2 to match base_dist.mean shape return jnp.moveaxis(means, 0, -2) @property def variance(self) -> ArrayLike: # Given z_t = z_0 + \sum_{k=1}^t A^{t-k} \epsilon_t, the covariance of the state # vector at step t is E[z_t transpose(z_t)] = \sum_{k,k'}^t A^{t-k} # E[\epsilon_k transpose(\epsilon_{k'})] transpose(A^{t-k'}). We only have # contributions for k = k' because innovations at different steps are # independent such that E[z_t transpose(z_t)] = \sum_k^t A^{t-k} @ # @ covariance_matrix @ transpose(A^{t-k}). The initial value is deterministic, # and we don't need to consider it here. Using `scan` is an easy way to evaluate # this expression. _, scale_tril = scan( lambda carry, _: (self.transition_matrix @ carry, carry), self.scale_tril, jnp.arange(self.num_steps), ) return ( jnp.diagonal(scale_tril @ scale_tril.mT, axis1=-1, axis2=-2) .cumsum(axis=0) .swapaxes(0, -2) )
[docs] @lazy_property def covariance_matrix(self): return self.scale_tril @ self.scale_tril.mT
[docs] @lazy_property def precision_matrix(self): identity = jnp.broadcast_to( jnp.eye(self.scale_tril.shape[-1]), self.scale_tril.shape ) return cho_solve((self.scale_tril, True), identity)
[docs] class GaussianRandomWalk(Distribution): arg_constraints = {"scale": constraints.positive} support = constraints.real_vector reparametrized_params = ["scale"] pytree_aux_fields = ("num_steps",) def __init__( self, scale: ArrayLike = 1.0, num_steps: int = 1, *, validate_args: Optional[bool] = None, ) -> None: assert isinstance(num_steps, int) and num_steps > 0, ( "`num_steps` argument should be a positive integer." ) self.scale = scale self.num_steps = num_steps batch_shape, event_shape = jnp.shape(scale), (num_steps,) super(GaussianRandomWalk, self).__init__( batch_shape, event_shape, validate_args=validate_args )
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: assert is_prng_key(key) shape = sample_shape + self.batch_shape + self.event_shape walks = random.normal(key, shape=shape) return jnp.cumsum(walks, axis=-1) * jnp.expand_dims(self.scale, axis=-1)
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: init_prob = Normal(0.0, self.scale).log_prob(value[..., 0]) scale = jnp.expand_dims(self.scale, -1) step_probs = Normal(value[..., :-1], scale).log_prob(value[..., 1:]) return init_prob + jnp.sum(step_probs, axis=-1)
@property def mean(self) -> ArrayLike: return jnp.zeros(self.batch_shape + self.event_shape) @property def variance(self) -> ArrayLike: return jnp.broadcast_to( jnp.expand_dims(self.scale, -1) ** 2 * jnp.arange(1, self.num_steps + 1), self.batch_shape + self.event_shape, )
[docs] class HalfCauchy(Distribution): reparametrized_params = ["scale"] support = constraints.positive arg_constraints = {"scale": constraints.positive} pytree_data_fields = ("_cauchy", "scale") def __init__( self, scale: ArrayLike = 1.0, *, validate_args: Optional[bool] = None, ) -> None: self._cauchy = Cauchy(0.0, scale) self.scale = scale super(HalfCauchy, self).__init__( batch_shape=jnp.shape(scale), validate_args=validate_args )
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: assert is_prng_key(key) return jnp.abs(self._cauchy.sample(key, sample_shape))
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: return self._cauchy.log_prob(value) + jnp.log(2)
[docs] def cdf(self, value: ArrayLike) -> ArrayLike: return self._cauchy.cdf(value) * 2 - 1
[docs] def icdf(self, q: ArrayLike) -> ArrayLike: return self._cauchy.icdf((q + 1) / 2)
@property def mean(self) -> ArrayLike: return jnp.full(self.batch_shape, jnp.inf) @property def variance(self) -> ArrayLike: return jnp.full(self.batch_shape, jnp.inf)
[docs] class HalfNormal(Distribution): reparametrized_params = ["scale"] support = constraints.positive arg_constraints = {"scale": constraints.positive} pytree_data_fields = ("_normal", "scale") def __init__( self, scale: ArrayLike = 1.0, *, validate_args: Optional[bool] = None, ) -> None: self._normal = Normal(0.0, scale) self.scale = scale super(HalfNormal, self).__init__( batch_shape=jnp.shape(scale), validate_args=validate_args )
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> Array: assert is_prng_key(key) return jnp.abs(self._normal.sample(key, sample_shape))
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: return self._normal.log_prob(value) + jnp.log(2)
[docs] def cdf(self, value: ArrayLike) -> ArrayLike: return self._normal.cdf(value) * 2 - 1
[docs] def icdf(self, q: ArrayLike) -> ArrayLike: return self._normal.icdf((q + 1) / 2)
@property def mean(self) -> ArrayLike: return jnp.sqrt(2 / jnp.pi) * self.scale @property def variance(self) -> ArrayLike: return (1 - 2 / jnp.pi) * self.scale**2
[docs] class InverseGamma(TransformedDistribution): """ .. note:: We keep the same notation `rate` as in Pyro but it plays the role of scale parameter of InverseGamma in literatures (e.g. wikipedia: https://en.wikipedia.org/wiki/Inverse-gamma_distribution) """ arg_constraints = { "concentration": constraints.positive, "rate": constraints.positive, } reparametrized_params = ["concentration", "rate"] support = constraints.positive def __init__( self, concentration: ArrayLike, rate: ArrayLike = 1.0, *, validate_args: Optional[bool] = None, ) -> None: base_dist = Gamma(concentration, rate) self.concentration = base_dist.concentration self.rate = base_dist.rate super(InverseGamma, self).__init__( base_dist, PowerTransform(-1.0), validate_args=validate_args ) @property def mean(self) -> ArrayLike: # mean is inf for alpha <= 1 a = self.rate / (self.concentration - 1) return jnp.where(self.concentration <= 1, jnp.inf, a) @property def variance(self) -> ArrayLike: # var is inf for alpha <= 2 a = (self.rate / (self.concentration - 1)) ** 2 / (self.concentration - 2) return jnp.where(self.concentration <= 2, jnp.inf, a)
[docs] def entropy(self) -> ArrayLike: return ( self.concentration + jnp.log(self.rate) + gammaln(self.concentration) - (1 + self.concentration) * digamma(self.concentration) )
[docs] class Gompertz(Distribution): r"""Gompertz Distribution. The Gompertz distribution is a distribution with support on the positive real line that is closely related to the Gumbel distribution. This implementation follows the notation used in the Wikipedia entry for the Gompertz distribution. See https://en.wikipedia.org/wiki/Gompertz_distribution. However, we call the parameter "eta" a concentration parameter and the parameter "b" a rate parameter (as opposed to scale parameter as in wikipedia description.) The CDF, in terms of `concentration` (`con`) and `rate`, is .. math:: F(x) = 1 - \exp \left\{ - \text{con} * \left [ \exp\{x * rate \} - 1 \right ] \right\} """ arg_constraints = { "concentration": constraints.positive, "rate": constraints.positive, } support = constraints.positive reparametrized_params = ["concentration", "rate"] def __init__( self, concentration: ArrayLike, rate: ArrayLike = 1.0, *, validate_args: Optional[bool] = None, ) -> None: self.concentration, self.rate = promote_shapes(concentration, rate) super(Gompertz, self).__init__( batch_shape=lax.broadcast_shapes(jnp.shape(concentration), jnp.shape(rate)), validate_args=validate_args, )
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: assert is_prng_key(key) random_shape = sample_shape + self.batch_shape + self.event_shape unifs = random.uniform(key, shape=random_shape) return self.icdf(unifs)
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: scaled_value = value * self.rate return ( jnp.log(self.concentration) + jnp.log(self.rate) + scaled_value - self.concentration * jnp.expm1(scaled_value) )
[docs] def cdf(self, value: ArrayLike) -> ArrayLike: return -jnp.expm1(-self.concentration * jnp.expm1(value * self.rate))
[docs] def icdf(self, q: ArrayLike) -> ArrayLike: return jnp.log1p(-jnp.log1p(-q) / self.concentration) / self.rate
@property def mean(self) -> ArrayLike: return -jnp.exp(self.concentration) * expi(-self.concentration) / self.rate
[docs] class Gumbel(Distribution): r"""The Gumbel (maximum) distribution, a continuous real-valued distribution parameterized by location :math:`\mu` and scale :math:`\beta > 0`. It is the limiting distribution of the maximum of a large number of i.i.d. samples from an exponential-tailed distribution. The Probability Density Function (PDF) is: .. math:: f(x \mid \mu, \beta) = \frac{1}{\beta} \exp\!\left( -\frac{x - \mu}{\beta} - \exp\!\left(-\frac{x - \mu}{\beta}\right) \right), \quad x \in \mathbb{R} where :math:`\mu \in \mathbb{R}` is the location (:attr:`loc`) and :math:`\beta > 0` is the scale (:attr:`scale`). """ arg_constraints = {"loc": constraints.real, "scale": constraints.positive} support = constraints.real reparametrized_params = ["loc", "scale"] def __init__( self, loc: ArrayLike = 0.0, scale: ArrayLike = 1.0, *, validate_args: Optional[bool] = None, ) -> None: r""" :param loc: Location parameter :math:`\mu \in \mathbb{R}`. Defaults to ``0.0``. :param scale: Scale parameter :math:`\beta > 0`. Defaults to ``1.0``. :param validate_args: If True, enforce domain constraints during initialization. """ self.loc, self.scale = promote_shapes(loc, scale) batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale)) super(Gumbel, self).__init__( batch_shape=batch_shape, validate_args=validate_args )
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: r"""Draw samples from the Gumbel distribution via the location-scale transform :math:`X = \mu + \beta Z`, where :math:`Z \sim \mathrm{Gumbel}(0, 1)` is drawn from :func:`~jax.random.gumbel`. :param key: A JAX PRNG key. :param sample_shape: Sample dimensions to prepend to the batch shape. :return: Real-valued samples from the Gumbel distribution. """ assert is_prng_key(key) standard_gumbel_sample = random.gumbel( key, shape=sample_shape + self.batch_shape + self.event_shape ) return self.loc + self.scale * standard_gumbel_sample
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: r"""Evaluate the log probability density function at ``value``. Letting :math:`z = (x - \mu)/\beta`, .. math:: \ln f(x \mid \mu, \beta) = -z - e^{-z} - \ln \beta :param value: Real-valued point :math:`x` at which to evaluate the log PDF. :return: Log probability density evaluated under the Gumbel distribution. """ z = (value - self.loc) / self.scale return -(z + jnp.exp(-z)) - jnp.log(self.scale)
@property def mean(self) -> ArrayLike: r"""Mean of the Gumbel distribution: .. math:: \mathbb{E}[X] = \mu + \beta \gamma where :math:`\gamma \approx 0.5772\ldots` is the Euler-Mascheroni constant, available at, :data:`~jax.numpy.euler_gamma`. """ return jnp.broadcast_to( self.loc + self.scale * jnp.euler_gamma, self.batch_shape ) @property def variance(self) -> ArrayLike: r"""Variance of the Gumbel distribution: .. math:: \mathrm{Var}(X) = \frac{\pi^2}{6} \beta^2 """ return jnp.broadcast_to(jnp.pi**2 / 6.0 * self.scale**2, self.batch_shape)
[docs] def cdf(self, value: ArrayLike) -> ArrayLike: r"""Cumulative Distribution Function (CDF) of the Gumbel distribution: .. math:: F(x \mid \mu, \beta) = \exp\!\left(-\exp\!\left(-\frac{x - \mu}{\beta}\right)\right) :param value: Real-valued point :math:`x` at which to evaluate the CDF. :return: CDF values in :math:`[0, 1]`. """ return jnp.exp(-jnp.exp((self.loc - value) / self.scale))
[docs] def icdf(self, q: ArrayLike) -> ArrayLike: r"""Inverse CDF (quantile function) of the Gumbel distribution: .. math:: F^{-1}(q \mid \mu, \beta) = \mu - \beta \ln(-\ln q), \quad q \in (0, 1) :param q: Quantile values in :math:`(0, 1)`. :return: Real-valued quantiles of the Gumbel distribution at ``q``. """ return self.loc - self.scale * jnp.log(-jnp.log(q))
[docs] class Kumaraswamy(Distribution): arg_constraints = { "concentration1": constraints.positive, "concentration0": constraints.positive, } reparametrized_params = ["concentration1", "concentration0"] support = constraints.unit_interval # XXX: This flag is used to approximate the Taylor expansion # of KL(Kumaraswamy||Beta) following # https://arxiv.org/abs/1605.06197 Formula (12) # We follow the paper and set this to 10 but to get more precise KL, # we can set this flag to 1000. KL_KUMARASWAMY_BETA_TAYLOR_ORDER = 10 def __init__( self, concentration1: ArrayLike, concentration0: ArrayLike, *, validate_args: Optional[bool] = None, ) -> None: self.concentration1, self.concentration0 = promote_shapes( concentration1, concentration0 ) batch_shape = lax.broadcast_shapes( jnp.shape(concentration1), jnp.shape(concentration0) ) super().__init__(batch_shape=batch_shape, validate_args=validate_args)
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: assert is_prng_key(key) finfo = jnp.finfo(jnp.result_type(float)) u = random.uniform( key, shape=sample_shape + self.batch_shape, minval=finfo.tiny ) u_con0 = jnp.clip(u ** (1 / self.concentration0), None, 1 - finfo.eps) log_sample = jnp.log1p(-u_con0) / self.concentration1 return jnp.clip(jnp.exp(log_sample), finfo.tiny, 1 - finfo.eps)
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: finfo = jnp.finfo(jnp.result_type(float)) normalize_term = jnp.log(self.concentration0) + jnp.log(self.concentration1) value_con1 = jnp.clip(value**self.concentration1, None, 1 - finfo.eps) return ( xlogy(self.concentration1 - 1, value) + xlog1py(self.concentration0 - 1, -value_con1) + normalize_term )
@property def mean(self) -> ArrayLike: log_beta = betaln(1 + 1 / self.concentration1, self.concentration0) return self.concentration0 * jnp.exp(log_beta) @property def variance(self) -> ArrayLike: log_beta = betaln(1 + 2 / self.concentration1, self.concentration0) return self.concentration0 * jnp.exp(log_beta) - jnp.square(self.mean)
[docs] class Laplace(Distribution): r"""The Laplace (double-exponential) distribution, a continuous real-valued distribution parameterized by location :math:`\mu` and scale :math:`b > 0`. It is the distribution of the difference of two i.i.d. exponential variates and has heavier tails than the Normal distribution. The Probability Density Function (PDF) is: .. math:: f(x \mid \mu, b) = \frac{1}{2 b} \exp\!\left(-\frac{|x - \mu|}{b}\right), \quad x \in \mathbb{R} where :math:`\mu \in \mathbb{R}` is the location (:attr:`loc`) and :math:`b > 0` is the scale (:attr:`scale`). """ arg_constraints = {"loc": constraints.real, "scale": constraints.positive} support = constraints.real reparametrized_params = ["loc", "scale"] def __init__( self, loc: ArrayLike = 0.0, scale: ArrayLike = 1.0, *, validate_args: Optional[bool] = None, ) -> None: r""" :param loc: Location parameter :math:`\mu \in \mathbb{R}`. Defaults to ``0.0``. :param scale: Scale parameter :math:`b > 0`. Defaults to ``1.0``. :param validate_args: If True, enforce domain constraints during initialization. """ self.loc, self.scale = promote_shapes(loc, scale) batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale)) super(Laplace, self).__init__( batch_shape=batch_shape, validate_args=validate_args )
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: r"""Draw samples via the location-scale transform :math:`X = \mu + b Z`, where :math:`Z \sim \mathrm{Laplace}(0, 1)` is drawn from :func:`~jax.random.laplace`. :param key: A JAX PRNG key. :param sample_shape: Sample dimensions to prepend to the batch shape. :return: Real-valued samples from the Laplace distribution. """ assert is_prng_key(key) eps = random.laplace( key, shape=sample_shape + self.batch_shape + self.event_shape ) return self.loc + eps * self.scale
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: r"""Evaluate the log probability density function at ``value``: .. math:: \ln f(x \mid \mu, b) = -\frac{|x - \mu|}{b} - \ln(2 b) :param value: Real-valued point :math:`x` at which to evaluate the log PDF. :return: Log probability density evaluated under the Laplace distribution. """ normalize_term = jnp.log(2 * self.scale) value_scaled = jnp.abs(value - self.loc) / self.scale return -value_scaled - normalize_term
@property def mean(self) -> ArrayLike: r"""Mean of the Laplace distribution: .. math:: \mathbb{E}[X] = \mu """ return jnp.broadcast_to(self.loc, self.batch_shape) @property def variance(self) -> ArrayLike: r"""Variance of the Laplace distribution: .. math:: \mathrm{Var}(X) = 2 b^2 """ return jnp.broadcast_to(2 * self.scale**2, self.batch_shape)
[docs] def cdf(self, value: ArrayLike) -> ArrayLike: r"""Cumulative Distribution Function (CDF) of the Laplace distribution. Letting :math:`z = (x - \mu)/b`, .. math:: F(x \mid \mu, b) = \frac{1}{2} - \frac{1}{2}\, \operatorname{sgn}(z)\,\left(e^{-|z|} - 1\right) The implementation uses :func:`~jax.numpy.expm1` for numerical stability near :math:`z = 0`. :param value: Real-valued point :math:`x` at which to evaluate the CDF. :return: CDF values in :math:`[0, 1]`. """ scaled = (value - self.loc) / self.scale return 0.5 - 0.5 * jnp.sign(scaled) * jnp.expm1(-jnp.abs(scaled))
[docs] def icdf(self, q: ArrayLike) -> ArrayLike: r"""Inverse CDF (quantile function) of the Laplace distribution: .. math:: F^{-1}(q \mid \mu, b) = \mu - b\,\mathrm{sgn}\left(q - \frac{1}{2}\right)\, \ln\!\left(1 - 2 \left| q - \frac{1}{2} \right| \right), \quad q \in [0, 1] :param q: Quantile values in :math:`[0, 1]`. :return: Real-valued quantiles of the Laplace distribution at ``q``. """ a = q - 0.5 return self.loc - self.scale * jnp.sign(a) * jnp.log1p(-2 * jnp.abs(a))
[docs] def entropy(self) -> ArrayLike: r"""Differential entropy of the Laplace distribution: .. math:: H(X) = \ln(2 b) + 1 """ return jnp.log(2 * self.scale) + 1
[docs] class LKJ(TransformedDistribution): r""" LKJ distribution for correlation matrices. The distribution is controlled by ``concentration`` parameter :math:`\eta` to make the probability of the correlation matrix :math:`M` proportional to :math:`\det(M)^{\eta - 1}`. Because of that, when ``concentration == 1``, we have a uniform distribution over correlation matrices. When ``concentration > 1``, the distribution favors samples with large large determinent. This is useful when we know a priori that the underlying variables are not correlated. When ``concentration < 1``, the distribution favors samples with small determinent. This is useful when we know a priori that some underlying variables are correlated. Sample code for using LKJ in the context of multivariate normal sample:: def model(y): # y has dimension N x d d = y.shape[1] N = y.shape[0] # Vector of variances for each of the d variables theta = numpyro.sample("theta", dist.HalfCauchy(jnp.ones(d))) concentration = jnp.ones(1) # Implies a uniform distribution over correlation matrices corr_mat = numpyro.sample("corr_mat", dist.LKJ(d, concentration)) sigma = jnp.sqrt(theta) # we can also use a faster formula `cov_mat = jnp.outer(sigma, sigma) * corr_mat` cov_mat = jnp.matmul(jnp.matmul(jnp.diag(sigma), corr_mat), jnp.diag(sigma)) # Vector of expectations mu = jnp.zeros(d) with numpyro.plate("observations", N): obs = numpyro.sample("obs", dist.MultivariateNormal(mu, covariance_matrix=cov_mat), obs=y) return obs :param int dimension: dimension of the matrices :param ndarray concentration: concentration/shape parameter of the distribution (often referred to as eta) :param str sample_method: Either "cvine" or "onion". Both methods are proposed in [1] and offer the same distribution over correlation matrices. But they are different in how to generate samples. Defaults to "onion". **References** [1] `Generating random correlation matrices based on vines and extended onion method`, Daniel Lewandowski, Dorota Kurowicka, Harry Joe """ arg_constraints = {"concentration": constraints.positive} reparametrized_params = ["concentration"] support = constraints.corr_matrix pytree_aux_fields = ("dimension", "sample_method") def __init__( self, dimension: int, concentration: ArrayLike = 1.0, sample_method: Literal["onion", "cvine"] = "onion", *, validate_args: Optional[bool] = None, ) -> None: base_dist = LKJCholesky(dimension, concentration, sample_method) self.dimension, self.concentration = ( base_dist.dimension, base_dist.concentration, ) self.sample_method = sample_method super(LKJ, self).__init__( base_dist, CorrMatrixCholeskyTransform().inv, validate_args=validate_args ) @property def mean(self) -> ArrayLike: return jnp.broadcast_to( jnp.identity(self.dimension), self.batch_shape + (self.dimension, self.dimension), )
[docs] class LKJCholesky(Distribution): r""" LKJ distribution for lower Cholesky factors of correlation matrices. The distribution is controlled by ``concentration`` parameter :math:`\eta` to make the probability of the correlation matrix :math:`M` generated from a Cholesky factor proportional to :math:`\det(M)^{\eta - 1}`. Because of that, when ``concentration == 1``, we have a uniform distribution over Cholesky factors of correlation matrices. When ``concentration > 1``, the distribution favors samples with large diagonal entries (hence large determinent). This is useful when we know a priori that the underlying variables are not correlated. When ``concentration < 1``, the distribution favors samples with small diagonal entries (hence small determinent). This is useful when we know a priori that some underlying variables are correlated. Sample code for using LKJCholesky in the context of multivariate normal sample:: def model(y): # y has dimension N x d d = y.shape[1] N = y.shape[0] # Vector of variances for each of the d variables theta = numpyro.sample("theta", dist.HalfCauchy(jnp.ones(d))) # Lower cholesky factor of a correlation matrix concentration = jnp.ones(1) # Implies a uniform distribution over correlation matrices L_omega = numpyro.sample("L_omega", dist.LKJCholesky(d, concentration)) # Lower cholesky factor of the covariance matrix sigma = jnp.sqrt(theta) # we can also use a faster formula `L_Omega = sigma[..., None] * L_omega` L_Omega = jnp.matmul(jnp.diag(sigma), L_omega) # Vector of expectations mu = jnp.zeros(d) with numpyro.plate("observations", N): obs = numpyro.sample("obs", dist.MultivariateNormal(mu, scale_tril=L_Omega), obs=y) return obs :param int dimension: dimension of the matrices :param ndarray concentration: concentration/shape parameter of the distribution (often referred to as eta) :param str sample_method: Either "cvine" or "onion". Both methods are proposed in [1] and offer the same distribution over correlation matrices. But they are different in how to generate samples. Defaults to "onion". **References** [1] `Generating random correlation matrices based on vines and extended onion method`, Daniel Lewandowski, Dorota Kurowicka, Harry Joe """ arg_constraints = {"concentration": constraints.positive} reparametrized_params = ["concentration"] support = constraints.corr_cholesky pytree_data_fields = ("_beta", "concentration") pytree_aux_fields = ("dimension", "sample_method") def __init__( self, dimension: int, concentration: ArrayLike = 1.0, sample_method: Literal["onion", "cvine"] = "onion", *, validate_args: Optional[bool] = None, ) -> None: if dimension < 2: raise ValueError("Dimension must be greater than or equal to 2.") self.dimension = dimension self.concentration = concentration batch_shape = jnp.shape(concentration) event_shape = (dimension, dimension) # We construct base distributions to generate samples for each method. # The purpose of this base distribution is to generate a distribution for # correlation matrices which is proportional to `det(M)^{\eta - 1}`. # (note that this is not a unique way to define base distribution) # Both of the following methods have marginal distribution of each off-diagonal # element of sampled correlation matrices is Beta(eta + (D-2) / 2, eta + (D-2) / 2) # (up to a linear transform: x -> 2x - 1) Dm1 = self.dimension - 1 marginal_concentration = concentration + 0.5 * (self.dimension - 2) offset = 0.5 * jnp.arange(Dm1) if sample_method == "onion": # The following construction follows from the algorithm in Section 3.2 of [1]: # NB: in [1], the method for case k > 1 can also work for the case k = 1. beta_concentration0 = ( jnp.expand_dims(marginal_concentration, axis=-1) - offset ) beta_concentration1 = offset + 0.5 self._beta = Beta(beta_concentration1, beta_concentration0) elif sample_method == "cvine": # The following construction follows from the algorithm in Section 2.4 of [1]: # offset_tril is [0, 1, 1, 2, 2, 2,...] / 2 offset_tril = matrix_to_tril_vec(jnp.broadcast_to(offset, (Dm1, Dm1))) beta_concentration = ( jnp.expand_dims(marginal_concentration, axis=-1) - offset_tril ) self._beta = Beta(beta_concentration, beta_concentration) else: raise ValueError("`method` should be one of 'cvine' or 'onion'.") self.sample_method = sample_method super(LKJCholesky, self).__init__( batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args, ) def _cvine(self, key: jax.Array, size): # C-vine method first uses beta_dist to generate partial correlations, # then apply signed stick breaking to transform to cholesky factor. # Here is an attempt to prove that using signed stick breaking to # generate correlation matrices is the same as the C-vine method in [1] # for the entry r_32. # # With notations follow from [1], we define # p: partial correlation matrix, # c: cholesky factor, # r: correlation matrix. # From recursive formula (2) in [1], we have # r_32 = p_32 * sqrt{(1 - p_21^2)*(1 - p_31^2)} + p_21 * p_31 =: I # On the other hand, signed stick breaking process gives: # l_21 = p_21, l_31 = p_31, l_22 = sqrt(1 - p_21^2), l_32 = p_32 * sqrt(1 - p_31^2) # r_32 = l_21 * l_31 + l_22 * l_32 # = p_21 * p_31 + p_32 * sqrt{(1 - p_21^2)*(1 - p_31^2)} = I beta_sample = self._beta.sample(key, size) partial_correlation = 2 * beta_sample - 1 # scale to domain to (-1, 1) return signed_stick_breaking_tril(partial_correlation) def _onion(self, key: jax.Array, size): key_beta, key_normal = random.split(key) # Now we generate w term in Algorithm 3.2 of [1]. beta_sample = self._beta.sample(key_beta, size) # The following Normal distribution is used to create a uniform distribution on # a hypershere (ref: http://mathworld.wolfram.com/HyperspherePointPicking.html) normal_sample = random.normal( key_normal, shape=size + self.batch_shape + (self.dimension * (self.dimension - 1) // 2,), ) normal_sample = vec_to_tril_matrix(normal_sample, diagonal=0) u_hypershere = normal_sample / jnp.linalg.norm( normal_sample, axis=-1, keepdims=True ) w = jnp.expand_dims(jnp.sqrt(beta_sample), axis=-1) * u_hypershere # put w into the off-diagonal triangular part cholesky = jnp.zeros(size + self.batch_shape + self.event_shape) cholesky = cholesky.at[..., 1:, :-1].set(w) # correct the diagonal # NB: beta_sample = sum(w ** 2) because norm 2 of u is 1. diag = jnp.ones(cholesky.shape[:-1]).at[..., 1:].set(jnp.sqrt(1 - beta_sample)) return add_diag(cholesky, diag)
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: assert is_prng_key(key) if self.sample_method == "onion": return self._onion(key, sample_shape) else: return self._cvine(key, sample_shape)
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: # Note about computing Jacobian of the transformation from Cholesky factor to # correlation matrix: # # Assume C = L@Lt and L = (1 0 0; a \sqrt(1-a^2) 0; b c \sqrt(1-b^2-c^2)), we have # Then off-diagonal lower triangular vector of L is transformed to the off-diagonal # lower triangular vector of C by the transform: # (a, b, c) -> (a, b, ab + c\sqrt(1-a^2)) # Hence, Jacobian = 1 * 1 * \sqrt(1 - a^2) = \sqrt(1 - a^2) = L22, where L22 # is the 2th diagonal element of L # Generally, for a D dimensional matrix, we have: # Jacobian = L22^(D-2) * L33^(D-3) * ... * Ldd^0 # # From [1], we know that probability of a correlation matrix is proportional to # determinant ** (concentration - 1) = prod(L_ii ^ 2(concentration - 1)) # On the other hand, Jabobian of the transformation from Cholesky factor to # correlation matrix is: # prod(L_ii ^ (D - i)) # So the probability of a Cholesky factor is proportional to # prod(L_ii ^ (2 * concentration - 2 + D - i)) =: prod(L_ii ^ order_i) # with order_i = 2 * concentration - 2 + D - i, # i = 2..D (we omit the element i = 1 because L_11 = 1) # Compute `order` vector (note that we need to reindex i -> i-2): one_to_D = jnp.arange(1, self.dimension) order_offset = (3 - self.dimension) + one_to_D order = 2 * jnp.expand_dims(self.concentration, axis=-1) - order_offset # Compute unnormalized log_prob: value_diag = jnp.asarray(value)[..., one_to_D, one_to_D] unnormalized = jnp.sum(order * jnp.log(value_diag), axis=-1) # Compute normalization constant (on the first proof of page 1999 of [1]) Dm1 = self.dimension - 1 alpha = self.concentration + 0.5 * Dm1 denominator = gammaln(alpha) * Dm1 numerator = multigammaln(alpha - 0.5, Dm1) # pi_constant in [1] is D * (D - 1) / 4 * log(pi) # pi_constant in multigammaln is (D - 1) * (D - 2) / 4 * log(pi) # hence, we need to add a pi_constant = (D - 1) * log(pi) / 2 pi_constant = 0.5 * Dm1 * jnp.log(jnp.pi) normalize_term = pi_constant + numerator - denominator return unnormalized - normalize_term
[docs] class LogNormal(TransformedDistribution): arg_constraints = {"loc": constraints.real, "scale": constraints.positive} support = constraints.positive reparametrized_params = ["loc", "scale"] def __init__( self, loc: ArrayLike = 0.0, scale: ArrayLike = 1.0, *, validate_args: Optional[bool] = None, ) -> None: base_dist = Normal(loc, scale) self.loc, self.scale = base_dist.loc, base_dist.scale super(LogNormal, self).__init__( base_dist, ExpTransform(), validate_args=validate_args ) @property def mean(self) -> ArrayLike: return jnp.exp(self.loc + self.scale**2 / 2) @property def variance(self) -> ArrayLike: return (jnp.exp(self.scale**2) - 1) * jnp.exp(2 * self.loc + self.scale**2)
[docs] def entropy(self) -> ArrayLike: return (1 + jnp.log(2 * jnp.pi)) / 2 + self.loc + jnp.log(self.scale)
[docs] class Logistic(Distribution): r"""The Logistic distribution, a continuous real-valued distribution parameterized by location :math:`\mu` and scale :math:`s > 0`. Its CDF is the standard logistic (sigmoid) function shifted and scaled to :math:`\mu`, :math:`s`, which makes it the natural latent distribution underlying logistic regression. The Probability Density Function (PDF) is: .. math:: f(x \mid \mu, s) = \frac{ \exp\!\left(-\displaystyle\frac{x - \mu}{s}\right) }{ s \left(1 + \exp\!\left(-\displaystyle\frac{x - \mu}{s}\right)\right)^{2} }, \quad x \in \mathbb{R} where :math:`\mu \in \mathbb{R}` is the location (:attr:`loc`) and :math:`s > 0` is the scale (:attr:`scale`). """ arg_constraints = {"loc": constraints.real, "scale": constraints.positive} support = constraints.real reparametrized_params = ["loc", "scale"] def __init__( self, loc: ArrayLike = 0.0, scale: ArrayLike = 1.0, *, validate_args: Optional[bool] = None, ) -> None: r""" :param loc: Location parameter :math:`\mu \in \mathbb{R}`. Defaults to ``0.0``. :param scale: Scale parameter :math:`s > 0`. Defaults to ``1.0``. :param validate_args: If True, enforce domain constraints during initialization. """ self.loc, self.scale = promote_shapes(loc, scale) batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale)) super(Logistic, self).__init__(batch_shape, validate_args=validate_args)
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: r"""Draw samples via the location-scale transform :math:`X = \mu + s Z`, where :math:`Z \sim \mathrm{Logistic}(0, 1)` is drawn from :func:`~jax.random.logistic`. :param key: A JAX PRNG key. :param sample_shape: Sample dimensions to prepend to the batch shape. :return: Real-valued samples from the Logistic distribution. """ assert is_prng_key(key) z = random.logistic( key, shape=sample_shape + self.batch_shape + self.event_shape ) return self.loc + z * self.scale
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: r"""Evaluate the log probability density function at ``value``. Letting :math:`u = (\mu - x)/s`, the log PDF is .. math:: \ln f(x \mid \mu, s) = u - \ln s - 2 \ln(1 + e^{u}) The implementation uses :func:`~jax.nn.softplus` for :math:`\ln(1 + e^{u})`, which is numerically stable for both large positive and large negative values of :math:`u`. :param value: Real-valued point :math:`x` at which to evaluate the log PDF. :return: Log probability density evaluated under the Logistic distribution. """ log_exponent = (self.loc - value) / self.scale log_denominator = jnp.log(self.scale) + 2 * nn.softplus(log_exponent) return log_exponent - log_denominator
@property def mean(self) -> ArrayLike: r"""Mean of the Logistic distribution: .. math:: \mathbb{E}[X] = \mu """ return jnp.broadcast_to(self.loc, self.batch_shape) @property def variance(self) -> ArrayLike: r"""Variance of the Logistic distribution: .. math:: \mathrm{Var}(X) = \frac{\pi^2 s^2}{3} """ var = (self.scale**2) * (jnp.pi**2) / 3 return jnp.broadcast_to(var, self.batch_shape)
[docs] def cdf(self, value: ArrayLike) -> ArrayLike: r"""Cumulative Distribution Function (CDF) of the Logistic distribution. Letting :math:`z = (x - \mu)/s`, .. math:: F(x \mid \mu, s) = \sigma(z) = \frac{1}{1 + e^{-z}} where :math:`\sigma` is the logistic sigmoid, computed via :func:`~jax.scipy.special.expit`. :param value: Real-valued point :math:`x` at which to evaluate the CDF. :return: CDF values in :math:`[0, 1]`. """ scaled = (value - self.loc) / self.scale return expit(scaled)
[docs] def icdf(self, q: ArrayLike) -> ArrayLike: r"""Inverse CDF (quantile function) of the Logistic distribution: .. math:: F^{-1}(q \mid \mu, s) = \mu + s\,\operatorname{logit}(q), \quad q \in [0, 1] where :math:`\operatorname{logit}(q) = \ln(q / (1 - q))`. :param q: Quantile values in :math:`[0, 1]`. :return: Real-valued quantiles of the Logistic distribution at ``q``. """ return self.loc + self.scale * logit(q)
[docs] def entropy(self) -> ArrayLike: r"""Differential entropy of the Logistic distribution: .. math:: H(X) = \ln(s) + 2 """ return jnp.broadcast_to(jnp.log(self.scale) + 2, self.batch_shape)
[docs] class LogUniform(TransformedDistribution): arg_constraints = {"low": constraints.positive, "high": constraints.positive} reparametrized_params = ["low", "high"] pytree_data_fields = ("low", "high", "_support") def __init__( self, low: ArrayLike, high: ArrayLike, *, validate_args: Optional[bool] = None, ) -> None: base_dist = Uniform(jnp.log(low), jnp.log(high)) self.low, self.high = promote_shapes(low, high) self._support = constraints.interval(self.low, self.high) super(LogUniform, self).__init__( base_dist, ExpTransform(), validate_args=validate_args ) @constraints.dependent_property(is_discrete=False, event_dim=0) def support(self) -> constraints.Constraint: return self._support @property def mean(self) -> ArrayLike: return (self.high - self.low) / jnp.log(self.high / self.low) @property def variance(self) -> ArrayLike: return ( 0.5 * (self.high**2 - self.low**2) / jnp.log(self.high / self.low) - self.mean**2 )
[docs] def entropy(self) -> ArrayLike: log_low = jnp.log(self.low) log_high = jnp.log(self.high) return (log_low + log_high) / 2 + jnp.log(log_high - log_low)
def _batch_solve_triangular(A, B): """ Extende solve_triangular for the case that B.ndim > A.ndim. This is achieved by first flattening the leading B.ndim - A.ndim dimensions of B and then moving the first dimension to the end. :param jnp.ndarray (...,M,M) A: An array with lower triangular structure in the last two dimensions. :param jnp.ndarray (...,M,N) B: Right-hand side matrix in A x = B. :return: Solution of A x = B. """ event_shape = B.shape[-2:] batch_shape = lax.broadcast_shapes(A.shape[:-2], B.shape[-A.ndim : -2]) sample_shape = B.shape[: -A.ndim] n, p = event_shape A = jnp.broadcast_to(A, batch_shape + A.shape[-2:]) B = jnp.broadcast_to(B, sample_shape + batch_shape + event_shape) B_flat = jnp.moveaxis(B.reshape((-1,) + batch_shape + event_shape), 0, -2).reshape( batch_shape + (n,) + (-1,) ) X_flat = solve_triangular(A, B_flat, lower=True) sample_shape_dim = len(sample_shape) src_axes = tuple([-2 - i for i in range(sample_shape_dim)]) src_axes = src_axes[::-1] dest_axes = tuple([i for i in range(sample_shape_dim)]) X = jnp.moveaxis( X_flat.reshape(batch_shape + (n,) + sample_shape + (p,)), src_axes, dest_axes, ) return X def _batch_trace_from_cholesky(L): """Computes the trace of matrix X given it's Cholesky decomposition matrix L. :param jnp.ndarray(..., M, M) L: An array with lower triangular structure in the last two dimensions. :return: Trace of X, where X = L L^T """ return jnp.square(L).sum((-1, -2))
[docs] class MatrixNormal(Distribution): """ Matrix variate normal distribution as described in [1] but with a lower_triangular parametrization, i.e. :math:`U=scale_tril_row @ scale_tril_row^{T}` and :math:`V=scale_tril_column @ scale_tril_column^{T}`. The distribution is related to the multivariate normal distribution in the following way. If :math:`X ~ MN(loc,U,V)` then :math:`vec(X) ~ MVN(vec(loc), kron(V,U) )`. :param array_like loc: Location of the distribution. :param array_like scale_tril_row: Lower cholesky of rows covariance matrix. :param array_like scale_tril_column: Lower cholesky of columns covariance matrix. **References** [1] https://en.wikipedia.org/wiki/Matrix_normal_distribution """ arg_constraints = { "loc": constraints.real_vector, "scale_tril_row": constraints.lower_cholesky, "scale_tril_column": constraints.lower_cholesky, } support = constraints.real_matrix reparametrized_params = [ "loc", "scale_tril_row", "scale_tril_column", ] def __init__( self, loc: Array, scale_tril_row: Array, scale_tril_column: Array, *, validate_args: Optional[bool] = None, ) -> None: event_shape = loc.shape[-2:] batch_shape = lax.broadcast_shapes( jnp.shape(loc)[:-2], jnp.shape(scale_tril_row)[:-2], jnp.shape(scale_tril_column)[:-2], ) (self.loc,) = promote_shapes(loc, shape=batch_shape + loc.shape[-2:]) (self.scale_tril_row,) = promote_shapes( scale_tril_row, shape=batch_shape + scale_tril_row.shape[-2:] ) (self.scale_tril_column,) = promote_shapes( scale_tril_column, shape=batch_shape + scale_tril_column.shape[-2:] ) super(MatrixNormal, self).__init__( batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args, ) @property def mean(self) -> ArrayLike: return jnp.broadcast_to(self.loc, self.shape())
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: eps = random.normal( key, shape=sample_shape + self.batch_shape + self.event_shape ) samples = self.loc + self.scale_tril_row @ eps @ jnp.swapaxes( self.scale_tril_column, -2, -1 ) return samples
[docs] @validate_sample def log_prob(self, values): n, p = self.event_shape row_log_det = tri_logabsdet(self.scale_tril_row) col_log_det = tri_logabsdet(self.scale_tril_column) log_det_term = ( p * row_log_det + n * col_log_det + 0.5 * n * p * jnp.log(2 * jnp.pi) ) # compute the trace term diff = values - self.loc diff_row_solve = _batch_solve_triangular(A=self.scale_tril_row, B=diff) diff_col_solve = _batch_solve_triangular( A=self.scale_tril_column, B=jnp.swapaxes(diff_row_solve, -2, -1) ) batched_trace_term = _batch_trace_from_cholesky(diff_col_solve) log_prob = -0.5 * batched_trace_term - log_det_term return log_prob
def _batch_mahalanobis(bL, bx): if bL.shape[:-1] == bx.shape: # no need to use the below optimization procedure solve_bL_bx = solve_triangular(bL, bx[..., None], lower=True).squeeze(-1) return jnp.sum(jnp.square(solve_bL_bx), -1) # NB: The following procedure handles the case: bL.shape = (i, 1, n, n), bx.shape = (i, j, n) # because we don't want to broadcast bL to the shape (i, j, n, n). # Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n), # we are going to make bx have shape (..., 1, j, i, 1, n) to apply batched tril_solve sample_ndim = bx.ndim - bL.ndim + 1 # size of sample_shape out_shape = jnp.shape(bx)[:-1] # shape of output # Reshape bx with the shape (..., 1, i, j, 1, n) bx_new_shape = out_shape[:sample_ndim] for sL, sx in zip(bL.shape[:-2], out_shape[sample_ndim:]): bx_new_shape += (sx // sL, sL) bx_new_shape += (-1,) bx = jnp.reshape(bx, bx_new_shape) # Permute bx to make it have shape (..., 1, j, i, 1, n) permute_dims = ( tuple(range(sample_ndim)) + tuple(range(sample_ndim, bx.ndim - 1, 2)) + tuple(range(sample_ndim + 1, bx.ndim - 1, 2)) + (bx.ndim - 1,) ) bx = jnp.transpose(bx, permute_dims) # reshape to (-1, i, 1, n) xt = jnp.reshape(bx, (-1,) + bL.shape[:-1]) # permute to (i, 1, n, -1) xt = jnp.moveaxis(xt, 0, -1) solve_bL_bx = solve_triangular(bL, xt, lower=True) # shape: (i, 1, n, -1) M = jnp.sum(solve_bL_bx**2, axis=-2) # shape: (i, 1, -1) # permute back to (-1, i, 1) M = jnp.moveaxis(M, -1, 0) # reshape back to (..., 1, j, i, 1) M = jnp.reshape(M, bx.shape[:-1]) # permute back to (..., 1, i, j, 1) permute_inv_dims = tuple(range(sample_ndim)) for i in range(bL.ndim - 2): permute_inv_dims += (sample_ndim + i, len(out_shape) + i) M = jnp.transpose(M, permute_inv_dims) return jnp.reshape(M, out_shape)
[docs] class MultivariateNormal(Distribution): arg_constraints = { "loc": constraints.real_vector, "covariance_matrix": constraints.positive_definite, "precision_matrix": constraints.positive_definite, "scale_tril": constraints.lower_cholesky, } support = constraints.real_vector reparametrized_params = [ "loc", "covariance_matrix", "precision_matrix", "scale_tril", ] def __init__( self, loc: ArrayLike = 0.0, covariance_matrix: Optional[Array] = None, precision_matrix: Optional[Array] = None, scale_tril: Optional[Array] = None, *, validate_args: Optional[bool] = None, ) -> None: assert_one_of( covariance_matrix=covariance_matrix, precision_matrix=precision_matrix, scale_tril=scale_tril, ) if jnp.ndim(loc) == 0: (loc,) = promote_shapes(loc, shape=(1,)) # temporary append a new axis to loc loc = loc[..., jnp.newaxis] if covariance_matrix is not None: loc, self.covariance_matrix = promote_shapes(loc, covariance_matrix) self.scale_tril = jnp.linalg.cholesky(self.covariance_matrix) elif precision_matrix is not None: loc, self.precision_matrix = promote_shapes(loc, precision_matrix) self.scale_tril = cholesky_of_inverse(self.precision_matrix) elif scale_tril is not None: loc, self.scale_tril = promote_shapes(loc, scale_tril) batch_shape = lax.broadcast_shapes( jnp.shape(loc)[:-2], jnp.shape(self.scale_tril)[:-2] ) event_shape = jnp.shape(self.scale_tril)[-1:] self.loc = loc[..., 0] super(MultivariateNormal, self).__init__( batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args, )
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: assert is_prng_key(key) eps = random.normal( key, shape=sample_shape + self.batch_shape + self.event_shape ) return self.loc + jnp.squeeze( jnp.matmul(self.scale_tril, eps[..., jnp.newaxis]), axis=-1 )
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: M = _batch_mahalanobis(self.scale_tril, value - self.loc) half_log_det = tri_logabsdet(self.scale_tril) normalize_term = half_log_det + 0.5 * self.scale_tril.shape[-1] * jnp.log( 2 * jnp.pi ) return -0.5 * M - normalize_term
[docs] @lazy_property def covariance_matrix(self): return jnp.matmul(self.scale_tril, jnp.swapaxes(self.scale_tril, -1, -2))
[docs] @lazy_property def precision_matrix(self): identity = jnp.broadcast_to( jnp.eye(self.scale_tril.shape[-1]), self.scale_tril.shape ) return cho_solve((self.scale_tril, True), identity)
@property def mean(self) -> ArrayLike: return jnp.broadcast_to(self.loc, self.shape()) @property def variance(self) -> ArrayLike: return jnp.broadcast_to( jnp.sum(self.scale_tril**2, axis=-1), self.batch_shape + self.event_shape )
[docs] @staticmethod def infer_shapes( loc=(), covariance_matrix=None, precision_matrix=None, scale_tril=None ): assert_one_of( covariance_matrix=covariance_matrix, precision_matrix=precision_matrix, scale_tril=scale_tril, ) batch_shape, event_shape = loc[:-1], loc[-1:] for matrix in [covariance_matrix, precision_matrix, scale_tril]: if matrix is not None: batch_shape = lax.broadcast_shapes(batch_shape, matrix[:-2]) event_shape = lax.broadcast_shapes(event_shape, matrix[-1:]) return batch_shape, event_shape
[docs] def entropy(self) -> ArrayLike: (n,) = self.event_shape half_log_det = tri_logabsdet(self.scale_tril) return n * (jnp.log(2 * np.pi) + 1) / 2 + half_log_det
def _is_sparse(A): from scipy import sparse return sparse.issparse(A) def _to_sparse(A): from scipy import sparse return sparse.csr_matrix(A)
[docs] class CAR(Distribution): r""" The Conditional Autoregressive (CAR) distribution is a special case of the multivariate normal in which the precision matrix is structured according to the adjacency matrix of sites. The amount of autocorrelation between sites is controlled by ``correlation``. The distribution is a popular prior for areal spatial data. :param float or ndarray loc: mean of the multivariate normal :param float correlation: autoregression parameter. For most cases, the value should lie between 0 (sites are independent, collapses to an iid multivariate normal) and 1 (perfect autocorrelation between sites), but the specification allows for negative correlations. :param float conditional_precision: positive precision for the multivariate normal :param ndarray or scipy.sparse.csr_matrix adj_matrix: symmetric adjacency matrix where 1 indicates adjacency between sites and 0 otherwise. :class:`jax.numpy.ndarray` ``adj_matrix`` is supported but is **not** recommended over :class:`numpy.ndarray` or :class:`scipy.sparse.spmatrix`. :param bool is_sparse: whether to use a sparse form of ``adj_matrix`` in calculations (must be True if ``adj_matrix`` is a :class:`scipy.sparse.spmatrix`) """ arg_constraints = { "loc": constraints.real_vector, "correlation": constraints.open_interval(-1, 1), "conditional_precision": constraints.positive, "adj_matrix": constraints.dependent(is_discrete=False, event_dim=2), } support = constraints.real_vector reparametrized_params = [ "loc", "correlation", "conditional_precision", "adj_matrix", ] pytree_aux_fields = ("is_sparse", "adj_matrix") def __init__( self, loc: ArrayLike, correlation: Array, conditional_precision: Array, adj_matrix: Array, *, is_sparse: bool = False, validate_args: Optional[bool] = None, ) -> None: if jnp.ndim(loc) == 0: (loc,) = promote_shapes(loc, shape=(1,)) self.is_sparse = is_sparse batch_shape = lax.broadcast_shapes( jnp.shape(loc)[:-1], jnp.shape(correlation), jnp.shape(conditional_precision), jnp.shape(adj_matrix)[:-2], ) if self.is_sparse: if adj_matrix.ndim != 2: raise ValueError( "Currently, we only support 2-dimensional adj_matrix. Please make a feature request", " if you need higher dimensional adj_matrix.", ) if not (isinstance(adj_matrix, np.ndarray) or _is_sparse(adj_matrix)): raise ValueError( "adj_matrix needs to be a numpy array or a scipy sparse matrix. Please make a feature", " request if you need to support jax ndarrays.", ) # TODO: look into future jax sparse csr functionality and other developments self.adj_matrix = _to_sparse(adj_matrix) else: assert not _is_sparse(adj_matrix), ( "adj_matrix is a sparse matrix so please specify `is_sparse=True`." ) # TODO: look into static jax ndarray representation (self.adj_matrix,) = promote_shapes( adj_matrix, shape=batch_shape + adj_matrix.shape[-2:] ) event_shape = jnp.shape(self.adj_matrix)[-1:] (self.loc,) = promote_shapes(loc, shape=batch_shape + event_shape) self.correlation, self.conditional_precision = promote_shapes( correlation, conditional_precision, shape=batch_shape ) super(CAR, self).__init__( batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args, ) if self._validate_args and (isinstance(adj_matrix, np.ndarray) or is_sparse): assert (self.adj_matrix.sum(axis=-1) > 0).all() > 0, ( "all sites in adjacency matrix must have neighbours" ) if self.is_sparse: assert (self.adj_matrix != self.adj_matrix.T).nnz == 0, ( "adjacency matrix must be symmetric" ) else: assert np.array_equal( self.adj_matrix, np.swapaxes(self.adj_matrix, -2, -1) ), "adjacency matrix must be symmetric"
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: # TODO: look into a sparse sampling method mvn = MultivariateNormal(self.mean, precision_matrix=self.precision_matrix) return mvn.sample(key, sample_shape=sample_shape)
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: phi = value - self.loc adj_matrix = self.adj_matrix if self.is_sparse: D = np.asarray(adj_matrix.sum(axis=-1)).squeeze(axis=-1) D_rsqrt = D ** (-0.5) adj_matrix_scaled = ( adj_matrix.multiply(D_rsqrt).multiply(D_rsqrt[:, np.newaxis]).toarray() ) adj_matrix = BCOO.from_scipy_sparse(adj_matrix) else: D = adj_matrix.sum(axis=-1) D_rsqrt = D ** (-0.5) adj_matrix_scaled = adj_matrix * ( D_rsqrt[..., None, :] * D_rsqrt[..., None] ) # TODO: look into sparse eigenvalue methods if isinstance(adj_matrix_scaled, np.ndarray): lam = np.linalg.eigvalsh(adj_matrix_scaled) else: lam = jnp.linalg.eigvalsh(adj_matrix_scaled) n = D.shape[-1] logprec = n * jnp.log(self.conditional_precision) logdet = jnp.log1p(-jnp.expand_dims(self.correlation, -1) * lam).sum(-1) logdet = logdet + jnp.log(D).sum(-1) logquad = self.conditional_precision * jnp.sum( phi * ( D * phi - jnp.expand_dims(self.correlation, -1) * (adj_matrix @ phi[..., jnp.newaxis]).squeeze(axis=-1) ), -1, ) return 0.5 * (-n * jnp.log(2 * jnp.pi) + logprec + logdet - logquad)
@property def mean(self) -> ArrayLike: return jnp.broadcast_to(self.loc, self.shape())
[docs] @lazy_property def precision_matrix(self): if self.is_sparse: adj_matrix = self.adj_matrix.toarray() else: adj_matrix = self.adj_matrix D = adj_matrix.sum(axis=-1, keepdims=True) * jnp.eye(adj_matrix.shape[-1]) conditional_precision = jnp.expand_dims(self.conditional_precision, (-2, -1)) correlation = jnp.expand_dims(self.correlation, (-2, -1)) return conditional_precision * (D - correlation * adj_matrix)
[docs] @staticmethod def infer_shapes(loc, correlation, conditional_precision, adj_matrix): event_shape = adj_matrix[-1:] batch_shape = lax.broadcast_shapes( loc[:-1], correlation, conditional_precision, adj_matrix[:-2] ) return batch_shape, event_shape
[docs] def tree_flatten(self): data, aux = super().tree_flatten() adj_matrix_data_idx = type(self).gather_pytree_data_fields().index("adj_matrix") adj_matrix_aux_idx = type(self).gather_pytree_aux_fields().index("adj_matrix") if not self.is_sparse: aux = list(aux) aux[adj_matrix_aux_idx] = None aux = tuple(aux) else: data = list(data) data[adj_matrix_data_idx] = None data = tuple(data) return data, aux
[docs] @classmethod def tree_unflatten(cls, aux_data, params): d = super().tree_unflatten(aux_data, params) if not d.is_sparse: adj_matrix_data_idx = cls.gather_pytree_data_fields().index("adj_matrix") setattr(d, "adj_matrix", params[adj_matrix_data_idx]) else: adj_matrix_aux_idx = cls.gather_pytree_aux_fields().index("adj_matrix") setattr(d, "adj_matrix", aux_data[adj_matrix_aux_idx]) return d
[docs] class MultivariateStudentT(Distribution): arg_constraints = { "df": constraints.positive, "loc": constraints.real_vector, "scale_tril": constraints.lower_cholesky, } support = constraints.real_vector reparametrized_params = ["df", "loc", "scale_tril"] pytree_data_fields = ("df", "loc", "scale_tril", "_chi2") def __init__( self, df: ArrayLike, loc: ArrayLike = 0.0, scale_tril: Optional[Array] = None, *, validate_args: Optional[bool] = None, ) -> None: if jnp.ndim(loc) == 0: (loc,) = promote_shapes(loc, shape=(1,)) batch_shape = lax.broadcast_shapes( jnp.shape(df), jnp.shape(loc)[:-1], jnp.shape(scale_tril)[:-2] ) (self.df,) = promote_shapes(df, shape=batch_shape) (self.loc,) = promote_shapes(loc, shape=batch_shape + loc.shape[-1:]) (self.scale_tril,) = promote_shapes( scale_tril, shape=batch_shape + scale_tril.shape[-2:] ) event_shape = jnp.shape(self.scale_tril)[-1:] self._chi2 = Chi2(self.df) super(MultivariateStudentT, self).__init__( batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args, )
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: assert is_prng_key(key) key_normal, key_chi2 = random.split(key) std_normal = random.normal( key_normal, shape=sample_shape + self.batch_shape + self.event_shape, ) z = self._chi2.sample(key_chi2, sample_shape) y = std_normal * jnp.expand_dims(jnp.sqrt(self.df / z), -1) return self.loc + jnp.squeeze( jnp.matmul(self.scale_tril, y[..., jnp.newaxis]), axis=-1 )
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: n = self.scale_tril.shape[-1] Z = ( tri_logabsdet(self.scale_tril) + 0.5 * n * jnp.log(self.df) + 0.5 * n * jnp.log(jnp.pi) + gammaln(0.5 * self.df) - gammaln(0.5 * (self.df + n)) ) M = _batch_mahalanobis(self.scale_tril, value - self.loc) return -0.5 * (self.df + n) * jnp.log1p(M / self.df) - Z
[docs] @lazy_property def covariance_matrix(self): # NB: this is not covariance of this distribution; # the actual covariance is df / (df - 2) * covariance_matrix return jnp.matmul(self.scale_tril, jnp.swapaxes(self.scale_tril, -1, -2))
[docs] @lazy_property def precision_matrix(self) -> Array: identity = jnp.broadcast_to( jnp.eye(self.scale_tril.shape[-1]), self.scale_tril.shape ) return cho_solve((self.scale_tril, True), identity)
@property def mean(self) -> ArrayLike: # for df <= 1. should be jnp.nan (keeping jnp.inf for consistency with scipy) return jnp.broadcast_to( jnp.where(jnp.expand_dims(self.df, -1) <= 1, jnp.inf, self.loc), self.shape(), ) @property def variance(self) -> ArrayLike: df = jnp.expand_dims(self.df, -1) var = jnp.power(self.scale_tril, 2).sum(-1) * (df / (df - 2)) var = jnp.where(df > 2, var, jnp.inf) var = jnp.where(df <= 1, jnp.nan, var) return jnp.broadcast_to(var, self.batch_shape + self.event_shape)
[docs] @staticmethod def infer_shapes(df, loc, scale_tril): event_shape = (scale_tril[-1],) batch_shape = lax.broadcast_shapes(df, loc[:-1], scale_tril[:-2]) return batch_shape, event_shape
def _batch_mv(bmat: Array, bvec: Array) -> Array: r""" Performs a batched matrix-vector product, with compatible but different batch shapes. This function takes as input `bmat`, containing :math:`n \times n` matrices, and `bvec`, containing length :math:`n` vectors. Both `bmat` and `bvec` may have any number of leading dimensions, which correspond to a batch shape. They are not necessarily assumed to have the same batch shape, just ones which can be broadcasted. """ return jnp.squeeze(jnp.matmul(bmat, jnp.expand_dims(bvec, axis=-1)), axis=-1) def _batch_capacitance_tril(W: Array, D: Array) -> Array: r""" Computes Cholesky of :math:`I + W.T @ inv(D) @ W` for a batch of matrices :math:`W` and a batch of vectors :math:`D`. """ Wt_Dinv = jnp.swapaxes(W, -1, -2) / jnp.expand_dims(D, -2) K = jnp.matmul(Wt_Dinv, W) # could be inefficient return jnp.linalg.cholesky(add_diag(K, 1)) def _batch_lowrank_logdet(W: Array, D: Array, capacitance_tril: Array) -> Array: r""" Uses "matrix determinant lemma":: log|W @ W.T + D| = log|C| + log|D|, where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute the log determinant. """ return 2 * tri_logabsdet(capacitance_tril) + jnp.log(D).sum(-1) def _batch_lowrank_mahalanobis( W: Array, D: Array, x: Array, capacitance_tril: Array ) -> Array: r""" Uses "Woodbury matrix identity":: inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D), where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute the squared Mahalanobis distance :math:`x.T @ inv(W @ W.T + D) @ x`. """ Wt_Dinv = jnp.swapaxes(W, -1, -2) / jnp.expand_dims(D, -2) Wt_Dinv_x = _batch_mv(Wt_Dinv, x) mahalanobis_term1 = jnp.sum(jnp.square(x) / D, axis=-1) mahalanobis_term2 = _batch_mahalanobis(capacitance_tril, Wt_Dinv_x) return mahalanobis_term1 - mahalanobis_term2
[docs] class LowRankMultivariateNormal(Distribution): arg_constraints = { "loc": constraints.real_vector, "cov_factor": constraints.independent(constraints.real, 2), "cov_diag": constraints.independent(constraints.positive, 1), } support = constraints.real_vector reparametrized_params = ["loc", "cov_factor", "cov_diag"] pytree_data_fields = ("loc", "cov_factor", "cov_diag", "_capacitance_tril") def __init__( self, loc: Array, cov_factor: Array, cov_diag: Array, *, validate_args: Optional[bool] = None, ) -> None: if jnp.ndim(loc) < 1: raise ValueError("`loc` must be at least one-dimensional.") event_shape = jnp.shape(loc)[-1:] if jnp.ndim(cov_factor) < 2: raise ValueError( "`cov_factor` must be at least two-dimensional, " "with optional leading batch dimensions" ) if jnp.shape(cov_factor)[-2:-1] != event_shape: raise ValueError( "`cov_factor` must be a batch of matrices with shape {} x m".format( event_shape[0] ) ) if jnp.shape(cov_diag)[-1:] != event_shape: raise ValueError( "`cov_diag` must be a batch of vectors with shape {}".format( self.event_shape ) ) loc, cov_factor, cov_diag = promote_shapes( loc[..., jnp.newaxis], cov_factor, cov_diag[..., jnp.newaxis] ) batch_shape = lax.broadcast_shapes( jnp.shape(loc), jnp.shape(cov_factor), jnp.shape(cov_diag) )[:-2] self.loc = loc[..., 0] self.cov_factor = cov_factor cov_diag = cov_diag[..., 0] self.cov_diag = cov_diag self._capacitance_tril = _batch_capacitance_tril(cov_factor, cov_diag) super(LowRankMultivariateNormal, self).__init__( batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args, ) @property def mean(self) -> Array: return self.loc
[docs] @lazy_property def variance(self) -> Array: raw_variance = jnp.square(self.cov_factor).sum(-1) + self.cov_diag return jnp.broadcast_to(raw_variance, self.batch_shape + self.event_shape)
[docs] @lazy_property def scale_tril(self) -> Array: # The following identity is used to increase the numerically computation stability # for Cholesky decomposition (see http://www.gaussianprocess.org/gpml/, Section 3.4.3): # W @ W.T + D = D1/2 @ (I + D-1/2 @ W @ W.T @ D-1/2) @ D1/2 # The matrix "I + D-1/2 @ W @ W.T @ D-1/2" has eigenvalues bounded from below by 1, # hence it is well-conditioned and safe to take Cholesky decomposition. cov_diag_sqrt_unsqueeze = jnp.expand_dims(jnp.sqrt(self.cov_diag), axis=-1) Dinvsqrt_W = self.cov_factor / cov_diag_sqrt_unsqueeze K = jnp.matmul(Dinvsqrt_W, jnp.swapaxes(Dinvsqrt_W, -1, -2)) K = add_diag(K, 1) scale_tril = cov_diag_sqrt_unsqueeze * jnp.linalg.cholesky(K) return scale_tril
[docs] @lazy_property def covariance_matrix(self) -> Array: covariance_matrix = add_diag( jnp.matmul(self.cov_factor, jnp.swapaxes(self.cov_factor, -1, -2)), self.cov_diag, ) return covariance_matrix
[docs] @lazy_property def precision_matrix(self) -> Array: # We use "Woodbury matrix identity" to take advantage of low rank form:: # inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D) # where :math:`C` is the capacitance matrix. Wt_Dinv = jnp.swapaxes(self.cov_factor, -1, -2) / jnp.expand_dims( self.cov_diag, axis=-2 ) A = solve_triangular(Wt_Dinv, self._capacitance_tril, lower=True) inverse_cov_diag = jnp.reciprocal(self.cov_diag) return add_diag(-jnp.matmul(jnp.swapaxes(A, -1, -2), A), inverse_cov_diag)
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: assert is_prng_key(key) key_W, key_D = random.split(key) batch_shape = sample_shape + self.batch_shape W_shape = batch_shape + self.cov_factor.shape[-1:] D_shape = batch_shape + self.cov_diag.shape[-1:] eps_W = random.normal(key_W, W_shape) eps_D = random.normal(key_D, D_shape) return ( self.loc + _batch_mv(self.cov_factor, eps_W) + jnp.sqrt(self.cov_diag) * eps_D )
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: diff = value - self.loc M = _batch_lowrank_mahalanobis( self.cov_factor, self.cov_diag, diff, self._capacitance_tril ) log_det = _batch_lowrank_logdet( self.cov_factor, self.cov_diag, self._capacitance_tril ) return -0.5 * (self.loc.shape[-1] * jnp.log(2 * jnp.pi) + log_det + M)
[docs] def entropy(self) -> ArrayLike: log_det = _batch_lowrank_logdet( self.cov_factor, self.cov_diag, self._capacitance_tril ) H = 0.5 * (self.loc.shape[-1] * (1.0 + jnp.log(2 * jnp.pi)) + log_det) return jnp.broadcast_to(H, self.batch_shape)
[docs] @staticmethod def infer_shapes(loc, cov_factor, cov_diag): event_shape = loc[-1:] batch_shape = lax.broadcast_shapes(loc[:-1], cov_factor[:-2], cov_diag[:-1]) return batch_shape, event_shape
[docs] class Normal(Distribution): r"""Normal (Gaussian) distribution parameterized by mean (:attr:`loc`) and standard deviation (:attr:`scale`). The probability density function (PDF) is defined as: .. math:: f(x; \mu, \sigma) = \frac{1}{\sigma \sqrt{2\pi}} \exp\!\left( -\frac{(x - \mu)^2}{2\sigma^2} \right) where :math:`x \in \mathbb{R}`, :math:`\mu \in \mathbb{R}` is the mean, and :math:`\sigma > 0` is the standard deviation. :param loc: Mean of the distribution (:math:`\mu`). :type loc: ArrayLike :param scale: Standard deviation of the distribution (:math:`\sigma`). :type scale: ArrayLike :param validate_args: Whether to validate input constraints, defaults to None. :type validate_args: bool, optional """ arg_constraints = {"loc": constraints.real, "scale": constraints.positive} support = constraints.real reparametrized_params = ["loc", "scale"] def __init__( self, loc: ArrayLike = 0.0, scale: ArrayLike = 1.0, *, validate_args: Optional[bool] = None, ) -> None: self.loc, self.scale = promote_shapes(loc, scale) batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale)) super(Normal, self).__init__( batch_shape=batch_shape, validate_args=validate_args )
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: r"""Generates samples via the reparameterization trick: :math:`X = \mu + \sigma \epsilon`, where :math:`\epsilon \sim \mathcal{N}(0,1)`. :param key: JAX PRNGKey for reproducibility. :type key: jax.Array :param sample_shape: The shape of the samples to be generated. :type sample_shape: tuple[int, ...] :return: Samples from the Normal distribution of shape ``sample_shape + batch_shape``. :rtype: ArrayLike """ assert is_prng_key(key) eps = random.normal( key, shape=sample_shape + self.batch_shape + self.event_shape ) return self.loc + eps * self.scale
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: r"""Calculates the log of the probability density function. .. math:: \log f(x; \mu, \sigma) = -\frac{(x - \mu)^2}{2\sigma^2} - \log(\sigma \sqrt{2\pi}) :param value: Values at which to evaluate the log density. :type value: ArrayLike :return: Log probability density. :rtype: ArrayLike """ normalize_term = jnp.log(jnp.sqrt(2 * jnp.pi) * self.scale) value_scaled = (value - self.loc) / self.scale return -0.5 * value_scaled**2 - normalize_term
[docs] def cdf(self, value: ArrayLike) -> ArrayLike: r"""Cumulative distribution function. .. math:: F(x; \mu, \sigma) = \Phi\!\left(\frac{x-\mu}{\sigma}\right) where, :math:`\Phi` is the `cumulative distribution function of standard normal distribution <https://en.wikipedia.org/wiki/Normal_distribution#Cumulative_distribution_function>`_. Implementation uses :func:`jax.scipy.special.ndtr` for :math:`\Phi`. :param value: Value to evaluate. :type value: ArrayLike """ scaled = (value - self.loc) / self.scale return ndtr(scaled)
[docs] def log_cdf(self, value: ArrayLike) -> ArrayLike: r"""Log of the cumulative distribution function. Implementation calls :func:`jax.scipy.stats.norm.logcdf`. :param value: Value to evaluate. :type value: ArrayLike """ return jax_norm.logcdf(value, loc=self.loc, scale=self.scale)
[docs] def icdf(self, q: ArrayLike) -> ArrayLike: r"""Inverse cumulative distribution function (Quantile function). .. math:: F^{-1}(q; \mu, \sigma) = \mu + \sigma\,\Phi^{-1}(q) where, :math:`\mathrm{\Phi^{-1}}` is inverse `cumulative distribution function of standard normal distribution <https://en.wikipedia.org/wiki/Normal_distribution#Cumulative_distribution_function>`_. Implementation uses :func:`jax.scipy.special.ndtri` for :math:`\Phi^{-1}`. :param q: Probability value in :math:`[0,1]`. :type q: ArrayLike """ return self.loc + self.scale * ndtri(q)
@property def mean(self) -> ArrayLike: r"""Calculates the analytical mean. .. math:: E[X] = \mu """ return jnp.broadcast_to(self.loc, self.batch_shape) @property def variance(self) -> ArrayLike: r"""Calculates the analytical variance. .. math:: \mathrm{Var}(X) = \sigma^2 """ return jnp.broadcast_to(self.scale**2, self.batch_shape)
[docs] def entropy(self) -> ArrayLike: r"""Entropy of the Normal distribution. .. math:: H(X) = \frac{1}{2} \ln(2\pi e \sigma^2) """ return jnp.broadcast_to( (jnp.log(2 * np.pi * self.scale**2) + 1) / 2, self.batch_shape )
[docs] class Pareto(TransformedDistribution): arg_constraints = {"scale": constraints.positive, "alpha": constraints.positive} reparametrized_params = ["scale", "alpha"] def __init__( self, scale: ArrayLike, alpha: ArrayLike, *, validate_args: Optional[bool] = None, ) -> None: self.scale, self.alpha = promote_shapes(scale, alpha) batch_shape = lax.broadcast_shapes(jnp.shape(scale), jnp.shape(alpha)) scale, alpha = ( jnp.broadcast_to(scale, batch_shape), jnp.broadcast_to(alpha, batch_shape), ) base_dist = Exponential(alpha) transforms = [ExpTransform(), AffineTransform(loc=0, scale=scale)] super(Pareto, self).__init__(base_dist, transforms, validate_args=validate_args) @property def mean(self) -> ArrayLike: # mean is inf for alpha <= 1 a = jnp.divide(self.alpha * self.scale, (self.alpha - 1)) return jnp.where(self.alpha <= 1, jnp.inf, a) @property def variance(self) -> ArrayLike: # var is inf for alpha <= 2 a = jnp.divide( (self.scale**2) * self.alpha, (self.alpha - 1) ** 2 * (self.alpha - 2) ) return jnp.where(self.alpha <= 2, jnp.inf, a) # override the default behaviour to save computations @constraints.dependent_property(is_discrete=False, event_dim=0) def support(self) -> constraints.Constraint: return constraints.greater_than(self.scale)
[docs] def entropy(self) -> ArrayLike: return jnp.log(self.scale / self.alpha) + 1 + 1 / self.alpha
[docs] class RelaxedBernoulliLogits(TransformedDistribution): arg_constraints = {"temperature": constraints.positive, "logits": constraints.real} support = constraints.unit_interval def __init__( self, temperature: ArrayLike, logits: ArrayLike, *, validate_args: Optional[bool] = None, ) -> None: self.temperature, self.logits = promote_shapes(temperature, logits) base_dist = Logistic(logits / temperature, 1 / temperature) transforms = [SigmoidTransform()] super().__init__(base_dist, transforms, validate_args=validate_args)
[docs] def RelaxedBernoulli( temperature, probs=None, logits=None, *, validate_args: Optional[bool] = None ): if probs is None and logits is None: raise ValueError("One of `probs` or `logits` must be specified.") if probs is not None: logits = _to_logits_bernoulli(probs) return RelaxedBernoulliLogits(temperature, logits, validate_args=validate_args)
[docs] class SoftLaplace(Distribution): """ Smooth distribution with Laplace-like tail behavior. This distribution corresponds to the log-convex density:: z = (value - loc) / scale log_prob = log(2 / pi) - log(scale) - logaddexp(z, -z) Like the Laplace density, this density has the heaviest possible tails (asymptotically) while still being log-convex. Unlike the Laplace distribution, this distribution is infinitely differentiable everywhere, and is thus suitable for HMC and Laplace approximation. :param loc: Location parameter. :param scale: Scale parameter. """ arg_constraints = {"loc": constraints.real, "scale": constraints.positive} support = constraints.real reparametrized_params = ["loc", "scale"] def __init__( self, loc: ArrayLike, scale: ArrayLike, *, validate_args: Optional[bool] = None, ) -> None: self.loc, self.scale = promote_shapes(loc, scale) batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale)) super().__init__(batch_shape=batch_shape, validate_args=validate_args)
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: z = (value - self.loc) / self.scale return jnp.log(2 / jnp.pi) - jnp.log(self.scale) - jnp.logaddexp(z, -z)
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: assert is_prng_key(key) dtype = jnp.result_type(float) finfo = jnp.finfo(dtype) minval = finfo.tiny u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval) return self.icdf(u)
# TODO: refactor validate_sample to only does validation check and use it here
[docs] def cdf(self, value: ArrayLike) -> ArrayLike: z = (value - self.loc) / self.scale return jnp.arctan(jnp.exp(z)) * (2 / jnp.pi)
[docs] def icdf(self, value: ArrayLike) -> ArrayLike: return jnp.log(jnp.tan(value * (jnp.pi / 2))) * self.scale + self.loc
@property def mean(self) -> ArrayLike: return self.loc @property def variance(self) -> ArrayLike: return (jnp.pi / 2 * self.scale) ** 2
[docs] class StudentT(Distribution): arg_constraints = { "df": constraints.positive, "loc": constraints.real, "scale": constraints.positive, } support = constraints.real reparametrized_params = ["df", "loc", "scale"] pytree_data_fields = ("df", "loc", "scale", "_chi2") def __init__( self, df: ArrayLike, loc: ArrayLike = 0.0, scale: ArrayLike = 1.0, *, validate_args: Optional[bool] = None, ) -> None: batch_shape = lax.broadcast_shapes( jnp.shape(df), jnp.shape(loc), jnp.shape(scale) ) self.df, self.loc, self.scale = promote_shapes( df, loc, scale, shape=batch_shape ) df = jnp.broadcast_to(df, batch_shape) self._chi2 = Chi2(df) super(StudentT, self).__init__(batch_shape, validate_args=validate_args)
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: assert is_prng_key(key) key_normal, key_chi2 = random.split(key) std_normal = random.normal(key_normal, shape=sample_shape + self.batch_shape) z = self._chi2.sample(key_chi2, sample_shape) y = std_normal * jnp.sqrt(self.df / z) return self.loc + self.scale * y
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: y = (value - self.loc) / self.scale z = ( jnp.log(self.scale) + 0.5 * jnp.log(self.df) + 0.5 * jnp.log(jnp.pi) + gammaln(0.5 * self.df) - gammaln(0.5 * (self.df + 1.0)) ) return -0.5 * (self.df + 1.0) * jnp.log1p(y**2.0 / self.df) - z
@property def mean(self) -> ArrayLike: # for df <= 1. should be jnp.nan (keeping jnp.inf for consistency with scipy) return jnp.broadcast_to( jnp.where(self.df <= 1, jnp.inf, self.loc), self.batch_shape ) @property def variance(self) -> ArrayLike: var = jnp.where( self.df > 2, jnp.divide(self.scale**2 * self.df, self.df - 2.0), jnp.inf ) var = jnp.where(self.df <= 1, jnp.nan, var) return jnp.broadcast_to(var, self.batch_shape)
[docs] def cdf(self, value: ArrayLike) -> ArrayLike: # Ref: https://en.wikipedia.org/wiki/Student's_t-distribution#Related_distributions # X^2 ~ F(1, df) -> df / (df + X^2) ~ Beta(df/2, 0.5) scaled = (value - self.loc) / self.scale scaled_squared = scaled * scaled beta_value = self.df / (self.df + scaled_squared) # when scaled < 0, returns 0.5 * Beta(df/2, 0.5).cdf(beta_value) # when scaled > 0, returns 1 - 0.5 * Beta(df/2, 0.5).cdf(beta_value) return 0.5 * ( 1 + jnp.sign(scaled) - jnp.sign(scaled) * betainc(0.5 * self.df, 0.5, beta_value) )
[docs] def icdf(self, q: ArrayLike) -> ArrayLike: beta_value = betaincinv(0.5 * self.df, 0.5, 1 - jnp.abs(1 - 2 * q)) scaled_squared = self.df * (1 / beta_value - 1) scaled = jnp.sign(q - 0.5) * jnp.sqrt(scaled_squared) return scaled * self.scale + self.loc
[docs] def entropy(self) -> ArrayLike: return jnp.broadcast_to( (self.df + 1) / 2 * (digamma((self.df + 1) / 2) - digamma(self.df / 2)) + jnp.log(self.df) / 2 + betaln(self.df / 2, 0.5) + jnp.log(self.scale), self.batch_shape, )
[docs] class Uniform(Distribution): arg_constraints = { "low": constraints.dependent(is_discrete=False, event_dim=0), "high": constraints.dependent(is_discrete=False, event_dim=0), } reparametrized_params = ["low", "high"] pytree_data_fields = ("low", "high", "_support") def __init__( self, low: ArrayLike = 0.0, high: ArrayLike = 1.0, *, validate_args: Optional[bool] = None, ) -> None: self.low, self.high = promote_shapes(low, high) batch_shape = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high)) self._support = constraints.interval(low, high) super().__init__(batch_shape, validate_args=validate_args) @constraints.dependent_property(is_discrete=False, event_dim=0) def support(self) -> constraints.Constraint: return self._support
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: shape = sample_shape + self.batch_shape return random.uniform(key, shape=shape, minval=self.low, maxval=self.high)
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: shape = lax.broadcast_shapes(jnp.shape(value), self.batch_shape) return -jnp.broadcast_to(jnp.log(self.high - self.low), shape)
[docs] def cdf(self, value: ArrayLike) -> ArrayLike: cdf = (value - self.low) / (self.high - self.low) return jnp.clip(cdf, 0.0, 1.0)
[docs] def icdf(self, value: ArrayLike) -> ArrayLike: return self.low + value * (self.high - self.low)
@property def mean(self) -> ArrayLike: return self.low + (self.high - self.low) / 2.0 @property def variance(self) -> ArrayLike: return (self.high - self.low) ** 2 / 12.0
[docs] @staticmethod def infer_shapes( low: tuple[int, ...] = (), high: tuple[int, ...] = () ) -> tuple[tuple[int, ...], tuple[int, ...]]: batch_shape = lax.broadcast_shapes(low, high) event_shape: tuple[int, ...] = () return batch_shape, event_shape
[docs] def entropy(self) -> ArrayLike: return jnp.log(self.high - self.low)
[docs] class Weibull(Distribution): arg_constraints = { "scale": constraints.positive, "concentration": constraints.positive, } support = constraints.positive reparametrized_params = ["scale", "concentration"] def __init__( self, scale: ArrayLike, concentration: ArrayLike, *, validate_args: Optional[bool] = None, ) -> None: self.concentration, self.scale = promote_shapes(concentration, scale) batch_shape = lax.broadcast_shapes(jnp.shape(concentration), jnp.shape(scale)) super().__init__(batch_shape=batch_shape, validate_args=validate_args)
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: assert is_prng_key(key) return random.weibull_min( key, scale=self.scale, concentration=self.concentration, shape=sample_shape + self.batch_shape, )
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: ll = -jnp.power(value / self.scale, self.concentration) ll += jnp.log(self.concentration) ll += (self.concentration - 1.0) * jnp.log(value) ll -= self.concentration * jnp.log(self.scale) return ll
[docs] def cdf(self, value: ArrayLike) -> ArrayLike: return 1 - jnp.exp(-((value / self.scale) ** self.concentration))
@property def mean(self) -> ArrayLike: return self.scale * jnp.exp(gammaln(1.0 + 1.0 / self.concentration)) @property def variance(self) -> ArrayLike: return self.scale**2 * ( jnp.exp(gammaln(1.0 + 2.0 / self.concentration)) - jnp.exp(gammaln(1.0 + 1.0 / self.concentration)) ** 2 )
[docs] def entropy(self) -> ArrayLike: return ( jnp.euler_gamma * (1 - 1 / self.concentration) + jnp.log(self.scale / self.concentration) + 1 )
[docs] class BetaProportion(Beta): r"""Beta distribution reparameterized in terms of a mean (:attr:`mean`) and a precision (:attr:`concentration`). Given mean :math:`\mu` and precision :math:`\phi`, the standard Beta parameters are derived as: .. math:: \alpha = \mu \phi, \quad \beta = (1 - \mu) \phi The resulting PDF is: .. math:: f(x; \mu, \phi) = \frac{x^{\mu\phi - 1} (1 - x)^{(1 - \mu)\phi - 1}}{\text{B}(\mu\phi, (1 - \mu)\phi)} **Reference** Ferrari, Silvia, and Francisco Cribari-Neto. "Beta regression for modelling rates and proportions." *Journal of Applied Statistics* 31.7 (2004): 799-815. :param mean: Mean of the distribution, restricted to the open interval (0, 1). :type mean: ArrayLike :param concentration: Precision parameter (:math:`\phi`), must be positive. :type concentration: ArrayLike :param validate_args: Whether to validate input constraints, defaults to None. :type validate_args: bool, optional """ arg_constraints = { "mean": constraints.open_interval(0.0, 1.0), "concentration": constraints.positive, } reparametrized_params = ["mean", "concentration"] support = constraints.unit_interval pytree_data_fields = ("concentration",) def __init__( self, mean: ArrayLike, concentration: ArrayLike, *, validate_args: Optional[bool] = None, ) -> None: self.concentration = jnp.broadcast_to( concentration, lax.broadcast_shapes(jnp.shape(concentration)) ) super().__init__( mean * concentration, (1.0 - mean) * concentration, validate_args=validate_args, )
[docs] class AsymmetricLaplaceQuantile(Distribution): """An alternative parameterization of AsymmetricLaplace commonly applied in Bayesian quantile regression. Instead of the `asymmetry` parameter employed by AsymmetricLaplace, to define the balance between left- versus right-hand sides of the distribution, this class utilizes a `quantile` parameter, which describes the proportion of probability density that falls to the left-hand side of the distribution. The `scale` parameter is also interpreted slightly differently than in AsymmetricLaplace. When `loc=0` and `scale=1`, AsymmetricLaplace(0,1,1) is equivalent to Laplace(0,1), while AsymmetricLaplaceQuantile(0,1,0.5) is equivalent to Laplace(0,2). """ arg_constraints = { "loc": constraints.real, "scale": constraints.positive, "quantile": constraints.open_interval(0.0, 1.0), } reparametrized_params = ["loc", "scale", "quantile"] support = constraints.real pytree_data_fields = ("loc", "scale", "quantile", "_ald") def __init__( self, loc: ArrayLike = 0.0, scale: ArrayLike = 1.0, quantile: ArrayLike = 0.5, *, validate_args: Optional[bool] = None, ) -> None: batch_shape = lax.broadcast_shapes( jnp.shape(loc), jnp.shape(scale), jnp.shape(quantile) ) self.loc, self.scale, self.quantile = promote_shapes( loc, scale, quantile, shape=batch_shape ) super(AsymmetricLaplaceQuantile, self).__init__( batch_shape=batch_shape, validate_args=validate_args ) asymmetry = (1 / ((1 / quantile) - 1)) ** 0.5 scale_classic = scale * asymmetry / quantile self._ald = AsymmetricLaplace(loc=loc, scale=scale_classic, asymmetry=asymmetry)
[docs] def log_prob(self, value: ArrayLike) -> ArrayLike: if self._validate_args: self._validate_sample(value) return self._ald.log_prob(value)
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: return self._ald.sample(key, sample_shape=sample_shape)
@property def mean(self) -> ArrayLike: return self._ald.mean @property def variance(self) -> ArrayLike: return self._ald.variance
[docs] def cdf(self, value: ArrayLike) -> ArrayLike: return self._ald.cdf(value)
[docs] def icdf(self, value: ArrayLike) -> ArrayLike: return self._ald.icdf(value)
[docs] class ZeroSumNormal(TransformedDistribution): r""" Zero Sum Normal distribution adapted from PyMC [1] as described in [2,3]. This is a Normal distribution where one or more axes are constrained to sum to zero (the last axis by default). .. math:: \begin{align*} ZSN(\sigma) = N(0, \sigma^2 (I - \tfrac{1}{n}J)) \\ \text{where} \ ~ J_{ij} = 1 \ ~ \text{and} \\ n = \text{number of zero-sum axes} \end{align*} :param array_like scale: Standard deviation of the underlying normal distribution before the zerosum constraint is enforced. :param tuple event_shape: The event shape of the distribution, the axes of which get constrained to sum to zero. **Example:** .. doctest:: >>> from numpy.testing import assert_allclose >>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.infer import MCMC, NUTS >>> N = 1000 >>> n_categories = 20 >>> rng_key = random.key(0) >>> key1, key2, key3 = random.split(rng_key, 3) >>> category_ind = random.choice(key1, jnp.arange(n_categories), shape=(N,)) >>> beta = random.normal(key2, shape=(n_categories,)) >>> beta -= beta.mean(-1) >>> y = 5 + beta[category_ind] + random.normal(key3, shape=(N,)) >>> def model(category_ind, y): # category_ind is an indexed categorical variable with 20 categories ... N = len(category_ind) ... alpha = numpyro.sample("alpha", dist.Normal(0, 2.5)) ... beta = numpyro.sample("beta", dist.ZeroSumNormal(1, event_shape=(n_categories,))) ... sigma = numpyro.sample("sigma", dist.Exponential(1)) ... with numpyro.plate("observations", N): ... mu = alpha + beta[category_ind] ... obs = numpyro.sample("obs", dist.Normal(mu, sigma), obs=y) ... return obs >>> nuts_kernel = NUTS(model=model, target_accept_prob=0.9) >>> mcmc = MCMC( ... sampler=nuts_kernel, ... num_samples=1_000, num_warmup=1_000, num_chains=4 ... ) >>> mcmc.run(random.key(0), category_ind=category_ind, y=y) >>> posterior_samples = mcmc.get_samples() >>> # Confirm everything along last axis sums to zero >>> assert_allclose(posterior_samples['beta'].sum(-1), 0, atol=1e-3) **References** [1] https://github.com/pymc-devs/pymc/blob/6252d2e58dc211c913ee2e652a4058d271d48bbd/pymc/distributions/multivariate.py#L2637 [2] https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.ZeroSumNormal.html [3] https://learnbayesstats.com/episode/74-optimizing-nuts-developing-zerosumnormal-distribution-adrian-seyboldt/ """ arg_constraints = {"scale": constraints.positive} reparametrized_params = ["scale"] def __init__( self, scale: ArrayLike, event_shape: tuple[int, ...], *, validate_args: Optional[bool] = None, ) -> None: event_ndim = len(event_shape) transformed_shape = tuple(size - 1 for size in event_shape) self.scale = scale super().__init__( Normal(0, scale).expand(transformed_shape).to_event(event_ndim), ZeroSumTransform(event_ndim), validate_args=validate_args, ) @constraints.dependent_property(is_discrete=False) def support(self) -> constraints.Constraint: return constraints.zero_sum(len(self.event_shape)) @property def mean(self) -> ArrayLike: return jnp.zeros(self.batch_shape + self.event_shape) @property def variance(self) -> ArrayLike: event_ndim = len(self.event_shape) zero_sum_axes = tuple(range(-event_ndim, 0)) theoretical_var = jnp.square(self.scale) for axis in zero_sum_axes: theoretical_var *= 1 - 1 / self.event_shape[axis] return jnp.broadcast_to(theoretical_var, self.batch_shape + self.event_shape)
[docs] class Wishart(TransformedDistribution): """ Wishart distribution for covariance matrices. :param concentration: Positive concentration parameter analogous to the concentration of a :class:`Gamma` distribution. The concentration must be larger than the dimensionality of the scale matrix. :param scale_matrix: Scale matrix analogous to the inverse rate of a :class:`Gamma` distribution. :param rate_matrix: Rate matrix anaologous to the rate of a :class:`Gamma` distribution. :param scale_tril: Cholesky decomposition of the :code:`scale_matrix`. """ arg_constraints = { "concentration": constraints.dependent(is_discrete=False, event_dim=0), "scale_matrix": constraints.positive_definite, "rate_matrix": constraints.positive_definite, "scale_tril": constraints.lower_cholesky, } support = constraints.positive_definite reparametrized_params = [ "scale_matrix", "rate_matrix", "scale_tril", ] def __init__( self, concentration: ArrayLike, scale_matrix: Optional[Array] = None, rate_matrix: Optional[Array] = None, scale_tril: Optional[Array] = None, *, validate_args: Optional[bool] = None, ) -> None: base_dist = WishartCholesky( concentration, scale_matrix, rate_matrix, scale_tril, validate_args=validate_args, ) super().__init__( base_dist, CholeskyTransform().inv, validate_args=validate_args )
[docs] @lazy_property def concentration(self): return self.base_dist.concentration
[docs] @lazy_property def scale_matrix(self): return self.base_dist.scale_matrix
[docs] @lazy_property def rate_matrix(self): return self.base_dist.rate_matrix
[docs] @lazy_property def scale_tril(self): return self.base_dist.scale_tril
[docs] @lazy_property def mean(self) -> ArrayLike: return self.concentration[..., None, None] * self.scale_matrix
[docs] @lazy_property def variance(self) -> ArrayLike: diag = jnp.diagonal(self.scale_matrix, axis1=-1, axis2=-2) return self.concentration[..., None, None] * ( self.scale_matrix**2 + diag[..., :, None] * diag[..., None, :] )
[docs] @staticmethod def infer_shapes( concentration=(), scale_matrix=None, rate_matrix=None, scale_tril=None ): return WishartCholesky.infer_shapes( concentration, scale_matrix, rate_matrix, scale_tril )
[docs] def entropy(self) -> ArrayLike: p = self.event_shape[-1] return ( (p + 1) * tri_logabsdet(self.scale_tril) + p * (p + 1) / 2 * jnp.log(2) + multigammaln(self.concentration / 2, p) - (self.concentration - p - 1) / 2 * multidigamma(self.concentration / 2, p) + self.concentration * p / 2 )
[docs] class WishartCholesky(Distribution): """ Cholesky factor of a Wishart distribution for covariance matrices. :param concentration: Positive concentration parameter analogous to the concentration of a :class:`Gamma` distribution. The concentration must be larger than the dimensionality of the scale matrix. :param scale_matrix: Scale matrix analogous to the inverse rate of a :class:`Gamma` distribution. :param rate_matrix: Rate matrix anaologous to the rate of a :class:`Gamma` distribution. :param scale_tril: Cholesky decomposition of the :code:`scale_matrix`. """ arg_constraints = { "concentration": constraints.dependent(is_discrete=False, event_dim=0), "scale_matrix": constraints.positive_definite, "rate_matrix": constraints.positive_definite, "scale_tril": constraints.lower_cholesky, } support = constraints.lower_cholesky reparametrized_params = [ "scale_matrix", "rate_matrix", "scale_tril", ] def __init__( self, concentration: ArrayLike, scale_matrix: Optional[Array] = None, rate_matrix: Optional[Array] = None, scale_tril: Optional[Array] = None, *, validate_args: Optional[bool] = None, ) -> None: assert_one_of( scale_matrix=scale_matrix, rate_matrix=rate_matrix, scale_tril=scale_tril, ) concentration = jnp.asarray(concentration)[..., None, None] if scale_matrix is not None: concentration, self.scale_matrix = promote_shapes( concentration, scale_matrix ) self.scale_tril = jnp.linalg.cholesky(self.scale_matrix) elif rate_matrix is not None: concentration, self.rate_matrix = promote_shapes(concentration, rate_matrix) self.scale_tril = cholesky_of_inverse(self.rate_matrix) elif scale_tril is not None: concentration, self.scale_tril = promote_shapes( concentration, jnp.asarray(scale_tril) ) batch_shape = lax.broadcast_shapes( jnp.shape(concentration)[:-2], jnp.shape(self.scale_tril)[:-2] ) event_shape = jnp.shape(self.scale_tril)[-2:] self.concentration = concentration[..., 0, 0] super().__init__( batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args, )
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: # The log density of the Wishart distribution includes a term # t = trace(rate_matrix @ cov). Here, value = cholesky(cov) such that # t = trace(value.T @ rate_matrix @ value) by the cyclical property of the # trace. The rate matrix is the inverse scale matrix with Cholesky decomposition # scale_tril. Thus, # t = trace(value.T @ inv(scale_tril).T @ inv(scale_tril) @ value), and we can # rewrite as t = trace(x.T @ x) for x = inv(scale_tril) @ value which we can # obtain easily by solving a triangular system. x is again triangular such that # trace(x @ x.T) is equal to the sum of squares of elements. x = solve_triangular(*jnp.broadcast_arrays(self.scale_tril, value), lower=True) trace = jnp.square(x).sum(axis=(-1, -2)) p = value.shape[-1] return ( (self.concentration - p - 1) * tri_logabsdet(value) - trace / 2 + p * (1 - self.concentration / 2) * jnp.log(2) - multigammaln(self.concentration / 2, p) - self.concentration * tri_logabsdet(self.scale_tril) # Part of the Jacobian of the Cholesky transformation. + jnp.sum( jnp.arange(p, 0, -1) * jnp.log(jnp.diagonal(value, axis1=-2, axis2=-1)), axis=-1, ) )
[docs] @lazy_property def scale_matrix(self): return jnp.matmul(self.scale_tril, self.scale_tril.mT)
[docs] @lazy_property def rate_matrix(self): identity = jnp.broadcast_to( jnp.eye(self.scale_tril.shape[-1]), self.scale_tril.shape ) return cho_solve((self.scale_tril, True), identity)
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: assert is_prng_key(key) # Sample using the Bartlett decomposition # (https://en.wikipedia.org/wiki/Wishart_distribution#Bartlett_decomposition). rng_diag, rng_offdiag = random.split(key) latent = jnp.zeros(sample_shape + self.batch_shape + self.event_shape) p = self.event_shape[-1] i = jnp.arange(p) latent = latent.at[..., i, i].set( jnp.sqrt( random.chisquare( rng_diag, self.concentration[..., None] - i, latent.shape[:-1] ) ) ) i, j = jnp.tril_indices(p, -1) assert i.size == p * (p - 1) // 2 latent = latent.at[..., i, j].set( random.normal(rng_offdiag, latent.shape[:-2] + (i.size,)) ) return jnp.matmul(*jnp.broadcast_arrays(self.scale_tril, latent))
[docs] @lazy_property def mean(self) -> ArrayLike: # The mean follows from the Bartlett decomposition sampling. All off-diagonal # elements of the latent variable have zero expectation. The diagonal are the # expected square roots of chi^2 variables which can be expressed in terms of # gamma functions (see # https://en.wikipedia.org/wiki/Chi-squared_distribution#Noncentral_moments). k = self.concentration[..., None] - jnp.arange(self.scale_tril.shape[-1]) sqrtchi2 = jnp.sqrt(2) * jnp.exp(gammaln((k + 1) / 2) - gammaln(k / 2)) return self.scale_tril * sqrtchi2[..., None, :]
[docs] @lazy_property def variance(self) -> ArrayLike: # We have the same as for the mean except now the lower off-diagonals are one # due to the standard normal noise, and the diagonals are equal to the dof of # the chi^2 variables. i = jnp.arange(self.scale_tril.shape[-1]) k = self.concentration[..., None] - i latent = jnp.tril( jnp.ones_like(k, shape=k.shape + (k.shape[-1],)).at[..., i, i].set(k) ) return jnp.square(self.scale_tril) @ latent - jnp.square(self.mean)
[docs] @staticmethod def infer_shapes( concentration: tuple[int, ...] = (), scale_matrix: Optional[tuple[int, ...]] = None, rate_matrix: Optional[tuple[int, ...]] = None, scale_tril: Optional[tuple[int, ...]] = None, ): assert_one_of( scale_matrix=scale_matrix, rate_matrix=rate_matrix, scale_tril=scale_tril, ) for matrix in [scale_matrix, rate_matrix, scale_tril]: if matrix is not None: batch_shape = lax.broadcast_shapes(concentration, matrix[:-2]) event_shape = matrix[-2:] return batch_shape, event_shape
[docs] class InverseWishart(TransformedDistribution): r""" Inverse Wishart distribution for covariance matrices. The Inverse Wishart distribution is the conjugate prior for the covariance matrix of a multivariate normal distribution. If :math:`\mathbf{X} \sim W^{-1}(\mathbf{\Psi}, \nu)`, then :math:`\mathbf{X}^{-1} \sim W(\mathbf{\Psi}^{-1}, \nu)` (Wishart distribution). .. math:: p(\mathbf{X} \mid \mathbf{\Psi}, \nu) = \frac{|\mathbf{\Psi}|^{\nu/2}}{2^{\nu p/2} \Gamma_p(\nu/2)} |\mathbf{X}|^{-(\nu + p + 1)/2} \exp\left( -\frac{1}{2} \mathrm{tr}(\mathbf{\Psi} \mathbf{X}^{-1}) \right) where :math:`p` is the dimension of the matrix, :math:`\nu > p - 1` is the degrees of freedom, and :math:`\mathbf{\Psi}` is the positive definite scale matrix. :param concentration: Degrees of freedom parameter (often denoted :math:`\nu`). Must be greater than `p - 1` where `p` is the dimension of the scale matrix. :param scale_matrix: Positive definite scale matrix :math:`\mathbf{\Psi}`, analogous to the inverse rate of a :class:`Gamma` distribution. :param rate_matrix: Inverse of the scale matrix, analogous to the rate of a :class:`Gamma` distribution. :param scale_tril: Cholesky decomposition of the scale matrix. **Properties** - **Mean**: :math:`\frac{\mathbf{\Psi}}{\nu - p - 1}` for :math:`\nu > p + 1` - **Mode**: :math:`\frac{\mathbf{\Psi}}{\nu + p + 1}` **References** [1] https://en.wikipedia.org/wiki/Inverse-Wishart_distribution """ arg_constraints = { "concentration": constraints.dependent(is_discrete=False, event_dim=0), "scale_matrix": constraints.positive_definite, "rate_matrix": constraints.positive_definite, "scale_tril": constraints.lower_cholesky, } support = constraints.positive_definite reparametrized_params = [ "scale_matrix", "rate_matrix", "scale_tril", ] def __init__( self, concentration: ArrayLike, scale_matrix: Optional[Array] = None, rate_matrix: Optional[Array] = None, scale_tril: Optional[Array] = None, *, validate_args: Optional[bool] = None, ) -> None: base_dist = InverseWishartCholesky( concentration, scale_matrix, rate_matrix, scale_tril, validate_args=validate_args, ) super().__init__( base_dist, CholeskyTransform().inv, validate_args=validate_args )
[docs] @lazy_property def concentration(self): return self.base_dist.concentration
[docs] @lazy_property def scale_matrix(self): return self.base_dist.scale_matrix
[docs] @lazy_property def rate_matrix(self): return self.base_dist.rate_matrix
[docs] @lazy_property def scale_tril(self): return self.base_dist.scale_tril
[docs] @lazy_property def mean(self) -> ArrayLike: # Mean exists only when concentration > p + 1 p = self.scale_matrix.shape[-1] return jnp.where( self.concentration[..., None, None] > p + 1, self.scale_matrix / (self.concentration[..., None, None] - p - 1), jnp.full_like(self.scale_matrix, jnp.nan), )
[docs] @lazy_property def mode(self) -> ArrayLike: p = self.scale_matrix.shape[-1] return self.scale_matrix / (self.concentration[..., None, None] + p + 1)
[docs] @lazy_property def variance(self) -> ArrayLike: # Variance of entry (i,j) for nu > p + 3 # Var(X_ij) = (Psi_ij^2 + Psi_ii * Psi_jj) / ((nu - p - 1)^2 * (nu - p - 3)) p = self.scale_matrix.shape[-1] nu = jnp.expand_dims(self.concentration, axis=(-1, -2)) psi = self.scale_matrix denom = (nu - p - 1) ** 2 * (nu - p - 3) psi_ii = jnp.diagonal(psi, axis1=-2, axis2=-1)[..., :, None] psi_jj = jnp.diagonal(psi, axis1=-2, axis2=-1)[..., None, :] var = (psi**2 + psi_ii * psi_jj) / denom return jnp.where(nu > p + 3, var, jnp.full_like(var, jnp.nan))
[docs] @staticmethod def infer_shapes( concentration=(), scale_matrix=None, rate_matrix=None, scale_tril=None ): return InverseWishartCholesky.infer_shapes( concentration, scale_matrix, rate_matrix, scale_tril )
[docs] class InverseWishartCholesky(Distribution): r""" Cholesky factor of an Inverse Wishart distribution for covariance matrices. This distribution samples the Cholesky factor :math:`\mathbf{L}` such that :math:`\mathbf{X} = \mathbf{L} \mathbf{L}^T \sim W^{-1}(\mathbf{\Psi}, \nu)`. :param concentration: Degrees of freedom parameter (often denoted :math:`\nu`). Must be greater than `p - 1` where `p` is the dimension of the scale matrix. :param scale_matrix: Positive definite scale matrix :math:`\mathbf{\Psi}`, analogous to the inverse rate of a :class:`Gamma` distribution. :param rate_matrix: Inverse of the scale matrix, analogous to the rate of a :class:`Gamma` distribution. :param scale_tril: Cholesky decomposition of the scale matrix. **References** [1] https://en.wikipedia.org/wiki/Inverse-Wishart_distribution """ arg_constraints = { "concentration": constraints.dependent(is_discrete=False, event_dim=0), "scale_matrix": constraints.positive_definite, "rate_matrix": constraints.positive_definite, "scale_tril": constraints.lower_cholesky, } support = constraints.lower_cholesky reparametrized_params = [ "scale_matrix", "rate_matrix", "scale_tril", ] def __init__( self, concentration: ArrayLike, scale_matrix: Optional[Array] = None, rate_matrix: Optional[Array] = None, scale_tril: Optional[Array] = None, *, validate_args: Optional[bool] = None, ) -> None: assert_one_of( scale_matrix=scale_matrix, rate_matrix=rate_matrix, scale_tril=scale_tril, ) concentration = jnp.asarray(concentration)[..., None, None] if scale_matrix is not None: concentration, self.scale_matrix = promote_shapes( concentration, scale_matrix ) self.scale_tril = jnp.linalg.cholesky(self.scale_matrix) elif rate_matrix is not None: concentration, self.rate_matrix = promote_shapes(concentration, rate_matrix) self.scale_tril = cholesky_of_inverse(self.rate_matrix) elif scale_tril is not None: concentration, self.scale_tril = promote_shapes( concentration, jnp.asarray(scale_tril) ) batch_shape = lax.broadcast_shapes( jnp.shape(concentration)[:-2], jnp.shape(self.scale_tril)[:-2] ) event_shape = jnp.shape(self.scale_tril)[-2:] self.concentration = concentration[..., 0, 0] super().__init__( batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args, )
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: # L = value (Cholesky factor), X = L @ L^T ~ InverseWishart(Psi, nu) # log p(X) = (nu/2) log|Psi| - (nu*p/2) log(2) - log Gamma_p(nu/2) # - ((nu+p+1)/2) log|X| - tr(Psi @ X^{-1}) / 2 # Trace trick: tr(Psi @ X^{-1}) = ||L^{-1} @ scale_tril||_F^2 x = solve_triangular(*jnp.broadcast_arrays(value, self.scale_tril), lower=True) trace = jnp.square(x).sum(axis=(-1, -2)) p = value.shape[-1] log_diag = jnp.log(jnp.diagonal(value, axis1=-2, axis2=-1)) return ( self.concentration * tri_logabsdet(self.scale_tril) # (nu/2) log|Psi| + p * (1 - self.concentration / 2) * jnp.log(2) # normalization - multigammaln(self.concentration / 2, p) + jnp.sum( (-self.concentration[..., None] - 1 - jnp.arange(p)) * log_diag, axis=-1, ) - trace / 2 )
[docs] @lazy_property def scale_matrix(self): return jnp.matmul(self.scale_tril, self.scale_tril.mT)
[docs] @lazy_property def rate_matrix(self): identity = jnp.broadcast_to( jnp.eye(self.scale_tril.shape[-1]), self.scale_tril.shape ) return cho_solve((self.scale_tril, True), identity)
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: assert is_prng_key(key) # Sample from standard InverseWishartCholesky using Bartlett decomposition # Ref: https://nbviewer.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril rng_diag, rng_offdiag = random.split(key) latent = jnp.zeros(sample_shape + self.batch_shape + self.event_shape) p = self.event_shape[-1] # Inverse Wishart Bartlett: nu - p + 1, nu - p + 2, ..., nu - 1, nu i = jnp.arange(p) latent = latent.at[..., i, i].set( jnp.sqrt( random.chisquare( rng_diag, self.concentration[..., None] + i - p + 1, latent.shape[:-1], ) ) ) i, j = jnp.tril_indices(p, -1) latent = latent.at[..., i, j].set( random.normal(rng_offdiag, latent.shape[:-2] + (i.size,)) ) # Get Cholesky of InverseWishart(I) by inverting latent identity = jnp.broadcast_to(jnp.eye(p), latent.shape) L_inv_std = solve_triangular(latent, identity, lower=True) # Transform to InverseWishart(Psi): L = scale_tril @ L_inv_std return jnp.matmul(self.scale_tril, L_inv_std)
[docs] @lazy_property def mean(self) -> ArrayLike: # Approximate: chol(E[X]) where E[X] = Psi / (nu - p - 1) for nu > p + 1 p = self.scale_tril.shape[-1] mean_x = jnp.where( self.concentration[..., None, None] > p + 1, self.scale_matrix / (self.concentration[..., None, None] - p - 1), jnp.full_like(self.scale_matrix, jnp.nan), ) return jnp.linalg.cholesky( jnp.where(jnp.isnan(mean_x), jnp.eye(p), mean_x) ) * jnp.where( self.concentration[..., None, None] > p + 1, jnp.ones_like(mean_x), jnp.full_like(mean_x, jnp.nan), )
[docs] @lazy_property def variance(self) -> ArrayLike: # Variance of Cholesky factor is complex; return NaN for now return jnp.full(self.batch_shape + self.event_shape, jnp.nan)
[docs] @staticmethod def infer_shapes( concentration: tuple[int, ...] = (), scale_matrix: Optional[tuple[int, ...]] = None, rate_matrix: Optional[tuple[int, ...]] = None, scale_tril: Optional[tuple[int, ...]] = None, ): assert_one_of( scale_matrix=scale_matrix, rate_matrix=rate_matrix, scale_tril=scale_tril, ) for matrix in [scale_matrix, rate_matrix, scale_tril]: if matrix is not None: batch_shape = lax.broadcast_shapes(concentration, matrix[:-2]) event_shape = matrix[-2:] return batch_shape, event_shape
[docs] class Levy(Distribution): r"""Lévy distribution is a special case of Lévy alpha-stable distribution. Its probability density function is given by, .. math:: f(x\mid \mu, c) = \sqrt{\frac{c}{2\pi(x-\mu)^{3}}} \exp\left(-\frac{c}{2(x-\mu)}\right), \qquad x > \mu where :math:`\mu` is the location parameter and :math:`c` is the scale parameter. :param loc: Location parameter. :param scale: Scale parameter. """ arg_constraints = { "loc": constraints.real, "scale": constraints.positive, } def __init__( self, loc: ArrayLike, scale: ArrayLike, *, validate_args: Optional[bool] = None, ) -> None: self.loc, self.scale = promote_shapes(loc, scale) batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale)) self._support = constraints.greater_than(loc) super(Levy, self).__init__(batch_shape, validate_args=validate_args) @constraints.dependent_property(is_discrete=False) def support(self) -> constraints.Constraint: return self._support
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: r"""Compute the log probability density function of the Lévy distribution. .. math:: \log f(x\mid \mu, c) = \frac{1}{2}\log\left(\frac{c}{2\pi}\right) - \frac{c}{2(x-\mu)} - \frac{3}{2}\log(x-\mu), \qquad x > \mu :param value: A batch of samples from the distribution. :return: an array with shape `value.shape[:-self.event_shape]` :rtype: numpy.ndarray """ shifted_value = value - self.loc return -0.5 * ( jnp.log(2.0 * jnp.pi) - jnp.log(self.scale) + self.scale / shifted_value ) - 1.5 * jnp.log(shifted_value)
[docs] def sample(self, key: ArrayLike, sample_shape: tuple[int, ...] = ()) -> ArrayLike: assert is_prng_key(key) u = random.uniform(key, shape=sample_shape + self.batch_shape) return self.icdf(u)
[docs] def icdf(self, q: ArrayLike) -> ArrayLike: r""" The inverse cumulative distribution function of Lévy distribution is given by, .. math:: F^{-1}(q\mid \mu, c) = \mu + c\left(\Phi^{-1}(1-q/2)\right)^{-2} where :math:`\Phi^{-1}` is the inverse of the standard normal cumulative distribution function. :param q: quantile values, should belong to [0, 1]. :return: the samples whose cdf values equals to `q`. """ return self.loc + self.scale * jnp.power(ndtri(1 - 0.5 * q), -2)
[docs] def cdf(self, value: ArrayLike) -> ArrayLike: r"""The cumulative distribution function of Lévy distribution is given by, .. math:: F(x\mid \mu, c) = 2 - 2\Phi\left(\sqrt{\frac{c}{x-\mu}}\right) where :math:`\Phi` is the standard normal cumulative distribution function. :param value: samples from Lévy distribution. :return: output of the cumulative distribution function evaluated at `value`. """ inv_standardized = self.scale / (value - self.loc) return 2.0 - 2.0 * ndtr(jnp.sqrt(inv_standardized))
@property def mean(self) -> ArrayLike: return jnp.broadcast_to(jnp.inf, self.batch_shape) @property def variance(self) -> ArrayLike: return jnp.broadcast_to(jnp.inf, self.batch_shape)
[docs] def entropy(self) -> ArrayLike: r"""If :math:`X \sim \text{Levy}(\mu, c)`, then the entropy of :math:`X` is given by, .. math:: H(X) = \frac{1}{2}+\frac{3}{2}\gamma+\frac{1}{2}\ln{\left(16\pi c^2\right)} """ return jnp.broadcast_to( 0.5 + 1.5 * jnp.euler_gamma + 0.5 * jnp.log(16 * jnp.pi), self.batch_shape ) + jnp.log(self.scale)
[docs] class CirculantNormal(TransformedDistribution): r""" Multivariate normal distribution with covariance matrix :math:`\mathbf{C}` that is positive-definite and circulant [1], i.e., has periodic boundary conditions. The density of a sample :math:`\mathbf{x}\in\mathbb{R}^n` is the standard multivariate normal density .. math:: p\left(\mathbf{x}\mid\boldsymbol{\mu},\mathbf{C}\right) = \frac{\left(\mathrm{det}\,\mathbf{C}\right)^{-1/2}}{\left(2\pi\right)^{n / 2}} \exp\left(-\frac{1}{2}\left(\mathbf{x}-\boldsymbol{\mu}\right)^\intercal \mathbf{C}^{-1}\left(\mathbf{x}-\boldsymbol{\mu}\right)\right), where :math:`\mathrm{det}` denotes the determinant and :math:`^\intercal` the transpose. Circulant matrices can be diagnolized efficiently using the discrete Fourier transform [1], allowing the log likelihood to be evaluated in :math:`n \log n` time for :math:`n` observations [2]. :param loc: Mean of the distribution :math:`\boldsymbol{\mu}`. :param covariance_row: First row of the circulant covariance matrix :math:`\boldsymbol{C}`. Because of periodic boundary conditions, the covariance matrix is fully determined by its first row (see :func:`jax.scipy.linalg.toeplitz` for further details). :param covariance_rfft: Real part of the real fast Fourier transform of :code:`covariance_row`, the first row of the circulant covariance matrix :math:`\boldsymbol{C}`. **References:** 1. Wikipedia. (n.d.). Circulant matrix. Retrieved March 6, 2025, from https://en.wikipedia.org/wiki/Circulant_matrix 2. Wood, A. T. A., & Chan, G. (1994). Simulation of Stationary Gaussian Processes in :math:`\left[0, 1\right]^d`. *Journal of Computational and Graphical Statistics*, 3(4), 409--432. https://doi.org/10.1080/10618600.1994.10474655 """ arg_constraints = { "loc": constraints.real_vector, "covariance_row": constraints.positive_definite_circulant_vector, "covariance_rfft": constraints.independent(constraints.positive, 1), } support = constraints.real_vector def __init__( self, loc: ArrayLike, covariance_row: Optional[ArrayLike] = None, covariance_rfft: Optional[ArrayLike] = None, *, validate_args: Optional[bool] = None, ) -> None: # We demand a one-dimensional input, because we cannot determine the event shape # if only the `covariance_rfft` is given. assert jnp.ndim(loc) > 0, "Location parameter must have at least one dimension." n = jnp.shape(loc)[-1] n_rfft = n // 2 + 1 assert_one_of(covariance_row=covariance_row, covariance_rfft=covariance_rfft) if covariance_rfft is None: # Evaluate `covariance_rfft` if not provided and validate. assert covariance_row.shape[-1] == n loc, covariance_row = promote_shapes(loc, covariance_row) covariance_rfft = jnp.fft.rfft(covariance_row).real self.covariance_row = covariance_row else: # The `covariance_rfft` and `loc` are not promotable because the trailing # dimension does not match. We manually retrieve the shapes and then # promote. loc_shape, covariance_rfft_shape = promote_shapes( loc[..., 0], covariance_rfft[..., 0], return_shapes=True ) loc = _reshape(loc, loc_shape + (n,)) covariance_rfft = _reshape( covariance_rfft, covariance_rfft_shape + (n_rfft,) ) self.loc = loc self.covariance_rfft = covariance_rfft # Construct the base distribution. n_imag = n - n_rfft assert self.covariance_rfft.shape[-1] == n_rfft var_rfft = (n * covariance_rfft / 2).at[..., 0].mul(2) if n % 2 == 0: var_rfft = var_rfft.at[..., -1].mul(2) var_rfft = jnp.concatenate([var_rfft, var_rfft[..., 1 : 1 + n_imag]], axis=-1) assert var_rfft.shape[-1] == n base_distribution = Normal(scale=jnp.sqrt(var_rfft)).to_event(1) super().__init__( base_distribution, [ PackRealFastFourierCoefficientsTransform((n,)), RealFastFourierTransform((n,)).inv, AffineTransform(loc, scale=1.0), ], validate_args=validate_args, ) @property def mean(self) -> ArrayLike: return jnp.broadcast_to(self.loc, self.shape())
[docs] @lazy_property def covariance_row(self) -> ArrayLike: return jnp.fft.irfft(self.covariance_rfft, n=self.event_shape[-1])
[docs] @lazy_property def covariance_matrix(self) -> ArrayLike: *leading_shape, n = self.covariance_row.shape if leading_shape: # `toeplitz` flattens the input, and we need to broadcast manually. (n,) = self.event_shape return vmap(toeplitz)(self.covariance_row.reshape((-1, n))).reshape( (*leading_shape, n, n) ) else: return toeplitz(self.covariance_row)
[docs] @lazy_property def variance(self) -> ArrayLike: return jnp.broadcast_to(self.covariance_row[..., 0, None], self.shape())
[docs] @staticmethod def infer_shapes( loc: tuple[int, ...] = (), covariance_row: Optional[tuple[int, ...]] = None, covariance_rfft: Optional[tuple[int, ...]] = None, ): assert_one_of(covariance_row=covariance_row, covariance_rfft=covariance_rfft) for cov in [covariance_rfft, covariance_row]: if cov is not None: batch_shape = jnp.broadcast_shapes(loc[:-1], cov[:-1]) event_shape = loc[-1:] return batch_shape, event_shape
[docs] def entropy(self) -> ArrayLike: (n,) = self.event_shape log_abs_det_jacobian = 2 * jnp.log(2) * ((n - 1) // 2) - jnp.log(n) * n return self.base_dist.entropy() + log_abs_det_jacobian / 2
[docs] class Dagum(Distribution): arg_constraints = { "concentration": constraints.positive, "sharpness": constraints.positive, "scale": constraints.positive, } support = constraints.positive reparametrized_params = ["concentration", "sharpness", "scale"] def __init__( self, concentration: ArrayLike, sharpness: ArrayLike, scale: ArrayLike, *, validate_args: Optional[bool] = None, ) -> None: r"""The Dagum distribution (or Mielke Beta-Kappa distribution) is a continuous probability distribution defined over positive real numbers. If :math:`p`, :math:`a` and :math:`b` are concentration, sharpness and scale values respectively, then Dagum distribution is defined as, .. math:: f(x\mid p,a,b):=\frac{ap}{x} \left(\frac{(x/b)^{ap}}{\left((x/b)^{a}+1\right)^{p+1}}\right) **References:** 1. Wikipedia. (n.d.). Dagum distribution. Retrieved March 31, 2025, from https://en.wikipedia.org/wiki/Dagum_distribution """ self.concentration, self.sharpness, self.scale = promote_shapes( concentration, sharpness, scale ) batch_shape = lax.broadcast_shapes( jnp.shape(concentration), jnp.shape(sharpness), jnp.shape(scale) ) super().__init__(batch_shape=batch_shape, validate_args=validate_args)
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: a_ln_x_m_ln_b = xlogy(self.sharpness, value) - xlogy(self.sharpness, self.scale) return ( jnp.log(self.sharpness) + jnp.log(self.concentration) - jnp.log(value) + self.concentration * a_ln_x_m_ln_b - (self.concentration + 1.0) * nn.softplus(a_ln_x_m_ln_b) )
[docs] def cdf(self, value: ArrayLike) -> ArrayLike: return jnp.exp( -self.concentration * nn.softplus( xlogy(self.sharpness, self.scale) - xlogy(self.sharpness, value) ) )
[docs] def icdf(self, q: ArrayLike) -> ArrayLike: q_root_p = jnp.power(q, -jnp.reciprocal(self.concentration)) return self.scale * jnp.power(q_root_p - 1.0, -jnp.reciprocal(self.sharpness))
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> jnp.ndarray: assert is_prng_key(key) return self.icdf(random.uniform(key, shape=self.shape(sample_shape)))
@property def mean(self) -> ArrayLike: safe_a = jnp.where(self.sharpness > 1.0, self.sharpness, 2.0) return jnp.where( self.sharpness > 1.0, (self.scale * self.concentration) * jnp.exp(betaln(1.0 - 1.0 / safe_a, self.concentration + 1.0 / safe_a)), jnp.inf, ) @property def variance(self) -> ArrayLike: safe_a = jnp.where(self.sharpness > 2.0, self.sharpness, 3.0) return jnp.where( self.sharpness > 2.0, (jnp.square(self.scale) * self.concentration) * jnp.exp(betaln(1.0 - 2.0 / safe_a, self.concentration + 2.0 / safe_a)) - jnp.square(self.mean), jnp.inf, )
[docs] class HurdleGamma(HurdleProbs): r"""A hurdle Gamma distribution: a two-part model in which a structural zero occurs with probability :math:`g` and, conditional on a positive outcome, the magnitude is drawn from :math:`\mathrm{Gamma}(\alpha, \lambda)`. The hurdle and the magnitude (given a positive value) are conditionally independent; see :class:`HurdleProbs` for the full mechanism and assumptions. Because :math:`P(X = 0) = 0` under a Gamma density, no truncation factor is needed and the PDF is .. math:: P(X = 0) = g, \qquad f(x) = (1 - g) \, \frac{\lambda^{\alpha} x^{\alpha - 1} e^{-\lambda x}}{\Gamma(\alpha)} \;\text{for } x > 0. :param ArrayLike gate: probability of a structural zero, :math:`g \in [0, 1]`. :param ArrayLike concentration: shape parameter :math:`\alpha > 0` of the Gamma. :param ArrayLike rate: rate parameter :math:`\lambda > 0` of the Gamma. **References:** 1. Cragg, J. G. (1971). Some Statistical Models for Limited Dependent Variables with Application to the Demand for Durable Goods. *Econometrica*, 39(5), 829-844. 2. Mullahy, J. (1986). Specification and testing of some modified count data models. *Journal of Econometrics*, 33(3), 341-365. """ arg_constraints = { "gate": constraints.unit_interval, "concentration": constraints.positive, "rate": constraints.positive, } support = constraints.nonnegative pytree_data_fields = ("concentration", "rate") def __init__( self, gate: ArrayLike, concentration: ArrayLike, rate: ArrayLike = 1.0, *, validate_args: Optional[bool] = None, ) -> None: _, self.concentration, self.rate = promote_shapes(gate, concentration, rate) super().__init__(Gamma(concentration, rate), gate, validate_args=validate_args)
[docs] class HurdleLogNormal(HurdleProbs): r"""A hurdle Log-Normal distribution: a two-part model in which a structural zero occurs with probability :math:`g` and, conditional on a positive outcome, the magnitude is drawn from :math:`\mathrm{LogNormal}(\mu, \sigma)`. The hurdle and the magnitude (given a positive value) are conditionally independent; see :class:`HurdleProbs` for the full mechanism and assumptions. Because :math:`P(X = 0) = 0` under a Log-Normal density, no truncation factor is needed and the PDF is .. math:: P(X = 0) = g, \qquad f(x) = (1 - g) \, \frac{1}{x \sigma \sqrt{2 \pi}} \exp\!\left( -\frac{(\ln x - \mu)^2}{2 \sigma^2} \right) \;\text{for } x > 0. :param ArrayLike gate: probability of a structural zero, :math:`g \in [0, 1]`. :param ArrayLike loc: location parameter :math:`\mu \in \mathbb{R}` (mean of :math:`\ln X` given :math:`X > 0`). :param ArrayLike scale: scale parameter :math:`\sigma > 0` (std-dev of :math:`\ln X` given :math:`X > 0`). **References:** 1. Cragg, J. G. (1971). Some Statistical Models for Limited Dependent Variables with Application to the Demand for Durable Goods. *Econometrica*, 39(5), 829-844. 2. Mullahy, J. (1986). Specification and testing of some modified count data models. *Journal of Econometrics*, 33(3), 341-365. """ arg_constraints = { "gate": constraints.unit_interval, "loc": constraints.real, "scale": constraints.positive, } support = constraints.nonnegative pytree_data_fields = ("loc", "scale") def __init__( self, gate: ArrayLike, loc: ArrayLike = 0.0, scale: ArrayLike = 1.0, *, validate_args: Optional[bool] = None, ) -> None: _, self.loc, self.scale = promote_shapes(gate, loc, scale) super().__init__(LogNormal(loc, scale), gate, validate_args=validate_args)