Source code for numpyro.distributions.truncated

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from jax import lax
import jax.numpy as jnp
import jax.random as random
from jax.scipy.special import logsumexp
from jax.tree_util import tree_map

from numpyro.distributions import constraints
from numpyro.distributions.continuous import (
    Cauchy,
    Laplace,
    Logistic,
    Normal,
    SoftLaplace,
    StudentT,
)
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import (
    clamp_probs,
    lazy_property,
    promote_shapes,
    validate_sample,
)
from numpyro.util import is_prng_key


[docs] class LeftTruncatedDistribution(Distribution): arg_constraints = {"low": constraints.real} reparametrized_params = ["low"] supported_types = (Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT) pytree_data_fields = ("base_dist", "low", "_support") def __init__(self, base_dist, low=0.0, *, validate_args=None): assert isinstance(base_dist, self.supported_types) assert ( base_dist.support is constraints.real ), "The base distribution should be univariate and have real support." batch_shape = lax.broadcast_shapes(base_dist.batch_shape, jnp.shape(low)) self.base_dist = tree_map( lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist ) (self.low,) = promote_shapes(low, shape=batch_shape) self._support = constraints.greater_than(low) super().__init__(batch_shape, validate_args=validate_args) @constraints.dependent_property(is_discrete=False, event_dim=0) def support(self): return self._support @lazy_property def _tail_prob_at_low(self): # if low < loc, returns cdf(low); otherwise returns 1 - cdf(low) loc = self.base_dist.loc sign = jnp.where(loc >= self.low, 1.0, -1.0) return self.base_dist.cdf(loc - sign * (loc - self.low)) @lazy_property def _tail_prob_at_high(self): # if low < loc, returns cdf(high) = 1; otherwise returns 1 - cdf(high) = 0 return jnp.where(self.low <= self.base_dist.loc, 1.0, 0.0)
[docs] def sample(self, key, sample_shape=()): assert is_prng_key(key) dtype = jnp.result_type(float) finfo = jnp.finfo(dtype) minval = finfo.tiny u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval) loc = self.base_dist.loc sign = jnp.where(loc >= self.low, 1.0, -1.0) return (1 - sign) * loc + sign * self.base_dist.icdf( (1 - u) * self._tail_prob_at_low + u * self._tail_prob_at_high )
@validate_sample def log_prob(self, value): sign = jnp.where(self.base_dist.loc >= self.low, 1.0, -1.0) return self.base_dist.log_prob(value) - jnp.log( sign * (self._tail_prob_at_high - self._tail_prob_at_low) ) @property def mean(self): if isinstance(self.base_dist, Normal): low_prob = jnp.exp(self.log_prob(self.low)) return self.base_dist.loc + low_prob * self.base_dist.scale**2 elif isinstance(self.base_dist, Cauchy): return jnp.full(self.batch_shape, jnp.nan) else: raise NotImplementedError("mean only available for Normal and Cauchy") @property def var(self): if isinstance(self.base_dist, Normal): low_prob = jnp.exp(self.log_prob(self.low)) return (self.base_dist.scale**2) * ( 1 + (self.low - self.base_dist.loc) * low_prob - (low_prob * self.base_dist.scale) ** 2 ) elif isinstance(self.base_dist, Cauchy): return jnp.full(self.batch_shape, jnp.nan) else: raise NotImplementedError("var only available for Normal and Cauchy")
[docs] class RightTruncatedDistribution(Distribution): arg_constraints = {"high": constraints.real} reparametrized_params = ["high"] supported_types = (Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT) pytree_data_fields = ("base_dist", "high", "_support") def __init__(self, base_dist, high=0.0, *, validate_args=None): assert isinstance(base_dist, self.supported_types) assert ( base_dist.support is constraints.real ), "The base distribution should be univariate and have real support." batch_shape = lax.broadcast_shapes(base_dist.batch_shape, jnp.shape(high)) self.base_dist = tree_map( lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist ) (self.high,) = promote_shapes(high, shape=batch_shape) self._support = constraints.less_than(high) super().__init__(batch_shape, validate_args=validate_args) @constraints.dependent_property(is_discrete=False, event_dim=0) def support(self): return self._support @lazy_property def _cdf_at_high(self): return self.base_dist.cdf(self.high)
[docs] def sample(self, key, sample_shape=()): assert is_prng_key(key) dtype = jnp.result_type(float) finfo = jnp.finfo(dtype) minval = finfo.tiny u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval) return self.base_dist.icdf(u * self._cdf_at_high)
@validate_sample def log_prob(self, value): return self.base_dist.log_prob(value) - jnp.log(self._cdf_at_high) @property def mean(self): if isinstance(self.base_dist, Normal): high_prob = jnp.exp(self.log_prob(self.high)) return self.base_dist.loc - high_prob * self.base_dist.scale**2 elif isinstance(self.base_dist, Cauchy): return jnp.full(self.batch_shape, jnp.nan) else: raise NotImplementedError("mean only available for Normal and Cauchy") @property def var(self): if isinstance(self.base_dist, Normal): high_prob = jnp.exp(self.log_prob(self.high)) return (self.base_dist.scale**2) * ( 1 - (self.high - self.base_dist.loc) * high_prob - (high_prob * self.base_dist.scale) ** 2 ) elif isinstance(self.base_dist, Cauchy): return jnp.full(self.batch_shape, jnp.nan) else: raise NotImplementedError("var only available for Normal and Cauchy")
[docs] class TwoSidedTruncatedDistribution(Distribution): arg_constraints = { "low": constraints.dependent, "high": constraints.dependent, } reparametrized_params = ["low", "high"] supported_types = (Cauchy, Laplace, Logistic, Normal, SoftLaplace, StudentT) pytree_data_fields = ("base_dist", "low", "high", "_support") def __init__(self, base_dist, low=0.0, high=1.0, *, validate_args=None): assert isinstance(base_dist, self.supported_types) assert ( base_dist.support is constraints.real ), "The base distribution should be univariate and have real support." batch_shape = lax.broadcast_shapes( base_dist.batch_shape, jnp.shape(low), jnp.shape(high) ) self.base_dist = tree_map( lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist ) (self.low,) = promote_shapes(low, shape=batch_shape) (self.high,) = promote_shapes(high, shape=batch_shape) self._support = constraints.interval(low, high) super().__init__(batch_shape, validate_args=validate_args) @constraints.dependent_property(is_discrete=False, event_dim=0) def support(self): return self._support @lazy_property def _tail_prob_at_low(self): # if low < loc, returns cdf(low); otherwise returns 1 - cdf(low) loc = self.base_dist.loc sign = jnp.where(loc >= self.low, 1.0, -1.0) return self.base_dist.cdf(loc - sign * (loc - self.low)) @lazy_property def _tail_prob_at_high(self): # if low < loc, returns cdf(high); otherwise returns 1 - cdf(high) loc = self.base_dist.loc sign = jnp.where(loc >= self.low, 1.0, -1.0) return self.base_dist.cdf(loc - sign * (loc - self.high)) @lazy_property def _log_diff_tail_probs(self): # use log_cdf method, if available, to avoid inf's in log_prob # fall back to cdf, if log_cdf not available log_cdf = getattr(self.base_dist, "log_cdf", None) if callable(log_cdf): return logsumexp( a=jnp.stack([log_cdf(self.high), log_cdf(self.low)], axis=-1), axis=-1, b=jnp.array([1, -1]), # subtract low from high ) else: loc = self.base_dist.loc sign = jnp.where(loc >= self.low, 1.0, -1.0) return jnp.log(sign * (self._tail_prob_at_high - self._tail_prob_at_low))
[docs] def sample(self, key, sample_shape=()): assert is_prng_key(key) dtype = jnp.result_type(float) finfo = jnp.finfo(dtype) minval = finfo.tiny u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval) # NB: we use a more numerically stable formula for a symmetric base distribution # A = icdf(cdf(low) + (cdf(high) - cdf(low)) * u) = icdf[(1 - u) * cdf(low) + u * cdf(high)] # will suffer by precision issues when low is large; # If low < loc: # A = icdf[(1 - u) * cdf(low) + u * cdf(high)] # Else # A = 2 * loc - icdf[(1 - u) * cdf(2*loc-low)) + u * cdf(2*loc - high)] loc = self.base_dist.loc sign = jnp.where(loc >= self.low, 1.0, -1.0) return (1 - sign) * loc + sign * self.base_dist.icdf( clamp_probs((1 - u) * self._tail_prob_at_low + u * self._tail_prob_at_high) )
@validate_sample def log_prob(self, value): # NB: we use a more numerically stable formula for a symmetric base distribution # if low < loc # cdf(high) - cdf(low) = as-is # if low > loc # cdf(high) - cdf(low) = cdf(2 * loc - low) - cdf(2 * loc - high) return self.base_dist.log_prob(value) - self._log_diff_tail_probs @property def mean(self): if isinstance(self.base_dist, Normal): low_prob = jnp.exp(self.log_prob(self.low)) high_prob = jnp.exp(self.log_prob(self.high)) return ( self.base_dist.loc + (low_prob - high_prob) * self.base_dist.scale**2 ) elif isinstance(self.base_dist, Cauchy): return jnp.full(self.batch_shape, jnp.nan) else: raise NotImplementedError("mean only available for Normal and Cauchy") @property def var(self): if isinstance(self.base_dist, Normal): low_prob = jnp.exp(self.log_prob(self.low)) high_prob = jnp.exp(self.log_prob(self.high)) return (self.base_dist.scale**2) * ( 1 + (self.low - self.base_dist.loc) * low_prob - (self.high - self.base_dist.loc) * high_prob - ((low_prob - high_prob) * self.base_dist.scale) ** 2 ) elif isinstance(self.base_dist, Cauchy): return jnp.full(self.batch_shape, jnp.nan) else: raise NotImplementedError("var only available for Normal and Cauchy")
[docs] def TruncatedDistribution(base_dist, low=None, high=None, *, validate_args=None): """ A function to generate a truncated distribution. :param base_dist: The base distribution to be truncated. This should be a univariate distribution. Currently, only the following distributions are supported: Cauchy, Laplace, Logistic, Normal, and StudentT. :param low: the value which is used to truncate the base distribution from below. Setting this parameter to None to not truncate from below. :param high: the value which is used to truncate the base distribution from above. Setting this parameter to None to not truncate from above. """ if high is None: if low is None: return base_dist else: return LeftTruncatedDistribution( base_dist, low=low, validate_args=validate_args ) elif low is None: return RightTruncatedDistribution( base_dist, high=high, validate_args=validate_args ) else: return TwoSidedTruncatedDistribution( base_dist, low=low, high=high, validate_args=validate_args )
[docs] def TruncatedCauchy(loc=0.0, scale=1.0, *, low=None, high=None, validate_args=None): return TruncatedDistribution( Cauchy(loc, scale), low=low, high=high, validate_args=validate_args )
[docs] def TruncatedNormal(loc=0.0, scale=1.0, *, low=None, high=None, validate_args=None): return TruncatedDistribution( Normal(loc, scale), low=low, high=high, validate_args=validate_args )
[docs] class TruncatedPolyaGamma(Distribution): truncation_point = 2.5 num_log_prob_terms = 7 num_gamma_variates = 8 assert num_log_prob_terms % 2 == 1 arg_constraints = {} support = constraints.interval(0.0, truncation_point) def __init__(self, batch_shape=(), *, validate_args=None): super(TruncatedPolyaGamma, self).__init__( batch_shape, validate_args=validate_args )
[docs] def sample(self, key, sample_shape=()): assert is_prng_key(key) denom = jnp.square(jnp.arange(0.5, self.num_gamma_variates)) x = random.gamma( key, jnp.ones(self.batch_shape + sample_shape + (self.num_gamma_variates,)) ) x = jnp.sum(x / denom, axis=-1) return jnp.clip(x * (0.5 / jnp.pi**2), a_max=self.truncation_point)
@validate_sample def log_prob(self, value): value = value[..., None] all_indices = jnp.arange(0, self.num_log_prob_terms) two_n_plus_one = 2.0 * all_indices + 1.0 log_terms = ( jnp.log(two_n_plus_one) - 1.5 * jnp.log(value) - 0.125 * jnp.square(two_n_plus_one) / value ) even_terms = jnp.take(log_terms, all_indices[::2], axis=-1) odd_terms = jnp.take(log_terms, all_indices[1::2], axis=-1) sum_even = jnp.exp(logsumexp(even_terms, axis=-1)) sum_odd = jnp.exp(logsumexp(odd_terms, axis=-1)) return jnp.log(sum_even - sum_odd) - 0.5 * jnp.log(2.0 * jnp.pi)