Source code for numpyro.distributions.conjugate

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0


from typing import Optional

import jax
from jax import lax, nn, random
import jax.numpy as jnp
from jax.scipy.special import betainc, betaln, gammaln
from jax.typing import ArrayLike

from numpyro.distributions import constraints
from numpyro.distributions.constraints import Constraint
from numpyro.distributions.continuous import Beta, Dirichlet, Gamma
from numpyro.distributions.discrete import (
    BinomialProbs,
    HurdleProbs,
    MultinomialProbs,
    Poisson,
    ZeroInflatedDistribution,
)
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import promote_shapes, validate_sample
from numpyro.util import is_prng_key


def _log_beta_1(alpha, value):
    # Note: support sparse `value`
    return gammaln(1 + value) + gammaln(alpha) - gammaln(value + alpha)


[docs] class BetaBinomial(Distribution): r""" Compound distribution comprising of a beta-binomial pair. The probability of success (``probs`` for the :class:`~numpyro.distributions.Binomial` distribution) is unknown and randomly drawn from a :class:`~numpyro.distributions.Beta` distribution prior to a certain number of Bernoulli trials given by ``total_count``. :param numpy.ndarray concentration1: 1st concentration parameter (alpha) for the Beta distribution. :param numpy.ndarray concentration0: 2nd concentration parameter (beta) for the Beta distribution. :param numpy.ndarray total_count: number of Bernoulli trials. """ arg_constraints = { "concentration1": constraints.positive, "concentration0": constraints.positive, "total_count": constraints.nonnegative_integer, } has_enumerate_support = True enumerate_support = BinomialProbs.enumerate_support pytree_data_fields = ("concentration1", "concentration0", "total_count", "_beta") def __init__( self, concentration1: ArrayLike, concentration0: ArrayLike, total_count: ArrayLike = 1, *, validate_args: Optional[bool] = None, ): self.concentration1, self.concentration0, self.total_count = promote_shapes( concentration1, concentration0, total_count ) batch_shape = lax.broadcast_shapes( jnp.shape(concentration1), jnp.shape(concentration0), jnp.shape(total_count) ) concentration1 = jnp.broadcast_to(concentration1, batch_shape) concentration0 = jnp.broadcast_to(concentration0, batch_shape) self._beta = Beta(concentration1, concentration0) super(BetaBinomial, self).__init__(batch_shape, validate_args=validate_args)
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: assert is_prng_key(key) key_beta, key_binom = random.split(key) probs = self._beta.sample(key_beta, sample_shape) return BinomialProbs(total_count=self.total_count, probs=probs).sample( key_binom )
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: return ( -_log_beta_1(self.total_count - value + 1, value) + betaln( value + self.concentration1, self.total_count - value + self.concentration0, ) - betaln(self.concentration0, self.concentration1) )
@property def mean(self) -> ArrayLike: return self._beta.mean * self.total_count @property def variance(self) -> ArrayLike: return ( self._beta.variance * self.total_count * (self.concentration0 + self.concentration1 + self.total_count) ) @constraints.dependent_property(is_discrete=True, event_dim=0) def support(self) -> Constraint: return constraints.integer_interval(0, self.total_count)
[docs] class BetaNegativeBinomial(Distribution): r""" Compound distribution comprising of a beta-negative-binomial pair. The ``probs`` parameter for the :class:`~numpyro.distributions.NegativeBinomialProbs` distribution is unknown and randomly drawn from a :class:`~numpyro.distributions.Beta` distribution prior to the negative binomial counting process. The Beta Negative Binomial is a heavy-tailed discrete distribution useful for modeling overdispersed count data. It arises as the marginal distribution when integrating out the success probability in a negative binomial model with a beta prior. :param numpy.ndarray concentration1: 1st concentration parameter (alpha) for the Beta distribution. :param numpy.ndarray concentration0: 2nd concentration parameter (beta) for the Beta distribution. :param numpy.ndarray n: positive number of successes parameter for the negative binomial distribution. **References** [1] https://en.wikipedia.org/wiki/Beta_negative_binomial_distribution [2] https://mc-stan.org/docs/functions-reference/unbounded_discrete_distributions.html#beta-neg-binomial """ arg_constraints = { "concentration1": constraints.positive, "concentration0": constraints.positive, "n": constraints.positive, } support = constraints.nonnegative_integer pytree_data_fields = ("concentration1", "concentration0", "n", "_beta") def __init__( self, concentration1: ArrayLike, concentration0: ArrayLike, n: ArrayLike, *, validate_args: Optional[bool] = None, ): self.concentration1, self.concentration0, self.n = promote_shapes( concentration1, concentration0, n ) batch_shape = lax.broadcast_shapes( jnp.shape(concentration1), jnp.shape(concentration0), jnp.shape(n) ) concentration1 = jnp.broadcast_to(concentration1, batch_shape) concentration0 = jnp.broadcast_to(concentration0, batch_shape) self._beta = Beta(concentration1, concentration0) super(BetaNegativeBinomial, self).__init__( batch_shape, validate_args=validate_args )
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: r"""If :math:`X \sim \mathrm{BetaNegativeBinomial}(\alpha, \beta, n)`, then the sampling procedure is: .. math:: \begin{align*} p &\sim \mathrm{Beta}(\alpha, \beta) \\ X \mid p &\sim \mathrm{NegativeBinomial}(n, p) \end{align*} It uses :class:`~numpyro.distributions.continuous.Beta` to generate samples from the Beta distribution and :class:`~numpyro.distributions.discrete.NegativeBinomialProbs` to generate samples from the Negative Binomial distribution. """ assert is_prng_key(key) key_beta, key_nb = random.split(key) probs = self._beta.sample(key_beta, sample_shape) return NegativeBinomialProbs(total_count=self.n, probs=probs).sample(key_nb)
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: r"""If :math:`X \sim \mathrm{BetaNegativeBinomial}(\alpha, \beta, n)`, then the log probability mass function is: .. math:: P(X = k) = \binom{n + k - 1}{k} \frac{B(\alpha + k, \beta + n)}{B(\alpha, \beta)} To ensure differentiability, the binomial coefficient is computed using gamma functions. """ return ( gammaln(self.n + value) - gammaln(self.n) - gammaln(value + 1) + betaln(self.concentration1 + value, self.concentration0 + self.n) - betaln(self.concentration1, self.concentration0) )
@property def mean(self) -> ArrayLike: r"""If :math:`X \sim \mathrm{BetaNegativeBinomial}(\alpha, \beta, n)` and :math:`\beta > 1`, then the mean is: .. math:: \mathbb{E}[X] = \frac{n\alpha}{\beta - 1}, otherwise, the its undefined. """ return jnp.where( self.concentration0 > 1, self.n * self.concentration1 / (self.concentration0 - 1), jnp.inf, ) @property def variance(self) -> ArrayLike: r"""If :math:`X \sim \mathrm{BetaNegativeBinomial}(\alpha, \beta, n)` and :math:`\beta > 2`, then the variance is: .. math:: \mathrm{Var}[X] = \frac{n\alpha (n + \beta - 1)(\alpha + \beta - 1)}{(\beta - 1)^2 \cdot (\beta - 2)}, otherwise, the its undefined. """ alpha = self.concentration1 beta = self.concentration0 n = self.n var = ( n * alpha * (n + beta - 1) * (alpha + beta - 1) / (jnp.square(beta - 1) * (beta - 2)) ) return jnp.where(beta > 2, var, jnp.inf)
[docs] class DirichletMultinomial(Distribution): r""" Compound distribution comprising of a dirichlet-multinomial pair. The probability of classes (``probs`` for the :class:`~numpyro.distributions.Multinomial` distribution) is unknown and randomly drawn from a :class:`~numpyro.distributions.Dirichlet` distribution prior to a certain number of Categorical trials given by ``total_count``. :param numpy.ndarray concentration: concentration parameter (alpha) for the Dirichlet distribution. :param numpy.ndarray total_count: number of Categorical trials. :param int total_count_max: the maximum number of trials, i.e. `max(total_count)` """ arg_constraints = { "concentration": constraints.independent(constraints.positive, 1), "total_count": constraints.nonnegative_integer, } pytree_data_fields = ("concentration", "_dirichlet") pytree_aux_fields = ("total_count", "total_count_max") def __init__( self, concentration: ArrayLike, total_count: ArrayLike = 1, *, total_count_max: Optional[int] = None, validate_args: Optional[bool] = None, ): if jnp.ndim(concentration) < 1: raise ValueError( "`concentration` parameter must be at least one-dimensional." ) batch_shape = lax.broadcast_shapes( jnp.shape(concentration)[:-1], jnp.shape(total_count) ) concentration_shape = batch_shape + jnp.shape(concentration)[-1:] (self.concentration,) = promote_shapes(concentration, shape=concentration_shape) (self.total_count,) = promote_shapes(total_count, shape=batch_shape) self.total_count_max = total_count_max concentration = jnp.broadcast_to(self.concentration, concentration_shape) self._dirichlet = Dirichlet(concentration) super().__init__( self._dirichlet.batch_shape, self._dirichlet.event_shape, validate_args=validate_args, )
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: assert is_prng_key(key) key_dirichlet, key_multinom = random.split(key) probs = self._dirichlet.sample(key_dirichlet, sample_shape) return MultinomialProbs( total_count=self.total_count, probs=probs, total_count_max=self.total_count_max, ).sample(key_multinom)
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: alpha = self.concentration return _log_beta_1(alpha.sum(-1), value.sum(-1)) - _log_beta_1( alpha, value ).sum(-1)
@property def mean(self) -> ArrayLike: return self._dirichlet.mean * jnp.expand_dims(self.total_count, -1) @property def variance(self) -> ArrayLike: n = jnp.expand_dims(self.total_count, -1) alpha = self.concentration alpha_sum = self.concentration.sum(-1, keepdims=True) alpha_ratio = alpha / alpha_sum return n * alpha_ratio * (1 - alpha_ratio) * (n + alpha_sum) / (1 + alpha_sum) @constraints.dependent_property(is_discrete=True, event_dim=1) def support(self) -> Constraint: return constraints.multinomial(self.total_count)
[docs] @staticmethod def infer_shapes( concentration: ArrayLike, total_count=() ) -> tuple[tuple[int, ...], tuple[int, ...]]: batch_shape = lax.broadcast_shapes(concentration[:-1], total_count) event_shape = concentration[-1:] return batch_shape, event_shape
[docs] class GammaPoisson(Distribution): r""" Compound distribution comprising of a gamma-poisson pair, also referred to as a gamma-poisson mixture. The ``rate`` parameter for the :class:`~numpyro.distributions.Poisson` distribution is unknown and randomly drawn from a :class:`~numpyro.distributions.Gamma` distribution. :param numpy.ndarray concentration: shape parameter (alpha) of the Gamma distribution. :param numpy.ndarray rate: rate parameter (rate) for the Gamma distribution. """ arg_constraints = { "concentration": constraints.positive, "rate": constraints.positive, } support = constraints.nonnegative_integer pytree_data_fields = ("concentration", "rate", "_gamma") def __init__( self, concentration: ArrayLike, rate: ArrayLike = 1.0, *, validate_args: Optional[bool] = None, ): self.concentration, self.rate = promote_shapes(concentration, rate) self._gamma = Gamma(concentration, rate) super(GammaPoisson, self).__init__( self._gamma.batch_shape, validate_args=validate_args )
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: r"""If :math:`X \sim \mathrm{GammaPoisson}(\alpha, \lambda)`, then the sampling procedure is: .. math:: \begin{align*} \theta &\sim \mathrm{Gamma}(\alpha, \lambda) \\ X \mid \theta &\sim \mathrm{Poisson}(\theta) \end{align*} It uses :class:`~numpyro.distributions.continuous.Gamma` to generate samples from the Gamma distribution and :class:`~numpyro.distributions.continuous.Poisson` to generate samples from the Poisson distribution. """ assert is_prng_key(key) key_gamma, key_poisson = random.split(key) rate = self._gamma.sample(key_gamma, sample_shape) return Poisson(rate).sample(key_poisson)
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: r"""If :math:`X \sim \mathrm{GammaPoisson}(\alpha, \lambda)`, then the probability mass function is: .. math:: p_{X}(k) = \frac{\lambda^\alpha}{(\alpha + k)(1+\lambda)^{\alpha + k}\mathrm{B}(\alpha, k + 1)} """ post_value = self.concentration + value return ( -betaln(self.concentration, value + 1) - jnp.log(post_value) + self.concentration * jnp.log(self.rate) - post_value * jnp.log1p(self.rate) )
@property def mean(self) -> ArrayLike: r"""If :math:`X \sim \mathrm{GammaPoisson}(\alpha, \lambda)`, then the mean is: .. math:: \mathbb{E}[X] = \frac{\alpha}{\lambda} """ return self.concentration / self.rate @property def variance(self) -> ArrayLike: r"""If :math:`X \sim \mathrm{GammaPoisson}(\alpha, \lambda)`, then the variance is: .. math:: \mathrm{Var}[X] = \frac{\alpha}{\lambda^2}(1 + \lambda) """ return self.concentration / jnp.square(self.rate) * (1 + self.rate)
[docs] def cdf(self, value: ArrayLike) -> ArrayLike: r"""If :math:`X \sim \mathrm{GammaPoisson}(\alpha, \lambda)`, then the cumulative distribution function is: .. math:: F_{X}(x) = \frac{1}{\mathrm{B}(\alpha, x + 1)} \int_{0}^{\frac{\lambda}{1 + \lambda}} t^{\alpha - 1} (1 - t)^{x} dt which is the regularized incomplete beta function. This implementation uses :func:`~jax.scipy.special.betainc`. """ bt = betainc(self.concentration, value + 1.0, self.rate / (self.rate + 1.0)) return bt
[docs] def NegativeBinomial( total_count: ArrayLike, probs: Optional[ArrayLike] = None, logits: Optional[ArrayLike] = None, *, validate_args: Optional[bool] = None, ) -> GammaPoisson: """Factory function for Negative Binomial distribution. :param int total_count: Number of successful trials. :param Optional[ArrayLike] probs: Probability of success for each trial, by default None :param Optional[ArrayLike] logits: Log-odds of success for each trial, by default None :param Optional[bool] validate_args: Whether to validate the parameters, by default None :return: An instance of :class:`NegativeBinomialProbs` or :class:`NegativeBinomialLogits` depending on the provided parameters. :rtype: GammaPoisson :raises ValueError: If neither :code:`probs` nor :code:`logits` is specified. """ if probs is not None: return NegativeBinomialProbs(total_count, probs, validate_args=validate_args) elif logits is not None: return NegativeBinomialLogits(total_count, logits, validate_args=validate_args) else: raise ValueError("One of `probs` or `logits` must be specified.")
[docs] class NegativeBinomialProbs(GammaPoisson): r"""Negative Binomial distribution parameterized by :code:`total_count` (:math:`r`) and :code:`probs` (:math:`p`). It is implemented as a :math:`\displaystyle\mathrm{GammaPoisson}(n, \frac{1}{p} - 1)` distribution. :param total_count: Number of successful trials (:math:`r`). :param probs: Probability of success for each trial (:math:`p`). """ arg_constraints = { "total_count": constraints.positive, "probs": constraints.unit_interval, } support = constraints.nonnegative_integer def __init__( self, total_count: ArrayLike, probs: ArrayLike, *, validate_args: Optional[bool] = None, ): self.total_count, self.probs = promote_shapes(total_count, probs) concentration = total_count rate = 1.0 / probs - 1.0 super().__init__(concentration, rate, validate_args=validate_args)
[docs] class NegativeBinomialLogits(GammaPoisson): r"""Negative Binomial distribution parameterized by :code:`total_count` (:math:`r`) and :code:`logits` (:math:`\displaystyle\mathrm{logits}(p)=\log \frac{p}{1-p}`). It is implemented as a :math:`\mathrm{GammaPoisson}(n, \exp(-\mathrm{logits}(p)))` distribution. :param total_count: Number of successful trials. :param logits: Log-odds of success for each trial (:math:`\ln \frac{p}{1-p}`). """ arg_constraints = { "total_count": constraints.positive, "logits": constraints.real, } support = constraints.nonnegative_integer def __init__( self, total_count: ArrayLike, logits: ArrayLike, *, validate_args: Optional[bool] = None, ): self.total_count, self.logits = promote_shapes(total_count, logits) concentration = total_count rate = jnp.exp(-logits) super().__init__(concentration, rate, validate_args=validate_args)
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: r"""If :math:`X \sim \mathrm{NegativeBinomial}(r, \mathrm{logits}(p))`, then the log probability mass function is: .. math:: \ln P(X = k) = -r \ln(1+\exp(\mathrm{logits}(p))) - k \ln(1+\exp(-\mathrm{logits}(p))) - \ln\Gamma(1 + k) - \ln\Gamma(\alpha) + \ln\Gamma(k + \alpha) """ return -( self.total_count * nn.softplus(self.logits) + value * nn.softplus(-self.logits) + _log_beta_1(self.total_count, value) )
[docs] class NegativeBinomial2(GammaPoisson): r"""If :math:`X \sim \mathrm{NegativeBinomial2}(\mu, \alpha)`, then :math:`X \sim \mathrm{GammaPoisson}(\alpha, \frac{\alpha}{\mu})`. :param numpy.ndarray mean: mean parameter (:math:`\mu`). :param numpy.ndarray concentration: concentration parameter (:math:`\alpha`). """ arg_constraints = { "mean": constraints.positive, "concentration": constraints.positive, } support = constraints.nonnegative_integer pytree_data_fields = ("concentration",) def __init__( self, mean: ArrayLike, concentration: ArrayLike, *, validate_args: Optional[bool] = None, ): rate = concentration / mean super().__init__(concentration, rate, validate_args=validate_args)
[docs] def ZeroInflatedNegativeBinomial2( mean: ArrayLike, concentration: ArrayLike, *, gate: Optional[ArrayLike] = None, gate_logits: Optional[ArrayLike] = None, validate_args: Optional[bool] = None, ): return ZeroInflatedDistribution( NegativeBinomial2(mean, concentration, validate_args=validate_args), gate=gate, gate_logits=gate_logits, validate_args=validate_args, )
[docs] def HurdleNegativeBinomial2( gate: ArrayLike, mean: ArrayLike, concentration: ArrayLike, *, validate_args: Optional[bool] = None, ) -> HurdleProbs: r"""A hurdle Negative Binomial distribution (NB2 / mean-dispersion parameterization): a two-part model in which structural zeros are produced by a Bernoulli "hurdle" with probability :math:`g` and positive counts follow a zero-truncated :math:`\mathrm{NegativeBinomial2}(\mu, \alpha)`. The hurdle and the magnitude (given a positive count) are conditionally independent; see :class:`HurdleProbs` for the full mechanism and assumptions. Compared to a Hurdle Poisson, NB2 accommodates count data that is over-dispersed (variance greater than the mean). The probability mass function is .. math:: P(X = 0) = g, \qquad P(X = k) = (1 - g) \, \frac{\mathrm{NB2}(k\mid\mu, \alpha)} {1 - \mathrm{NB2}(0\mid\mu, \alpha)} \;\text{for } k \geq 1, where :math:`\mathrm{NB2}(\cdot\mid\mu, \alpha)` is the PMF of a Negative Binomial distribution with mean :math:`\mu` and concentration (dispersion) :math:`\alpha`. :param ArrayLike gate: probability of a structural zero, :math:`g \in [0, 1]`. :param ArrayLike mean: mean :math:`\mu > 0` of the underlying NegativeBinomial2. :param ArrayLike concentration: concentration :math:`\alpha > 0`. **References:** 1. Mullahy, J. (1986). Specification and testing of some modified count data models. *Journal of Econometrics*, 33(3), 341-365. 2. Cragg, J. G. (1971). Some Statistical Models for Limited Dependent Variables with Application to the Demand for Durable Goods. *Econometrica*, 39(5), 829-844. """ return HurdleProbs( NegativeBinomial2(mean, concentration, validate_args=validate_args), gate, validate_args=validate_args, )