# The implementation largely follows the design in PyTorch's `torch.distributions`
#
# Copyright (c) 2016- Facebook, Inc (Adam Paszke)
# Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
# Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
# Copyright (c) 2011-2013 NYU (Clement Farabet)
# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
# Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
from jax import lax
from jax.lib import xla_bridge
from jax.nn import softmax
import jax.numpy as np
import jax.random as random
from jax.scipy.special import expit, gammaln, logsumexp
from numpyro.distributions import constraints
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import (
binary_cross_entropy_with_logits,
binomial,
categorical,
clamp_probs,
get_dtype,
lazy_property,
multinomial,
poisson,
promote_shapes,
sum_rightmost,
validate_sample,
xlog1py,
xlogy
)
from numpyro.util import copy_docs_from
def _to_probs_bernoulli(logits):
return 1 / (1 + np.exp(-logits))
def _to_logits_bernoulli(probs):
ps_clamped = clamp_probs(probs)
return np.log(ps_clamped) - np.log1p(-ps_clamped)
def _to_probs_multinom(logits):
return softmax(logits, axis=-1)
def _to_logits_multinom(probs):
minval = np.finfo(get_dtype(probs)).min
return np.clip(np.log(probs), a_min=minval)
[docs]@copy_docs_from(Distribution)
class BernoulliProbs(Distribution):
arg_constraints = {'probs': constraints.unit_interval}
support = constraints.boolean
def __init__(self, probs, validate_args=None):
self.probs = probs
super(BernoulliProbs, self).__init__(batch_shape=np.shape(self.probs), validate_args=validate_args)
[docs] def sample(self, key, sample_shape=()):
return random.bernoulli(key, self.probs, shape=sample_shape + self.batch_shape)
@validate_sample
def log_prob(self, value):
return xlogy(value, self.probs) + xlog1py(1 - value, -self.probs)
@property
def mean(self):
return self.probs
@property
def variance(self):
return self.probs * (1 - self.probs)
[docs]@copy_docs_from(Distribution)
class BernoulliLogits(Distribution):
arg_constraints = {'logits': constraints.real}
support = constraints.boolean
def __init__(self, logits=None, validate_args=None):
self.logits = logits
super(BernoulliLogits, self).__init__(batch_shape=np.shape(self.logits), validate_args=validate_args)
[docs] def sample(self, key, sample_shape=()):
return random.bernoulli(key, self.probs, shape=sample_shape + self.batch_shape)
@validate_sample
def log_prob(self, value):
return -binary_cross_entropy_with_logits(self.logits, value)
[docs] @lazy_property
def probs(self):
return _to_probs_bernoulli(self.logits)
@property
def mean(self):
return self.probs
@property
def variance(self):
return self.probs * (1 - self.probs)
[docs]def Bernoulli(probs=None, logits=None, validate_args=None):
if probs is not None:
return BernoulliProbs(probs, validate_args=validate_args)
elif logits is not None:
return BernoulliLogits(logits, validate_args=validate_args)
else:
raise ValueError('One of `probs` or `logits` must be specified.')
[docs]@copy_docs_from(Distribution)
class BinomialProbs(Distribution):
arg_constraints = {'total_count': constraints.nonnegative_integer,
'probs': constraints.unit_interval}
def __init__(self, probs, total_count=1, validate_args=None):
self.probs, self.total_count = promote_shapes(probs, total_count)
batch_shape = lax.broadcast_shapes(np.shape(probs), np.shape(total_count))
super(BinomialProbs, self).__init__(batch_shape=batch_shape, validate_args=validate_args)
[docs] def sample(self, key, sample_shape=()):
return binomial(key, self.probs, n=self.total_count, shape=sample_shape + self.batch_shape)
@validate_sample
def log_prob(self, value):
log_factorial_n = gammaln(self.total_count + 1)
log_factorial_k = gammaln(value + 1)
log_factorial_nmk = gammaln(self.total_count - value + 1)
return (log_factorial_n - log_factorial_k - log_factorial_nmk +
xlogy(value, self.probs) + xlog1py(self.total_count - value, -self.probs))
@property
def mean(self):
return np.broadcast_to(self.total_count * self.probs, self.batch_shape)
@property
def variance(self):
return np.broadcast_to(self.total_count * self.probs * (1 - self.probs), self.batch_shape)
@property
def support(self):
return constraints.integer_interval(0, self.total_count)
[docs]@copy_docs_from(Distribution)
class BinomialLogits(Distribution):
arg_constraints = {'total_count': constraints.nonnegative_integer,
'logits': constraints.real}
def __init__(self, logits, total_count=1, validate_args=None):
self.logits, self.total_count = promote_shapes(logits, total_count)
batch_shape = lax.broadcast_shapes(np.shape(logits), np.shape(total_count))
super(BinomialLogits, self).__init__(batch_shape=batch_shape, validate_args=validate_args)
[docs] def sample(self, key, sample_shape=()):
return binomial(key, self.probs, n=self.total_count, shape=sample_shape + self.batch_shape)
@validate_sample
def log_prob(self, value):
log_factorial_n = gammaln(self.total_count + 1)
log_factorial_k = gammaln(value + 1)
log_factorial_nmk = gammaln(self.total_count - value + 1)
normalize_term = (self.total_count * np.clip(self.logits, 0) +
xlog1py(self.total_count, np.exp(-np.abs(self.logits))) -
log_factorial_n)
return value * self.logits - log_factorial_k - log_factorial_nmk - normalize_term
[docs] @lazy_property
def probs(self):
return _to_probs_bernoulli(self.logits)
@property
def mean(self):
return np.broadcast_to(self.total_count * self.probs, self.batch_shape)
@property
def variance(self):
return np.broadcast_to(self.total_count * self.probs * (1 - self.probs), self.batch_shape)
@property
def support(self):
return constraints.integer_interval(0, self.total_count)
[docs]def Binomial(total_count=1, probs=None, logits=None, validate_args=None):
if probs is not None:
return BinomialProbs(probs, total_count, validate_args=validate_args)
elif logits is not None:
return BinomialLogits(logits, total_count, validate_args=validate_args)
else:
raise ValueError('One of `probs` or `logits` must be specified.')
[docs]@copy_docs_from(Distribution)
class CategoricalProbs(Distribution):
arg_constraints = {'probs': constraints.simplex}
def __init__(self, probs, validate_args=None):
if np.ndim(probs) < 1:
raise ValueError("`probs` parameter must be at least one-dimensional.")
self.probs = probs
super(CategoricalProbs, self).__init__(batch_shape=np.shape(self.probs)[:-1],
validate_args=validate_args)
[docs] def sample(self, key, sample_shape=()):
return categorical(key, self.probs, shape=sample_shape + self.batch_shape)
@validate_sample
def log_prob(self, value):
batch_shape = lax.broadcast_shapes(np.shape(value), self.batch_shape)
value = np.expand_dims(value, axis=-1)
value = np.broadcast_to(value, batch_shape + (1,))
logits = _to_logits_multinom(self.probs)
log_pmf = np.broadcast_to(logits, batch_shape + np.shape(logits)[-1:])
return np.take_along_axis(log_pmf, value, axis=-1)[..., 0]
@property
def mean(self):
return np.full(self.batch_shape, np.nan, dtype=get_dtype(self.probs))
@property
def variance(self):
return np.full(self.batch_shape, np.nan, dtype=get_dtype(self.probs))
@property
def support(self):
return constraints.integer_interval(0, np.shape(self.probs)[-1])
[docs]@copy_docs_from(Distribution)
class CategoricalLogits(Distribution):
arg_constraints = {'logits': constraints.real}
def __init__(self, logits, validate_args=None):
if np.ndim(logits) < 1:
raise ValueError("`logits` parameter must be at least one-dimensional.")
self.logits = logits
super(CategoricalLogits, self).__init__(batch_shape=np.shape(logits)[:-1],
validate_args=validate_args)
[docs] def sample(self, key, sample_shape=()):
return categorical(key, self.probs, shape=sample_shape + self.batch_shape)
@validate_sample
def log_prob(self, value):
batch_shape = lax.broadcast_shapes(np.shape(value), self.batch_shape)
value = np.expand_dims(value, -1)
value = np.broadcast_to(value, batch_shape + (1,))
log_pmf = self.logits - logsumexp(self.logits, axis=-1, keepdims=True)
log_pmf = np.broadcast_to(log_pmf, batch_shape + np.shape(log_pmf)[-1:])
return np.take_along_axis(log_pmf, value, -1)[..., 0]
[docs] @lazy_property
def probs(self):
return _to_probs_multinom(self.logits)
@property
def mean(self):
return np.full(self.batch_shape, np.nan, dtype=get_dtype(self.logits))
@property
def variance(self):
return np.full(self.batch_shape, np.nan, dtype=get_dtype(self.logits))
@property
def support(self):
return constraints.integer_interval(0, np.shape(self.logits)[-1])
[docs]def Categorical(probs=None, logits=None, validate_args=None):
if probs is not None:
return CategoricalProbs(probs, validate_args=validate_args)
elif logits is not None:
return CategoricalLogits(logits, validate_args=validate_args)
else:
raise ValueError('One of `probs` or `logits` must be specified.')
[docs]@copy_docs_from(Distribution)
class Delta(Distribution):
arg_constraints = {'value': constraints.real, 'log_density': constraints.real}
support = constraints.real
def __init__(self, value=0., log_density=0., event_ndim=0, validate_args=None):
if event_ndim > np.ndim(value):
raise ValueError('Expected event_dim <= v.dim(), actual {} vs {}'
.format(event_ndim, np.ndim(value)))
batch_dim = np.ndim(value) - event_ndim
batch_shape = np.shape(value)[:batch_dim]
event_shape = np.shape(value)[batch_dim:]
self.value = lax.convert_element_type(value, xla_bridge.canonicalize_dtype(np.float64))
# NB: following Pyro implementation, log_density should be broadcasted to batch_shape
self.log_density = promote_shapes(log_density, shape=batch_shape)[0]
super(Delta, self).__init__(batch_shape, event_shape, validate_args=validate_args)
[docs] def sample(self, key, sample_shape=()):
shape = sample_shape + self.batch_shape + self.event_shape
return np.broadcast_to(self.value, shape)
@validate_sample
def log_prob(self, value):
log_prob = np.log(value == self.value)
log_prob = sum_rightmost(log_prob, len(self.event_shape))
return log_prob + self.log_density
@property
def mean(self):
return self.value
@property
def variance(self):
return np.zeros(self.batch_shape + self.event_shape)
[docs]class OrderedLogistic(CategoricalProbs):
"""
A categorical distribution with ordered outcomes.
**References:**
1. *Stan Functions Reference, v2.20 section 12.6*,
Stan Development Team
:param numpy.ndarray predictor: prediction in real domain; typically this is output
of a linear model.
:param numpy.ndarray cutpoints: positions in real domain to separate categories.
"""
arg_constraints = {'predictor': constraints.real,
'cutpoints': constraints.ordered_vector}
def __init__(self, predictor, cutpoints, validate_args=None):
predictor, self.cutpoints = promote_shapes(np.expand_dims(predictor, -1), cutpoints)
self.predictor = predictor[..., 0]
cumulative_probs = expit(cutpoints - predictor)
# add two boundary points 0 and 1
pad_width = [(0, 0)] * (np.ndim(cumulative_probs) - 1) + [(1, 1)]
cumulative_probs = np.pad(cumulative_probs, pad_width, constant_values=(0, 1))
probs = cumulative_probs[..., 1:] - cumulative_probs[..., :-1]
super(OrderedLogistic, self).__init__(probs, validate_args=validate_args)
[docs]class PRNGIdentity(Distribution):
"""
Distribution over :func:`~jax.random.PRNGKey`. This can be used to
draw a batch of :func:`~jax.random.PRNGKey` using the :class:`~numpyro.handlers.seed`
handler. Only `sample` method is supported.
"""
def __init__(self):
super(PRNGIdentity, self).__init__(event_shape=(2,))
[docs] def sample(self, key, sample_shape=()):
return np.reshape(random.split(key, np.product(sample_shape).astype(np.int32)),
sample_shape + self.event_shape)
[docs]@copy_docs_from(Distribution)
class MultinomialProbs(Distribution):
arg_constraints = {'total_count': constraints.nonnegative_integer,
'probs': constraints.simplex}
def __init__(self, probs, total_count=1, validate_args=None):
if np.ndim(probs) < 1:
raise ValueError("`probs` parameter must be at least one-dimensional.")
batch_shape = lax.broadcast_shapes(np.shape(probs)[:-1], np.shape(total_count))
self.probs = promote_shapes(probs, shape=batch_shape + np.shape(probs)[-1:])[0]
self.total_count = promote_shapes(total_count, shape=batch_shape)[0]
super(MultinomialProbs, self).__init__(batch_shape=batch_shape,
event_shape=np.shape(self.probs)[-1:],
validate_args=validate_args)
[docs] def sample(self, key, sample_shape=()):
return multinomial(key, self.probs, self.total_count, shape=sample_shape + self.batch_shape)
@validate_sample
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
return gammaln(self.total_count + 1) \
+ np.sum(xlogy(value, self.probs) - gammaln(value + 1), axis=-1)
@property
def mean(self):
return self.probs * np.expand_dims(self.total_count, -1)
@property
def variance(self):
return np.expand_dims(self.total_count, -1) * self.probs * (1 - self.probs)
@property
def support(self):
return constraints.multinomial(self.total_count)
[docs]@copy_docs_from(Distribution)
class MultinomialLogits(Distribution):
arg_constraints = {'total_count': constraints.nonnegative_integer,
'logits': constraints.real}
def __init__(self, logits, total_count=1, validate_args=None):
if np.ndim(logits) < 1:
raise ValueError("`logits` parameter must be at least one-dimensional.")
batch_shape = lax.broadcast_shapes(np.shape(logits)[:-1], np.shape(total_count))
self.logits = promote_shapes(logits, shape=batch_shape + np.shape(logits)[-1:])[0]
self.total_count = promote_shapes(total_count, shape=batch_shape)[0]
super(MultinomialLogits, self).__init__(batch_shape=batch_shape,
event_shape=np.shape(self.logits)[-1:],
validate_args=validate_args)
[docs] def sample(self, key, sample_shape=()):
return multinomial(key, self.probs, self.total_count, shape=sample_shape + self.batch_shape)
@validate_sample
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
normalize_term = self.total_count * logsumexp(self.logits, axis=-1) \
- gammaln(self.total_count + 1)
return np.sum(value * self.logits - gammaln(value + 1), axis=-1) - normalize_term
[docs] @lazy_property
def probs(self):
return _to_probs_multinom(self.logits)
@property
def mean(self):
return np.expand_dims(self.total_count, -1) * self.probs
@property
def variance(self):
return np.expand_dims(self.total_count, -1) * self.probs * (1 - self.probs)
@property
def support(self):
return constraints.multinomial(self.total_count)
[docs]def Multinomial(total_count=1, probs=None, logits=None, validate_args=None):
if probs is not None:
return MultinomialProbs(probs, total_count, validate_args=validate_args)
elif logits is not None:
return MultinomialLogits(logits, total_count, validate_args=validate_args)
else:
raise ValueError('One of `probs` or `logits` must be specified.')
[docs]@copy_docs_from(Distribution)
class Poisson(Distribution):
arg_constraints = {'rate': constraints.positive}
support = constraints.nonnegative_integer
def __init__(self, rate, validate_args=None):
self.rate = rate
super(Poisson, self).__init__(np.shape(rate), validate_args=validate_args)
[docs] def sample(self, key, sample_shape=()):
return poisson(key, self.rate, shape=sample_shape + self.batch_shape)
@validate_sample
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
return (np.log(self.rate) * value) - gammaln(value + 1) - self.rate
@property
def mean(self):
return self.rate
@property
def variance(self):
return self.rate
[docs]class ZeroInflatedPoisson(Distribution):
"""
A Zero Inflated Poisson distribution.
:param numpy.ndarray gate: probability of extra zeros.
:param numpy.ndarray rate: rate of Poisson distribution.
"""
arg_constraints = {'gate': constraints.unit_interval, 'rate': constraints.positive}
support = constraints.nonnegative_integer
def __init__(self, gate, rate=1., validate_args=None):
batch_shape = lax.broadcast_shapes(np.shape(gate), np.shape(rate))
self.gate, self.rate = promote_shapes(gate, rate)
super(ZeroInflatedPoisson, self).__init__(batch_shape, validate_args=validate_args)
[docs] def sample(self, key, sample_shape=()):
key_bern, key_poisson = random.split(key)
shape = sample_shape + self.batch_shape
mask = random.bernoulli(key_bern, self.gate, shape)
samples = poisson(key_poisson, self.rate, shape)
return np.where(mask, 0, samples)
@validate_sample
def log_prob(self, value):
log_prob = np.log(self.rate) * value - gammaln(value + 1) + (np.log1p(-self.gate) - self.rate)
return np.where(value == 0, np.logaddexp(np.log(self.gate), log_prob), log_prob)
[docs] @lazy_property
def mean(self):
return (1 - self.gate) * self.rate
[docs] @lazy_property
def variance(self):
return (1 - self.gate) * self.rate * (1 + self.rate * self.gate)