# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from jax import lax
import jax.numpy as jnp
import jax.random as random
from jax.scipy.special import logsumexp
from jax.tree_util import tree_map
from numpyro.distributions import constraints
from numpyro.distributions.continuous import (
Cauchy,
Laplace,
Logistic,
Normal,
SoftLaplace,
StudentT,
)
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import (
clamp_probs,
lazy_property,
promote_shapes,
validate_sample,
)
from numpyro.util import is_prng_key
[docs]
class LeftTruncatedDistribution(Distribution):
arg_constraints = {"low": constraints.real}
reparametrized_params = ["low"]
supported_types = (Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT)
pytree_data_fields = ("base_dist", "low", "_support")
def __init__(self, base_dist, low=0.0, *, validate_args=None):
assert isinstance(base_dist, self.supported_types)
assert (
base_dist.support is constraints.real
), "The base distribution should be univariate and have real support."
batch_shape = lax.broadcast_shapes(base_dist.batch_shape, jnp.shape(low))
self.base_dist = tree_map(
lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist
)
(self.low,) = promote_shapes(low, shape=batch_shape)
self._support = constraints.greater_than(low)
super().__init__(batch_shape, validate_args=validate_args)
@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self):
return self._support
@lazy_property
def _tail_prob_at_low(self):
# if low < loc, returns cdf(low); otherwise returns 1 - cdf(low)
loc = self.base_dist.loc
sign = jnp.where(loc >= self.low, 1.0, -1.0)
return self.base_dist.cdf(loc - sign * (loc - self.low))
@lazy_property
def _tail_prob_at_high(self):
# if low < loc, returns cdf(high) = 1; otherwise returns 1 - cdf(high) = 0
return jnp.where(self.low <= self.base_dist.loc, 1.0, 0.0)
[docs]
def sample(self, key, sample_shape=()):
assert is_prng_key(key)
dtype = jnp.result_type(float)
finfo = jnp.finfo(dtype)
minval = finfo.tiny
u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval)
loc = self.base_dist.loc
sign = jnp.where(loc >= self.low, 1.0, -1.0)
return (1 - sign) * loc + sign * self.base_dist.icdf(
(1 - u) * self._tail_prob_at_low + u * self._tail_prob_at_high
)
@validate_sample
def log_prob(self, value):
sign = jnp.where(self.base_dist.loc >= self.low, 1.0, -1.0)
return self.base_dist.log_prob(value) - jnp.log(
sign * (self._tail_prob_at_high - self._tail_prob_at_low)
)
@property
def mean(self):
if isinstance(self.base_dist, Normal):
low_prob = jnp.exp(self.log_prob(self.low))
return self.base_dist.loc + low_prob * self.base_dist.scale**2
elif isinstance(self.base_dist, Cauchy):
return jnp.full(self.batch_shape, jnp.nan)
else:
raise NotImplementedError("mean only available for Normal and Cauchy")
@property
def var(self):
if isinstance(self.base_dist, Normal):
low_prob = jnp.exp(self.log_prob(self.low))
return (self.base_dist.scale**2) * (
1
+ (self.low - self.base_dist.loc) * low_prob
- (low_prob * self.base_dist.scale) ** 2
)
elif isinstance(self.base_dist, Cauchy):
return jnp.full(self.batch_shape, jnp.nan)
else:
raise NotImplementedError("var only available for Normal and Cauchy")
[docs]
class RightTruncatedDistribution(Distribution):
arg_constraints = {"high": constraints.real}
reparametrized_params = ["high"]
supported_types = (Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT)
pytree_data_fields = ("base_dist", "high", "_support")
def __init__(self, base_dist, high=0.0, *, validate_args=None):
assert isinstance(base_dist, self.supported_types)
assert (
base_dist.support is constraints.real
), "The base distribution should be univariate and have real support."
batch_shape = lax.broadcast_shapes(base_dist.batch_shape, jnp.shape(high))
self.base_dist = tree_map(
lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist
)
(self.high,) = promote_shapes(high, shape=batch_shape)
self._support = constraints.less_than(high)
super().__init__(batch_shape, validate_args=validate_args)
@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self):
return self._support
@lazy_property
def _cdf_at_high(self):
return self.base_dist.cdf(self.high)
[docs]
def sample(self, key, sample_shape=()):
assert is_prng_key(key)
dtype = jnp.result_type(float)
finfo = jnp.finfo(dtype)
minval = finfo.tiny
u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval)
return self.base_dist.icdf(u * self._cdf_at_high)
@validate_sample
def log_prob(self, value):
return self.base_dist.log_prob(value) - jnp.log(self._cdf_at_high)
@property
def mean(self):
if isinstance(self.base_dist, Normal):
high_prob = jnp.exp(self.log_prob(self.high))
return self.base_dist.loc - high_prob * self.base_dist.scale**2
elif isinstance(self.base_dist, Cauchy):
return jnp.full(self.batch_shape, jnp.nan)
else:
raise NotImplementedError("mean only available for Normal and Cauchy")
@property
def var(self):
if isinstance(self.base_dist, Normal):
high_prob = jnp.exp(self.log_prob(self.high))
return (self.base_dist.scale**2) * (
1
- (self.high - self.base_dist.loc) * high_prob
- (high_prob * self.base_dist.scale) ** 2
)
elif isinstance(self.base_dist, Cauchy):
return jnp.full(self.batch_shape, jnp.nan)
else:
raise NotImplementedError("var only available for Normal and Cauchy")
[docs]
class TwoSidedTruncatedDistribution(Distribution):
arg_constraints = {
"low": constraints.dependent,
"high": constraints.dependent,
}
reparametrized_params = ["low", "high"]
supported_types = (Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT)
pytree_data_fields = ("base_dist", "low", "high", "_support")
def __init__(self, base_dist, low=0.0, high=1.0, *, validate_args=None):
assert isinstance(base_dist, self.supported_types)
assert (
base_dist.support is constraints.real
), "The base distribution should be univariate and have real support."
batch_shape = lax.broadcast_shapes(
base_dist.batch_shape, jnp.shape(low), jnp.shape(high)
)
self.base_dist = tree_map(
lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist
)
(self.low,) = promote_shapes(low, shape=batch_shape)
(self.high,) = promote_shapes(high, shape=batch_shape)
self._support = constraints.interval(low, high)
super().__init__(batch_shape, validate_args=validate_args)
@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self):
return self._support
@lazy_property
def _tail_prob_at_low(self):
# if low < loc, returns cdf(low); otherwise returns 1 - cdf(low)
loc = self.base_dist.loc
sign = jnp.where(loc >= self.low, 1.0, -1.0)
return self.base_dist.cdf(loc - sign * (loc - self.low))
@lazy_property
def _tail_prob_at_high(self):
# if low < loc, returns cdf(high); otherwise returns 1 - cdf(high)
loc = self.base_dist.loc
sign = jnp.where(loc >= self.low, 1.0, -1.0)
return self.base_dist.cdf(loc - sign * (loc - self.high))
@lazy_property
def _log_diff_tail_probs(self):
# use log_cdf method, if available, to avoid inf's in log_prob
# fall back to cdf, if log_cdf not available
log_cdf = getattr(self.base_dist, "log_cdf", None)
if callable(log_cdf):
return logsumexp(
a=jnp.stack([log_cdf(self.high), log_cdf(self.low)], axis=-1),
axis=-1,
b=jnp.array([1, -1]), # subtract low from high
)
else:
loc = self.base_dist.loc
sign = jnp.where(loc >= self.low, 1.0, -1.0)
return jnp.log(sign * (self._tail_prob_at_high - self._tail_prob_at_low))
[docs]
def sample(self, key, sample_shape=()):
assert is_prng_key(key)
dtype = jnp.result_type(float)
finfo = jnp.finfo(dtype)
minval = finfo.tiny
u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval)
# NB: we use a more numerically stable formula for a symmetric base distribution
# A = icdf(cdf(low) + (cdf(high) - cdf(low)) * u) = icdf[(1 - u) * cdf(low) + u * cdf(high)]
# will suffer by precision issues when low is large;
# If low < loc:
# A = icdf[(1 - u) * cdf(low) + u * cdf(high)]
# Else
# A = 2 * loc - icdf[(1 - u) * cdf(2*loc-low)) + u * cdf(2*loc - high)]
loc = self.base_dist.loc
sign = jnp.where(loc >= self.low, 1.0, -1.0)
return (1 - sign) * loc + sign * self.base_dist.icdf(
clamp_probs((1 - u) * self._tail_prob_at_low + u * self._tail_prob_at_high)
)
@validate_sample
def log_prob(self, value):
# NB: we use a more numerically stable formula for a symmetric base distribution
# if low < loc
# cdf(high) - cdf(low) = as-is
# if low > loc
# cdf(high) - cdf(low) = cdf(2 * loc - low) - cdf(2 * loc - high)
return self.base_dist.log_prob(value) - self._log_diff_tail_probs
@property
def mean(self):
if isinstance(self.base_dist, Normal):
low_prob = jnp.exp(self.log_prob(self.low))
high_prob = jnp.exp(self.log_prob(self.high))
return self.base_dist.loc + (low_prob - high_prob) * self.base_dist.scale**2
elif isinstance(self.base_dist, Cauchy):
return jnp.full(self.batch_shape, jnp.nan)
else:
raise NotImplementedError("mean only available for Normal and Cauchy")
@property
def var(self):
if isinstance(self.base_dist, Normal):
low_prob = jnp.exp(self.log_prob(self.low))
high_prob = jnp.exp(self.log_prob(self.high))
return (self.base_dist.scale**2) * (
1
+ (self.low - self.base_dist.loc) * low_prob
- (self.high - self.base_dist.loc) * high_prob
- ((low_prob - high_prob) * self.base_dist.scale) ** 2
)
elif isinstance(self.base_dist, Cauchy):
return jnp.full(self.batch_shape, jnp.nan)
else:
raise NotImplementedError("var only available for Normal and Cauchy")
[docs]
def TruncatedDistribution(base_dist, low=None, high=None, *, validate_args=None):
"""
A function to generate a truncated distribution.
:param base_dist: The base distribution to be truncated. This should be a univariate
distribution. Currently, only the following distributions are supported:
Cauchy, Laplace, Logistic, Normal, and StudentT.
:param low: the value which is used to truncate the base distribution from below.
Setting this parameter to None to not truncate from below.
:param high: the value which is used to truncate the base distribution from above.
Setting this parameter to None to not truncate from above.
"""
if high is None:
if low is None:
return base_dist
else:
return LeftTruncatedDistribution(
base_dist, low=low, validate_args=validate_args
)
elif low is None:
return RightTruncatedDistribution(
base_dist, high=high, validate_args=validate_args
)
else:
return TwoSidedTruncatedDistribution(
base_dist, low=low, high=high, validate_args=validate_args
)
[docs]
def TruncatedCauchy(loc=0.0, scale=1.0, *, low=None, high=None, validate_args=None):
return TruncatedDistribution(
Cauchy(loc, scale), low=low, high=high, validate_args=validate_args
)
[docs]
def TruncatedNormal(loc=0.0, scale=1.0, *, low=None, high=None, validate_args=None):
return TruncatedDistribution(
Normal(loc, scale), low=low, high=high, validate_args=validate_args
)
[docs]
class TruncatedPolyaGamma(Distribution):
truncation_point = 2.5
num_log_prob_terms = 7
num_gamma_variates = 8
assert num_log_prob_terms % 2 == 1
arg_constraints = {}
support = constraints.interval(0.0, truncation_point)
def __init__(self, batch_shape=(), *, validate_args=None):
super(TruncatedPolyaGamma, self).__init__(
batch_shape, validate_args=validate_args
)
[docs]
def sample(self, key, sample_shape=()):
assert is_prng_key(key)
denom = jnp.square(jnp.arange(0.5, self.num_gamma_variates))
x = random.gamma(
key, jnp.ones(self.batch_shape + sample_shape + (self.num_gamma_variates,))
)
x = jnp.sum(x / denom, axis=-1)
return jnp.clip(x * (0.5 / jnp.pi**2), a_max=self.truncation_point)
@validate_sample
def log_prob(self, value):
value = value[..., None]
all_indices = jnp.arange(0, self.num_log_prob_terms)
two_n_plus_one = 2.0 * all_indices + 1.0
log_terms = (
jnp.log(two_n_plus_one)
- 1.5 * jnp.log(value)
- 0.125 * jnp.square(two_n_plus_one) / value
)
even_terms = jnp.take(log_terms, all_indices[::2], axis=-1)
odd_terms = jnp.take(log_terms, all_indices[1::2], axis=-1)
sum_even = jnp.exp(logsumexp(even_terms, axis=-1))
sum_odd = jnp.exp(logsumexp(odd_terms, axis=-1))
return jnp.log(sum_even - sum_odd) - 0.5 * jnp.log(2.0 * jnp.pi)