Source code for numpyro.distributions.truncated

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0


from typing import Optional, Union

import jax
from jax import lax
import jax.numpy as jnp
import jax.random as random
from jax.scipy.special import logsumexp
from jax.typing import ArrayLike

from numpyro.distributions import constraints
from numpyro.distributions.constraints import Constraint
from numpyro.distributions.continuous import (
    Cauchy,
    Laplace,
    Logistic,
    Normal,
    SoftLaplace,
    StudentT,
)
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import (
    clamp_probs,
    lazy_property,
    promote_shapes,
    validate_sample,
)
from numpyro.util import is_prng_key


[docs] class LeftTruncatedDistribution(Distribution): arg_constraints = {"low": constraints.real} reparametrized_params = ["low"] supported_types = (Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT) pytree_data_fields = ("base_dist", "low", "_support") def __init__( self, base_dist: Union[Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT], low: ArrayLike = 0.0, *, validate_args: Optional[bool] = None, ): assert isinstance(base_dist, self.supported_types) assert base_dist.support is constraints.real, ( "The base distribution should be univariate and have real support." ) batch_shape = lax.broadcast_shapes(base_dist.batch_shape, jnp.shape(low)) self.base_dist: Union[ Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT ] = jax.tree.map(lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist) (self.low,) = promote_shapes(low, shape=batch_shape) self._support = constraints.greater_than(low) super().__init__(batch_shape, validate_args=validate_args) @constraints.dependent_property(is_discrete=False, event_dim=0) def support(self) -> Constraint: return self._support @lazy_property def _tail_prob_at_low(self): # if low < loc, returns cdf(low); otherwise returns 1 - cdf(low) loc = self.base_dist.loc sign = jnp.where(loc >= self.low, 1.0, -1.0) return self.base_dist.cdf(loc - sign * (loc - self.low)) @lazy_property def _tail_prob_at_high(self): # if low < loc, returns cdf(high) = 1; otherwise returns 1 - cdf(high) = 0 return jnp.where(self.low <= self.base_dist.loc, 1.0, 0.0)
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: assert is_prng_key(key) dtype = jnp.result_type(float) finfo = jnp.finfo(dtype) minval = finfo.tiny u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval) return self.icdf(u)
[docs] def icdf(self, q: ArrayLike) -> ArrayLike: loc = self.base_dist.loc sign = jnp.where(loc >= self.low, 1.0, -1.0) ppf = (1 - sign) * loc + sign * self.base_dist.icdf( (1 - q) * self._tail_prob_at_low + q * self._tail_prob_at_high ) return jnp.where(q < 0, jnp.nan, ppf)
[docs] def cdf(self, value: ArrayLike) -> ArrayLike: # For left truncated distribution: CDF(x) = (F(x) - F(low)) / (1 - F(low)) # where F is the base distribution CDF base_cdf_value = self.base_dist.cdf(value) base_cdf_low = self.base_dist.cdf(self.low) # Handle the case where value < low (should be 0) # and value >= low (should be the truncated CDF) truncated_cdf = (base_cdf_value - base_cdf_low) / (1.0 - base_cdf_low) # Clamp to [0, 1] and handle values below the truncation point result = jnp.where(value < self.low, 0.0, jnp.clip(truncated_cdf, 0.0, 1.0)) return result
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: sign = jnp.where(self.base_dist.loc >= self.low, 1.0, -1.0) return self.base_dist.log_prob(value) - jnp.log( sign * (self._tail_prob_at_high - self._tail_prob_at_low) )
@property def mean(self) -> ArrayLike: if isinstance(self.base_dist, Normal): low_prob = jnp.exp(self.log_prob(self.low)) return self.base_dist.loc + low_prob * self.base_dist.scale**2 elif isinstance(self.base_dist, Cauchy): return jnp.full(self.batch_shape, jnp.nan) else: raise NotImplementedError("mean only available for Normal and Cauchy") @property def var(self) -> ArrayLike: if isinstance(self.base_dist, Normal): low_prob = jnp.exp(self.log_prob(self.low)) return (self.base_dist.scale**2) * ( 1 + (self.low - self.base_dist.loc) * low_prob - (low_prob * self.base_dist.scale) ** 2 ) elif isinstance(self.base_dist, Cauchy): return jnp.full(self.batch_shape, jnp.nan) else: raise NotImplementedError("var only available for Normal and Cauchy")
[docs] class RightTruncatedDistribution(Distribution): arg_constraints = {"high": constraints.real} reparametrized_params = ["high"] supported_types = (Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT) pytree_data_fields = ("base_dist", "high", "_support") def __init__( self, base_dist: Union[Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT], high: ArrayLike = 0.0, *, validate_args: Optional[bool] = None, ): assert isinstance(base_dist, self.supported_types) assert base_dist.support is constraints.real, ( "The base distribution should be univariate and have real support." ) batch_shape = lax.broadcast_shapes(base_dist.batch_shape, jnp.shape(high)) self.base_dist: Union[ Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT ] = jax.tree.map(lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist) (self.high,) = promote_shapes(high, shape=batch_shape) self._support = constraints.less_than(high) super().__init__(batch_shape, validate_args=validate_args) @constraints.dependent_property(is_discrete=False, event_dim=0) def support(self) -> Constraint: return self._support @lazy_property def _cdf_at_high(self) -> ArrayLike: return self.base_dist.cdf(self.high)
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: assert is_prng_key(key) dtype = jnp.result_type(float) finfo = jnp.finfo(dtype) minval = finfo.tiny u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval) return self.icdf(u)
[docs] def icdf(self, q: ArrayLike) -> ArrayLike: ppf = self.base_dist.icdf(q * self._cdf_at_high) return jnp.where(q > 1, jnp.nan, ppf)
[docs] def cdf(self, value: ArrayLike) -> ArrayLike: # For right truncated distribution: CDF(x) = F(x) / F(high) # where F is the base distribution CDF base_cdf_value = self.base_dist.cdf(value) base_cdf_high = self._cdf_at_high # Handle the case where value > high (should be 1) # and value <= high (should be the truncated CDF) truncated_cdf = base_cdf_value / base_cdf_high # Clamp to [0, 1] and handle values above the truncation point result = jnp.where(value > self.high, 1.0, jnp.clip(truncated_cdf, 0.0, 1.0)) return result
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: return self.base_dist.log_prob(value) - jnp.log(self._cdf_at_high)
@property def mean(self) -> ArrayLike: if isinstance(self.base_dist, Normal): high_prob = jnp.exp(self.log_prob(self.high)) return self.base_dist.loc - high_prob * self.base_dist.scale**2 elif isinstance(self.base_dist, Cauchy): return jnp.full(self.batch_shape, jnp.nan) else: raise NotImplementedError("mean only available for Normal and Cauchy") @property def var(self) -> ArrayLike: if isinstance(self.base_dist, Normal): high_prob = jnp.exp(self.log_prob(self.high)) return (self.base_dist.scale**2) * ( 1 - (self.high - self.base_dist.loc) * high_prob - (high_prob * self.base_dist.scale) ** 2 ) elif isinstance(self.base_dist, Cauchy): return jnp.full(self.batch_shape, jnp.nan) else: raise NotImplementedError("var only available for Normal and Cauchy")
[docs] class TwoSidedTruncatedDistribution(Distribution): arg_constraints = { "low": constraints.dependent(is_discrete=False, event_dim=0), "high": constraints.dependent(is_discrete=False, event_dim=0), } reparametrized_params = ["low", "high"] supported_types = (Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT) pytree_data_fields = ("base_dist", "low", "high", "_support") def __init__( self, base_dist: Union[Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT], low: ArrayLike = 0.0, high: ArrayLike = 1.0, *, validate_args: Optional[bool] = None, ): assert isinstance(base_dist, self.supported_types) assert base_dist.support is constraints.real, ( "The base distribution should be univariate and have real support." ) batch_shape = lax.broadcast_shapes( base_dist.batch_shape, jnp.shape(low), jnp.shape(high) ) self.base_dist: Union[ Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT ] = jax.tree.map(lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist) (self.low,) = promote_shapes(low, shape=batch_shape) (self.high,) = promote_shapes(high, shape=batch_shape) self._support = constraints.interval(low, high) super().__init__(batch_shape, validate_args=validate_args) @constraints.dependent_property(is_discrete=False, event_dim=0) def support(self) -> Constraint: return self._support @lazy_property def _tail_prob_at_low(self) -> ArrayLike: # if low < loc, returns cdf(low); otherwise returns 1 - cdf(low) loc = self.base_dist.loc sign = jnp.where(loc >= self.low, 1.0, -1.0) return self.base_dist.cdf(loc - sign * (loc - self.low)) @lazy_property def _tail_prob_at_high(self) -> ArrayLike: # if low < loc, returns cdf(high); otherwise returns 1 - cdf(high) loc = self.base_dist.loc sign = jnp.where(loc >= self.low, 1.0, -1.0) return self.base_dist.cdf(loc - sign * (loc - self.high)) @lazy_property def _log_diff_tail_probs(self) -> ArrayLike: # use log_cdf method, if available, to avoid inf's in log_prob # fall back to cdf, if log_cdf not available log_cdf = getattr(self.base_dist, "log_cdf", None) if callable(log_cdf): return logsumexp( a=jnp.stack([log_cdf(self.high), log_cdf(self.low)], axis=-1), axis=-1, b=jnp.array([1, -1]), # subtract low from high ) else: loc = self.base_dist.loc sign = jnp.where(loc >= self.low, 1.0, -1.0) return jnp.log(sign * (self._tail_prob_at_high - self._tail_prob_at_low))
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: assert is_prng_key(key) dtype = jnp.result_type(float) finfo = jnp.finfo(dtype) minval = finfo.tiny u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval) return self.icdf(u)
[docs] def icdf(self, q: ArrayLike) -> ArrayLike: # NB: we use a more numerically stable formula for a symmetric base distribution # A = icdf(cdf(low) + (cdf(high) - cdf(low)) * q) = icdf[(1 - q) * cdf(low) + q * cdf(high)] # will suffer by precision issues when low is large; # If low < loc: # A = icdf[(1 - q) * cdf(low) + q * cdf(high)] # Else # A = 2 * loc - icdf[(1 - q) * cdf(2*loc-low)) + q * cdf(2*loc - high)] loc = self.base_dist.loc sign = jnp.where(loc >= self.low, 1.0, -1.0) ppf = (1 - sign) * loc + sign * self.base_dist.icdf( clamp_probs((1 - q) * self._tail_prob_at_low + q * self._tail_prob_at_high) ) return jnp.where(jnp.logical_or(q < 0, q > 1), jnp.nan, ppf)
[docs] def cdf(self, value: ArrayLike) -> ArrayLike: # For two-sided truncated distribution: CDF(x) = (F(x) - F(low)) / (F(high) - F(low)) # where F is the base distribution CDF base_cdf_value = self.base_dist.cdf(value) base_cdf_low = self.base_dist.cdf(self.low) base_cdf_high = self.base_dist.cdf(self.high) # Calculate the normalization constant (F(high) - F(low)) normalization = base_cdf_high - base_cdf_low # Calculate the truncated CDF truncated_cdf = (base_cdf_value - base_cdf_low) / normalization # Handle values outside the truncation interval result = jnp.where( value < self.low, 0.0, jnp.where(value > self.high, 1.0, jnp.clip(truncated_cdf, 0.0, 1.0)), ) return result
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: # NB: we use a more numerically stable formula for a symmetric base distribution # if low < loc # cdf(high) - cdf(low) = as-is # if low > loc # cdf(high) - cdf(low) = cdf(2 * loc - low) - cdf(2 * loc - high) return self.base_dist.log_prob(value) - self._log_diff_tail_probs
@property def mean(self) -> ArrayLike: if isinstance(self.base_dist, Normal): low_prob = jnp.exp(self.log_prob(self.low)) high_prob = jnp.exp(self.log_prob(self.high)) return self.base_dist.loc + (low_prob - high_prob) * self.base_dist.scale**2 elif isinstance(self.base_dist, Cauchy): return jnp.full(self.batch_shape, jnp.nan) else: raise NotImplementedError("mean only available for Normal and Cauchy") @property def var(self) -> ArrayLike: if isinstance(self.base_dist, Normal): low_prob = jnp.exp(self.log_prob(self.low)) high_prob = jnp.exp(self.log_prob(self.high)) return (self.base_dist.scale**2) * ( 1 + (self.low - self.base_dist.loc) * low_prob - (self.high - self.base_dist.loc) * high_prob - ((low_prob - high_prob) * self.base_dist.scale) ** 2 ) elif isinstance(self.base_dist, Cauchy): return jnp.full(self.batch_shape, jnp.nan) else: raise NotImplementedError("var only available for Normal and Cauchy")
[docs] def TruncatedDistribution( base_dist: Union[Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT], low: Optional[ArrayLike] = None, high: Optional[ArrayLike] = None, *, validate_args: Optional[bool] = None, ): """ A function to generate a truncated distribution. :param base_dist: The base distribution to be truncated. This should be a univariate distribution. Currently, only the following distributions are supported: Cauchy, Laplace, Logistic, Normal, and StudentT. :param low: the value which is used to truncate the base distribution from below. Setting this parameter to None to not truncate from below. :param high: the value which is used to truncate the base distribution from above. Setting this parameter to None to not truncate from above. """ if high is None: if low is None: return base_dist else: return LeftTruncatedDistribution( base_dist, low=low, validate_args=validate_args ) elif low is None: return RightTruncatedDistribution( base_dist, high=high, validate_args=validate_args ) else: return TwoSidedTruncatedDistribution( base_dist, low=low, high=high, validate_args=validate_args )
[docs] def TruncatedCauchy( loc: ArrayLike = 0.0, scale: ArrayLike = 1.0, *, low: Optional[ArrayLike] = None, high: Optional[ArrayLike] = None, validate_args: Optional[bool] = None, ): return TruncatedDistribution( Cauchy(loc, scale), low=low, high=high, validate_args=validate_args )
[docs] def TruncatedNormal( loc: ArrayLike = 0.0, scale: ArrayLike = 1.0, *, low: Optional[ArrayLike] = None, high: Optional[ArrayLike] = None, validate_args: Optional[bool] = None, ): return TruncatedDistribution( Normal(loc, scale), low=low, high=high, validate_args=validate_args )
[docs] class TruncatedPolyaGamma(Distribution): truncation_point = 2.5 num_log_prob_terms = 7 num_gamma_variates = 8 assert num_log_prob_terms % 2 == 1 arg_constraints = {} support = constraints.interval(0.0, truncation_point) def __init__( self, batch_shape: tuple[int, ...] = (), *, validate_args: Optional[bool] = None ): super(TruncatedPolyaGamma, self).__init__( batch_shape, validate_args=validate_args )
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: assert is_prng_key(key) denom = jnp.square(jnp.arange(0.5, self.num_gamma_variates)) x = random.gamma( key, jnp.ones(self.batch_shape + sample_shape + (self.num_gamma_variates,)) ) x = jnp.sum(x / denom, axis=-1) return jnp.clip(x * (0.5 / jnp.pi**2), None, self.truncation_point)
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: value = value[..., None] all_indices = jnp.arange(0, self.num_log_prob_terms) two_n_plus_one = 2.0 * all_indices + 1.0 log_terms = ( jnp.log(two_n_plus_one) - 1.5 * jnp.log(value) - 0.125 * jnp.square(two_n_plus_one) / value ) even_terms = jnp.take(log_terms, all_indices[::2], axis=-1) odd_terms = jnp.take(log_terms, all_indices[1::2], axis=-1) sum_even = jnp.exp(logsumexp(even_terms, axis=-1)) sum_odd = jnp.exp(logsumexp(odd_terms, axis=-1)) return jnp.log(sum_even - sum_odd) - 0.5 * jnp.log(2.0 * jnp.pi)
[docs] class DoublyTruncatedPowerLaw(Distribution): r"""Power law distribution with :math:`\alpha` index, and lower and upper bounds. We can define the power law distribution as, .. math:: f(x; \alpha, a, b) = \frac{x^{\alpha}}{Z(\alpha, a, b)}, where, :math:`a` and :math:`b` are the lower and upper bounds respectively, and :math:`Z(\alpha, a, b)` is the normalization constant. It is defined as, .. math:: Z(\alpha, a, b) = \begin{cases} \log(b) - \log(a) & \text{if } \alpha = -1, \\ \frac{b^{1 + \alpha} - a^{1 + \alpha}}{1 + \alpha} & \text{otherwise}. \end{cases} :param alpha: index of the power law distribution :param low: lower bound of the distribution :param high: upper bound of the distribution """ arg_constraints = { "alpha": constraints.real, "low": constraints.greater_than_eq(0), "high": constraints.greater_than(0), } reparametrized_params = ["alpha", "low", "high"] pytree_aux_fields = ("_support",) pytree_data_fields = ("alpha", "low", "high") def __init__( self, alpha: ArrayLike, low: ArrayLike, high: ArrayLike, *, validate_args: Optional[bool] = None, ): self.alpha, self.low, self.high = promote_shapes(alpha, low, high) self._support = constraints.interval(low, high) batch_shape = lax.broadcast_shapes( jnp.shape(alpha), jnp.shape(low), jnp.shape(high) ) super(DoublyTruncatedPowerLaw, self).__init__( batch_shape=batch_shape, validate_args=validate_args ) @constraints.dependent_property(is_discrete=False, event_dim=0) def support(self) -> Constraint: return self._support
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: r"""Logarithmic probability distribution: Z inequal minus one: .. math:: \frac{(\alpha + 1)x^\alpha}{b^{\alpha + 1} - a^{\alpha + 1}} Z equal minus one: .. math:: \frac{x^\alpha}{\log(b) - \log(a)} Derivations are calculated by Wolfram Alpha via the Jacobian matrix accordingly. """ @jax.custom_jvp def f( x: ArrayLike, alpha: ArrayLike, low: ArrayLike, high: ArrayLike ) -> ArrayLike: neq_neg1_mask = jnp.not_equal(alpha, -1.0) neq_neg1_alpha = jnp.where(neq_neg1_mask, alpha, 0.0) # eq_neg1_alpha = jnp.where(~neq_neg1_mask, alpha, -1.0) def neq_neg1_fn() -> ArrayLike: one_more_alpha = 1.0 + neq_neg1_alpha return jnp.log( jnp.power(x, neq_neg1_alpha) * (one_more_alpha) / (jnp.power(high, one_more_alpha) - jnp.power(low, one_more_alpha)) ) def eq_neg1_fn() -> ArrayLike: return -jnp.log(x) - jnp.log(jnp.log(high) - jnp.log(low)) return jnp.where(neq_neg1_mask, neq_neg1_fn(), eq_neg1_fn()) @f.defjvp def f_jvp( primals: tuple[ArrayLike, ArrayLike, ArrayLike, ArrayLike], tangents: tuple[ArrayLike, ArrayLike, ArrayLike, ArrayLike], ) -> tuple[ArrayLike, ArrayLike]: x, alpha, low, high = primals x_t, alpha_t, low_t, high_t = tangents log_low = jnp.log(low) log_high = jnp.log(high) log_x = jnp.log(x) # Mask and alpha values delta_eq_neg1 = 10e-4 neq_neg1_mask = jnp.not_equal(alpha, -1.0) neq_neg1_alpha = jnp.where(neq_neg1_mask, alpha, 0.0) eq_neg1_alpha = jnp.where(jnp.not_equal(alpha, 0.0), alpha, -1.0) primal_out = f(*primals) # Alpha tangent with approximation # Variable part for all values alpha unequal -1 def alpha_tangent_variable(alpha: ArrayLike) -> ArrayLike: one_more_alpha = 1.0 + alpha low_pow_one_more_alpha = jnp.power(low, one_more_alpha) high_pow_one_more_alpha = jnp.power(high, one_more_alpha) return jnp.reciprocal(one_more_alpha) + ( low_pow_one_more_alpha * log_low - high_pow_one_more_alpha * log_high ) / (high_pow_one_more_alpha - low_pow_one_more_alpha) # Alpha tangent alpha_tangent = jnp.where( neq_neg1_mask, log_x + alpha_tangent_variable(neq_neg1_alpha), # Approximate derivative with right and lefthand approximation log_x + ( alpha_tangent_variable(alpha - delta_eq_neg1) + alpha_tangent_variable(alpha + delta_eq_neg1) ) * 0.5, ) # High and low tangents for alpha unequal -1 one_more_alpha = 1.0 + neq_neg1_alpha low_pow_one_more_alpha = jnp.power(low, one_more_alpha) high_pow_one_more_alpha = jnp.power(high, one_more_alpha) change_sq = jnp.square(high_pow_one_more_alpha - low_pow_one_more_alpha) low_tangent_neq_neg1_common = ( jnp.square(one_more_alpha) * jnp.power(x, neq_neg1_alpha) / change_sq ) low_tangent_neq_neg1 = low_tangent_neq_neg1_common * jnp.power( low, neq_neg1_alpha ) high_tangent_neq_neg1 = low_tangent_neq_neg1_common * jnp.power( high, neq_neg1_alpha ) # High and low tangents for alpha equal -1 low_tangent_eq_neg1_common = jnp.power(x, eq_neg1_alpha) / jnp.square( log_high - log_low ) low_tangent_eq_neg1 = low_tangent_eq_neg1_common / low high_tangent_eq_neg1 = -low_tangent_eq_neg1_common / high # High and low tangents low_tangent = jnp.where( neq_neg1_mask, low_tangent_neq_neg1, low_tangent_eq_neg1 ) high_tangent = jnp.where( neq_neg1_mask, high_tangent_neq_neg1, high_tangent_eq_neg1 ) # Final tangents tangent_out = ( alpha / x * x_t + alpha_tangent * alpha_t + low_tangent * low_t + high_tangent * high_t ) return primal_out, tangent_out return f(value, self.alpha, self.low, self.high)
[docs] def cdf(self, value: ArrayLike) -> ArrayLike: r"""Cumulated probability distribution: Z inequal minus one: .. math:: \frac{x^{\alpha + 1} - a^{\alpha + 1}}{b^{\alpha + 1} - a^{\alpha + 1}} Z equal minus one: .. math:: \frac{\log(x) - \log(a)}{\log(b) - \log(a)} Derivations are calculated by Wolfram Alpha via the Jacobian matrix accordingly. """ @jax.custom_jvp def f( x: ArrayLike, alpha: ArrayLike, low: ArrayLike, high: ArrayLike ) -> ArrayLike: neq_neg1_mask = jnp.not_equal(alpha, -1.0) neq_neg1_alpha = jnp.where(neq_neg1_mask, alpha, 0.0) def cdf_when_alpha_neq_neg1() -> ArrayLike: one_more_alpha = 1.0 + neq_neg1_alpha low_pow_one_more_alpha = jnp.power(low, one_more_alpha) return (jnp.power(x, one_more_alpha) - low_pow_one_more_alpha) / ( jnp.power(high, one_more_alpha) - low_pow_one_more_alpha ) def cdf_when_alpha_eq_neg1() -> ArrayLike: return jnp.log(x / low) / jnp.log(high / low) cdf_val = jnp.where( neq_neg1_mask, cdf_when_alpha_neq_neg1(), cdf_when_alpha_eq_neg1(), ) return jnp.clip(cdf_val, 0.0, 1.0) @f.defjvp def f_jvp( primals: tuple[ArrayLike, ArrayLike, ArrayLike, ArrayLike], tangents: tuple[ArrayLike, ArrayLike, ArrayLike, ArrayLike], ) -> tuple[ArrayLike, ArrayLike]: x, alpha, low, high = primals x_t, alpha_t, low_t, high_t = tangents log_low = jnp.log(low) log_high = jnp.log(high) log_x = jnp.log(x) delta_eq_neg1 = 10e-4 neq_neg1_mask = jnp.not_equal(alpha, -1.0) neq_neg1_alpha = jnp.where(neq_neg1_mask, alpha, 0.0) # Calculate primal primal_out = f(*primals) # Tangents for alpha not equals -1 def x_neq_neg1(alpha: ArrayLike) -> ArrayLike: one_more_alpha = 1.0 + alpha return (one_more_alpha * jnp.power(x, alpha)) / ( jnp.power(high, one_more_alpha) - jnp.power(low, one_more_alpha) ) def alpha_neq_neg1(alpha: ArrayLike) -> ArrayLike: one_more_alpha = 1.0 + alpha low_pow_one_more_alpha = jnp.power(low, one_more_alpha) high_pow_one_more_alpha = jnp.power(high, one_more_alpha) x_pow_one_more_alpha = jnp.power(x, one_more_alpha) term1 = ( x_pow_one_more_alpha * log_x - low_pow_one_more_alpha * log_low ) / (high_pow_one_more_alpha - low_pow_one_more_alpha) term2 = ( (x_pow_one_more_alpha - low_pow_one_more_alpha) * ( high_pow_one_more_alpha * log_high - low_pow_one_more_alpha * log_low ) ) / jnp.square(high_pow_one_more_alpha - low_pow_one_more_alpha) return term1 - term2 def low_neq_neg1(alpha: ArrayLike) -> ArrayLike: one_more_alpha = 1.0 + alpha low_pow_one_more_alpha = jnp.power(low, one_more_alpha) high_pow_one_more_alpha = jnp.power(high, one_more_alpha) x_pow_one_more_alpha = jnp.power(x, one_more_alpha) change = high_pow_one_more_alpha - low_pow_one_more_alpha term2 = one_more_alpha * jnp.power(low, alpha) / change term1 = term2 * (x_pow_one_more_alpha - low_pow_one_more_alpha) / change return term1 - term2 def high_neq_neg1(alpha: ArrayLike) -> ArrayLike: one_more_alpha = 1.0 + alpha low_pow_one_more_alpha = jnp.power(low, one_more_alpha) high_pow_one_more_alpha = jnp.power(high, one_more_alpha) x_pow_one_more_alpha = jnp.power(x, one_more_alpha) return -( one_more_alpha * jnp.power(high, alpha) * (x_pow_one_more_alpha - low_pow_one_more_alpha) ) / jnp.square(high_pow_one_more_alpha - low_pow_one_more_alpha) # Tangents for alpha equals -1 def x_eq_neg1() -> ArrayLike: return jnp.reciprocal(x * (log_high - log_low)) def low_eq_neg1() -> ArrayLike: return (log_x - log_low) / ( jnp.square(log_high - log_low) * low ) - jnp.reciprocal((log_high - log_low) * low) def high_eq_neg1() -> ArrayLike: return (log_x - log_low) / (jnp.square(log_high - log_low) * high) # Including approximation for alpha = -1 tangent_out = ( jnp.where(neq_neg1_mask, x_neq_neg1(neq_neg1_alpha), x_eq_neg1()) * x_t + jnp.where( neq_neg1_mask, alpha_neq_neg1(neq_neg1_alpha), ( alpha_neq_neg1(alpha - delta_eq_neg1) + alpha_neq_neg1(alpha + delta_eq_neg1) ) * 0.5, ) * alpha_t + jnp.where(neq_neg1_mask, low_neq_neg1(neq_neg1_alpha), low_eq_neg1()) * low_t + jnp.where( neq_neg1_mask, high_neq_neg1(neq_neg1_alpha), high_eq_neg1() ) * high_t ) return primal_out, tangent_out return f(value, self.alpha, self.low, self.high)
[docs] def icdf(self, q: ArrayLike) -> ArrayLike: r"""Inverse cumulated probability distribution: Z inequal minus one: .. math:: a \left(\frac{b}{a}\right)^{q} Z equal minus one: .. math:: \left(a^{1 + \alpha} + q (b^{1 + \alpha} - a^{1 + \alpha})\right)^{\frac{1}{1 + \alpha}} Derivations are calculated by Wolfram Alpha via the Jacobian matrix accordingly. """ @jax.custom_jvp def f( q: ArrayLike, alpha: ArrayLike, low: ArrayLike, high: ArrayLike ) -> ArrayLike: neq_neg1_mask = jnp.not_equal(alpha, -1.0) neq_neg1_alpha = jnp.where(neq_neg1_mask, alpha, 0.0) def icdf_alpha_neq_neg1() -> ArrayLike: one_more_alpha = 1.0 + neq_neg1_alpha low_pow_one_more_alpha = jnp.power(low, one_more_alpha) high_pow_one_more_alpha = jnp.power(high, one_more_alpha) return jnp.power( low_pow_one_more_alpha + q * (high_pow_one_more_alpha - low_pow_one_more_alpha), jnp.reciprocal(one_more_alpha), ) def icdf_alpha_eq_neg1() -> ArrayLike: return jnp.power(high / low, q) * low icdf_val = jnp.where( neq_neg1_mask, icdf_alpha_neq_neg1(), icdf_alpha_eq_neg1(), ) return icdf_val @f.defjvp def f_jvp( primals: tuple[ArrayLike, ArrayLike, ArrayLike, ArrayLike], tangents: tuple[ArrayLike, ArrayLike, ArrayLike, ArrayLike], ) -> tuple[ArrayLike, ArrayLike]: x, alpha, low, high = primals x_t, alpha_t, low_t, high_t = tangents log_low = jnp.log(low) log_high = jnp.log(high) high_over_low = jnp.divide(high, low) delta_eq_neg1 = 10e-4 neq_neg1_mask = jnp.not_equal(alpha, -1.0) neq_neg1_alpha = jnp.where(neq_neg1_mask, alpha, 0.0) primal_out = f(*primals) # Tangents for alpha not equal -1 def x_neq_neg1(alpha: ArrayLike) -> ArrayLike: one_more_alpha = 1.0 + alpha low_pow_one_more_alpha = jnp.power(low, one_more_alpha) high_pow_one_more_alpha = jnp.power(high, one_more_alpha) change = high_pow_one_more_alpha - low_pow_one_more_alpha return ( change * jnp.power( low_pow_one_more_alpha + x * change, jnp.reciprocal(one_more_alpha) - 1, ) ) / one_more_alpha def alpha_neq_neg1(alpha: ArrayLike) -> ArrayLike: one_more_alpha = 1.0 + alpha low_pow_one_more_alpha = jnp.power(low, one_more_alpha) high_pow_one_more_alpha = jnp.power(high, one_more_alpha) factor0 = low_pow_one_more_alpha + x * ( high_pow_one_more_alpha - low_pow_one_more_alpha ) term1 = jnp.power(factor0, jnp.reciprocal(one_more_alpha)) term2 = ( low_pow_one_more_alpha * log_low + x * ( high_pow_one_more_alpha * log_high - low_pow_one_more_alpha * log_low ) ) / (one_more_alpha * factor0) term3 = jnp.log(factor0) / jnp.square(one_more_alpha) return term1 * (term2 - term3) def low_neq_neg1(alpha: ArrayLike) -> ArrayLike: one_more_alpha = 1.0 + alpha low_pow_one_more_alpha = jnp.power(low, one_more_alpha) high_pow_one_more_alpha = jnp.power(high, one_more_alpha) return ( (1.0 - x) * jnp.power(low, alpha) * jnp.power( low_pow_one_more_alpha + x * (high_pow_one_more_alpha - low_pow_one_more_alpha), jnp.reciprocal(one_more_alpha) - 1, ) ) def high_neq_neg1(alpha: ArrayLike) -> ArrayLike: one_more_alpha = 1.0 + alpha low_pow_one_more_alpha = jnp.power(low, one_more_alpha) high_pow_one_more_alpha = jnp.power(high, one_more_alpha) return ( x * jnp.power(high, alpha) * jnp.power( low_pow_one_more_alpha + x * (high_pow_one_more_alpha - low_pow_one_more_alpha), jnp.reciprocal(one_more_alpha) - 1, ) ) # Tangents for alpha equals -1 def dx_eq_neg1() -> ArrayLike: return low * jnp.power(high_over_low, x) * (log_high - log_low) def low_eq_neg1() -> ArrayLike: return ( jnp.power(high_over_low, x) - (high * x * jnp.power(high_over_low, x - 1)) / low ) def high_eq_neg1() -> ArrayLike: return x * jnp.power(high_over_low, x - 1) # Including approximation for alpha = -1 \ tangent_out = ( jnp.where(neq_neg1_mask, x_neq_neg1(neq_neg1_alpha), dx_eq_neg1()) * x_t + jnp.where( neq_neg1_mask, alpha_neq_neg1(neq_neg1_alpha), ( alpha_neq_neg1(alpha - delta_eq_neg1) + alpha_neq_neg1(alpha + delta_eq_neg1) ) * 0.5, ) * alpha_t + jnp.where(neq_neg1_mask, low_neq_neg1(neq_neg1_alpha), low_eq_neg1()) * low_t + jnp.where( neq_neg1_mask, high_neq_neg1(neq_neg1_alpha), high_eq_neg1() ) * high_t ) return primal_out, tangent_out return f(q, self.alpha, self.low, self.high)
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: assert is_prng_key(key) u = random.uniform(key, sample_shape + self.batch_shape) samples = self.icdf(u) return samples
[docs] class LowerTruncatedPowerLaw(Distribution): r"""Lower truncated power law distribution with :math:`\alpha` index. We can define the power law distribution as, .. math:: f(x; \alpha, a) = (-\alpha-1)a^{-\alpha - 1}x^{-\alpha}, \qquad x \geq a, \qquad \alpha < -1, where, :math:`a` is the lower bound. The cdf of the distribution is given by, .. math:: F(x; \alpha, a) = 1 - \left(\frac{x}{a}\right)^{1+\alpha}. The k-th moment of the distribution is given by, .. math:: E[X^k] = \begin{cases} \frac{-\alpha-1}{-\alpha-1-k}a^k & \text{if } k < -\alpha-1, \\ \infty & \text{otherwise}. \end{cases} :param alpha: index of the power law distribution :param low: lower bound of the distribution """ arg_constraints = { "alpha": constraints.less_than(-1.0), "low": constraints.greater_than(0.0), } reparametrized_params = ["alpha", "low"] pytree_aux_fields = ("_support",) def __init__( self, alpha: ArrayLike, low: ArrayLike, *, validate_args: Optional[bool] = None ): self.alpha, self.low = promote_shapes(alpha, low) batch_shape = lax.broadcast_shapes(jnp.shape(alpha), jnp.shape(low)) self._support = constraints.greater_than(low) super(LowerTruncatedPowerLaw, self).__init__( batch_shape=batch_shape, validate_args=validate_args ) @constraints.dependent_property(is_discrete=False, event_dim=0) def support(self) -> Constraint: return self._support
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: one_more_alpha = 1.0 + self.alpha return ( self.alpha * jnp.log(value) + jnp.log(-one_more_alpha) - one_more_alpha * jnp.log(self.low) )
[docs] def cdf(self, value: ArrayLike) -> ArrayLike: cdf_val = jnp.where( jnp.less_equal(value, self.low), jnp.zeros_like(value), 1.0 - jnp.power(value / self.low, 1.0 + self.alpha), ) return cdf_val
[docs] def icdf(self, q: ArrayLike) -> ArrayLike: nan_mask = jnp.logical_or(jnp.isnan(q), jnp.less(q, 0.0)) nan_mask = jnp.logical_or(nan_mask, jnp.greater(q, 1.0)) return jnp.where( nan_mask, jnp.nan, self.low * jnp.power(1.0 - q, jnp.reciprocal(1.0 + self.alpha)), )
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: assert is_prng_key(key) u = random.uniform(key, sample_shape + self.batch_shape) samples = self.icdf(u) return samples