Source code for numpyro.distributions.truncated

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import jax
from jax import lax
import jax.numpy as jnp
import jax.random as random
from jax.scipy.special import logsumexp

from numpyro.distributions import constraints
from numpyro.distributions.continuous import (
    Cauchy,
    Laplace,
    Logistic,
    Normal,
    SoftLaplace,
    StudentT,
)
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import (
    clamp_probs,
    lazy_property,
    promote_shapes,
    validate_sample,
)
from numpyro.util import is_prng_key


[docs] class LeftTruncatedDistribution(Distribution): arg_constraints = {"low": constraints.real} reparametrized_params = ["low"] supported_types = (Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT) pytree_data_fields = ("base_dist", "low", "_support") def __init__(self, base_dist, low=0.0, *, validate_args=None): assert isinstance(base_dist, self.supported_types) assert base_dist.support is constraints.real, ( "The base distribution should be univariate and have real support." ) batch_shape = lax.broadcast_shapes(base_dist.batch_shape, jnp.shape(low)) self.base_dist = jax.tree.map( lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist ) (self.low,) = promote_shapes(low, shape=batch_shape) self._support = constraints.greater_than(low) super().__init__(batch_shape, validate_args=validate_args) @constraints.dependent_property(is_discrete=False, event_dim=0) def support(self): return self._support @lazy_property def _tail_prob_at_low(self): # if low < loc, returns cdf(low); otherwise returns 1 - cdf(low) loc = self.base_dist.loc sign = jnp.where(loc >= self.low, 1.0, -1.0) return self.base_dist.cdf(loc - sign * (loc - self.low)) @lazy_property def _tail_prob_at_high(self): # if low < loc, returns cdf(high) = 1; otherwise returns 1 - cdf(high) = 0 return jnp.where(self.low <= self.base_dist.loc, 1.0, 0.0)
[docs] def sample(self, key, sample_shape=()): assert is_prng_key(key) dtype = jnp.result_type(float) finfo = jnp.finfo(dtype) minval = finfo.tiny u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval) return self.icdf(u)
[docs] def icdf(self, q): loc = self.base_dist.loc sign = jnp.where(loc >= self.low, 1.0, -1.0) ppf = (1 - sign) * loc + sign * self.base_dist.icdf( (1 - q) * self._tail_prob_at_low + q * self._tail_prob_at_high ) return jnp.where(q < 0, jnp.nan, ppf)
@validate_sample def log_prob(self, value): sign = jnp.where(self.base_dist.loc >= self.low, 1.0, -1.0) return self.base_dist.log_prob(value) - jnp.log( sign * (self._tail_prob_at_high - self._tail_prob_at_low) ) @property def mean(self): if isinstance(self.base_dist, Normal): low_prob = jnp.exp(self.log_prob(self.low)) return self.base_dist.loc + low_prob * self.base_dist.scale**2 elif isinstance(self.base_dist, Cauchy): return jnp.full(self.batch_shape, jnp.nan) else: raise NotImplementedError("mean only available for Normal and Cauchy") @property def var(self): if isinstance(self.base_dist, Normal): low_prob = jnp.exp(self.log_prob(self.low)) return (self.base_dist.scale**2) * ( 1 + (self.low - self.base_dist.loc) * low_prob - (low_prob * self.base_dist.scale) ** 2 ) elif isinstance(self.base_dist, Cauchy): return jnp.full(self.batch_shape, jnp.nan) else: raise NotImplementedError("var only available for Normal and Cauchy")
[docs] class RightTruncatedDistribution(Distribution): arg_constraints = {"high": constraints.real} reparametrized_params = ["high"] supported_types = (Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT) pytree_data_fields = ("base_dist", "high", "_support") def __init__(self, base_dist, high=0.0, *, validate_args=None): assert isinstance(base_dist, self.supported_types) assert base_dist.support is constraints.real, ( "The base distribution should be univariate and have real support." ) batch_shape = lax.broadcast_shapes(base_dist.batch_shape, jnp.shape(high)) self.base_dist = jax.tree.map( lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist ) (self.high,) = promote_shapes(high, shape=batch_shape) self._support = constraints.less_than(high) super().__init__(batch_shape, validate_args=validate_args) @constraints.dependent_property(is_discrete=False, event_dim=0) def support(self): return self._support @lazy_property def _cdf_at_high(self): return self.base_dist.cdf(self.high)
[docs] def sample(self, key, sample_shape=()): assert is_prng_key(key) dtype = jnp.result_type(float) finfo = jnp.finfo(dtype) minval = finfo.tiny u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval) return self.icdf(u)
[docs] def icdf(self, q): ppf = self.base_dist.icdf(q * self._cdf_at_high) return jnp.where(q > 1, jnp.nan, ppf)
@validate_sample def log_prob(self, value): return self.base_dist.log_prob(value) - jnp.log(self._cdf_at_high) @property def mean(self): if isinstance(self.base_dist, Normal): high_prob = jnp.exp(self.log_prob(self.high)) return self.base_dist.loc - high_prob * self.base_dist.scale**2 elif isinstance(self.base_dist, Cauchy): return jnp.full(self.batch_shape, jnp.nan) else: raise NotImplementedError("mean only available for Normal and Cauchy") @property def var(self): if isinstance(self.base_dist, Normal): high_prob = jnp.exp(self.log_prob(self.high)) return (self.base_dist.scale**2) * ( 1 - (self.high - self.base_dist.loc) * high_prob - (high_prob * self.base_dist.scale) ** 2 ) elif isinstance(self.base_dist, Cauchy): return jnp.full(self.batch_shape, jnp.nan) else: raise NotImplementedError("var only available for Normal and Cauchy")
[docs] class TwoSidedTruncatedDistribution(Distribution): arg_constraints = { "low": constraints.dependent, "high": constraints.dependent, } reparametrized_params = ["low", "high"] supported_types = (Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT) pytree_data_fields = ("base_dist", "low", "high", "_support") def __init__(self, base_dist, low=0.0, high=1.0, *, validate_args=None): assert isinstance(base_dist, self.supported_types) assert base_dist.support is constraints.real, ( "The base distribution should be univariate and have real support." ) batch_shape = lax.broadcast_shapes( base_dist.batch_shape, jnp.shape(low), jnp.shape(high) ) self.base_dist = jax.tree.map( lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist ) (self.low,) = promote_shapes(low, shape=batch_shape) (self.high,) = promote_shapes(high, shape=batch_shape) self._support = constraints.interval(low, high) super().__init__(batch_shape, validate_args=validate_args) @constraints.dependent_property(is_discrete=False, event_dim=0) def support(self): return self._support @lazy_property def _tail_prob_at_low(self): # if low < loc, returns cdf(low); otherwise returns 1 - cdf(low) loc = self.base_dist.loc sign = jnp.where(loc >= self.low, 1.0, -1.0) return self.base_dist.cdf(loc - sign * (loc - self.low)) @lazy_property def _tail_prob_at_high(self): # if low < loc, returns cdf(high); otherwise returns 1 - cdf(high) loc = self.base_dist.loc sign = jnp.where(loc >= self.low, 1.0, -1.0) return self.base_dist.cdf(loc - sign * (loc - self.high)) @lazy_property def _log_diff_tail_probs(self): # use log_cdf method, if available, to avoid inf's in log_prob # fall back to cdf, if log_cdf not available log_cdf = getattr(self.base_dist, "log_cdf", None) if callable(log_cdf): return logsumexp( a=jnp.stack([log_cdf(self.high), log_cdf(self.low)], axis=-1), axis=-1, b=jnp.array([1, -1]), # subtract low from high ) else: loc = self.base_dist.loc sign = jnp.where(loc >= self.low, 1.0, -1.0) return jnp.log(sign * (self._tail_prob_at_high - self._tail_prob_at_low))
[docs] def sample(self, key, sample_shape=()): assert is_prng_key(key) dtype = jnp.result_type(float) finfo = jnp.finfo(dtype) minval = finfo.tiny u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval) return self.icdf(u)
[docs] def icdf(self, q): # NB: we use a more numerically stable formula for a symmetric base distribution # A = icdf(cdf(low) + (cdf(high) - cdf(low)) * q) = icdf[(1 - q) * cdf(low) + q * cdf(high)] # will suffer by precision issues when low is large; # If low < loc: # A = icdf[(1 - q) * cdf(low) + q * cdf(high)] # Else # A = 2 * loc - icdf[(1 - q) * cdf(2*loc-low)) + q * cdf(2*loc - high)] loc = self.base_dist.loc sign = jnp.where(loc >= self.low, 1.0, -1.0) ppf = (1 - sign) * loc + sign * self.base_dist.icdf( clamp_probs((1 - q) * self._tail_prob_at_low + q * self._tail_prob_at_high) ) return jnp.where(jnp.logical_or(q < 0, q > 1), jnp.nan, ppf)
@validate_sample def log_prob(self, value): # NB: we use a more numerically stable formula for a symmetric base distribution # if low < loc # cdf(high) - cdf(low) = as-is # if low > loc # cdf(high) - cdf(low) = cdf(2 * loc - low) - cdf(2 * loc - high) return self.base_dist.log_prob(value) - self._log_diff_tail_probs @property def mean(self): if isinstance(self.base_dist, Normal): low_prob = jnp.exp(self.log_prob(self.low)) high_prob = jnp.exp(self.log_prob(self.high)) return self.base_dist.loc + (low_prob - high_prob) * self.base_dist.scale**2 elif isinstance(self.base_dist, Cauchy): return jnp.full(self.batch_shape, jnp.nan) else: raise NotImplementedError("mean only available for Normal and Cauchy") @property def var(self): if isinstance(self.base_dist, Normal): low_prob = jnp.exp(self.log_prob(self.low)) high_prob = jnp.exp(self.log_prob(self.high)) return (self.base_dist.scale**2) * ( 1 + (self.low - self.base_dist.loc) * low_prob - (self.high - self.base_dist.loc) * high_prob - ((low_prob - high_prob) * self.base_dist.scale) ** 2 ) elif isinstance(self.base_dist, Cauchy): return jnp.full(self.batch_shape, jnp.nan) else: raise NotImplementedError("var only available for Normal and Cauchy")
[docs] def TruncatedDistribution(base_dist, low=None, high=None, *, validate_args=None): """ A function to generate a truncated distribution. :param base_dist: The base distribution to be truncated. This should be a univariate distribution. Currently, only the following distributions are supported: Cauchy, Laplace, Logistic, Normal, and StudentT. :param low: the value which is used to truncate the base distribution from below. Setting this parameter to None to not truncate from below. :param high: the value which is used to truncate the base distribution from above. Setting this parameter to None to not truncate from above. """ if high is None: if low is None: return base_dist else: return LeftTruncatedDistribution( base_dist, low=low, validate_args=validate_args ) elif low is None: return RightTruncatedDistribution( base_dist, high=high, validate_args=validate_args ) else: return TwoSidedTruncatedDistribution( base_dist, low=low, high=high, validate_args=validate_args )
[docs] def TruncatedCauchy(loc=0.0, scale=1.0, *, low=None, high=None, validate_args=None): return TruncatedDistribution( Cauchy(loc, scale), low=low, high=high, validate_args=validate_args )
[docs] def TruncatedNormal(loc=0.0, scale=1.0, *, low=None, high=None, validate_args=None): return TruncatedDistribution( Normal(loc, scale), low=low, high=high, validate_args=validate_args )
[docs] class TruncatedPolyaGamma(Distribution): truncation_point = 2.5 num_log_prob_terms = 7 num_gamma_variates = 8 assert num_log_prob_terms % 2 == 1 arg_constraints = {} support = constraints.interval(0.0, truncation_point) def __init__(self, batch_shape=(), *, validate_args=None): super(TruncatedPolyaGamma, self).__init__( batch_shape, validate_args=validate_args )
[docs] def sample(self, key, sample_shape=()): assert is_prng_key(key) denom = jnp.square(jnp.arange(0.5, self.num_gamma_variates)) x = random.gamma( key, jnp.ones(self.batch_shape + sample_shape + (self.num_gamma_variates,)) ) x = jnp.sum(x / denom, axis=-1) return jnp.clip(x * (0.5 / jnp.pi**2), None, self.truncation_point)
@validate_sample def log_prob(self, value): value = value[..., None] all_indices = jnp.arange(0, self.num_log_prob_terms) two_n_plus_one = 2.0 * all_indices + 1.0 log_terms = ( jnp.log(two_n_plus_one) - 1.5 * jnp.log(value) - 0.125 * jnp.square(two_n_plus_one) / value ) even_terms = jnp.take(log_terms, all_indices[::2], axis=-1) odd_terms = jnp.take(log_terms, all_indices[1::2], axis=-1) sum_even = jnp.exp(logsumexp(even_terms, axis=-1)) sum_odd = jnp.exp(logsumexp(odd_terms, axis=-1)) return jnp.log(sum_even - sum_odd) - 0.5 * jnp.log(2.0 * jnp.pi)
[docs] class DoublyTruncatedPowerLaw(Distribution): r"""Power law distribution with :math:`\alpha` index, and lower and upper bounds. We can define the power law distribution as, .. math:: f(x; \alpha, a, b) = \frac{x^{\alpha}}{Z(\alpha, a, b)}, where, :math:`a` and :math:`b` are the lower and upper bounds respectively, and :math:`Z(\alpha, a, b)` is the normalization constant. It is defined as, .. math:: Z(\alpha, a, b) = \begin{cases} \log(b) - \log(a) & \text{if } \alpha = -1, \\ \frac{b^{1 + \alpha} - a^{1 + \alpha}}{1 + \alpha} & \text{otherwise}. \end{cases} :param alpha: index of the power law distribution :param low: lower bound of the distribution :param high: upper bound of the distribution """ arg_constraints = { "alpha": constraints.real, "low": constraints.greater_than_eq(0), "high": constraints.greater_than(0), } reparametrized_params = ["alpha", "low", "high"] pytree_aux_fields = ("_support",) pytree_data_fields = ("alpha", "low", "high") def __init__(self, alpha, low, high, *, validate_args=None): self.alpha, self.low, self.high = promote_shapes(alpha, low, high) self._support = constraints.interval(low, high) batch_shape = lax.broadcast_shapes( jnp.shape(alpha), jnp.shape(low), jnp.shape(high) ) super(DoublyTruncatedPowerLaw, self).__init__( batch_shape=batch_shape, validate_args=validate_args ) @constraints.dependent_property(is_discrete=False, event_dim=0) def support(self): return self._support @validate_sample def log_prob(self, value): r"""Logarithmic probability distribution: Z inequal minus one: .. math:: (x^\alpha) (\alpha + 1)/(b^(\alpha + 1) - a^(\alpha + 1)) Z equal minus one: .. math:: (x^\alpha)/(log(b) - log(a)) Derivations are calculated by Wolfram Alpha via the Jacobian matrix accordingly. """ @jax.custom_jvp def f(x, alpha, low, high): neq_neg1_mask = jnp.not_equal(alpha, -1.0) neq_neg1_alpha = jnp.where(neq_neg1_mask, alpha, 0.0) # eq_neg1_alpha = jnp.where(~neq_neg1_mask, alpha, -1.0) def neq_neg1_fn(): one_more_alpha = 1.0 + neq_neg1_alpha return jnp.log( jnp.power(x, neq_neg1_alpha) * (one_more_alpha) / (jnp.power(high, one_more_alpha) - jnp.power(low, one_more_alpha)) ) def eq_neg1_fn(): return -jnp.log(x) - jnp.log(jnp.log(high) - jnp.log(low)) return jnp.where(neq_neg1_mask, neq_neg1_fn(), eq_neg1_fn()) @f.defjvp def f_jvp(primals, tangents): x, alpha, low, high = primals x_t, alpha_t, low_t, high_t = tangents log_low = jnp.log(low) log_high = jnp.log(high) log_x = jnp.log(x) # Mask and alpha values delta_eq_neg1 = 10e-4 neq_neg1_mask = jnp.not_equal(alpha, -1.0) neq_neg1_alpha = jnp.where(neq_neg1_mask, alpha, 0.0) eq_neg1_alpha = jnp.where(jnp.not_equal(alpha, 0.0), alpha, -1.0) primal_out = f(*primals) # Alpha tangent with approximation # Variable part for all values alpha unequal -1 def alpha_tangent_variable(alpha): one_more_alpha = 1.0 + alpha low_pow_one_more_alpha = jnp.power(low, one_more_alpha) high_pow_one_more_alpha = jnp.power(high, one_more_alpha) return jnp.reciprocal(one_more_alpha) + ( low_pow_one_more_alpha * log_low - high_pow_one_more_alpha * log_high ) / (high_pow_one_more_alpha - low_pow_one_more_alpha) # Alpha tangent alpha_tangent = jnp.where( neq_neg1_mask, log_x + alpha_tangent_variable(neq_neg1_alpha), # Approximate derivate with right an lefthand approximation log_x + ( alpha_tangent_variable(alpha - delta_eq_neg1) + alpha_tangent_variable(alpha + delta_eq_neg1) ) * 0.5, ) # High and low tangents for alpha unequal -1 one_more_alpha = 1.0 + neq_neg1_alpha low_pow_one_more_alpha = jnp.power(low, one_more_alpha) high_pow_one_more_alpha = jnp.power(high, one_more_alpha) change_sq = jnp.square(high_pow_one_more_alpha - low_pow_one_more_alpha) low_tangent_neq_neg1_common = ( jnp.square(one_more_alpha) * jnp.power(x, neq_neg1_alpha) / change_sq ) low_tangent_neq_neg1 = low_tangent_neq_neg1_common * jnp.power( low, neq_neg1_alpha ) high_tangent_neq_neg1 = low_tangent_neq_neg1_common * jnp.power( high, neq_neg1_alpha ) # High and low tangents for alpha equal -1 low_tangent_eq_neg1_common = jnp.power(x, eq_neg1_alpha) / jnp.square( log_high - log_low ) low_tangent_eq_neg1 = low_tangent_eq_neg1_common / low high_tangent_eq_neg1 = -low_tangent_eq_neg1_common / high # High and low tangents low_tangent = jnp.where( neq_neg1_mask, low_tangent_neq_neg1, low_tangent_eq_neg1 ) high_tangent = jnp.where( neq_neg1_mask, high_tangent_neq_neg1, high_tangent_eq_neg1 ) # Final tangents tangent_out = ( alpha / x * x_t + alpha_tangent * alpha_t + low_tangent * low_t + high_tangent * high_t ) return primal_out, tangent_out return f(value, self.alpha, self.low, self.high)
[docs] def cdf(self, value): r"""Cumulated probability distribution: Z inequal minus one: .. math:: \frac{x^{\alpha + 1} - a^{\alpha + 1}}{b^{\alpha + 1} - a^{\alpha + 1}} Z equal minus one: .. math:: \frac{\log(x) - \log(a)}{\log(b) - \log(a)} Derivations are calculated by Wolfram Alpha via the Jacobian matrix accordingly. """ @jax.custom_jvp def f(x, alpha, low, high): neq_neg1_mask = jnp.not_equal(alpha, -1.0) neq_neg1_alpha = jnp.where(neq_neg1_mask, alpha, 0.0) def cdf_when_alpha_neq_neg1(): one_more_alpha = 1.0 + neq_neg1_alpha low_pow_one_more_alpha = jnp.power(low, one_more_alpha) return (jnp.power(x, one_more_alpha) - low_pow_one_more_alpha) / ( jnp.power(high, one_more_alpha) - low_pow_one_more_alpha ) def cdf_when_alpha_eq_neg1(): return jnp.log(x / low) / jnp.log(high / low) cdf_val = jnp.where( neq_neg1_mask, cdf_when_alpha_neq_neg1(), cdf_when_alpha_eq_neg1(), ) return jnp.clip(cdf_val, 0.0, 1.0) @f.defjvp def f_jvp(primals, tangents): x, alpha, low, high = primals x_t, alpha_t, low_t, high_t = tangents log_low = jnp.log(low) log_high = jnp.log(high) log_x = jnp.log(x) delta_eq_neg1 = 10e-4 neq_neg1_mask = jnp.not_equal(alpha, -1.0) neq_neg1_alpha = jnp.where(neq_neg1_mask, alpha, 0.0) # Calculate primal primal_out = f(*primals) # Tangents for alpha not equals -1 def x_neq_neg1(alpha): one_more_alpha = 1.0 + alpha return (one_more_alpha * jnp.power(x, alpha)) / ( jnp.power(high, one_more_alpha) - jnp.power(low, one_more_alpha) ) def alpha_neq_neg1(alpha): one_more_alpha = 1.0 + alpha low_pow_one_more_alpha = jnp.power(low, one_more_alpha) high_pow_one_more_alpha = jnp.power(high, one_more_alpha) x_pow_one_more_alpha = jnp.power(x, one_more_alpha) term1 = ( x_pow_one_more_alpha * log_x - low_pow_one_more_alpha * log_low ) / (high_pow_one_more_alpha - low_pow_one_more_alpha) term2 = ( (x_pow_one_more_alpha - low_pow_one_more_alpha) * ( high_pow_one_more_alpha * log_high - low_pow_one_more_alpha * log_low ) ) / jnp.square(high_pow_one_more_alpha - low_pow_one_more_alpha) return term1 - term2 def low_neq_neg1(alpha): one_more_alpha = 1.0 + alpha low_pow_one_more_alpha = jnp.power(low, one_more_alpha) high_pow_one_more_alpha = jnp.power(high, one_more_alpha) x_pow_one_more_alpha = jnp.power(x, one_more_alpha) change = high_pow_one_more_alpha - low_pow_one_more_alpha term2 = one_more_alpha * jnp.power(low, alpha) / change term1 = term2 * (x_pow_one_more_alpha - low_pow_one_more_alpha) / change return term1 - term2 def high_neq_neg1(alpha): one_more_alpha = 1.0 + alpha low_pow_one_more_alpha = jnp.power(low, one_more_alpha) high_pow_one_more_alpha = jnp.power(high, one_more_alpha) x_pow_one_more_alpha = jnp.power(x, one_more_alpha) return -( one_more_alpha * jnp.power(high, alpha) * (x_pow_one_more_alpha - low_pow_one_more_alpha) ) / jnp.square(high_pow_one_more_alpha - low_pow_one_more_alpha) # Tangents for alpha equals -1 def x_eq_neg1(): return jnp.reciprocal(x * (log_high - log_low)) def low_eq_neg1(): return (log_x - log_low) / ( jnp.square(log_high - log_low) * low ) - jnp.reciprocal((log_high - log_low) * low) def high_eq_neg1(): return (log_x - log_low) / (jnp.square(log_high - log_low) * high) # Including approximation for alpha = -1 tangent_out = ( jnp.where(neq_neg1_mask, x_neq_neg1(neq_neg1_alpha), x_eq_neg1()) * x_t + jnp.where( neq_neg1_mask, alpha_neq_neg1(neq_neg1_alpha), ( alpha_neq_neg1(alpha - delta_eq_neg1) + alpha_neq_neg1(alpha + delta_eq_neg1) ) * 0.5, ) * alpha_t + jnp.where(neq_neg1_mask, low_neq_neg1(neq_neg1_alpha), low_eq_neg1()) * low_t + jnp.where( neq_neg1_mask, high_neq_neg1(neq_neg1_alpha), high_eq_neg1() ) * high_t ) return primal_out, tangent_out return f(value, self.alpha, self.low, self.high)
[docs] def icdf(self, q): r"""Inverse cumulated probability distribution: Z inequal minus one: .. math:: a \left(\frac{b}{a}\right)^{q} Z equal minus one: .. math:: \left(a^{1 + \alpha} + q (b^{1 + \alpha} - a^{1 + \alpha})\right)^{\frac{1}{1 + \alpha}} Derivations are calculated by Wolfram Alpha via the Jacobian matrix accordingly. """ @jax.custom_jvp def f(q, alpha, low, high): neq_neg1_mask = jnp.not_equal(alpha, -1.0) neq_neg1_alpha = jnp.where(neq_neg1_mask, alpha, 0.0) def icdf_alpha_neq_neg1(): one_more_alpha = 1.0 + neq_neg1_alpha low_pow_one_more_alpha = jnp.power(low, one_more_alpha) high_pow_one_more_alpha = jnp.power(high, one_more_alpha) return jnp.power( low_pow_one_more_alpha + q * (high_pow_one_more_alpha - low_pow_one_more_alpha), jnp.reciprocal(one_more_alpha), ) def icdf_alpha_eq_neg1(): return jnp.power(high / low, q) * low icdf_val = jnp.where( neq_neg1_mask, icdf_alpha_neq_neg1(), icdf_alpha_eq_neg1(), ) return icdf_val @f.defjvp def f_jvp(primals, tangents): x, alpha, low, high = primals x_t, alpha_t, low_t, high_t = tangents log_low = jnp.log(low) log_high = jnp.log(high) high_over_low = jnp.divide(high, low) delta_eq_neg1 = 10e-4 neq_neg1_mask = jnp.not_equal(alpha, -1.0) neq_neg1_alpha = jnp.where(neq_neg1_mask, alpha, 0.0) primal_out = f(*primals) # Tangents for alpha not equal -1 def x_neq_neg1(alpha): one_more_alpha = 1.0 + alpha low_pow_one_more_alpha = jnp.power(low, one_more_alpha) high_pow_one_more_alpha = jnp.power(high, one_more_alpha) change = high_pow_one_more_alpha - low_pow_one_more_alpha return ( change * jnp.power( low_pow_one_more_alpha + x * change, jnp.reciprocal(one_more_alpha) - 1, ) ) / one_more_alpha def alpha_neq_neg1(alpha): one_more_alpha = 1.0 + alpha low_pow_one_more_alpha = jnp.power(low, one_more_alpha) high_pow_one_more_alpha = jnp.power(high, one_more_alpha) factor0 = low_pow_one_more_alpha + x * ( high_pow_one_more_alpha - low_pow_one_more_alpha ) term1 = jnp.power(factor0, jnp.reciprocal(one_more_alpha)) term2 = ( low_pow_one_more_alpha * log_low + x * ( high_pow_one_more_alpha * log_high - low_pow_one_more_alpha * log_low ) ) / (one_more_alpha * factor0) term3 = jnp.log(factor0) / jnp.square(one_more_alpha) return term1 * (term2 - term3) def low_neq_neg1(alpha): one_more_alpha = 1.0 + alpha low_pow_one_more_alpha = jnp.power(low, one_more_alpha) high_pow_one_more_alpha = jnp.power(high, one_more_alpha) return ( (1.0 - x) * jnp.power(low, alpha) * jnp.power( low_pow_one_more_alpha + x * (high_pow_one_more_alpha - low_pow_one_more_alpha), jnp.reciprocal(one_more_alpha) - 1, ) ) def high_neq_neg1(alpha): one_more_alpha = 1.0 + alpha low_pow_one_more_alpha = jnp.power(low, one_more_alpha) high_pow_one_more_alpha = jnp.power(high, one_more_alpha) return ( x * jnp.power(high, alpha) * jnp.power( low_pow_one_more_alpha + x * (high_pow_one_more_alpha - low_pow_one_more_alpha), jnp.reciprocal(one_more_alpha) - 1, ) ) # Tangents for alpha equals -1 def dx_eq_neg1(): return low * jnp.power(high_over_low, x) * (log_high - log_low) def low_eq_neg1(): return ( jnp.power(high_over_low, x) - (high * x * jnp.power(high_over_low, x - 1)) / low ) def high_eq_neg1(): return x * jnp.power(high_over_low, x - 1) # Including approximation for alpha = -1 \ tangent_out = ( jnp.where(neq_neg1_mask, x_neq_neg1(neq_neg1_alpha), dx_eq_neg1()) * x_t + jnp.where( neq_neg1_mask, alpha_neq_neg1(neq_neg1_alpha), ( alpha_neq_neg1(alpha - delta_eq_neg1) + alpha_neq_neg1(alpha + delta_eq_neg1) ) * 0.5, ) * alpha_t + jnp.where(neq_neg1_mask, low_neq_neg1(neq_neg1_alpha), low_eq_neg1()) * low_t + jnp.where( neq_neg1_mask, high_neq_neg1(neq_neg1_alpha), high_eq_neg1() ) * high_t ) return primal_out, tangent_out return f(q, self.alpha, self.low, self.high)
[docs] def sample(self, key, sample_shape=()): assert is_prng_key(key) u = random.uniform(key, sample_shape + self.batch_shape) samples = self.icdf(u) return samples
[docs] class LowerTruncatedPowerLaw(Distribution): r"""Lower truncated power law distribution with :math:`\alpha` index. We can define the power law distribution as, .. math:: f(x; \alpha, a) = (-\alpha-1)a^{-\alpha - 1}x^{-\alpha}, \qquad x \geq a, \qquad \alpha < -1, where, :math:`a` is the lower bound. The cdf of the distribution is given by, .. math:: F(x; \alpha, a) = 1 - \left(\frac{x}{a}\right)^{1+\alpha}. The k-th moment of the distribution is given by, .. math:: E[X^k] = \begin{cases} \frac{-\alpha-1}{-\alpha-1-k}a^k & \text{if } k < -\alpha-1, \\ \infty & \text{otherwise}. \end{cases} :param alpha: index of the power law distribution :param low: lower bound of the distribution """ arg_constraints = { "alpha": constraints.less_than(-1.0), "low": constraints.greater_than(0.0), } reparametrized_params = ["alpha", "low"] pytree_aux_fields = ("_support",) def __init__(self, alpha, low, *, validate_args=None): self.alpha, self.low = promote_shapes(alpha, low) batch_shape = lax.broadcast_shapes(jnp.shape(alpha), jnp.shape(low)) self._support = constraints.greater_than(low) super(LowerTruncatedPowerLaw, self).__init__( batch_shape=batch_shape, validate_args=validate_args ) @constraints.dependent_property(is_discrete=False, event_dim=0) def support(self): return self._support @validate_sample def log_prob(self, value): one_more_alpha = 1.0 + self.alpha return ( self.alpha * jnp.log(value) + jnp.log(-one_more_alpha) - one_more_alpha * jnp.log(self.low) )
[docs] def cdf(self, value): cdf_val = jnp.where( jnp.less_equal(value, self.low), jnp.zeros_like(value), 1.0 - jnp.power(value / self.low, 1.0 + self.alpha), ) return cdf_val
[docs] def icdf(self, q): nan_mask = jnp.logical_or(jnp.isnan(q), jnp.less(q, 0.0)) nan_mask = jnp.logical_or(nan_mask, jnp.greater(q, 1.0)) return jnp.where( nan_mask, jnp.nan, self.low * jnp.power(1.0 - q, jnp.reciprocal(1.0 + self.alpha)), )
[docs] def sample(self, key, sample_shape=()): assert is_prng_key(key) u = random.uniform(key, sample_shape + self.batch_shape) samples = self.icdf(u) return samples