# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import jax
from jax import lax, nn, random
import jax.numpy as jnp
from jax.scipy.special import betainc, betaln, gammaln
from jax.typing import ArrayLike
from numpyro._typing import ConstraintT
from numpyro.distributions import constraints
from numpyro.distributions.continuous import Beta, Dirichlet, Gamma
from numpyro.distributions.discrete import (
BinomialProbs,
MultinomialProbs,
Poisson,
ZeroInflatedDistribution,
)
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import promote_shapes, validate_sample
from numpyro.util import is_prng_key
def _log_beta_1(alpha, value):
# XXX: support sparse `value`
return gammaln(1 + value) + gammaln(alpha) - gammaln(value + alpha)
[docs]
class BetaBinomial(Distribution):
r"""
Compound distribution comprising of a beta-binomial pair. The probability of
success (``probs`` for the :class:`~numpyro.distributions.Binomial` distribution)
is unknown and randomly drawn from a :class:`~numpyro.distributions.Beta` distribution
prior to a certain number of Bernoulli trials given by ``total_count``.
:param numpy.ndarray concentration1: 1st concentration parameter (alpha) for the
Beta distribution.
:param numpy.ndarray concentration0: 2nd concentration parameter (beta) for the
Beta distribution.
:param numpy.ndarray total_count: number of Bernoulli trials.
"""
arg_constraints = {
"concentration1": constraints.positive,
"concentration0": constraints.positive,
"total_count": constraints.nonnegative_integer,
}
has_enumerate_support = True
enumerate_support = BinomialProbs.enumerate_support
pytree_data_fields = ("concentration1", "concentration0", "total_count", "_beta")
def __init__(
self,
concentration1: ArrayLike,
concentration0: ArrayLike,
total_count: int = 1,
*,
validate_args: Optional[bool] = None,
):
self.concentration1, self.concentration0, self.total_count = promote_shapes(
concentration1, concentration0, total_count
)
batch_shape = lax.broadcast_shapes(
jnp.shape(concentration1), jnp.shape(concentration0), jnp.shape(total_count)
)
concentration1 = jnp.broadcast_to(concentration1, batch_shape)
concentration0 = jnp.broadcast_to(concentration0, batch_shape)
self._beta = Beta(concentration1, concentration0)
super(BetaBinomial, self).__init__(batch_shape, validate_args=validate_args)
[docs]
def sample(
self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = ()
) -> ArrayLike:
assert is_prng_key(key)
key_beta, key_binom = random.split(key)
probs = self._beta.sample(key_beta, sample_shape)
return BinomialProbs(total_count=self.total_count, probs=probs).sample(
key_binom
)
@validate_sample
def log_prob(self, value: ArrayLike) -> ArrayLike:
return (
-_log_beta_1(self.total_count - value + 1, value)
+ betaln(
value + self.concentration1,
self.total_count - value + self.concentration0,
)
- betaln(self.concentration0, self.concentration1)
)
@property
def mean(self) -> ArrayLike:
return self._beta.mean * self.total_count
@property
def variance(self) -> ArrayLike:
return (
self._beta.variance
* self.total_count
* (self.concentration0 + self.concentration1 + self.total_count)
)
@constraints.dependent_property(is_discrete=True, event_dim=0)
def support(self) -> ConstraintT:
return constraints.integer_interval(0, self.total_count)
[docs]
class DirichletMultinomial(Distribution):
r"""
Compound distribution comprising of a dirichlet-multinomial pair. The probability of
classes (``probs`` for the :class:`~numpyro.distributions.Multinomial` distribution)
is unknown and randomly drawn from a :class:`~numpyro.distributions.Dirichlet`
distribution prior to a certain number of Categorical trials given by
``total_count``.
:param numpy.ndarray concentration: concentration parameter (alpha) for the
Dirichlet distribution.
:param numpy.ndarray total_count: number of Categorical trials.
:param int total_count_max: the maximum number of trials, i.e. `max(total_count)`
"""
arg_constraints = {
"concentration": constraints.independent(constraints.positive, 1),
"total_count": constraints.nonnegative_integer,
}
pytree_data_fields = ("concentration", "_dirichlet")
pytree_aux_fields = ("total_count", "total_count_max")
def __init__(
self,
concentration: ArrayLike,
total_count: int = 1,
*,
total_count_max: Optional[int] = None,
validate_args: Optional[bool] = None,
):
if jnp.ndim(concentration) < 1:
raise ValueError(
"`concentration` parameter must be at least one-dimensional."
)
batch_shape = lax.broadcast_shapes(
jnp.shape(concentration)[:-1], jnp.shape(total_count)
)
concentration_shape = batch_shape + jnp.shape(concentration)[-1:]
(self.concentration,) = promote_shapes(concentration, shape=concentration_shape)
(self.total_count,) = promote_shapes(total_count, shape=batch_shape)
self.total_count_max = total_count_max
concentration = jnp.broadcast_to(self.concentration, concentration_shape)
self._dirichlet = Dirichlet(concentration)
super().__init__(
self._dirichlet.batch_shape,
self._dirichlet.event_shape,
validate_args=validate_args,
)
[docs]
def sample(
self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = ()
) -> ArrayLike:
assert is_prng_key(key)
key_dirichlet, key_multinom = random.split(key)
probs = self._dirichlet.sample(key_dirichlet, sample_shape)
return MultinomialProbs(
total_count=self.total_count,
probs=probs,
total_count_max=self.total_count_max,
).sample(key_multinom)
@validate_sample
def log_prob(self, value: ArrayLike) -> ArrayLike:
alpha = self.concentration
return _log_beta_1(alpha.sum(-1), value.sum(-1)) - _log_beta_1(
alpha, value
).sum(-1)
@property
def mean(self) -> ArrayLike:
return self._dirichlet.mean * jnp.expand_dims(self.total_count, -1)
@property
def variance(self) -> ArrayLike:
n = jnp.expand_dims(self.total_count, -1)
alpha = self.concentration
alpha_sum = self.concentration.sum(-1, keepdims=True)
alpha_ratio = alpha / alpha_sum
return n * alpha_ratio * (1 - alpha_ratio) * (n + alpha_sum) / (1 + alpha_sum)
@constraints.dependent_property(is_discrete=True, event_dim=1)
def support(self) -> ConstraintT:
return constraints.multinomial(self.total_count)
[docs]
@staticmethod
def infer_shapes(
concentration: ArrayLike, total_count=()
) -> tuple[tuple[int, ...], tuple[int, ...]]:
batch_shape = lax.broadcast_shapes(concentration[:-1], total_count)
event_shape = concentration[-1:]
return batch_shape, event_shape
[docs]
class GammaPoisson(Distribution):
r"""
Compound distribution comprising of a gamma-poisson pair, also referred to as
a gamma-poisson mixture. The ``rate`` parameter for the
:class:`~numpyro.distributions.Poisson` distribution is unknown and randomly
drawn from a :class:`~numpyro.distributions.Gamma` distribution.
:param numpy.ndarray concentration: shape parameter (alpha) of the Gamma distribution.
:param numpy.ndarray rate: rate parameter (beta) for the Gamma distribution.
"""
arg_constraints = {
"concentration": constraints.positive,
"rate": constraints.positive,
}
support = constraints.nonnegative_integer
pytree_data_fields = ("concentration", "rate", "_gamma")
def __init__(
self,
concentration: ArrayLike,
rate: ArrayLike = 1.0,
*,
validate_args: Optional[bool] = None,
):
self.concentration, self.rate = promote_shapes(concentration, rate)
self._gamma = Gamma(concentration, rate)
super(GammaPoisson, self).__init__(
self._gamma.batch_shape, validate_args=validate_args
)
[docs]
def sample(
self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = ()
) -> ArrayLike:
assert is_prng_key(key)
key_gamma, key_poisson = random.split(key)
rate = self._gamma.sample(key_gamma, sample_shape)
return Poisson(rate).sample(key_poisson)
@validate_sample
def log_prob(self, value: ArrayLike) -> ArrayLike:
post_value = self.concentration + value
return (
-betaln(self.concentration, value + 1)
- jnp.log(post_value)
+ self.concentration * jnp.log(self.rate)
- post_value * jnp.log1p(self.rate)
)
@property
def mean(self) -> ArrayLike:
return self.concentration / self.rate
@property
def variance(self) -> ArrayLike:
return self.concentration / jnp.square(self.rate) * (1 + self.rate)
[docs]
def cdf(self, value: ArrayLike) -> ArrayLike:
bt = betainc(self.concentration, value + 1.0, self.rate / (self.rate + 1.0))
return bt
[docs]
def NegativeBinomial(
total_count: int,
probs: Optional[ArrayLike] = None,
logits: Optional[ArrayLike] = None,
*,
validate_args: Optional[bool] = None,
):
if probs is not None:
return NegativeBinomialProbs(total_count, probs, validate_args=validate_args)
elif logits is not None:
return NegativeBinomialLogits(total_count, logits, validate_args=validate_args)
else:
raise ValueError("One of `probs` or `logits` must be specified.")
[docs]
class NegativeBinomialProbs(GammaPoisson):
arg_constraints = {
"total_count": constraints.positive,
"probs": constraints.unit_interval,
}
support = constraints.nonnegative_integer
def __init__(
self,
total_count: int,
probs: ArrayLike,
*,
validate_args: Optional[bool] = None,
):
self.total_count, self.probs = promote_shapes(total_count, probs)
concentration = total_count
rate = 1.0 / probs - 1.0
super().__init__(concentration, rate, validate_args=validate_args)
[docs]
class NegativeBinomialLogits(GammaPoisson):
arg_constraints = {
"total_count": constraints.positive,
"logits": constraints.real,
}
support = constraints.nonnegative_integer
def __init__(
self,
total_count: int,
logits: ArrayLike,
*,
validate_args: Optional[bool] = None,
):
self.total_count, self.logits = promote_shapes(total_count, logits)
concentration = total_count
rate = jnp.exp(-logits)
super().__init__(concentration, rate, validate_args=validate_args)
@validate_sample
def log_prob(self, value: ArrayLike) -> ArrayLike:
return -(
self.total_count * nn.softplus(self.logits)
+ value * nn.softplus(-self.logits)
+ _log_beta_1(self.total_count, value)
)
[docs]
class NegativeBinomial2(GammaPoisson):
"""
Another parameterization of GammaPoisson with `rate` is replaced by `mean`.
"""
arg_constraints = {
"mean": constraints.positive,
"concentration": constraints.positive,
}
support = constraints.nonnegative_integer
pytree_data_fields = ("concentration",)
def __init__(
self,
mean: ArrayLike,
concentration: ArrayLike,
*,
validate_args: Optional[bool] = None,
):
rate = concentration / mean
super().__init__(concentration, rate, validate_args=validate_args)
[docs]
def ZeroInflatedNegativeBinomial2(
mean: ArrayLike,
concentration: ArrayLike,
*,
gate: Optional[ArrayLike] = None,
gate_logits: Optional[ArrayLike] = None,
validate_args: Optional[bool] = None,
):
return ZeroInflatedDistribution(
NegativeBinomial2(mean, concentration, validate_args=validate_args),
gate=gate,
gate_logits=gate_logits,
validate_args=validate_args,
)