Source code for numpyro.contrib.tfp.distributions

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from functools import lru_cache
import warnings

from multipledispatch import dispatch
import numpy as np

import jax
import jax.numpy as jnp
from tensorflow_probability.substrates.jax import bijectors as tfb, distributions as tfd

import numpyro.distributions as numpyro_dist
from numpyro.distributions import (
    Distribution as NumPyroDistribution,
from numpyro.distributions.transforms import Transform, biject_to
from numpyro.util import find_stack_level, not_jax_tracer

def _get_codomain(bijector):
    if bijector.__class__.__name__ == "Sigmoid":
        return constraints.interval(bijector.low, bijector.high)
    elif bijector.__class__.__name__ == "Identity":
        return constraints.real
    elif bijector.__class__.__name__ in ["Exp", "SoftPlus"]:
        return constraints.positive
    elif bijector.__class__.__name__ == "GeneralizedPareto":
        loc, scale, concentration = bijector.loc, bijector.scale, bijector.concentration
        if not_jax_tracer(concentration) and np.all(np.less(concentration, 0)):
            return constraints.interval(loc, loc + scale / jnp.abs(concentration))
        # XXX: here we suppose concentration > 0
        # which is not true in general, but should cover enough usage cases
            return constraints.greater_than(loc)
    elif bijector.__class__.__name__ == "SoftmaxCentered":
        return constraints.simplex
    elif bijector.__class__.__name__ == "Chain":
        return _get_codomain(bijector.bijectors[-1])
        return constraints.real

[docs] class BijectorConstraint(constraints.Constraint): """ A constraint which is codomain of a TensorFlow bijector. :param ~tensorflow_probability.substrates.jax.bijectors.Bijector bijector: a TensorFlow bijector """ def __init__(self, bijector): self.bijector = bijector @property def event_dim(self): return self.bijector.forward_min_event_ndims def __call__(self, x): return self.codomain(x) # a convenient property to inspect the actual support of a TFP distribution @property def codomain(self): return _get_codomain(self.bijector) def tree_flatten(self): return self.bijector, () @classmethod def tree_unflatten(cls, _, bijector): return cls(bijector)
[docs] class BijectorTransform(Transform): """ A wrapper for TensorFlow bijectors to make them compatible with NumPyro's transforms. :param ~tensorflow_probability.substrates.jax.bijectors.Bijector bijector: a TensorFlow bijector """ def __init__(self, bijector): self.bijector = bijector @property def domain(self): return BijectorConstraint(tfb.Invert(self.bijector)) @property def codomain(self): return BijectorConstraint(self.bijector) def __call__(self, x): return self.bijector.forward(x) def _inverse(self, y): return self.bijector.inverse(y) def log_abs_det_jacobian(self, x, y, intermediates=None): return self.bijector.forward_log_det_jacobian(x, self.domain.event_dim) def forward_shape(self, shape): out_shape = self.bijector.forward_event_shape(shape) in_event_shape = self.bijector.inverse_event_shape(out_shape) batch_shape = shape[: len(shape) - len(in_event_shape)] return batch_shape + out_shape def inverse_shape(self, shape): in_shape = self.bijector.inverse_event_shape(shape) out_event_shape = self.bijector.forward_event_shape(in_shape) batch_shape = shape[: len(shape) - len(out_event_shape)] return batch_shape + in_shape def tree_flatten(self): return self.bijector, () @classmethod def tree_unflatten(cls, _, bijector): return cls(bijector)
@biject_to.register(BijectorConstraint) def _transform_to_bijector_constraint(constraint): return BijectorTransform(constraint.bijector) def _onehot_enumerate_support(self, expand=True): n = self.event_shape[-1] values = jnp.identity(n, dtype=jnp.result_type(self.dtype)) values = values.reshape((n,) + (1,) * len(self.batch_shape) + (n,)) if expand: values = jnp.broadcast_to(values, (n,) + self.batch_shape + (n,)) return values class _TFPDistributionMeta(type(NumPyroDistribution)): @lru_cache(maxsize=None) def __getitem__(cls, tfd_class): assert issubclass(tfd_class, tfd.Distribution) tfd_class_name = tfd_class.__name__ def init(self, *args, **kwargs): warnings.warn( "Importing distributions from numpyro.contrib.tfp.distributions is " "deprecated. You should import distributions directly from " "tensorflow_probability.substrates.jax.distributions instead.", FutureWarning, stacklevel=find_stack_level(), ) self.tfp_dist = tfd_class(*args, **kwargs) _PyroDist = type(tfd_class_name, (TFPDistribution,), {}) _PyroDist.__init__ = init if tfd_class is tfd.InverseGamma: _PyroDist.arg_constraints = { "concentration": constraints.positive, "scale": constraints.positive, } elif tfd_class is tfd.OneHotCategorical: _PyroDist.arg_constraints = {"logits": constraints.real_vector} _PyroDist.has_enumerate_support = True = constraints.simplex _PyroDist.is_discrete = True _PyroDist.enumerate_support = _onehot_enumerate_support elif tfd_class is tfd.OrderedLogistic: _PyroDist.arg_constraints = { "cutpoints": constraints.ordered_vector, "loc": constraints.real, } elif tfd_class is tfd.Pareto: _PyroDist.arg_constraints = { "concentration": constraints.positive, "scale": constraints.positive, } elif tfd_class is tfd.TruncatedNormal: _PyroDist.arg_constraints = { "low": constraints.real, "high": constraints.real, "loc": constraints.real, "scale": constraints.positive, } elif tfd_class is tfd.TruncatedCauchy: _PyroDist.arg_constraints = { "low": constraints.real, "high": constraints.real, "loc": constraints.real, "scale": constraints.positive, } else: # Mixture distributions in tfp and numpyro are different. if hasattr(numpyro_dist, tfd_class_name) and tfd_class_name != "Mixture": numpyro_dist_class = getattr(numpyro_dist, tfd_class_name) # resolve FooProbs/FooLogits namespaces numpyro_dist_class = getattr( numpyro_dist, f"{tfd_class_name}Logits", numpyro_dist_class ) _PyroDist.arg_constraints = numpyro_dist_class.arg_constraints _PyroDist.has_enumerate_support = ( numpyro_dist_class.has_enumerate_support ) _PyroDist.enumerate_support = numpyro_dist_class.enumerate_support return _PyroDist
[docs] class TFPDistribution(NumPyroDistribution, metaclass=_TFPDistributionMeta): """ A thin wrapper for TensorFlow Probability (TFP) distributions. The constructor has the same signature as the corresponding TFP distribution. This class can be used to convert a TFP distribution to a NumPyro-compatible one as follows:: d = TFPDistribution[tfd.Normal](0, 1) Note that typical use cases do not require explicitly invoking this wrapper, since NumPyro wraps TFP distributions automatically under the hood in model code, e.g.:: from tensorflow_probability.substrates.jax import distributions as tfd def model(): numpyro.sample("x", tfd.Normal(0, 1)) """ def __getattr__(self, name): # return parameters from the constructor if name in self.tfp_dist.parameters: return self.tfp_dist.parameters[name] elif name in ["dtype", "reparameterization_type"]: return getattr(self.tfp_dist, name) raise AttributeError(name) @property def batch_shape(self): # TFP shapes are special tuples that can not be used directly # with lax.broadcast_shapes. So we convert them to tuple. return tuple(self.tfp_dist.batch_shape) @property def event_shape(self): return tuple(self.tfp_dist.event_shape) @property def has_rsample(self): return self.tfp_dist.reparameterization_type is tfd.FULLY_REPARAMETERIZED def sample(self, key, sample_shape=()): return self.tfp_dist.sample(sample_shape=sample_shape, seed=key) def log_prob(self, value): return self.tfp_dist.log_prob(value) @property def mean(self): return self.tfp_dist.mean() @property def variance(self): return self.tfp_dist.variance() def cdf(self, value): return self.tfp_dist.cdf(value) def icdf(self, q): return self.tfp_dist.quantile(q) @property def support(self): bijector = self.tfp_dist._default_event_space_bijector() if bijector is not None: return BijectorConstraint(bijector) else: return None @property def is_discrete(self): # XXX: this should cover most cases return is None def tree_flatten(self): return jax.tree_util.tree_flatten(self.tfp_dist) @classmethod def tree_unflatten(cls, aux_data, params): fn = jax.tree_util.tree_unflatten(aux_data, params) with warnings.catch_warnings(): warnings.simplefilter("ignore", category=FutureWarning) return TFPDistribution[fn.__class__](**fn.parameters)
@dispatch(TFPDistribution, TFPDistribution) def kl_divergence(p, q): # noqa: F811 return tfd.kl_divergence(p.tfp_dist, q.tfp_dist) __all__ = ["BijectorConstraint", "BijectorTransform", "TFPDistribution"] _len_all = len(__all__) for _name, _Dist in tfd.__dict__.items(): if not isinstance(_Dist, type): continue if not issubclass(_Dist, tfd.Distribution): continue if _Dist is tfd.Distribution: continue _PyroDist = TFPDistribution[_Dist] _PyroDist.__module__ = __name__ locals()[_name] = _PyroDist _PyroDist.__doc__ = """ Wraps `{}.{} <{}>`_ with :class:`~numpyro.contrib.tfp.distributions.TFPDistribution`. """.format( _Dist.__module__, _Dist.__name__, _Dist.__name__ ) __all__.append(_name) # Create sphinx documentation. __doc__ = "\n\n".join( [ """ {0} ---------------------------------------------------------------- .. autoclass:: numpyro.contrib.tfp.distributions.{0} """.format( _name ) for _name in __all__[:_len_all] ] )