# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from jax import lax
import jax.numpy as jnp
import jax.random as random
from jax.scipy.special import logsumexp
from jax.tree_util import tree_map
from numpyro.distributions import constraints
from numpyro.distributions.continuous import (
Cauchy,
Laplace,
Logistic,
Normal,
SoftLaplace,
StudentT,
)
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import (
is_prng_key,
lazy_property,
promote_shapes,
validate_sample,
)
[docs]class LeftTruncatedDistribution(Distribution):
arg_constraints = {"low": constraints.real}
reparametrized_params = ["low"]
supported_types = (Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT)
def __init__(self, base_dist, low=0.0, validate_args=None):
assert isinstance(base_dist, self.supported_types)
assert (
base_dist.support is constraints.real
), "The base distribution should be univariate and have real support."
batch_shape = lax.broadcast_shapes(base_dist.batch_shape, jnp.shape(low))
self.base_dist = tree_map(
lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist
)
(self.low,) = promote_shapes(low, shape=batch_shape)
self._support = constraints.greater_than(low)
super().__init__(batch_shape, validate_args=validate_args)
@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self):
return self._support
@lazy_property
def _tail_prob_at_low(self):
# if low < loc, returns cdf(low); otherwise returns 1 - cdf(low)
loc = self.base_dist.loc
sign = jnp.where(loc >= self.low, 1.0, -1.0)
return self.base_dist.cdf(loc - sign * (loc - self.low))
@lazy_property
def _tail_prob_at_high(self):
# if low < loc, returns cdf(high) = 1; otherwise returns 1 - cdf(high) = 0
return jnp.where(self.low <= self.base_dist.loc, 1.0, 0.0)
[docs] def sample(self, key, sample_shape=()):
assert is_prng_key(key)
u = random.uniform(key, sample_shape + self.batch_shape)
loc = self.base_dist.loc
sign = jnp.where(loc >= self.low, 1.0, -1.0)
return (1 - sign) * loc + sign * self.base_dist.icdf(
(1 - u) * self._tail_prob_at_low + u * self._tail_prob_at_high
)
@validate_sample
def log_prob(self, value):
sign = jnp.where(self.base_dist.loc >= self.low, 1.0, -1.0)
return self.base_dist.log_prob(value) - jnp.log(
sign * (self._tail_prob_at_high - self._tail_prob_at_low)
)
[docs] def tree_flatten(self):
base_flatten, base_aux = self.base_dist.tree_flatten()
if isinstance(self._support.lower_bound, (int, float)):
return base_flatten, (
type(self.base_dist),
base_aux,
self._support.lower_bound,
)
else:
return (base_flatten, self.low), (type(self.base_dist), base_aux)
[docs] @classmethod
def tree_unflatten(cls, aux_data, params):
if len(aux_data) == 2:
base_flatten, low = params
base_cls, base_aux = aux_data
else:
base_flatten = params
base_cls, base_aux, low = aux_data
base_dist = base_cls.tree_unflatten(base_aux, base_flatten)
return cls(base_dist, low=low)
[docs]class RightTruncatedDistribution(Distribution):
arg_constraints = {"high": constraints.real}
reparametrized_params = ["high"]
supported_types = (Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT)
def __init__(self, base_dist, high=0.0, validate_args=None):
assert isinstance(base_dist, self.supported_types)
assert (
base_dist.support is constraints.real
), "The base distribution should be univariate and have real support."
batch_shape = lax.broadcast_shapes(base_dist.batch_shape, jnp.shape(high))
self.base_dist = tree_map(
lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist
)
(self.high,) = promote_shapes(high, shape=batch_shape)
self._support = constraints.less_than(high)
super().__init__(batch_shape, validate_args=validate_args)
@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self):
return self._support
@lazy_property
def _cdf_at_high(self):
return self.base_dist.cdf(self.high)
[docs] def sample(self, key, sample_shape=()):
assert is_prng_key(key)
u = random.uniform(key, sample_shape + self.batch_shape)
return self.base_dist.icdf(u * self._cdf_at_high)
@validate_sample
def log_prob(self, value):
return self.base_dist.log_prob(value) - jnp.log(self._cdf_at_high)
[docs] def tree_flatten(self):
base_flatten, base_aux = self.base_dist.tree_flatten()
if isinstance(self._support.upper_bound, (int, float)):
return base_flatten, (
type(self.base_dist),
base_aux,
self._support.upper_bound,
)
else:
return (base_flatten, self.high), (type(self.base_dist), base_aux)
[docs] @classmethod
def tree_unflatten(cls, aux_data, params):
if len(aux_data) == 2:
base_flatten, high = params
base_cls, base_aux = aux_data
else:
base_flatten = params
base_cls, base_aux, high = aux_data
base_dist = base_cls.tree_unflatten(base_aux, base_flatten)
return cls(base_dist, high=high)
[docs]class TwoSidedTruncatedDistribution(Distribution):
arg_constraints = {"low": constraints.dependent, "high": constraints.dependent}
reparametrized_params = ["low", "high"]
supported_types = (Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT)
def __init__(self, base_dist, low=0.0, high=1.0, validate_args=None):
assert isinstance(base_dist, self.supported_types)
assert (
base_dist.support is constraints.real
), "The base distribution should be univariate and have real support."
batch_shape = lax.broadcast_shapes(
base_dist.batch_shape, jnp.shape(low), jnp.shape(high)
)
self.base_dist = tree_map(
lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist
)
(self.low,) = promote_shapes(low, shape=batch_shape)
(self.high,) = promote_shapes(high, shape=batch_shape)
self._support = constraints.interval(low, high)
super().__init__(batch_shape, validate_args=validate_args)
@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self):
return self._support
@lazy_property
def _tail_prob_at_low(self):
# if low < loc, returns cdf(low); otherwise returns 1 - cdf(low)
loc = self.base_dist.loc
sign = jnp.where(loc >= self.low, 1.0, -1.0)
return self.base_dist.cdf(loc - sign * (loc - self.low))
@lazy_property
def _tail_prob_at_high(self):
# if low < loc, returns cdf(high); otherwise returns 1 - cdf(high)
loc = self.base_dist.loc
sign = jnp.where(loc >= self.low, 1.0, -1.0)
return self.base_dist.cdf(loc - sign * (loc - self.high))
[docs] def sample(self, key, sample_shape=()):
assert is_prng_key(key)
u = random.uniform(key, sample_shape + self.batch_shape)
# NB: we use a more numerically stable formula for a symmetric base distribution
# A = icdf(cdf(low) + (cdf(high) - cdf(low)) * u) = icdf[(1 - u) * cdf(low) + u * cdf(high)]
# will suffer by precision issues when low is large;
# If low < loc:
# A = icdf[(1 - u) * cdf(low) + u * cdf(high)]
# Else
# A = 2 * loc - icdf[(1 - u) * cdf(2*loc-low)) + u * cdf(2*loc - high)]
loc = self.base_dist.loc
sign = jnp.where(loc >= self.low, 1.0, -1.0)
return (1 - sign) * loc + sign * self.base_dist.icdf(
(1 - u) * self._tail_prob_at_low + u * self._tail_prob_at_high
)
@validate_sample
def log_prob(self, value):
# NB: we use a more numerically stable formula for a symmetric base distribution
# if low < loc
# cdf(high) - cdf(low) = as-is
# if low > loc
# cdf(high) - cdf(low) = cdf(2 * loc - low) - cdf(2 * loc - high)
sign = jnp.where(self.base_dist.loc >= self.low, 1.0, -1.0)
return self.base_dist.log_prob(value) - jnp.log(
sign * (self._tail_prob_at_high - self._tail_prob_at_low)
)
[docs] def tree_flatten(self):
base_flatten, base_aux = self.base_dist.tree_flatten()
if isinstance(self._support.lower_bound, (int, float)) and isinstance(
self._support.upper_bound, (int, float)
):
return base_flatten, (
type(self.base_dist),
base_aux,
self._support.lower_bound,
self._support.upper_bound,
)
else:
return (base_flatten, self.low, self.high), (type(self.base_dist), base_aux)
[docs] @classmethod
def tree_unflatten(cls, aux_data, params):
if len(aux_data) == 2:
base_flatten, low, high = params
base_cls, base_aux = aux_data
else:
base_flatten = params
base_cls, base_aux, low, high = aux_data
base_dist = base_cls.tree_unflatten(base_aux, base_flatten)
return cls(base_dist, low=low, high=high)
[docs]def TruncatedDistribution(base_dist, low=None, high=None, validate_args=None):
"""
A function to generate a truncated distribution.
:param base_dist: The base distribution to be truncated. This should be a univariate
distribution. Currently, only the following distributions are supported:
Cauchy, Laplace, Logistic, Normal, and StudentT.
:param low: the value which is used to truncate the base distribution from below.
Setting this parameter to None to not truncate from below.
:param high: the value which is used to truncate the base distribution from above.
Setting this parameter to None to not truncate from above.
"""
if high is None:
if low is None:
return base_dist
else:
return LeftTruncatedDistribution(
base_dist, low=low, validate_args=validate_args
)
elif low is None:
return RightTruncatedDistribution(
base_dist, high=high, validate_args=validate_args
)
else:
return TwoSidedTruncatedDistribution(
base_dist, low=low, high=high, validate_args=validate_args
)
[docs]class TruncatedCauchy(LeftTruncatedDistribution):
arg_constraints = {
"low": constraints.real,
"loc": constraints.real,
"scale": constraints.positive,
}
reparametrized_params = ["low", "loc", "scale"]
def __init__(self, low=0.0, loc=0.0, scale=1.0, validate_args=None):
self.low, self.loc, self.scale = promote_shapes(low, loc, scale)
super().__init__(
Cauchy(self.loc, self.scale), low=self.low, validate_args=validate_args
)
@property
def mean(self):
return jnp.full(self.batch_shape, jnp.nan)
@property
def variance(self):
return jnp.full(self.batch_shape, jnp.nan)
[docs] def tree_flatten(self):
if isinstance(self._support.lower_bound, (int, float)):
aux_data = self._support.lower_bound
else:
aux_data = None
return (self.low, self.loc, self.scale), aux_data
[docs] @classmethod
def tree_unflatten(cls, aux_data, params):
d = cls(*params)
if aux_data is not None:
d._support = constraints.greater_than(aux_data)
return d
[docs]class TruncatedNormal(LeftTruncatedDistribution):
arg_constraints = {
"low": constraints.real,
"loc": constraints.real,
"scale": constraints.positive,
}
reparametrized_params = ["low", "loc", "scale"]
def __init__(self, low=0.0, loc=0.0, scale=1.0, validate_args=None):
self.low, self.loc, self.scale = promote_shapes(low, loc, scale)
super().__init__(
Normal(self.loc, self.scale), low=self.low, validate_args=validate_args
)
@property
def mean(self):
low_prob = jnp.exp(self.log_prob(self.low))
return self.loc + low_prob * self.scale ** 2
@property
def variance(self):
low_prob = jnp.exp(self.log_prob(self.low))
return (self.scale ** 2) * (
1 + (self.low - self.loc) * low_prob - (low_prob * self.scale) ** 2
)
[docs] def tree_flatten(self):
if isinstance(self._support.lower_bound, (int, float)):
aux_data = self._support.lower_bound
else:
aux_data = None
return (self.low, self.loc, self.scale), aux_data
[docs] @classmethod
def tree_unflatten(cls, aux_data, params):
d = cls(*params)
if aux_data is not None:
d._support = constraints.greater_than(aux_data)
return d
[docs]class TruncatedPolyaGamma(Distribution):
truncation_point = 2.5
num_log_prob_terms = 7
num_gamma_variates = 8
assert num_log_prob_terms % 2 == 1
arg_constraints = {}
support = constraints.interval(0.0, truncation_point)
def __init__(self, batch_shape=(), validate_args=None):
super(TruncatedPolyaGamma, self).__init__(
batch_shape, validate_args=validate_args
)
[docs] def sample(self, key, sample_shape=()):
assert is_prng_key(key)
denom = jnp.square(jnp.arange(0.5, self.num_gamma_variates))
x = random.gamma(
key, jnp.ones(self.batch_shape + sample_shape + (self.num_gamma_variates,))
)
x = jnp.sum(x / denom, axis=-1)
return jnp.clip(x * (0.5 / jnp.pi ** 2), a_max=self.truncation_point)
@validate_sample
def log_prob(self, value):
value = value[..., None]
all_indices = jnp.arange(0, self.num_log_prob_terms)
two_n_plus_one = 2.0 * all_indices + 1.0
log_terms = (
jnp.log(two_n_plus_one)
- 1.5 * jnp.log(value)
- 0.125 * jnp.square(two_n_plus_one) / value
)
even_terms = jnp.take(log_terms, all_indices[::2], axis=-1)
odd_terms = jnp.take(log_terms, all_indices[1::2], axis=-1)
sum_even = jnp.exp(logsumexp(even_terms, axis=-1))
sum_odd = jnp.exp(logsumexp(odd_terms, axis=-1))
return jnp.log(sum_even - sum_odd) - 0.5 * jnp.log(2.0 * jnp.pi)
[docs] def tree_flatten(self):
return (), self.batch_shape
[docs] @classmethod
def tree_unflatten(cls, aux_data, params):
return cls(batch_shape=aux_data)