# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
# The implementation largely follows the design in PyTorch's `torch.distributions`
#
# Copyright (c) 2016- Facebook, Inc (Adam Paszke)
# Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
# Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
# Copyright (c) 2011-2013 NYU (Clement Farabet)
# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
# Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
import numpy as np
import jax
from jax import lax
from jax.nn import softmax, softplus
import jax.numpy as jnp
import jax.random as random
from jax.scipy.special import expit, gammaincc, gammaln, logsumexp, xlog1py, xlogy
from numpyro.distributions import constraints, transforms
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import (
binary_cross_entropy_with_logits,
binomial,
categorical,
clamp_probs,
is_prng_key,
lazy_property,
multinomial,
promote_shapes,
validate_sample,
)
from numpyro.util import not_jax_tracer
def _to_probs_bernoulli(logits):
return expit(logits)
def _to_logits_bernoulli(probs):
ps_clamped = clamp_probs(probs)
return jnp.log(ps_clamped) - jnp.log1p(-ps_clamped)
def _to_probs_multinom(logits):
return softmax(logits, axis=-1)
def _to_logits_multinom(probs):
minval = jnp.finfo(jnp.result_type(probs)).min
return jnp.clip(jnp.log(probs), a_min=minval)
[docs]class BernoulliProbs(Distribution):
arg_constraints = {"probs": constraints.unit_interval}
support = constraints.boolean
has_enumerate_support = True
def __init__(self, probs, *, validate_args=None):
self.probs = probs
super(BernoulliProbs, self).__init__(
batch_shape=jnp.shape(self.probs), validate_args=validate_args
)
[docs] def sample(self, key, sample_shape=()):
assert is_prng_key(key)
samples = random.bernoulli(
key, self.probs, shape=sample_shape + self.batch_shape
)
return samples.astype(jnp.result_type(samples, int))
@validate_sample
def log_prob(self, value):
ps_clamped = clamp_probs(self.probs)
return xlogy(value, ps_clamped) + xlog1py(1 - value, -ps_clamped)
[docs] @lazy_property
def logits(self):
return _to_logits_bernoulli(self.probs)
@property
def mean(self):
return self.probs
@property
def variance(self):
return self.probs * (1 - self.probs)
[docs] def enumerate_support(self, expand=True):
values = jnp.arange(2).reshape((-1,) + (1,) * len(self.batch_shape))
if expand:
values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
return values
[docs]class BernoulliLogits(Distribution):
arg_constraints = {"logits": constraints.real}
support = constraints.boolean
has_enumerate_support = True
def __init__(self, logits=None, *, validate_args=None):
self.logits = logits
super(BernoulliLogits, self).__init__(
batch_shape=jnp.shape(self.logits), validate_args=validate_args
)
[docs] def sample(self, key, sample_shape=()):
assert is_prng_key(key)
samples = random.bernoulli(
key, self.probs, shape=sample_shape + self.batch_shape
)
return samples.astype(jnp.result_type(samples, int))
@validate_sample
def log_prob(self, value):
return -binary_cross_entropy_with_logits(self.logits, value)
[docs] @lazy_property
def probs(self):
return _to_probs_bernoulli(self.logits)
@property
def mean(self):
return self.probs
@property
def variance(self):
return self.probs * (1 - self.probs)
[docs] def enumerate_support(self, expand=True):
values = jnp.arange(2).reshape((-1,) + (1,) * len(self.batch_shape))
if expand:
values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
return values
[docs]def Bernoulli(probs=None, logits=None, *, validate_args=None):
if probs is not None:
return BernoulliProbs(probs, validate_args=validate_args)
elif logits is not None:
return BernoulliLogits(logits, validate_args=validate_args)
else:
raise ValueError("One of `probs` or `logits` must be specified.")
[docs]class BinomialProbs(Distribution):
arg_constraints = {
"probs": constraints.unit_interval,
"total_count": constraints.nonnegative_integer,
}
has_enumerate_support = True
def __init__(self, probs, total_count=1, *, validate_args=None):
self.probs, self.total_count = promote_shapes(probs, total_count)
batch_shape = lax.broadcast_shapes(jnp.shape(probs), jnp.shape(total_count))
super(BinomialProbs, self).__init__(
batch_shape=batch_shape, validate_args=validate_args
)
[docs] def sample(self, key, sample_shape=()):
assert is_prng_key(key)
return binomial(
key, self.probs, n=self.total_count, shape=sample_shape + self.batch_shape
)
@validate_sample
def log_prob(self, value):
log_factorial_n = gammaln(self.total_count + 1)
log_factorial_k = gammaln(value + 1)
log_factorial_nmk = gammaln(self.total_count - value + 1)
probs = clamp_probs(self.probs)
return (
log_factorial_n
- log_factorial_k
- log_factorial_nmk
+ xlogy(value, probs)
+ xlog1py(self.total_count - value, -probs)
)
[docs] @lazy_property
def logits(self):
return _to_logits_bernoulli(self.probs)
@property
def mean(self):
return jnp.broadcast_to(self.total_count * self.probs, self.batch_shape)
@property
def variance(self):
return jnp.broadcast_to(
self.total_count * self.probs * (1 - self.probs), self.batch_shape
)
@constraints.dependent_property(is_discrete=True, event_dim=0)
def support(self):
return constraints.integer_interval(0, self.total_count)
[docs] def enumerate_support(self, expand=True):
if not_jax_tracer(self.total_count):
total_count = np.amax(self.total_count)
# NB: the error can't be raised if inhomogeneous issue happens when tracing
if np.amin(self.total_count) != total_count:
raise NotImplementedError(
"Inhomogeneous total count not supported" " by `enumerate_support`."
)
else:
total_count = jnp.amax(self.total_count)
values = jnp.arange(total_count + 1).reshape(
(-1,) + (1,) * len(self.batch_shape)
)
if expand:
values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
return values
[docs]class BinomialLogits(Distribution):
arg_constraints = {
"logits": constraints.real,
"total_count": constraints.nonnegative_integer,
}
has_enumerate_support = True
enumerate_support = BinomialProbs.enumerate_support
def __init__(self, logits, total_count=1, *, validate_args=None):
self.logits, self.total_count = promote_shapes(logits, total_count)
batch_shape = lax.broadcast_shapes(jnp.shape(logits), jnp.shape(total_count))
super(BinomialLogits, self).__init__(
batch_shape=batch_shape, validate_args=validate_args
)
[docs] def sample(self, key, sample_shape=()):
assert is_prng_key(key)
return binomial(
key, self.probs, n=self.total_count, shape=sample_shape + self.batch_shape
)
@validate_sample
def log_prob(self, value):
log_factorial_n = gammaln(self.total_count + 1)
log_factorial_k = gammaln(value + 1)
log_factorial_nmk = gammaln(self.total_count - value + 1)
normalize_term = (
self.total_count * jnp.clip(self.logits, 0)
+ xlog1py(self.total_count, jnp.exp(-jnp.abs(self.logits)))
- log_factorial_n
)
return (
value * self.logits - log_factorial_k - log_factorial_nmk - normalize_term
)
[docs] @lazy_property
def probs(self):
return _to_probs_bernoulli(self.logits)
@property
def mean(self):
return jnp.broadcast_to(self.total_count * self.probs, self.batch_shape)
@property
def variance(self):
return jnp.broadcast_to(
self.total_count * self.probs * (1 - self.probs), self.batch_shape
)
@constraints.dependent_property(is_discrete=True, event_dim=0)
def support(self):
return constraints.integer_interval(0, self.total_count)
[docs]def Binomial(total_count=1, probs=None, logits=None, *, validate_args=None):
if probs is not None:
return BinomialProbs(probs, total_count, validate_args=validate_args)
elif logits is not None:
return BinomialLogits(logits, total_count, validate_args=validate_args)
else:
raise ValueError("One of `probs` or `logits` must be specified.")
[docs]class CategoricalProbs(Distribution):
arg_constraints = {"probs": constraints.simplex}
has_enumerate_support = True
def __init__(self, probs, *, validate_args=None):
if jnp.ndim(probs) < 1:
raise ValueError("`probs` parameter must be at least one-dimensional.")
self.probs = probs
super(CategoricalProbs, self).__init__(
batch_shape=jnp.shape(self.probs)[:-1], validate_args=validate_args
)
[docs] def sample(self, key, sample_shape=()):
assert is_prng_key(key)
return categorical(key, self.probs, shape=sample_shape + self.batch_shape)
@validate_sample
def log_prob(self, value):
batch_shape = lax.broadcast_shapes(jnp.shape(value), self.batch_shape)
value = jnp.expand_dims(value, axis=-1)
value = jnp.broadcast_to(value, batch_shape + (1,))
logits = self.logits
log_pmf = jnp.broadcast_to(logits, batch_shape + jnp.shape(logits)[-1:])
return jnp.take_along_axis(log_pmf, value, axis=-1)[..., 0]
[docs] @lazy_property
def logits(self):
return _to_logits_multinom(self.probs)
@property
def mean(self):
return jnp.full(self.batch_shape, jnp.nan, dtype=jnp.result_type(self.probs))
@property
def variance(self):
return jnp.full(self.batch_shape, jnp.nan, dtype=jnp.result_type(self.probs))
@constraints.dependent_property(is_discrete=True, event_dim=0)
def support(self):
return constraints.integer_interval(0, jnp.shape(self.probs)[-1] - 1)
[docs] def enumerate_support(self, expand=True):
values = jnp.arange(self.probs.shape[-1]).reshape(
(-1,) + (1,) * len(self.batch_shape)
)
if expand:
values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
return values
[docs]class CategoricalLogits(Distribution):
arg_constraints = {"logits": constraints.real_vector}
has_enumerate_support = True
def __init__(self, logits, *, validate_args=None):
if jnp.ndim(logits) < 1:
raise ValueError("`logits` parameter must be at least one-dimensional.")
self.logits = logits
super(CategoricalLogits, self).__init__(
batch_shape=jnp.shape(logits)[:-1], validate_args=validate_args
)
[docs] def sample(self, key, sample_shape=()):
assert is_prng_key(key)
return random.categorical(
key, self.logits, shape=sample_shape + self.batch_shape
)
@validate_sample
def log_prob(self, value):
batch_shape = lax.broadcast_shapes(jnp.shape(value), self.batch_shape)
value = jnp.expand_dims(value, -1)
value = jnp.broadcast_to(value, batch_shape + (1,))
log_pmf = self.logits - logsumexp(self.logits, axis=-1, keepdims=True)
log_pmf = jnp.broadcast_to(log_pmf, batch_shape + jnp.shape(log_pmf)[-1:])
return jnp.take_along_axis(log_pmf, value, -1)[..., 0]
[docs] @lazy_property
def probs(self):
return _to_probs_multinom(self.logits)
@property
def mean(self):
return jnp.full(self.batch_shape, jnp.nan, dtype=jnp.result_type(self.logits))
@property
def variance(self):
return jnp.full(self.batch_shape, jnp.nan, dtype=jnp.result_type(self.logits))
@constraints.dependent_property(is_discrete=True, event_dim=0)
def support(self):
return constraints.integer_interval(0, jnp.shape(self.logits)[-1] - 1)
[docs] def enumerate_support(self, expand=True):
values = jnp.arange(self.logits.shape[-1]).reshape(
(-1,) + (1,) * len(self.batch_shape)
)
if expand:
values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
return values
[docs]def Categorical(probs=None, logits=None, *, validate_args=None):
if probs is not None:
return CategoricalProbs(probs, validate_args=validate_args)
elif logits is not None:
return CategoricalLogits(logits, validate_args=validate_args)
else:
raise ValueError("One of `probs` or `logits` must be specified.")
[docs]class OrderedLogistic(CategoricalProbs):
"""
A categorical distribution with ordered outcomes.
**References:**
1. *Stan Functions Reference, v2.20 section 12.6*,
Stan Development Team
:param numpy.ndarray predictor: prediction in real domain; typically this is output
of a linear model.
:param numpy.ndarray cutpoints: positions in real domain to separate categories.
"""
arg_constraints = {
"predictor": constraints.real,
"cutpoints": constraints.ordered_vector,
}
def __init__(self, predictor, cutpoints, *, validate_args=None):
if jnp.ndim(predictor) == 0:
(predictor,) = promote_shapes(predictor, shape=(1,))
else:
predictor = predictor[..., None]
predictor, self.cutpoints = promote_shapes(predictor, cutpoints)
self.predictor = predictor[..., 0]
probs = transforms.SimplexToOrderedTransform(self.predictor).inv(self.cutpoints)
super(OrderedLogistic, self).__init__(probs, validate_args=validate_args)
[docs] @staticmethod
def infer_shapes(predictor, cutpoints):
batch_shape = lax.broadcast_shapes(predictor, cutpoints[:-1])
event_shape = ()
return batch_shape, event_shape
[docs]class MultinomialProbs(Distribution):
arg_constraints = {
"probs": constraints.simplex,
"total_count": constraints.nonnegative_integer,
}
def __init__(self, probs, total_count=1, *, validate_args=None):
if jnp.ndim(probs) < 1:
raise ValueError("`probs` parameter must be at least one-dimensional.")
batch_shape, event_shape = self.infer_shapes(
jnp.shape(probs), jnp.shape(total_count)
)
self.probs = promote_shapes(probs, shape=batch_shape + jnp.shape(probs)[-1:])[0]
self.total_count = promote_shapes(total_count, shape=batch_shape)[0]
super(MultinomialProbs, self).__init__(
batch_shape=batch_shape,
event_shape=event_shape,
validate_args=validate_args,
)
[docs] def sample(self, key, sample_shape=()):
assert is_prng_key(key)
return multinomial(
key, self.probs, self.total_count, shape=sample_shape + self.batch_shape
)
@validate_sample
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
return gammaln(self.total_count + 1) + jnp.sum(
xlogy(value, self.probs) - gammaln(value + 1), axis=-1
)
[docs] @lazy_property
def logits(self):
return _to_logits_multinom(self.probs)
@property
def mean(self):
return self.probs * jnp.expand_dims(self.total_count, -1)
@property
def variance(self):
return jnp.expand_dims(self.total_count, -1) * self.probs * (1 - self.probs)
@constraints.dependent_property(is_discrete=True, event_dim=1)
def support(self):
return constraints.multinomial(self.total_count)
[docs] @staticmethod
def infer_shapes(probs, total_count):
batch_shape = lax.broadcast_shapes(probs[:-1], total_count)
event_shape = probs[-1:]
return batch_shape, event_shape
[docs]class MultinomialLogits(Distribution):
arg_constraints = {
"logits": constraints.real_vector,
"total_count": constraints.nonnegative_integer,
}
def __init__(self, logits, total_count=1, *, validate_args=None):
if jnp.ndim(logits) < 1:
raise ValueError("`logits` parameter must be at least one-dimensional.")
batch_shape, event_shape = self.infer_shapes(
jnp.shape(logits), jnp.shape(total_count)
)
self.logits = promote_shapes(
logits, shape=batch_shape + jnp.shape(logits)[-1:]
)[0]
self.total_count = promote_shapes(total_count, shape=batch_shape)[0]
super(MultinomialLogits, self).__init__(
batch_shape=batch_shape,
event_shape=event_shape,
validate_args=validate_args,
)
[docs] def sample(self, key, sample_shape=()):
assert is_prng_key(key)
return multinomial(
key, self.probs, self.total_count, shape=sample_shape + self.batch_shape
)
@validate_sample
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
normalize_term = self.total_count * logsumexp(self.logits, axis=-1) - gammaln(
self.total_count + 1
)
return (
jnp.sum(value * self.logits - gammaln(value + 1), axis=-1) - normalize_term
)
[docs] @lazy_property
def probs(self):
return _to_probs_multinom(self.logits)
@property
def mean(self):
return jnp.expand_dims(self.total_count, -1) * self.probs
@property
def variance(self):
return jnp.expand_dims(self.total_count, -1) * self.probs * (1 - self.probs)
@constraints.dependent_property(is_discrete=True, event_dim=1)
def support(self):
return constraints.multinomial(self.total_count)
[docs] @staticmethod
def infer_shapes(logits, total_count):
batch_shape = lax.broadcast_shapes(logits[:-1], total_count)
event_shape = logits[-1:]
return batch_shape, event_shape
[docs]def Multinomial(total_count=1, probs=None, logits=None, *, validate_args=None):
if probs is not None:
return MultinomialProbs(probs, total_count, validate_args=validate_args)
elif logits is not None:
return MultinomialLogits(logits, total_count, validate_args=validate_args)
else:
raise ValueError("One of `probs` or `logits` must be specified.")
[docs]class Poisson(Distribution):
r"""
Creates a Poisson distribution parameterized by rate, the rate parameter.
Samples are nonnegative integers, with a pmf given by
.. math::
\mathrm{rate}^k \frac{e^{-\mathrm{rate}}}{k!}
:param numpy.ndarray rate: The rate parameter
:param bool is_sparse: Whether to assume value is mostly zero when computing
:meth:`log_prob`, which can speed up computation when data is sparse.
"""
arg_constraints = {"rate": constraints.positive}
support = constraints.nonnegative_integer
def __init__(self, rate, *, is_sparse=False, validate_args=None):
self.rate = rate
self.is_sparse = is_sparse
super(Poisson, self).__init__(jnp.shape(rate), validate_args=validate_args)
[docs] def sample(self, key, sample_shape=()):
assert is_prng_key(key)
return random.poisson(key, self.rate, shape=sample_shape + self.batch_shape)
@validate_sample
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
if (
self.is_sparse
and not isinstance(value, jax.core.Tracer)
and jnp.size(value) > 1
):
shape = lax.broadcast_shapes(self.batch_shape, jnp.shape(value))
rate = jnp.broadcast_to(self.rate, shape).reshape(-1)
nonzero = np.broadcast_to(jax.device_get(value) > 0, shape).reshape(-1)
value = jnp.broadcast_to(value, shape).reshape(-1)
sparse_value = value[nonzero]
sparse_rate = rate[nonzero]
return (
jnp.asarray(-rate, dtype=jnp.result_type(float))
.at[nonzero]
.add(
jnp.log(sparse_rate) * sparse_value - gammaln(sparse_value + 1),
)
.reshape(shape)
)
return (jnp.log(self.rate) * value) - gammaln(value + 1) - self.rate
@property
def mean(self):
return self.rate
@property
def variance(self):
return self.rate
[docs] def cdf(self, value):
k = jnp.floor(value) + 1
return gammaincc(k, self.rate)
class ZeroInflatedProbs(Distribution):
arg_constraints = {"gate": constraints.unit_interval}
def __init__(self, base_dist, gate, *, validate_args=None):
batch_shape = lax.broadcast_shapes(jnp.shape(gate), base_dist.batch_shape)
(self.gate,) = promote_shapes(gate, shape=batch_shape)
assert base_dist.support.is_discrete
if base_dist.event_shape:
raise ValueError(
"ZeroInflatedProbs expected empty base_dist.event_shape but got {}".format(
base_dist.event_shape
)
)
# XXX: we might need to promote parameters of base_dist but let's keep
# this simplified for now
self.base_dist = base_dist.expand(batch_shape)
super(ZeroInflatedProbs, self).__init__(
batch_shape, validate_args=validate_args
)
def sample(self, key, sample_shape=()):
assert is_prng_key(key)
key_bern, key_base = random.split(key)
shape = sample_shape + self.batch_shape
mask = random.bernoulli(key_bern, self.gate, shape)
samples = self.base_dist(rng_key=key_base, sample_shape=sample_shape)
return jnp.where(mask, 0, samples)
@validate_sample
def log_prob(self, value):
log_prob = jnp.log1p(-self.gate) + self.base_dist.log_prob(value)
return jnp.where(value == 0, jnp.log(self.gate + jnp.exp(log_prob)), log_prob)
@constraints.dependent_property(is_discrete=True, event_dim=0)
def support(self):
return self.base_dist.support
@lazy_property
def mean(self):
return (1 - self.gate) * self.base_dist.mean
@lazy_property
def variance(self):
return (1 - self.gate) * (
self.base_dist.mean**2 + self.base_dist.variance
) - self.mean**2
@property
def has_enumerate_support(self):
return self.base_dist.has_enumerate_support
def enumerate_support(self, expand=True):
return self.base_dist.enumerate_support(expand=expand)
class ZeroInflatedLogits(ZeroInflatedProbs):
arg_constraints = {"gate_logits": constraints.real}
def __init__(self, base_dist, gate_logits, *, validate_args=None):
gate = _to_probs_bernoulli(gate_logits)
batch_shape = lax.broadcast_shapes(jnp.shape(gate), base_dist.batch_shape)
(self.gate_logits,) = promote_shapes(gate_logits, shape=batch_shape)
super().__init__(base_dist, gate, validate_args=validate_args)
@validate_sample
def log_prob(self, value):
log_prob_minus_log_gate = -self.gate_logits + self.base_dist.log_prob(value)
log_gate = -softplus(-self.gate_logits)
log_prob = log_prob_minus_log_gate + log_gate
zero_log_prob = softplus(log_prob_minus_log_gate) + log_gate
return jnp.where(value == 0, zero_log_prob, log_prob)
[docs]def ZeroInflatedDistribution(
base_dist, *, gate=None, gate_logits=None, validate_args=None
):
"""
Generic Zero Inflated distribution.
:param Distribution base_dist: the base distribution.
:param numpy.ndarray gate: probability of extra zeros given via a Bernoulli distribution.
:param numpy.ndarray gate_logits: logits of extra zeros given via a Bernoulli distribution.
"""
if (gate is None) == (gate_logits is None):
raise ValueError(
"Either `gate` or `gate_logits` must be specified, but not both."
)
if gate is not None:
return ZeroInflatedProbs(base_dist, gate, validate_args=validate_args)
else:
return ZeroInflatedLogits(base_dist, gate_logits, validate_args=validate_args)
[docs]class ZeroInflatedPoisson(ZeroInflatedProbs):
"""
A Zero Inflated Poisson distribution.
:param numpy.ndarray gate: probability of extra zeros.
:param numpy.ndarray rate: rate of Poisson distribution.
"""
arg_constraints = {"gate": constraints.unit_interval, "rate": constraints.positive}
support = constraints.nonnegative_integer
# TODO: resolve inconsistent parameter order w.r.t. Pyro
# and support `gate_logits` argument
def __init__(self, gate, rate=1.0, *, validate_args=None):
_, self.rate = promote_shapes(gate, rate)
super().__init__(Poisson(self.rate), gate, validate_args=validate_args)
[docs]class GeometricProbs(Distribution):
arg_constraints = {"probs": constraints.unit_interval}
support = constraints.nonnegative_integer
def __init__(self, probs, *, validate_args=None):
self.probs = probs
super(GeometricProbs, self).__init__(
batch_shape=jnp.shape(self.probs), validate_args=validate_args
)
[docs] def sample(self, key, sample_shape=()):
assert is_prng_key(key)
probs = self.probs
dtype = jnp.result_type(probs)
shape = sample_shape + self.batch_shape
u = random.uniform(key, shape, dtype)
return jnp.floor(jnp.log1p(-u) / jnp.log1p(-probs))
@validate_sample
def log_prob(self, value):
probs = jnp.where((self.probs == 1) & (value == 0), 0, self.probs)
return value * jnp.log1p(-probs) + jnp.log(probs)
[docs] @lazy_property
def logits(self):
return _to_logits_bernoulli(self.probs)
@property
def mean(self):
return 1.0 / self.probs - 1.0
@property
def variance(self):
return (1.0 / self.probs - 1.0) / self.probs
[docs]class GeometricLogits(Distribution):
arg_constraints = {"logits": constraints.real}
support = constraints.nonnegative_integer
def __init__(self, logits, *, validate_args=None):
self.logits = logits
super(GeometricLogits, self).__init__(
batch_shape=jnp.shape(self.logits), validate_args=validate_args
)
[docs] @lazy_property
def probs(self):
return _to_probs_bernoulli(self.logits)
[docs] def sample(self, key, sample_shape=()):
assert is_prng_key(key)
logits = self.logits
dtype = jnp.result_type(logits)
shape = sample_shape + self.batch_shape
u = random.uniform(key, shape, dtype)
return jnp.floor(jnp.log1p(-u) / -softplus(logits))
@validate_sample
def log_prob(self, value):
return (-value - 1) * softplus(self.logits) + self.logits
@property
def mean(self):
return 1.0 / self.probs - 1.0
@property
def variance(self):
return (1.0 / self.probs - 1.0) / self.probs
[docs]def Geometric(probs=None, logits=None, *, validate_args=None):
if probs is not None:
return GeometricProbs(probs, validate_args=validate_args)
elif logits is not None:
return GeometricLogits(logits, validate_args=validate_args)
else:
raise ValueError("One of `probs` or `logits` must be specified.")