# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import jax
from jax import lax, nn, random
import jax.numpy as jnp
from jax.scipy.special import betainc, betaln, gammaln
from jax.typing import ArrayLike
from numpyro.distributions import constraints
from numpyro.distributions.constraints import Constraint
from numpyro.distributions.continuous import Beta, Dirichlet, Gamma
from numpyro.distributions.discrete import (
BinomialProbs,
HurdleProbs,
MultinomialProbs,
Poisson,
ZeroInflatedDistribution,
)
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import promote_shapes, validate_sample
from numpyro.util import is_prng_key
def _log_beta_1(alpha, value):
# Note: support sparse `value`
return gammaln(1 + value) + gammaln(alpha) - gammaln(value + alpha)
[docs]
class BetaBinomial(Distribution):
r"""
Compound distribution comprising of a beta-binomial pair. The probability of
success (``probs`` for the :class:`~numpyro.distributions.Binomial` distribution)
is unknown and randomly drawn from a :class:`~numpyro.distributions.Beta` distribution
prior to a certain number of Bernoulli trials given by ``total_count``.
:param numpy.ndarray concentration1: 1st concentration parameter (alpha) for the
Beta distribution.
:param numpy.ndarray concentration0: 2nd concentration parameter (beta) for the
Beta distribution.
:param numpy.ndarray total_count: number of Bernoulli trials.
"""
arg_constraints = {
"concentration1": constraints.positive,
"concentration0": constraints.positive,
"total_count": constraints.nonnegative_integer,
}
has_enumerate_support = True
enumerate_support = BinomialProbs.enumerate_support
pytree_data_fields = ("concentration1", "concentration0", "total_count", "_beta")
def __init__(
self,
concentration1: ArrayLike,
concentration0: ArrayLike,
total_count: ArrayLike = 1,
*,
validate_args: Optional[bool] = None,
):
self.concentration1, self.concentration0, self.total_count = promote_shapes(
concentration1, concentration0, total_count
)
batch_shape = lax.broadcast_shapes(
jnp.shape(concentration1), jnp.shape(concentration0), jnp.shape(total_count)
)
concentration1 = jnp.broadcast_to(concentration1, batch_shape)
concentration0 = jnp.broadcast_to(concentration0, batch_shape)
self._beta = Beta(concentration1, concentration0)
super(BetaBinomial, self).__init__(batch_shape, validate_args=validate_args)
[docs]
def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike:
assert is_prng_key(key)
key_beta, key_binom = random.split(key)
probs = self._beta.sample(key_beta, sample_shape)
return BinomialProbs(total_count=self.total_count, probs=probs).sample(
key_binom
)
[docs]
@validate_sample
def log_prob(self, value: ArrayLike) -> ArrayLike:
return (
-_log_beta_1(self.total_count - value + 1, value)
+ betaln(
value + self.concentration1,
self.total_count - value + self.concentration0,
)
- betaln(self.concentration0, self.concentration1)
)
@property
def mean(self) -> ArrayLike:
return self._beta.mean * self.total_count
@property
def variance(self) -> ArrayLike:
return (
self._beta.variance
* self.total_count
* (self.concentration0 + self.concentration1 + self.total_count)
)
@constraints.dependent_property(is_discrete=True, event_dim=0)
def support(self) -> Constraint:
return constraints.integer_interval(0, self.total_count)
[docs]
class BetaNegativeBinomial(Distribution):
r"""
Compound distribution comprising of a beta-negative-binomial pair. The ``probs``
parameter for the :class:`~numpyro.distributions.NegativeBinomialProbs` distribution
is unknown and randomly drawn from a :class:`~numpyro.distributions.Beta` distribution
prior to the negative binomial counting process.
The Beta Negative Binomial is a heavy-tailed discrete distribution useful for modeling
overdispersed count data. It arises as the marginal distribution when integrating out
the success probability in a negative binomial model with a beta prior.
:param numpy.ndarray concentration1: 1st concentration parameter (alpha) for the
Beta distribution.
:param numpy.ndarray concentration0: 2nd concentration parameter (beta) for the
Beta distribution.
:param numpy.ndarray n: positive number of successes parameter for the negative
binomial distribution.
**References**
[1] https://en.wikipedia.org/wiki/Beta_negative_binomial_distribution
[2] https://mc-stan.org/docs/functions-reference/unbounded_discrete_distributions.html#beta-neg-binomial
"""
arg_constraints = {
"concentration1": constraints.positive,
"concentration0": constraints.positive,
"n": constraints.positive,
}
support = constraints.nonnegative_integer
pytree_data_fields = ("concentration1", "concentration0", "n", "_beta")
def __init__(
self,
concentration1: ArrayLike,
concentration0: ArrayLike,
n: ArrayLike,
*,
validate_args: Optional[bool] = None,
):
self.concentration1, self.concentration0, self.n = promote_shapes(
concentration1, concentration0, n
)
batch_shape = lax.broadcast_shapes(
jnp.shape(concentration1), jnp.shape(concentration0), jnp.shape(n)
)
concentration1 = jnp.broadcast_to(concentration1, batch_shape)
concentration0 = jnp.broadcast_to(concentration0, batch_shape)
self._beta = Beta(concentration1, concentration0)
super(BetaNegativeBinomial, self).__init__(
batch_shape, validate_args=validate_args
)
[docs]
def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike:
r"""If :math:`X \sim \mathrm{BetaNegativeBinomial}(\alpha, \beta, n)`, then the sampling
procedure is:
.. math::
\begin{align*}
p &\sim \mathrm{Beta}(\alpha, \beta) \\
X \mid p &\sim \mathrm{NegativeBinomial}(n, p)
\end{align*}
It uses :class:`~numpyro.distributions.continuous.Beta` to generate samples
from the Beta distribution and
:class:`~numpyro.distributions.discrete.NegativeBinomialProbs` to generate samples
from the Negative Binomial distribution.
"""
assert is_prng_key(key)
key_beta, key_nb = random.split(key)
probs = self._beta.sample(key_beta, sample_shape)
return NegativeBinomialProbs(total_count=self.n, probs=probs).sample(key_nb)
[docs]
@validate_sample
def log_prob(self, value: ArrayLike) -> ArrayLike:
r"""If :math:`X \sim \mathrm{BetaNegativeBinomial}(\alpha, \beta, n)`, then the log
probability mass function is:
.. math::
P(X = k) = \binom{n + k - 1}{k} \frac{B(\alpha + k, \beta + n)}{B(\alpha, \beta)}
To ensure differentiability, the binomial coefficient is computed using
gamma functions.
"""
return (
gammaln(self.n + value)
- gammaln(self.n)
- gammaln(value + 1)
+ betaln(self.concentration1 + value, self.concentration0 + self.n)
- betaln(self.concentration1, self.concentration0)
)
@property
def mean(self) -> ArrayLike:
r"""If :math:`X \sim \mathrm{BetaNegativeBinomial}(\alpha, \beta, n)` and
:math:`\beta > 1`, then the mean is:
.. math::
\mathbb{E}[X] = \frac{n\alpha}{\beta - 1},
otherwise, the its undefined.
"""
return jnp.where(
self.concentration0 > 1,
self.n * self.concentration1 / (self.concentration0 - 1),
jnp.inf,
)
@property
def variance(self) -> ArrayLike:
r"""If :math:`X \sim \mathrm{BetaNegativeBinomial}(\alpha, \beta, n)` and
:math:`\beta > 2`, then the variance is:
.. math::
\mathrm{Var}[X] =
\frac{n\alpha (n + \beta - 1)(\alpha + \beta - 1)}{(\beta - 1)^2 \cdot (\beta - 2)},
otherwise, the its undefined.
"""
alpha = self.concentration1
beta = self.concentration0
n = self.n
var = (
n
* alpha
* (n + beta - 1)
* (alpha + beta - 1)
/ (jnp.square(beta - 1) * (beta - 2))
)
return jnp.where(beta > 2, var, jnp.inf)
[docs]
class DirichletMultinomial(Distribution):
r"""
Compound distribution comprising of a dirichlet-multinomial pair. The probability of
classes (``probs`` for the :class:`~numpyro.distributions.Multinomial` distribution)
is unknown and randomly drawn from a :class:`~numpyro.distributions.Dirichlet`
distribution prior to a certain number of Categorical trials given by
``total_count``.
:param numpy.ndarray concentration: concentration parameter (alpha) for the
Dirichlet distribution.
:param numpy.ndarray total_count: number of Categorical trials.
:param int total_count_max: the maximum number of trials, i.e. `max(total_count)`
"""
arg_constraints = {
"concentration": constraints.independent(constraints.positive, 1),
"total_count": constraints.nonnegative_integer,
}
pytree_data_fields = ("concentration", "_dirichlet")
pytree_aux_fields = ("total_count", "total_count_max")
def __init__(
self,
concentration: ArrayLike,
total_count: ArrayLike = 1,
*,
total_count_max: Optional[int] = None,
validate_args: Optional[bool] = None,
):
if jnp.ndim(concentration) < 1:
raise ValueError(
"`concentration` parameter must be at least one-dimensional."
)
batch_shape = lax.broadcast_shapes(
jnp.shape(concentration)[:-1], jnp.shape(total_count)
)
concentration_shape = batch_shape + jnp.shape(concentration)[-1:]
(self.concentration,) = promote_shapes(concentration, shape=concentration_shape)
(self.total_count,) = promote_shapes(total_count, shape=batch_shape)
self.total_count_max = total_count_max
concentration = jnp.broadcast_to(self.concentration, concentration_shape)
self._dirichlet = Dirichlet(concentration)
super().__init__(
self._dirichlet.batch_shape,
self._dirichlet.event_shape,
validate_args=validate_args,
)
[docs]
def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike:
assert is_prng_key(key)
key_dirichlet, key_multinom = random.split(key)
probs = self._dirichlet.sample(key_dirichlet, sample_shape)
return MultinomialProbs(
total_count=self.total_count,
probs=probs,
total_count_max=self.total_count_max,
).sample(key_multinom)
[docs]
@validate_sample
def log_prob(self, value: ArrayLike) -> ArrayLike:
alpha = self.concentration
return _log_beta_1(alpha.sum(-1), value.sum(-1)) - _log_beta_1(
alpha, value
).sum(-1)
@property
def mean(self) -> ArrayLike:
return self._dirichlet.mean * jnp.expand_dims(self.total_count, -1)
@property
def variance(self) -> ArrayLike:
n = jnp.expand_dims(self.total_count, -1)
alpha = self.concentration
alpha_sum = self.concentration.sum(-1, keepdims=True)
alpha_ratio = alpha / alpha_sum
return n * alpha_ratio * (1 - alpha_ratio) * (n + alpha_sum) / (1 + alpha_sum)
@constraints.dependent_property(is_discrete=True, event_dim=1)
def support(self) -> Constraint:
return constraints.multinomial(self.total_count)
[docs]
@staticmethod
def infer_shapes(
concentration: ArrayLike, total_count=()
) -> tuple[tuple[int, ...], tuple[int, ...]]:
batch_shape = lax.broadcast_shapes(concentration[:-1], total_count)
event_shape = concentration[-1:]
return batch_shape, event_shape
[docs]
class GammaPoisson(Distribution):
r"""
Compound distribution comprising of a gamma-poisson pair, also referred to as
a gamma-poisson mixture. The ``rate`` parameter for the
:class:`~numpyro.distributions.Poisson` distribution is unknown and randomly
drawn from a :class:`~numpyro.distributions.Gamma` distribution.
:param numpy.ndarray concentration: shape parameter (alpha) of the Gamma distribution.
:param numpy.ndarray rate: rate parameter (rate) for the Gamma distribution.
"""
arg_constraints = {
"concentration": constraints.positive,
"rate": constraints.positive,
}
support = constraints.nonnegative_integer
pytree_data_fields = ("concentration", "rate", "_gamma")
def __init__(
self,
concentration: ArrayLike,
rate: ArrayLike = 1.0,
*,
validate_args: Optional[bool] = None,
):
self.concentration, self.rate = promote_shapes(concentration, rate)
self._gamma = Gamma(concentration, rate)
super(GammaPoisson, self).__init__(
self._gamma.batch_shape, validate_args=validate_args
)
[docs]
def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike:
r"""If :math:`X \sim \mathrm{GammaPoisson}(\alpha, \lambda)`, then the sampling
procedure is:
.. math::
\begin{align*}
\theta &\sim \mathrm{Gamma}(\alpha, \lambda) \\
X \mid \theta &\sim \mathrm{Poisson}(\theta)
\end{align*}
It uses :class:`~numpyro.distributions.continuous.Gamma` to generate samples
from the Gamma distribution and
:class:`~numpyro.distributions.continuous.Poisson` to generate samples from the
Poisson distribution.
"""
assert is_prng_key(key)
key_gamma, key_poisson = random.split(key)
rate = self._gamma.sample(key_gamma, sample_shape)
return Poisson(rate).sample(key_poisson)
[docs]
@validate_sample
def log_prob(self, value: ArrayLike) -> ArrayLike:
r"""If :math:`X \sim \mathrm{GammaPoisson}(\alpha, \lambda)`, then the
probability mass function is:
.. math::
p_{X}(k) = \frac{\lambda^\alpha}{(\alpha + k)(1+\lambda)^{\alpha + k}\mathrm{B}(\alpha, k + 1)}
"""
post_value = self.concentration + value
return (
-betaln(self.concentration, value + 1)
- jnp.log(post_value)
+ self.concentration * jnp.log(self.rate)
- post_value * jnp.log1p(self.rate)
)
@property
def mean(self) -> ArrayLike:
r"""If :math:`X \sim \mathrm{GammaPoisson}(\alpha, \lambda)`, then the mean is:
.. math::
\mathbb{E}[X] = \frac{\alpha}{\lambda}
"""
return self.concentration / self.rate
@property
def variance(self) -> ArrayLike:
r"""If :math:`X \sim \mathrm{GammaPoisson}(\alpha, \lambda)`, then the variance is:
.. math::
\mathrm{Var}[X] = \frac{\alpha}{\lambda^2}(1 + \lambda)
"""
return self.concentration / jnp.square(self.rate) * (1 + self.rate)
[docs]
def cdf(self, value: ArrayLike) -> ArrayLike:
r"""If :math:`X \sim \mathrm{GammaPoisson}(\alpha, \lambda)`, then the cumulative
distribution function is:
.. math::
F_{X}(x) = \frac{1}{\mathrm{B}(\alpha, x + 1)}
\int_{0}^{\frac{\lambda}{1 + \lambda}} t^{\alpha - 1} (1 - t)^{x} dt
which is the regularized incomplete beta function.
This implementation uses :func:`~jax.scipy.special.betainc`.
"""
bt = betainc(self.concentration, value + 1.0, self.rate / (self.rate + 1.0))
return bt
[docs]
def NegativeBinomial(
total_count: ArrayLike,
probs: Optional[ArrayLike] = None,
logits: Optional[ArrayLike] = None,
*,
validate_args: Optional[bool] = None,
) -> GammaPoisson:
"""Factory function for Negative Binomial distribution.
:param int total_count: Number of successful trials.
:param Optional[ArrayLike] probs: Probability of success for each trial, by default None
:param Optional[ArrayLike] logits: Log-odds of success for each trial, by default None
:param Optional[bool] validate_args: Whether to validate the parameters, by default None
:return: An instance of :class:`NegativeBinomialProbs` or
:class:`NegativeBinomialLogits` depending on the provided parameters.
:rtype: GammaPoisson
:raises ValueError: If neither :code:`probs` nor :code:`logits` is specified.
"""
if probs is not None:
return NegativeBinomialProbs(total_count, probs, validate_args=validate_args)
elif logits is not None:
return NegativeBinomialLogits(total_count, logits, validate_args=validate_args)
else:
raise ValueError("One of `probs` or `logits` must be specified.")
[docs]
class NegativeBinomialProbs(GammaPoisson):
r"""Negative Binomial distribution parameterized by :code:`total_count` (:math:`r`)
and :code:`probs` (:math:`p`). It is implemented as a
:math:`\displaystyle\mathrm{GammaPoisson}(n, \frac{1}{p} - 1)` distribution.
:param total_count: Number of successful trials (:math:`r`).
:param probs: Probability of success for each trial (:math:`p`).
"""
arg_constraints = {
"total_count": constraints.positive,
"probs": constraints.unit_interval,
}
support = constraints.nonnegative_integer
def __init__(
self,
total_count: ArrayLike,
probs: ArrayLike,
*,
validate_args: Optional[bool] = None,
):
self.total_count, self.probs = promote_shapes(total_count, probs)
concentration = total_count
rate = 1.0 / probs - 1.0
super().__init__(concentration, rate, validate_args=validate_args)
[docs]
class NegativeBinomialLogits(GammaPoisson):
r"""Negative Binomial distribution parameterized by :code:`total_count` (:math:`r`)
and :code:`logits` (:math:`\displaystyle\mathrm{logits}(p)=\log \frac{p}{1-p}`). It
is implemented as a :math:`\mathrm{GammaPoisson}(n, \exp(-\mathrm{logits}(p)))`
distribution.
:param total_count: Number of successful trials.
:param logits: Log-odds of success for each trial (:math:`\ln \frac{p}{1-p}`).
"""
arg_constraints = {
"total_count": constraints.positive,
"logits": constraints.real,
}
support = constraints.nonnegative_integer
def __init__(
self,
total_count: ArrayLike,
logits: ArrayLike,
*,
validate_args: Optional[bool] = None,
):
self.total_count, self.logits = promote_shapes(total_count, logits)
concentration = total_count
rate = jnp.exp(-logits)
super().__init__(concentration, rate, validate_args=validate_args)
[docs]
@validate_sample
def log_prob(self, value: ArrayLike) -> ArrayLike:
r"""If :math:`X \sim \mathrm{NegativeBinomial}(r, \mathrm{logits}(p))`, then the log
probability mass function is:
.. math::
\ln P(X = k) = -r \ln(1+\exp(\mathrm{logits}(p)))
- k \ln(1+\exp(-\mathrm{logits}(p)))
- \ln\Gamma(1 + k) - \ln\Gamma(\alpha) + \ln\Gamma(k + \alpha)
"""
return -(
self.total_count * nn.softplus(self.logits)
+ value * nn.softplus(-self.logits)
+ _log_beta_1(self.total_count, value)
)
[docs]
class NegativeBinomial2(GammaPoisson):
r"""If :math:`X \sim \mathrm{NegativeBinomial2}(\mu, \alpha)`, then
:math:`X \sim \mathrm{GammaPoisson}(\alpha, \frac{\alpha}{\mu})`.
:param numpy.ndarray mean: mean parameter (:math:`\mu`).
:param numpy.ndarray concentration: concentration parameter (:math:`\alpha`).
"""
arg_constraints = {
"mean": constraints.positive,
"concentration": constraints.positive,
}
support = constraints.nonnegative_integer
pytree_data_fields = ("concentration",)
def __init__(
self,
mean: ArrayLike,
concentration: ArrayLike,
*,
validate_args: Optional[bool] = None,
):
rate = concentration / mean
super().__init__(concentration, rate, validate_args=validate_args)
[docs]
def ZeroInflatedNegativeBinomial2(
mean: ArrayLike,
concentration: ArrayLike,
*,
gate: Optional[ArrayLike] = None,
gate_logits: Optional[ArrayLike] = None,
validate_args: Optional[bool] = None,
):
return ZeroInflatedDistribution(
NegativeBinomial2(mean, concentration, validate_args=validate_args),
gate=gate,
gate_logits=gate_logits,
validate_args=validate_args,
)
[docs]
def HurdleNegativeBinomial2(
gate: ArrayLike,
mean: ArrayLike,
concentration: ArrayLike,
*,
validate_args: Optional[bool] = None,
) -> HurdleProbs:
r"""A hurdle Negative Binomial distribution (NB2 / mean-dispersion
parameterization): a two-part model in which structural zeros are produced
by a Bernoulli "hurdle" with probability :math:`g` and positive counts
follow a zero-truncated :math:`\mathrm{NegativeBinomial2}(\mu, \alpha)`.
The hurdle and the magnitude (given a positive count) are conditionally
independent; see :class:`HurdleProbs` for the full mechanism and
assumptions. Compared to a Hurdle Poisson, NB2 accommodates count data
that is over-dispersed (variance greater than the mean).
The probability mass function is
.. math::
P(X = 0) = g, \qquad
P(X = k) = (1 - g) \, \frac{\mathrm{NB2}(k\mid\mu, \alpha)}
{1 - \mathrm{NB2}(0\mid\mu, \alpha)} \;\text{for } k \geq 1,
where :math:`\mathrm{NB2}(\cdot\mid\mu, \alpha)` is the PMF of a Negative Binomial
distribution with mean :math:`\mu` and concentration (dispersion) :math:`\alpha`.
:param ArrayLike gate: probability of a structural zero, :math:`g \in [0, 1]`.
:param ArrayLike mean: mean :math:`\mu > 0` of the underlying NegativeBinomial2.
:param ArrayLike concentration: concentration :math:`\alpha > 0`.
**References:**
1. Mullahy, J. (1986). Specification and testing of some modified count
data models. *Journal of Econometrics*, 33(3), 341-365.
2. Cragg, J. G. (1971). Some Statistical Models for Limited Dependent
Variables with Application to the Demand for Durable Goods.
*Econometrica*, 39(5), 829-844.
"""
return HurdleProbs(
NegativeBinomial2(mean, concentration, validate_args=validate_args),
gate,
validate_args=validate_args,
)