Source code for numpyro.distributions.discrete

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

# The implementation largely follows the design in PyTorch's `torch.distributions`
#
# Copyright (c) 2016-     Facebook, Inc            (Adam Paszke)
# Copyright (c) 2014-     Facebook, Inc            (Soumith Chintala)
# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
# Copyright (c) 2012-2014 Deepmind Technologies    (Koray Kavukcuoglu)
# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
# Copyright (c) 2011-2013 NYU                      (Clement Farabet)
# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
# Copyright (c) 2006      Idiap Research Institute (Samy Bengio)
# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

import warnings

import numpy as np

import jax
from jax import lax
from jax.nn import softmax, softplus
import jax.numpy as jnp
from jax.ops import index_add
import jax.random as random
from jax.scipy.special import expit, gammaincc, gammaln, logsumexp, xlog1py, xlogy

from numpyro.distributions import constraints
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import (
    binary_cross_entropy_with_logits,
    binomial,
    categorical,
    clamp_probs,
    is_prng_key,
    lazy_property,
    multinomial,
    promote_shapes,
    validate_sample,
)
from numpyro.util import not_jax_tracer


def _to_probs_bernoulli(logits):
    return expit(logits)


def _to_logits_bernoulli(probs):
    ps_clamped = clamp_probs(probs)
    return jnp.log(ps_clamped) - jnp.log1p(-ps_clamped)


def _to_probs_multinom(logits):
    return softmax(logits, axis=-1)


def _to_logits_multinom(probs):
    minval = jnp.finfo(jnp.result_type(probs)).min
    return jnp.clip(jnp.log(probs), a_min=minval)


[docs]class BernoulliProbs(Distribution): arg_constraints = {"probs": constraints.unit_interval} support = constraints.boolean has_enumerate_support = True def __init__(self, probs, validate_args=None): self.probs = probs super(BernoulliProbs, self).__init__( batch_shape=jnp.shape(self.probs), validate_args=validate_args )
[docs] def sample(self, key, sample_shape=()): assert is_prng_key(key) samples = random.bernoulli( key, self.probs, shape=sample_shape + self.batch_shape ) return samples.astype(jnp.result_type(samples, int))
@validate_sample def log_prob(self, value): return xlogy(value, self.probs) + xlog1py(1 - value, -self.probs)
[docs] @lazy_property def logits(self): return _to_logits_bernoulli(self.probs)
@property def mean(self): return self.probs @property def variance(self): return self.probs * (1 - self.probs)
[docs] def enumerate_support(self, expand=True): values = jnp.arange(2).reshape((-1,) + (1,) * len(self.batch_shape)) if expand: values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape) return values
[docs]class BernoulliLogits(Distribution): arg_constraints = {"logits": constraints.real} support = constraints.boolean has_enumerate_support = True def __init__(self, logits=None, validate_args=None): self.logits = logits super(BernoulliLogits, self).__init__( batch_shape=jnp.shape(self.logits), validate_args=validate_args )
[docs] def sample(self, key, sample_shape=()): assert is_prng_key(key) samples = random.bernoulli( key, self.probs, shape=sample_shape + self.batch_shape ) return samples.astype(jnp.result_type(samples, int))
@validate_sample def log_prob(self, value): return -binary_cross_entropy_with_logits(self.logits, value)
[docs] @lazy_property def probs(self): return _to_probs_bernoulli(self.logits)
@property def mean(self): return self.probs @property def variance(self): return self.probs * (1 - self.probs)
[docs] def enumerate_support(self, expand=True): values = jnp.arange(2).reshape((-1,) + (1,) * len(self.batch_shape)) if expand: values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape) return values
[docs]def Bernoulli(probs=None, logits=None, validate_args=None): if probs is not None: return BernoulliProbs(probs, validate_args=validate_args) elif logits is not None: return BernoulliLogits(logits, validate_args=validate_args) else: raise ValueError("One of `probs` or `logits` must be specified.")
[docs]class BinomialProbs(Distribution): arg_constraints = { "probs": constraints.unit_interval, "total_count": constraints.nonnegative_integer, } has_enumerate_support = True def __init__(self, probs, total_count=1, validate_args=None): self.probs, self.total_count = promote_shapes(probs, total_count) batch_shape = lax.broadcast_shapes(jnp.shape(probs), jnp.shape(total_count)) super(BinomialProbs, self).__init__( batch_shape=batch_shape, validate_args=validate_args )
[docs] def sample(self, key, sample_shape=()): assert is_prng_key(key) return binomial( key, self.probs, n=self.total_count, shape=sample_shape + self.batch_shape )
@validate_sample def log_prob(self, value): log_factorial_n = gammaln(self.total_count + 1) log_factorial_k = gammaln(value + 1) log_factorial_nmk = gammaln(self.total_count - value + 1) return ( log_factorial_n - log_factorial_k - log_factorial_nmk + xlogy(value, self.probs) + xlog1py(self.total_count - value, -self.probs) )
[docs] @lazy_property def logits(self): return _to_logits_bernoulli(self.probs)
@property def mean(self): return jnp.broadcast_to(self.total_count * self.probs, self.batch_shape) @property def variance(self): return jnp.broadcast_to( self.total_count * self.probs * (1 - self.probs), self.batch_shape ) @constraints.dependent_property(is_discrete=True, event_dim=0) def support(self): return constraints.integer_interval(0, self.total_count)
[docs] def enumerate_support(self, expand=True): if not_jax_tracer(self.total_count): total_count = np.amax(self.total_count) # NB: the error can't be raised if inhomogeneous issue happens when tracing if np.amin(self.total_count) != total_count: raise NotImplementedError( "Inhomogeneous total count not supported" " by `enumerate_support`." ) else: total_count = jnp.amax(self.total_count) values = jnp.arange(total_count + 1).reshape( (-1,) + (1,) * len(self.batch_shape) ) if expand: values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape) return values
[docs]class BinomialLogits(Distribution): arg_constraints = { "logits": constraints.real, "total_count": constraints.nonnegative_integer, } has_enumerate_support = True enumerate_support = BinomialProbs.enumerate_support def __init__(self, logits, total_count=1, validate_args=None): self.logits, self.total_count = promote_shapes(logits, total_count) batch_shape = lax.broadcast_shapes(jnp.shape(logits), jnp.shape(total_count)) super(BinomialLogits, self).__init__( batch_shape=batch_shape, validate_args=validate_args )
[docs] def sample(self, key, sample_shape=()): assert is_prng_key(key) return binomial( key, self.probs, n=self.total_count, shape=sample_shape + self.batch_shape )
@validate_sample def log_prob(self, value): log_factorial_n = gammaln(self.total_count + 1) log_factorial_k = gammaln(value + 1) log_factorial_nmk = gammaln(self.total_count - value + 1) normalize_term = ( self.total_count * jnp.clip(self.logits, 0) + xlog1py(self.total_count, jnp.exp(-jnp.abs(self.logits))) - log_factorial_n ) return ( value * self.logits - log_factorial_k - log_factorial_nmk - normalize_term )
[docs] @lazy_property def probs(self): return _to_probs_bernoulli(self.logits)
@property def mean(self): return jnp.broadcast_to(self.total_count * self.probs, self.batch_shape) @property def variance(self): return jnp.broadcast_to( self.total_count * self.probs * (1 - self.probs), self.batch_shape ) @constraints.dependent_property(is_discrete=True, event_dim=0) def support(self): return constraints.integer_interval(0, self.total_count)
[docs]def Binomial(total_count=1, probs=None, logits=None, validate_args=None): if probs is not None: return BinomialProbs(probs, total_count, validate_args=validate_args) elif logits is not None: return BinomialLogits(logits, total_count, validate_args=validate_args) else: raise ValueError("One of `probs` or `logits` must be specified.")
[docs]class CategoricalProbs(Distribution): arg_constraints = {"probs": constraints.simplex} has_enumerate_support = True def __init__(self, probs, validate_args=None): if jnp.ndim(probs) < 1: raise ValueError("`probs` parameter must be at least one-dimensional.") self.probs = probs super(CategoricalProbs, self).__init__( batch_shape=jnp.shape(self.probs)[:-1], validate_args=validate_args )
[docs] def sample(self, key, sample_shape=()): assert is_prng_key(key) return categorical(key, self.probs, shape=sample_shape + self.batch_shape)
@validate_sample def log_prob(self, value): batch_shape = lax.broadcast_shapes(jnp.shape(value), self.batch_shape) value = jnp.expand_dims(value, axis=-1) value = jnp.broadcast_to(value, batch_shape + (1,)) logits = self.logits log_pmf = jnp.broadcast_to(logits, batch_shape + jnp.shape(logits)[-1:]) return jnp.take_along_axis(log_pmf, value, axis=-1)[..., 0]
[docs] @lazy_property def logits(self): return _to_logits_multinom(self.probs)
@property def mean(self): return jnp.full(self.batch_shape, jnp.nan, dtype=jnp.result_type(self.probs)) @property def variance(self): return jnp.full(self.batch_shape, jnp.nan, dtype=jnp.result_type(self.probs)) @constraints.dependent_property(is_discrete=True, event_dim=0) def support(self): return constraints.integer_interval(0, jnp.shape(self.probs)[-1] - 1)
[docs] def enumerate_support(self, expand=True): values = jnp.arange(self.probs.shape[-1]).reshape( (-1,) + (1,) * len(self.batch_shape) ) if expand: values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape) return values
[docs]class CategoricalLogits(Distribution): arg_constraints = {"logits": constraints.real_vector} has_enumerate_support = True def __init__(self, logits, validate_args=None): if jnp.ndim(logits) < 1: raise ValueError("`logits` parameter must be at least one-dimensional.") self.logits = logits super(CategoricalLogits, self).__init__( batch_shape=jnp.shape(logits)[:-1], validate_args=validate_args )
[docs] def sample(self, key, sample_shape=()): assert is_prng_key(key) return random.categorical( key, self.logits, shape=sample_shape + self.batch_shape )
@validate_sample def log_prob(self, value): batch_shape = lax.broadcast_shapes(jnp.shape(value), self.batch_shape) value = jnp.expand_dims(value, -1) value = jnp.broadcast_to(value, batch_shape + (1,)) log_pmf = self.logits - logsumexp(self.logits, axis=-1, keepdims=True) log_pmf = jnp.broadcast_to(log_pmf, batch_shape + jnp.shape(log_pmf)[-1:]) return jnp.take_along_axis(log_pmf, value, -1)[..., 0]
[docs] @lazy_property def probs(self): return _to_probs_multinom(self.logits)
@property def mean(self): return jnp.full(self.batch_shape, jnp.nan, dtype=jnp.result_type(self.logits)) @property def variance(self): return jnp.full(self.batch_shape, jnp.nan, dtype=jnp.result_type(self.logits)) @constraints.dependent_property(is_discrete=True, event_dim=0) def support(self): return constraints.integer_interval(0, jnp.shape(self.logits)[-1] - 1)
[docs] def enumerate_support(self, expand=True): values = jnp.arange(self.logits.shape[-1]).reshape( (-1,) + (1,) * len(self.batch_shape) ) if expand: values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape) return values
[docs]def Categorical(probs=None, logits=None, validate_args=None): if probs is not None: return CategoricalProbs(probs, validate_args=validate_args) elif logits is not None: return CategoricalLogits(logits, validate_args=validate_args) else: raise ValueError("One of `probs` or `logits` must be specified.")
[docs]class OrderedLogistic(CategoricalProbs): """ A categorical distribution with ordered outcomes. **References:** 1. *Stan Functions Reference, v2.20 section 12.6*, Stan Development Team :param numpy.ndarray predictor: prediction in real domain; typically this is output of a linear model. :param numpy.ndarray cutpoints: positions in real domain to separate categories. """ arg_constraints = { "predictor": constraints.real, "cutpoints": constraints.ordered_vector, } def __init__(self, predictor, cutpoints, validate_args=None): if jnp.ndim(predictor) == 0: (predictor,) = promote_shapes(predictor, shape=(1,)) else: predictor = predictor[..., None] predictor, self.cutpoints = promote_shapes(predictor, cutpoints) self.predictor = predictor[..., 0] cumulative_probs = expit(cutpoints - predictor) # add two boundary points 0 and 1 pad_width = [(0, 0)] * (jnp.ndim(cumulative_probs) - 1) + [(1, 1)] cumulative_probs = jnp.pad(cumulative_probs, pad_width, constant_values=(0, 1)) probs = cumulative_probs[..., 1:] - cumulative_probs[..., :-1] super(OrderedLogistic, self).__init__(probs, validate_args=validate_args)
[docs] @staticmethod def infer_shapes(predictor, cutpoints): batch_shape = lax.broadcast_shapes(predictor, cutpoints[:-1]) event_shape = () return batch_shape, event_shape
[docs]class PRNGIdentity(Distribution): """ Distribution over :func:`~jax.random.PRNGKey`. This can be used to draw a batch of :func:`~jax.random.PRNGKey` using the :class:`~numpyro.handlers.seed` handler. Only `sample` method is supported. """ def __init__(self): warnings.warn( "PRNGIdentity distribution is deprecated. To get a random " "PRNG key, you can use `numpyro.prng_key()` instead.", FutureWarning, ) super(PRNGIdentity, self).__init__(event_shape=(2,))
[docs] def sample(self, key, sample_shape=()): return jnp.reshape( random.split(key, np.prod(sample_shape).astype(np.int32)), sample_shape + self.event_shape, )
[docs]class MultinomialProbs(Distribution): arg_constraints = { "probs": constraints.simplex, "total_count": constraints.nonnegative_integer, } def __init__(self, probs, total_count=1, validate_args=None): if jnp.ndim(probs) < 1: raise ValueError("`probs` parameter must be at least one-dimensional.") batch_shape, event_shape = self.infer_shapes( jnp.shape(probs), jnp.shape(total_count) ) self.probs = promote_shapes(probs, shape=batch_shape + jnp.shape(probs)[-1:])[0] self.total_count = promote_shapes(total_count, shape=batch_shape)[0] super(MultinomialProbs, self).__init__( batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args, )
[docs] def sample(self, key, sample_shape=()): assert is_prng_key(key) return multinomial( key, self.probs, self.total_count, shape=sample_shape + self.batch_shape )
@validate_sample def log_prob(self, value): if self._validate_args: self._validate_sample(value) return gammaln(self.total_count + 1) + jnp.sum( xlogy(value, self.probs) - gammaln(value + 1), axis=-1 )
[docs] @lazy_property def logits(self): return _to_logits_multinom(self.probs)
@property def mean(self): return self.probs * jnp.expand_dims(self.total_count, -1) @property def variance(self): return jnp.expand_dims(self.total_count, -1) * self.probs * (1 - self.probs) @constraints.dependent_property(is_discrete=True, event_dim=1) def support(self): return constraints.multinomial(self.total_count)
[docs] @staticmethod def infer_shapes(probs, total_count): batch_shape = lax.broadcast_shapes(probs[:-1], total_count) event_shape = probs[-1:] return batch_shape, event_shape
[docs]class MultinomialLogits(Distribution): arg_constraints = { "logits": constraints.real_vector, "total_count": constraints.nonnegative_integer, } def __init__(self, logits, total_count=1, validate_args=None): if jnp.ndim(logits) < 1: raise ValueError("`logits` parameter must be at least one-dimensional.") batch_shape, event_shape = self.infer_shapes( jnp.shape(logits), jnp.shape(total_count) ) self.logits = promote_shapes( logits, shape=batch_shape + jnp.shape(logits)[-1:] )[0] self.total_count = promote_shapes(total_count, shape=batch_shape)[0] super(MultinomialLogits, self).__init__( batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args, )
[docs] def sample(self, key, sample_shape=()): assert is_prng_key(key) return multinomial( key, self.probs, self.total_count, shape=sample_shape + self.batch_shape )
@validate_sample def log_prob(self, value): if self._validate_args: self._validate_sample(value) normalize_term = self.total_count * logsumexp(self.logits, axis=-1) - gammaln( self.total_count + 1 ) return ( jnp.sum(value * self.logits - gammaln(value + 1), axis=-1) - normalize_term )
[docs] @lazy_property def probs(self): return _to_probs_multinom(self.logits)
@property def mean(self): return jnp.expand_dims(self.total_count, -1) * self.probs @property def variance(self): return jnp.expand_dims(self.total_count, -1) * self.probs * (1 - self.probs) @constraints.dependent_property(is_discrete=True, event_dim=1) def support(self): return constraints.multinomial(self.total_count)
[docs] @staticmethod def infer_shapes(logits, total_count): batch_shape = lax.broadcast_shapes(logits[:-1], total_count) event_shape = logits[-1:] return batch_shape, event_shape
[docs]def Multinomial(total_count=1, probs=None, logits=None, validate_args=None): if probs is not None: return MultinomialProbs(probs, total_count, validate_args=validate_args) elif logits is not None: return MultinomialLogits(logits, total_count, validate_args=validate_args) else: raise ValueError("One of `probs` or `logits` must be specified.")
[docs]class Poisson(Distribution): arg_constraints = {"rate": constraints.positive} support = constraints.nonnegative_integer def __init__(self, rate, *, is_sparse=False, validate_args=None): self.rate = rate self.is_sparse = is_sparse super(Poisson, self).__init__(jnp.shape(rate), validate_args=validate_args)
[docs] def sample(self, key, sample_shape=()): assert is_prng_key(key) return random.poisson(key, self.rate, shape=sample_shape + self.batch_shape)
@validate_sample def log_prob(self, value): if self._validate_args: self._validate_sample(value) value = jax.device_get(value) if ( self.is_sparse and not isinstance(value, jax.core.Tracer) and jnp.size(value) > 1 ): shape = lax.broadcast_shapes(self.batch_shape, jnp.shape(value)) rate = jnp.broadcast_to(self.rate, shape).reshape(-1) value = jnp.broadcast_to(value, shape).reshape(-1) nonzero = value > 0 sparse_value = value[nonzero] sparse_rate = rate[nonzero] return index_add( -rate, nonzero, jnp.log(sparse_rate) * sparse_value - gammaln(sparse_value + 1), ).reshape(shape) return (jnp.log(self.rate) * value) - gammaln(value + 1) - self.rate @property def mean(self): return self.rate @property def variance(self): return self.rate
[docs] def cdf(self, value): k = jnp.floor(value) + 1 return gammaincc(k, self.rate)
class ZeroInflatedProbs(Distribution): arg_constraints = {"gate": constraints.unit_interval} def __init__(self, base_dist, gate, *, validate_args=None): batch_shape = lax.broadcast_shapes(jnp.shape(gate), base_dist.batch_shape) (self.gate,) = promote_shapes(gate, shape=batch_shape) assert base_dist.is_discrete if base_dist.event_shape: raise ValueError( "ZeroInflatedProbs expected empty base_dist.event_shape but got {}".format( base_dist.event_shape ) ) # XXX: we might need to promote parameters of base_dist but let's keep # this simplified for now self.base_dist = base_dist.expand(batch_shape) super(ZeroInflatedProbs, self).__init__( batch_shape, validate_args=validate_args ) def sample(self, key, sample_shape=()): assert is_prng_key(key) key_bern, key_base = random.split(key) shape = sample_shape + self.batch_shape mask = random.bernoulli(key_bern, self.gate, shape) samples = self.base_dist(rng_key=key_base, sample_shape=sample_shape) return jnp.where(mask, 0, samples) @validate_sample def log_prob(self, value): log_prob = jnp.log1p(-self.gate) + self.base_dist.log_prob(value) return jnp.where(value == 0, jnp.log(self.gate + jnp.exp(log_prob)), log_prob) @constraints.dependent_property(is_discrete=True, event_dim=0) def support(self): return self.base_dist.support @lazy_property def mean(self): return (1 - self.gate) * self.base_dist.mean @lazy_property def variance(self): return (1 - self.gate) * ( self.base_dist.mean ** 2 + self.base_dist.variance ) - self.mean ** 2 class ZeroInflatedLogits(ZeroInflatedProbs): arg_constraints = {"gate_logits": constraints.real} def __init__(self, base_dist, gate_logits, *, validate_args=None): gate = _to_probs_bernoulli(gate_logits) batch_shape = lax.broadcast_shapes(jnp.shape(gate), base_dist.batch_shape) (self.gate_logits,) = promote_shapes(gate_logits, shape=batch_shape) super().__init__(base_dist, gate, validate_args=validate_args) @validate_sample def log_prob(self, value): log_prob_minus_log_gate = -self.gate_logits + self.base_dist.log_prob(value) log_gate = -softplus(-self.gate_logits) log_prob = log_prob_minus_log_gate + log_gate zero_log_prob = softplus(log_prob_minus_log_gate) + log_gate return jnp.where(value == 0, zero_log_prob, log_prob)
[docs]def ZeroInflatedDistribution( base_dist, *, gate=None, gate_logits=None, validate_args=None ): """ Generic Zero Inflated distribution. :param Distribution base_dist: the base distribution. :param numpy.ndarray gate: probability of extra zeros given via a Bernoulli distribution. :param numpy.ndarray gate_logits: logits of extra zeros given via a Bernoulli distribution. """ if (gate is None) == (gate_logits is None): raise ValueError( "Either `gate` or `gate_logits` must be specified, but not both." ) if gate is not None: return ZeroInflatedProbs(base_dist, gate, validate_args=validate_args) else: return ZeroInflatedLogits(base_dist, gate_logits, validate_args=validate_args)
[docs]class ZeroInflatedPoisson(ZeroInflatedProbs): """ A Zero Inflated Poisson distribution. :param numpy.ndarray gate: probability of extra zeros. :param numpy.ndarray rate: rate of Poisson distribution. """ arg_constraints = {"gate": constraints.unit_interval, "rate": constraints.positive} support = constraints.nonnegative_integer # TODO: resolve inconsistent parameter order w.r.t. Pyro # and support `gate_logits` argument def __init__(self, gate, rate=1.0, validate_args=None): _, self.rate = promote_shapes(gate, rate) super().__init__(Poisson(self.rate), gate, validate_args=validate_args)
[docs]class GeometricProbs(Distribution): arg_constraints = {"probs": constraints.unit_interval} support = constraints.nonnegative_integer def __init__(self, probs, validate_args=None): self.probs = probs super(GeometricProbs, self).__init__( batch_shape=jnp.shape(self.probs), validate_args=validate_args )
[docs] def sample(self, key, sample_shape=()): assert is_prng_key(key) probs = self.probs dtype = jnp.result_type(probs) shape = sample_shape + self.batch_shape u = random.uniform(key, shape, dtype) return jnp.floor(jnp.log1p(-u) / jnp.log1p(-probs))
@validate_sample def log_prob(self, value): probs = jnp.where((self.probs == 1) & (value == 0), 0, self.probs) return value * jnp.log1p(-probs) + jnp.log(probs)
[docs] @lazy_property def logits(self): return _to_logits_bernoulli(self.probs)
@property def mean(self): return 1.0 / self.probs - 1.0 @property def variance(self): return (1.0 / self.probs - 1.0) / self.probs
[docs]class GeometricLogits(Distribution): arg_constraints = {"logits": constraints.real} support = constraints.nonnegative_integer def __init__(self, logits, validate_args=None): self.logits = logits super(GeometricLogits, self).__init__( batch_shape=jnp.shape(self.logits), validate_args=validate_args )
[docs] @lazy_property def probs(self): return _to_probs_bernoulli(self.logits)
[docs] def sample(self, key, sample_shape=()): assert is_prng_key(key) logits = self.logits dtype = jnp.result_type(logits) shape = sample_shape + self.batch_shape u = random.uniform(key, shape, dtype) return jnp.floor(jnp.log1p(-u) / -softplus(logits))
@validate_sample def log_prob(self, value): return (-value - 1) * softplus(self.logits) + self.logits @property def mean(self): return 1.0 / self.probs - 1.0 @property def variance(self): return (1.0 / self.probs - 1.0) / self.probs
[docs]def Geometric(probs=None, logits=None, validate_args=None): if probs is not None: return GeometricProbs(probs, validate_args=validate_args) elif logits is not None: return GeometricLogits(logits, validate_args=validate_args) else: raise ValueError("One of `probs` or `logits` must be specified.")