# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Union
import jax
from jax import lax
import jax.numpy as jnp
import jax.random as random
from jax.scipy.special import logsumexp
from jax.typing import ArrayLike
from numpyro.distributions import constraints
from numpyro.distributions.constraints import Constraint
from numpyro.distributions.continuous import (
Cauchy,
Laplace,
Logistic,
Normal,
SoftLaplace,
StudentT,
)
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import (
clamp_probs,
lazy_property,
promote_shapes,
validate_sample,
)
from numpyro.util import is_prng_key
[docs]
class LeftTruncatedDistribution(Distribution):
arg_constraints = {"low": constraints.real}
reparametrized_params = ["low"]
supported_types = (Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT)
pytree_data_fields = ("base_dist", "low", "_support")
def __init__(
self,
base_dist: Union[Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT],
low: ArrayLike = 0.0,
*,
validate_args: Optional[bool] = None,
):
assert isinstance(base_dist, self.supported_types)
assert base_dist.support is constraints.real, (
"The base distribution should be univariate and have real support."
)
batch_shape = lax.broadcast_shapes(base_dist.batch_shape, jnp.shape(low))
self.base_dist: Union[
Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT
] = jax.tree.map(lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist)
(self.low,) = promote_shapes(low, shape=batch_shape)
self._support = constraints.greater_than(low)
super().__init__(batch_shape, validate_args=validate_args)
@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self) -> Constraint:
return self._support
@lazy_property
def _tail_prob_at_low(self):
# if low < loc, returns cdf(low); otherwise returns 1 - cdf(low)
loc = self.base_dist.loc
sign = jnp.where(loc >= self.low, 1.0, -1.0)
return self.base_dist.cdf(loc - sign * (loc - self.low))
@lazy_property
def _tail_prob_at_high(self):
# if low < loc, returns cdf(high) = 1; otherwise returns 1 - cdf(high) = 0
return jnp.where(self.low <= self.base_dist.loc, 1.0, 0.0)
[docs]
def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike:
assert is_prng_key(key)
dtype = jnp.result_type(float)
finfo = jnp.finfo(dtype)
minval = finfo.tiny
u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval)
return self.icdf(u)
[docs]
def icdf(self, q: ArrayLike) -> ArrayLike:
loc = self.base_dist.loc
sign = jnp.where(loc >= self.low, 1.0, -1.0)
ppf = (1 - sign) * loc + sign * self.base_dist.icdf(
(1 - q) * self._tail_prob_at_low + q * self._tail_prob_at_high
)
return jnp.where(q < 0, jnp.nan, ppf)
[docs]
def cdf(self, value: ArrayLike) -> ArrayLike:
# For left truncated distribution: CDF(x) = (F(x) - F(low)) / (1 - F(low))
# where F is the base distribution CDF
base_cdf_value = self.base_dist.cdf(value)
base_cdf_low = self.base_dist.cdf(self.low)
# Handle the case where value < low (should be 0)
# and value >= low (should be the truncated CDF)
truncated_cdf = (base_cdf_value - base_cdf_low) / (1.0 - base_cdf_low)
# Clamp to [0, 1] and handle values below the truncation point
result = jnp.where(value < self.low, 0.0, jnp.clip(truncated_cdf, 0.0, 1.0))
return result
[docs]
@validate_sample
def log_prob(self, value: ArrayLike) -> ArrayLike:
sign = jnp.where(self.base_dist.loc >= self.low, 1.0, -1.0)
return self.base_dist.log_prob(value) - jnp.log(
sign * (self._tail_prob_at_high - self._tail_prob_at_low)
)
@property
def mean(self) -> ArrayLike:
if isinstance(self.base_dist, Normal):
low_prob = jnp.exp(self.log_prob(self.low))
return self.base_dist.loc + low_prob * self.base_dist.scale**2
elif isinstance(self.base_dist, Cauchy):
return jnp.full(self.batch_shape, jnp.nan)
else:
raise NotImplementedError("mean only available for Normal and Cauchy")
@property
def var(self) -> ArrayLike:
if isinstance(self.base_dist, Normal):
low_prob = jnp.exp(self.log_prob(self.low))
return (self.base_dist.scale**2) * (
1
+ (self.low - self.base_dist.loc) * low_prob
- (low_prob * self.base_dist.scale) ** 2
)
elif isinstance(self.base_dist, Cauchy):
return jnp.full(self.batch_shape, jnp.nan)
else:
raise NotImplementedError("var only available for Normal and Cauchy")
[docs]
class RightTruncatedDistribution(Distribution):
arg_constraints = {"high": constraints.real}
reparametrized_params = ["high"]
supported_types = (Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT)
pytree_data_fields = ("base_dist", "high", "_support")
def __init__(
self,
base_dist: Union[Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT],
high: ArrayLike = 0.0,
*,
validate_args: Optional[bool] = None,
):
assert isinstance(base_dist, self.supported_types)
assert base_dist.support is constraints.real, (
"The base distribution should be univariate and have real support."
)
batch_shape = lax.broadcast_shapes(base_dist.batch_shape, jnp.shape(high))
self.base_dist: Union[
Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT
] = jax.tree.map(lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist)
(self.high,) = promote_shapes(high, shape=batch_shape)
self._support = constraints.less_than(high)
super().__init__(batch_shape, validate_args=validate_args)
@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self) -> Constraint:
return self._support
@lazy_property
def _cdf_at_high(self) -> ArrayLike:
return self.base_dist.cdf(self.high)
[docs]
def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike:
assert is_prng_key(key)
dtype = jnp.result_type(float)
finfo = jnp.finfo(dtype)
minval = finfo.tiny
u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval)
return self.icdf(u)
[docs]
def icdf(self, q: ArrayLike) -> ArrayLike:
ppf = self.base_dist.icdf(q * self._cdf_at_high)
return jnp.where(q > 1, jnp.nan, ppf)
[docs]
def cdf(self, value: ArrayLike) -> ArrayLike:
# For right truncated distribution: CDF(x) = F(x) / F(high)
# where F is the base distribution CDF
base_cdf_value = self.base_dist.cdf(value)
base_cdf_high = self._cdf_at_high
# Handle the case where value > high (should be 1)
# and value <= high (should be the truncated CDF)
truncated_cdf = base_cdf_value / base_cdf_high
# Clamp to [0, 1] and handle values above the truncation point
result = jnp.where(value > self.high, 1.0, jnp.clip(truncated_cdf, 0.0, 1.0))
return result
[docs]
@validate_sample
def log_prob(self, value: ArrayLike) -> ArrayLike:
return self.base_dist.log_prob(value) - jnp.log(self._cdf_at_high)
@property
def mean(self) -> ArrayLike:
if isinstance(self.base_dist, Normal):
high_prob = jnp.exp(self.log_prob(self.high))
return self.base_dist.loc - high_prob * self.base_dist.scale**2
elif isinstance(self.base_dist, Cauchy):
return jnp.full(self.batch_shape, jnp.nan)
else:
raise NotImplementedError("mean only available for Normal and Cauchy")
@property
def var(self) -> ArrayLike:
if isinstance(self.base_dist, Normal):
high_prob = jnp.exp(self.log_prob(self.high))
return (self.base_dist.scale**2) * (
1
- (self.high - self.base_dist.loc) * high_prob
- (high_prob * self.base_dist.scale) ** 2
)
elif isinstance(self.base_dist, Cauchy):
return jnp.full(self.batch_shape, jnp.nan)
else:
raise NotImplementedError("var only available for Normal and Cauchy")
[docs]
class TwoSidedTruncatedDistribution(Distribution):
arg_constraints = {
"low": constraints.dependent(is_discrete=False, event_dim=0),
"high": constraints.dependent(is_discrete=False, event_dim=0),
}
reparametrized_params = ["low", "high"]
supported_types = (Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT)
pytree_data_fields = ("base_dist", "low", "high", "_support")
def __init__(
self,
base_dist: Union[Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT],
low: ArrayLike = 0.0,
high: ArrayLike = 1.0,
*,
validate_args: Optional[bool] = None,
):
assert isinstance(base_dist, self.supported_types)
assert base_dist.support is constraints.real, (
"The base distribution should be univariate and have real support."
)
batch_shape = lax.broadcast_shapes(
base_dist.batch_shape, jnp.shape(low), jnp.shape(high)
)
self.base_dist: Union[
Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT
] = jax.tree.map(lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist)
(self.low,) = promote_shapes(low, shape=batch_shape)
(self.high,) = promote_shapes(high, shape=batch_shape)
self._support = constraints.interval(low, high)
super().__init__(batch_shape, validate_args=validate_args)
@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self) -> Constraint:
return self._support
@lazy_property
def _tail_prob_at_low(self) -> ArrayLike:
# if low < loc, returns cdf(low); otherwise returns 1 - cdf(low)
loc = self.base_dist.loc
sign = jnp.where(loc >= self.low, 1.0, -1.0)
return self.base_dist.cdf(loc - sign * (loc - self.low))
@lazy_property
def _tail_prob_at_high(self) -> ArrayLike:
# if low < loc, returns cdf(high); otherwise returns 1 - cdf(high)
loc = self.base_dist.loc
sign = jnp.where(loc >= self.low, 1.0, -1.0)
return self.base_dist.cdf(loc - sign * (loc - self.high))
@lazy_property
def _log_diff_tail_probs(self) -> ArrayLike:
# use log_cdf method, if available, to avoid inf's in log_prob
# fall back to cdf, if log_cdf not available
log_cdf = getattr(self.base_dist, "log_cdf", None)
if callable(log_cdf):
return logsumexp(
a=jnp.stack([log_cdf(self.high), log_cdf(self.low)], axis=-1),
axis=-1,
b=jnp.array([1, -1]), # subtract low from high
)
else:
loc = self.base_dist.loc
sign = jnp.where(loc >= self.low, 1.0, -1.0)
return jnp.log(sign * (self._tail_prob_at_high - self._tail_prob_at_low))
[docs]
def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike:
assert is_prng_key(key)
dtype = jnp.result_type(float)
finfo = jnp.finfo(dtype)
minval = finfo.tiny
u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval)
return self.icdf(u)
[docs]
def icdf(self, q: ArrayLike) -> ArrayLike:
# NB: we use a more numerically stable formula for a symmetric base distribution
# A = icdf(cdf(low) + (cdf(high) - cdf(low)) * q) = icdf[(1 - q) * cdf(low) + q * cdf(high)]
# will suffer by precision issues when low is large;
# If low < loc:
# A = icdf[(1 - q) * cdf(low) + q * cdf(high)]
# Else
# A = 2 * loc - icdf[(1 - q) * cdf(2*loc-low)) + q * cdf(2*loc - high)]
loc = self.base_dist.loc
sign = jnp.where(loc >= self.low, 1.0, -1.0)
ppf = (1 - sign) * loc + sign * self.base_dist.icdf(
clamp_probs((1 - q) * self._tail_prob_at_low + q * self._tail_prob_at_high)
)
return jnp.where(jnp.logical_or(q < 0, q > 1), jnp.nan, ppf)
[docs]
def cdf(self, value: ArrayLike) -> ArrayLike:
# For two-sided truncated distribution: CDF(x) = (F(x) - F(low)) / (F(high) - F(low))
# where F is the base distribution CDF
base_cdf_value = self.base_dist.cdf(value)
base_cdf_low = self.base_dist.cdf(self.low)
base_cdf_high = self.base_dist.cdf(self.high)
# Calculate the normalization constant (F(high) - F(low))
normalization = base_cdf_high - base_cdf_low
# Calculate the truncated CDF
truncated_cdf = (base_cdf_value - base_cdf_low) / normalization
# Handle values outside the truncation interval
result = jnp.where(
value < self.low,
0.0,
jnp.where(value > self.high, 1.0, jnp.clip(truncated_cdf, 0.0, 1.0)),
)
return result
[docs]
@validate_sample
def log_prob(self, value: ArrayLike) -> ArrayLike:
# NB: we use a more numerically stable formula for a symmetric base distribution
# if low < loc
# cdf(high) - cdf(low) = as-is
# if low > loc
# cdf(high) - cdf(low) = cdf(2 * loc - low) - cdf(2 * loc - high)
return self.base_dist.log_prob(value) - self._log_diff_tail_probs
@property
def mean(self) -> ArrayLike:
if isinstance(self.base_dist, Normal):
low_prob = jnp.exp(self.log_prob(self.low))
high_prob = jnp.exp(self.log_prob(self.high))
return self.base_dist.loc + (low_prob - high_prob) * self.base_dist.scale**2
elif isinstance(self.base_dist, Cauchy):
return jnp.full(self.batch_shape, jnp.nan)
else:
raise NotImplementedError("mean only available for Normal and Cauchy")
@property
def var(self) -> ArrayLike:
if isinstance(self.base_dist, Normal):
low_prob = jnp.exp(self.log_prob(self.low))
high_prob = jnp.exp(self.log_prob(self.high))
return (self.base_dist.scale**2) * (
1
+ (self.low - self.base_dist.loc) * low_prob
- (self.high - self.base_dist.loc) * high_prob
- ((low_prob - high_prob) * self.base_dist.scale) ** 2
)
elif isinstance(self.base_dist, Cauchy):
return jnp.full(self.batch_shape, jnp.nan)
else:
raise NotImplementedError("var only available for Normal and Cauchy")
[docs]
def TruncatedDistribution(
base_dist: Union[Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT],
low: Optional[ArrayLike] = None,
high: Optional[ArrayLike] = None,
*,
validate_args: Optional[bool] = None,
):
"""
A function to generate a truncated distribution.
:param base_dist: The base distribution to be truncated. This should be a univariate
distribution. Currently, only the following distributions are supported:
Cauchy, Laplace, Logistic, Normal, and StudentT.
:param low: the value which is used to truncate the base distribution from below.
Setting this parameter to None to not truncate from below.
:param high: the value which is used to truncate the base distribution from above.
Setting this parameter to None to not truncate from above.
"""
if high is None:
if low is None:
return base_dist
else:
return LeftTruncatedDistribution(
base_dist, low=low, validate_args=validate_args
)
elif low is None:
return RightTruncatedDistribution(
base_dist, high=high, validate_args=validate_args
)
else:
return TwoSidedTruncatedDistribution(
base_dist, low=low, high=high, validate_args=validate_args
)
[docs]
def TruncatedCauchy(
loc: ArrayLike = 0.0,
scale: ArrayLike = 1.0,
*,
low: Optional[ArrayLike] = None,
high: Optional[ArrayLike] = None,
validate_args: Optional[bool] = None,
):
return TruncatedDistribution(
Cauchy(loc, scale), low=low, high=high, validate_args=validate_args
)
[docs]
def TruncatedNormal(
loc: ArrayLike = 0.0,
scale: ArrayLike = 1.0,
*,
low: Optional[ArrayLike] = None,
high: Optional[ArrayLike] = None,
validate_args: Optional[bool] = None,
):
return TruncatedDistribution(
Normal(loc, scale), low=low, high=high, validate_args=validate_args
)
[docs]
class TruncatedPolyaGamma(Distribution):
truncation_point = 2.5
num_log_prob_terms = 7
num_gamma_variates = 8
assert num_log_prob_terms % 2 == 1
arg_constraints = {}
support = constraints.interval(0.0, truncation_point)
def __init__(
self, batch_shape: tuple[int, ...] = (), *, validate_args: Optional[bool] = None
):
super(TruncatedPolyaGamma, self).__init__(
batch_shape, validate_args=validate_args
)
[docs]
def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike:
assert is_prng_key(key)
denom = jnp.square(jnp.arange(0.5, self.num_gamma_variates))
x = random.gamma(
key, jnp.ones(self.batch_shape + sample_shape + (self.num_gamma_variates,))
)
x = jnp.sum(x / denom, axis=-1)
return jnp.clip(x * (0.5 / jnp.pi**2), None, self.truncation_point)
[docs]
@validate_sample
def log_prob(self, value: ArrayLike) -> ArrayLike:
value = value[..., None]
all_indices = jnp.arange(0, self.num_log_prob_terms)
two_n_plus_one = 2.0 * all_indices + 1.0
log_terms = (
jnp.log(two_n_plus_one)
- 1.5 * jnp.log(value)
- 0.125 * jnp.square(two_n_plus_one) / value
)
even_terms = jnp.take(log_terms, all_indices[::2], axis=-1)
odd_terms = jnp.take(log_terms, all_indices[1::2], axis=-1)
sum_even = jnp.exp(logsumexp(even_terms, axis=-1))
sum_odd = jnp.exp(logsumexp(odd_terms, axis=-1))
return jnp.log(sum_even - sum_odd) - 0.5 * jnp.log(2.0 * jnp.pi)
[docs]
class DoublyTruncatedPowerLaw(Distribution):
r"""Power law distribution with :math:`\alpha` index, and lower and upper bounds.
We can define the power law distribution as,
.. math::
f(x; \alpha, a, b) = \frac{x^{\alpha}}{Z(\alpha, a, b)},
where, :math:`a` and :math:`b` are the lower and upper bounds respectively,
and :math:`Z(\alpha, a, b)` is the normalization constant. It is defined as,
.. math::
Z(\alpha, a, b) = \begin{cases}
\log(b) - \log(a) & \text{if } \alpha = -1, \\
\frac{b^{1 + \alpha} - a^{1 + \alpha}}{1 + \alpha} & \text{otherwise}.
\end{cases}
:param alpha: index of the power law distribution
:param low: lower bound of the distribution
:param high: upper bound of the distribution
"""
arg_constraints = {
"alpha": constraints.real,
"low": constraints.greater_than_eq(0),
"high": constraints.greater_than(0),
}
reparametrized_params = ["alpha", "low", "high"]
pytree_aux_fields = ("_support",)
pytree_data_fields = ("alpha", "low", "high")
def __init__(
self,
alpha: ArrayLike,
low: ArrayLike,
high: ArrayLike,
*,
validate_args: Optional[bool] = None,
):
self.alpha, self.low, self.high = promote_shapes(alpha, low, high)
self._support = constraints.interval(low, high)
batch_shape = lax.broadcast_shapes(
jnp.shape(alpha), jnp.shape(low), jnp.shape(high)
)
super(DoublyTruncatedPowerLaw, self).__init__(
batch_shape=batch_shape, validate_args=validate_args
)
@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self) -> Constraint:
return self._support
[docs]
@validate_sample
def log_prob(self, value: ArrayLike) -> ArrayLike:
r"""Logarithmic probability distribution:
Z inequal minus one:
.. math::
\frac{(\alpha + 1)x^\alpha}{b^{\alpha + 1} - a^{\alpha + 1}}
Z equal minus one:
.. math::
\frac{x^\alpha}{\log(b) - \log(a)}
Derivations are calculated by Wolfram Alpha via the Jacobian matrix accordingly.
"""
@jax.custom_jvp
def f(
x: ArrayLike, alpha: ArrayLike, low: ArrayLike, high: ArrayLike
) -> ArrayLike:
neq_neg1_mask = jnp.not_equal(alpha, -1.0)
neq_neg1_alpha = jnp.where(neq_neg1_mask, alpha, 0.0)
# eq_neg1_alpha = jnp.where(~neq_neg1_mask, alpha, -1.0)
def neq_neg1_fn() -> ArrayLike:
one_more_alpha = 1.0 + neq_neg1_alpha
return jnp.log(
jnp.power(x, neq_neg1_alpha)
* (one_more_alpha)
/ (jnp.power(high, one_more_alpha) - jnp.power(low, one_more_alpha))
)
def eq_neg1_fn() -> ArrayLike:
return -jnp.log(x) - jnp.log(jnp.log(high) - jnp.log(low))
return jnp.where(neq_neg1_mask, neq_neg1_fn(), eq_neg1_fn())
@f.defjvp
def f_jvp(
primals: tuple[ArrayLike, ArrayLike, ArrayLike, ArrayLike],
tangents: tuple[ArrayLike, ArrayLike, ArrayLike, ArrayLike],
) -> tuple[ArrayLike, ArrayLike]:
x, alpha, low, high = primals
x_t, alpha_t, low_t, high_t = tangents
log_low = jnp.log(low)
log_high = jnp.log(high)
log_x = jnp.log(x)
# Mask and alpha values
delta_eq_neg1 = 10e-4
neq_neg1_mask = jnp.not_equal(alpha, -1.0)
neq_neg1_alpha = jnp.where(neq_neg1_mask, alpha, 0.0)
eq_neg1_alpha = jnp.where(jnp.not_equal(alpha, 0.0), alpha, -1.0)
primal_out = f(*primals)
# Alpha tangent with approximation
# Variable part for all values alpha unequal -1
def alpha_tangent_variable(alpha: ArrayLike) -> ArrayLike:
one_more_alpha = 1.0 + alpha
low_pow_one_more_alpha = jnp.power(low, one_more_alpha)
high_pow_one_more_alpha = jnp.power(high, one_more_alpha)
return jnp.reciprocal(one_more_alpha) + (
low_pow_one_more_alpha * log_low
- high_pow_one_more_alpha * log_high
) / (high_pow_one_more_alpha - low_pow_one_more_alpha)
# Alpha tangent
alpha_tangent = jnp.where(
neq_neg1_mask,
log_x + alpha_tangent_variable(neq_neg1_alpha),
# Approximate derivative with right and lefthand approximation
log_x
+ (
alpha_tangent_variable(alpha - delta_eq_neg1)
+ alpha_tangent_variable(alpha + delta_eq_neg1)
)
* 0.5,
)
# High and low tangents for alpha unequal -1
one_more_alpha = 1.0 + neq_neg1_alpha
low_pow_one_more_alpha = jnp.power(low, one_more_alpha)
high_pow_one_more_alpha = jnp.power(high, one_more_alpha)
change_sq = jnp.square(high_pow_one_more_alpha - low_pow_one_more_alpha)
low_tangent_neq_neg1_common = (
jnp.square(one_more_alpha) * jnp.power(x, neq_neg1_alpha) / change_sq
)
low_tangent_neq_neg1 = low_tangent_neq_neg1_common * jnp.power(
low, neq_neg1_alpha
)
high_tangent_neq_neg1 = low_tangent_neq_neg1_common * jnp.power(
high, neq_neg1_alpha
)
# High and low tangents for alpha equal -1
low_tangent_eq_neg1_common = jnp.power(x, eq_neg1_alpha) / jnp.square(
log_high - log_low
)
low_tangent_eq_neg1 = low_tangent_eq_neg1_common / low
high_tangent_eq_neg1 = -low_tangent_eq_neg1_common / high
# High and low tangents
low_tangent = jnp.where(
neq_neg1_mask, low_tangent_neq_neg1, low_tangent_eq_neg1
)
high_tangent = jnp.where(
neq_neg1_mask, high_tangent_neq_neg1, high_tangent_eq_neg1
)
# Final tangents
tangent_out = (
alpha / x * x_t
+ alpha_tangent * alpha_t
+ low_tangent * low_t
+ high_tangent * high_t
)
return primal_out, tangent_out
return f(value, self.alpha, self.low, self.high)
[docs]
def cdf(self, value: ArrayLike) -> ArrayLike:
r"""Cumulated probability distribution:
Z inequal minus one:
.. math::
\frac{x^{\alpha + 1} - a^{\alpha + 1}}{b^{\alpha + 1} - a^{\alpha + 1}}
Z equal minus one:
.. math::
\frac{\log(x) - \log(a)}{\log(b) - \log(a)}
Derivations are calculated by Wolfram Alpha via the Jacobian matrix accordingly.
"""
@jax.custom_jvp
def f(
x: ArrayLike, alpha: ArrayLike, low: ArrayLike, high: ArrayLike
) -> ArrayLike:
neq_neg1_mask = jnp.not_equal(alpha, -1.0)
neq_neg1_alpha = jnp.where(neq_neg1_mask, alpha, 0.0)
def cdf_when_alpha_neq_neg1() -> ArrayLike:
one_more_alpha = 1.0 + neq_neg1_alpha
low_pow_one_more_alpha = jnp.power(low, one_more_alpha)
return (jnp.power(x, one_more_alpha) - low_pow_one_more_alpha) / (
jnp.power(high, one_more_alpha) - low_pow_one_more_alpha
)
def cdf_when_alpha_eq_neg1() -> ArrayLike:
return jnp.log(x / low) / jnp.log(high / low)
cdf_val = jnp.where(
neq_neg1_mask,
cdf_when_alpha_neq_neg1(),
cdf_when_alpha_eq_neg1(),
)
return jnp.clip(cdf_val, 0.0, 1.0)
@f.defjvp
def f_jvp(
primals: tuple[ArrayLike, ArrayLike, ArrayLike, ArrayLike],
tangents: tuple[ArrayLike, ArrayLike, ArrayLike, ArrayLike],
) -> tuple[ArrayLike, ArrayLike]:
x, alpha, low, high = primals
x_t, alpha_t, low_t, high_t = tangents
log_low = jnp.log(low)
log_high = jnp.log(high)
log_x = jnp.log(x)
delta_eq_neg1 = 10e-4
neq_neg1_mask = jnp.not_equal(alpha, -1.0)
neq_neg1_alpha = jnp.where(neq_neg1_mask, alpha, 0.0)
# Calculate primal
primal_out = f(*primals)
# Tangents for alpha not equals -1
def x_neq_neg1(alpha: ArrayLike) -> ArrayLike:
one_more_alpha = 1.0 + alpha
return (one_more_alpha * jnp.power(x, alpha)) / (
jnp.power(high, one_more_alpha) - jnp.power(low, one_more_alpha)
)
def alpha_neq_neg1(alpha: ArrayLike) -> ArrayLike:
one_more_alpha = 1.0 + alpha
low_pow_one_more_alpha = jnp.power(low, one_more_alpha)
high_pow_one_more_alpha = jnp.power(high, one_more_alpha)
x_pow_one_more_alpha = jnp.power(x, one_more_alpha)
term1 = (
x_pow_one_more_alpha * log_x - low_pow_one_more_alpha * log_low
) / (high_pow_one_more_alpha - low_pow_one_more_alpha)
term2 = (
(x_pow_one_more_alpha - low_pow_one_more_alpha)
* (
high_pow_one_more_alpha * log_high
- low_pow_one_more_alpha * log_low
)
) / jnp.square(high_pow_one_more_alpha - low_pow_one_more_alpha)
return term1 - term2
def low_neq_neg1(alpha: ArrayLike) -> ArrayLike:
one_more_alpha = 1.0 + alpha
low_pow_one_more_alpha = jnp.power(low, one_more_alpha)
high_pow_one_more_alpha = jnp.power(high, one_more_alpha)
x_pow_one_more_alpha = jnp.power(x, one_more_alpha)
change = high_pow_one_more_alpha - low_pow_one_more_alpha
term2 = one_more_alpha * jnp.power(low, alpha) / change
term1 = term2 * (x_pow_one_more_alpha - low_pow_one_more_alpha) / change
return term1 - term2
def high_neq_neg1(alpha: ArrayLike) -> ArrayLike:
one_more_alpha = 1.0 + alpha
low_pow_one_more_alpha = jnp.power(low, one_more_alpha)
high_pow_one_more_alpha = jnp.power(high, one_more_alpha)
x_pow_one_more_alpha = jnp.power(x, one_more_alpha)
return -(
one_more_alpha
* jnp.power(high, alpha)
* (x_pow_one_more_alpha - low_pow_one_more_alpha)
) / jnp.square(high_pow_one_more_alpha - low_pow_one_more_alpha)
# Tangents for alpha equals -1
def x_eq_neg1() -> ArrayLike:
return jnp.reciprocal(x * (log_high - log_low))
def low_eq_neg1() -> ArrayLike:
return (log_x - log_low) / (
jnp.square(log_high - log_low) * low
) - jnp.reciprocal((log_high - log_low) * low)
def high_eq_neg1() -> ArrayLike:
return (log_x - log_low) / (jnp.square(log_high - log_low) * high)
# Including approximation for alpha = -1
tangent_out = (
jnp.where(neq_neg1_mask, x_neq_neg1(neq_neg1_alpha), x_eq_neg1()) * x_t
+ jnp.where(
neq_neg1_mask,
alpha_neq_neg1(neq_neg1_alpha),
(
alpha_neq_neg1(alpha - delta_eq_neg1)
+ alpha_neq_neg1(alpha + delta_eq_neg1)
)
* 0.5,
)
* alpha_t
+ jnp.where(neq_neg1_mask, low_neq_neg1(neq_neg1_alpha), low_eq_neg1())
* low_t
+ jnp.where(
neq_neg1_mask, high_neq_neg1(neq_neg1_alpha), high_eq_neg1()
)
* high_t
)
return primal_out, tangent_out
return f(value, self.alpha, self.low, self.high)
[docs]
def icdf(self, q: ArrayLike) -> ArrayLike:
r"""Inverse cumulated probability distribution:
Z inequal minus one:
.. math::
a \left(\frac{b}{a}\right)^{q}
Z equal minus one:
.. math::
\left(a^{1 + \alpha} + q (b^{1 + \alpha} - a^{1 + \alpha})\right)^{\frac{1}{1 + \alpha}}
Derivations are calculated by Wolfram Alpha via the Jacobian matrix accordingly.
"""
@jax.custom_jvp
def f(
q: ArrayLike, alpha: ArrayLike, low: ArrayLike, high: ArrayLike
) -> ArrayLike:
neq_neg1_mask = jnp.not_equal(alpha, -1.0)
neq_neg1_alpha = jnp.where(neq_neg1_mask, alpha, 0.0)
def icdf_alpha_neq_neg1() -> ArrayLike:
one_more_alpha = 1.0 + neq_neg1_alpha
low_pow_one_more_alpha = jnp.power(low, one_more_alpha)
high_pow_one_more_alpha = jnp.power(high, one_more_alpha)
return jnp.power(
low_pow_one_more_alpha
+ q * (high_pow_one_more_alpha - low_pow_one_more_alpha),
jnp.reciprocal(one_more_alpha),
)
def icdf_alpha_eq_neg1() -> ArrayLike:
return jnp.power(high / low, q) * low
icdf_val = jnp.where(
neq_neg1_mask,
icdf_alpha_neq_neg1(),
icdf_alpha_eq_neg1(),
)
return icdf_val
@f.defjvp
def f_jvp(
primals: tuple[ArrayLike, ArrayLike, ArrayLike, ArrayLike],
tangents: tuple[ArrayLike, ArrayLike, ArrayLike, ArrayLike],
) -> tuple[ArrayLike, ArrayLike]:
x, alpha, low, high = primals
x_t, alpha_t, low_t, high_t = tangents
log_low = jnp.log(low)
log_high = jnp.log(high)
high_over_low = jnp.divide(high, low)
delta_eq_neg1 = 10e-4
neq_neg1_mask = jnp.not_equal(alpha, -1.0)
neq_neg1_alpha = jnp.where(neq_neg1_mask, alpha, 0.0)
primal_out = f(*primals)
# Tangents for alpha not equal -1
def x_neq_neg1(alpha: ArrayLike) -> ArrayLike:
one_more_alpha = 1.0 + alpha
low_pow_one_more_alpha = jnp.power(low, one_more_alpha)
high_pow_one_more_alpha = jnp.power(high, one_more_alpha)
change = high_pow_one_more_alpha - low_pow_one_more_alpha
return (
change
* jnp.power(
low_pow_one_more_alpha + x * change,
jnp.reciprocal(one_more_alpha) - 1,
)
) / one_more_alpha
def alpha_neq_neg1(alpha: ArrayLike) -> ArrayLike:
one_more_alpha = 1.0 + alpha
low_pow_one_more_alpha = jnp.power(low, one_more_alpha)
high_pow_one_more_alpha = jnp.power(high, one_more_alpha)
factor0 = low_pow_one_more_alpha + x * (
high_pow_one_more_alpha - low_pow_one_more_alpha
)
term1 = jnp.power(factor0, jnp.reciprocal(one_more_alpha))
term2 = (
low_pow_one_more_alpha * log_low
+ x
* (
high_pow_one_more_alpha * log_high
- low_pow_one_more_alpha * log_low
)
) / (one_more_alpha * factor0)
term3 = jnp.log(factor0) / jnp.square(one_more_alpha)
return term1 * (term2 - term3)
def low_neq_neg1(alpha: ArrayLike) -> ArrayLike:
one_more_alpha = 1.0 + alpha
low_pow_one_more_alpha = jnp.power(low, one_more_alpha)
high_pow_one_more_alpha = jnp.power(high, one_more_alpha)
return (
(1.0 - x)
* jnp.power(low, alpha)
* jnp.power(
low_pow_one_more_alpha
+ x * (high_pow_one_more_alpha - low_pow_one_more_alpha),
jnp.reciprocal(one_more_alpha) - 1,
)
)
def high_neq_neg1(alpha: ArrayLike) -> ArrayLike:
one_more_alpha = 1.0 + alpha
low_pow_one_more_alpha = jnp.power(low, one_more_alpha)
high_pow_one_more_alpha = jnp.power(high, one_more_alpha)
return (
x
* jnp.power(high, alpha)
* jnp.power(
low_pow_one_more_alpha
+ x * (high_pow_one_more_alpha - low_pow_one_more_alpha),
jnp.reciprocal(one_more_alpha) - 1,
)
)
# Tangents for alpha equals -1
def dx_eq_neg1() -> ArrayLike:
return low * jnp.power(high_over_low, x) * (log_high - log_low)
def low_eq_neg1() -> ArrayLike:
return (
jnp.power(high_over_low, x)
- (high * x * jnp.power(high_over_low, x - 1)) / low
)
def high_eq_neg1() -> ArrayLike:
return x * jnp.power(high_over_low, x - 1)
# Including approximation for alpha = -1 \
tangent_out = (
jnp.where(neq_neg1_mask, x_neq_neg1(neq_neg1_alpha), dx_eq_neg1()) * x_t
+ jnp.where(
neq_neg1_mask,
alpha_neq_neg1(neq_neg1_alpha),
(
alpha_neq_neg1(alpha - delta_eq_neg1)
+ alpha_neq_neg1(alpha + delta_eq_neg1)
)
* 0.5,
)
* alpha_t
+ jnp.where(neq_neg1_mask, low_neq_neg1(neq_neg1_alpha), low_eq_neg1())
* low_t
+ jnp.where(
neq_neg1_mask, high_neq_neg1(neq_neg1_alpha), high_eq_neg1()
)
* high_t
)
return primal_out, tangent_out
return f(q, self.alpha, self.low, self.high)
[docs]
def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike:
assert is_prng_key(key)
u = random.uniform(key, sample_shape + self.batch_shape)
samples = self.icdf(u)
return samples
[docs]
class LowerTruncatedPowerLaw(Distribution):
r"""Lower truncated power law distribution with :math:`\alpha` index.
We can define the power law distribution as,
.. math::
f(x; \alpha, a) = (-\alpha-1)a^{-\alpha - 1}x^{-\alpha},
\qquad x \geq a, \qquad \alpha < -1,
where, :math:`a` is the lower bound. The cdf of the distribution is given by,
.. math::
F(x; \alpha, a) = 1 - \left(\frac{x}{a}\right)^{1+\alpha}.
The k-th moment of the distribution is given by,
.. math::
E[X^k] = \begin{cases}
\frac{-\alpha-1}{-\alpha-1-k}a^k & \text{if } k < -\alpha-1, \\
\infty & \text{otherwise}.
\end{cases}
:param alpha: index of the power law distribution
:param low: lower bound of the distribution
"""
arg_constraints = {
"alpha": constraints.less_than(-1.0),
"low": constraints.greater_than(0.0),
}
reparametrized_params = ["alpha", "low"]
pytree_aux_fields = ("_support",)
def __init__(
self, alpha: ArrayLike, low: ArrayLike, *, validate_args: Optional[bool] = None
):
self.alpha, self.low = promote_shapes(alpha, low)
batch_shape = lax.broadcast_shapes(jnp.shape(alpha), jnp.shape(low))
self._support = constraints.greater_than(low)
super(LowerTruncatedPowerLaw, self).__init__(
batch_shape=batch_shape, validate_args=validate_args
)
@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self) -> Constraint:
return self._support
[docs]
@validate_sample
def log_prob(self, value: ArrayLike) -> ArrayLike:
one_more_alpha = 1.0 + self.alpha
return (
self.alpha * jnp.log(value)
+ jnp.log(-one_more_alpha)
- one_more_alpha * jnp.log(self.low)
)
[docs]
def cdf(self, value: ArrayLike) -> ArrayLike:
cdf_val = jnp.where(
jnp.less_equal(value, self.low),
jnp.zeros_like(value),
1.0 - jnp.power(value / self.low, 1.0 + self.alpha),
)
return cdf_val
[docs]
def icdf(self, q: ArrayLike) -> ArrayLike:
nan_mask = jnp.logical_or(jnp.isnan(q), jnp.less(q, 0.0))
nan_mask = jnp.logical_or(nan_mask, jnp.greater(q, 1.0))
return jnp.where(
nan_mask,
jnp.nan,
self.low * jnp.power(1.0 - q, jnp.reciprocal(1.0 + self.alpha)),
)
[docs]
def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike:
assert is_prng_key(key)
u = random.uniform(key, sample_shape + self.batch_shape)
samples = self.icdf(u)
return samples