# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
import numpy as np
from jax.dtypes import canonicalize_dtype
import jax.numpy as jnp
from tensorflow_probability.substrates.jax import bijectors as tfb
from tensorflow_probability.substrates.jax import distributions as tfd
import numpyro.distributions as numpyro_dist
from numpyro.distributions import Distribution as NumPyroDistribution
from numpyro.distributions import constraints
from numpyro.distributions.transforms import Transform, biject_to
from numpyro.util import not_jax_tracer
def _get_codomain(bijector):
if bijector.__class__.__name__ == "Sigmoid":
return constraints.interval(bijector.low, bijector.high)
elif bijector.__class__.__name__ == "Identity":
return constraints.real
elif bijector.__class__.__name__ in ["Exp", "SoftPlus"]:
return constraints.positive
elif bijector.__class__.__name__ == "GeneralizedPareto":
loc, scale, concentration = bijector.loc, bijector.scale, bijector.concentration
if not_jax_tracer(concentration) and np.all(concentration < 0):
return constraints.interval(loc, loc + scale / jnp.abs(concentration))
# XXX: here we suppose concentration > 0
# which is not true in general, but should cover enough usage cases
else:
return constraints.greater_than(loc)
elif bijector.__class__.__name__ == "SoftmaxCentered":
return constraints.simplex
elif bijector.__class__.__name__ == "Chain":
return _get_codomain(bijector.bijectors[-1])
else:
return constraints.real
[docs]class BijectorConstraint(constraints.Constraint):
"""
A constraint which is codomain of a TensorFlow bijector.
:param ~tensorflow_probability.substrates.jax.bijectors.Bijector bijector: a TensorFlow bijector
"""
def __init__(self, bijector):
self.bijector = bijector
def __call__(self, x):
return self.codomain(x)
# a convenient property to inspect the actual support of a TFP distribution
@property
def codomain(self):
return _get_codomain(self.bijector)
@biject_to.register(BijectorConstraint)
def _transform_to_bijector_constraint(constraint):
return BijectorTransform(constraint.bijector)
_TFPDistributionMeta = type(tfd.Distribution)
# XXX: we create this mixin class to avoid metaclass conflict between TFP and NumPyro Ditribution
class _TFPMixinMeta(_TFPDistributionMeta, type(NumPyroDistribution)):
def __init__(cls, name, bases, dct):
# XXX: _TFPDistributionMeta.__init__ registers cls as a PyTree
# for some reasons, when defining metaclass of TFPDistributionMixin to be _TFPMixinMeta,
# TFPDistributionMixin will be registered as a PyTree 2 times, which is not allowed
# in JAX, so we skip registering TFPDistributionMixin as a PyTree.
if name == "TFPDistributionMixin":
super(_TFPDistributionMeta, cls).__init__(name, bases, dct)
else:
super(_TFPMixinMeta, cls).__init__(name, bases, dct)
[docs]class TFPDistributionMixin(NumPyroDistribution, metaclass=_TFPMixinMeta):
"""
A mixin layer to make TensorFlow Probability (TFP) distribution compatible
with NumPyro internal.
"""
def __init_subclass__(cls, **kwargs):
# skip register pytree because TFP distributions are already pytrees
super(object, cls).__init_subclass__(**kwargs)
def __call__(self, *args, **kwargs):
key = kwargs.pop('rng_key')
kwargs.pop('sample_intermediates', False)
return self.sample(*args, seed=key, **kwargs)
@property
def support(self):
bijector = self._default_event_space_bijector()
if bijector is not None:
return BijectorConstraint(bijector)
else:
return None
@property
def is_discrete(self):
# XXX: this should cover most cases
return self.support is None
[docs]class InverseGamma(tfd.InverseGamma, TFPDistributionMixin):
arg_constraints = {"concentration": constraints.positive, "scale": constraints.positive}
[docs]class OneHotCategorical(tfd.OneHotCategorical, TFPDistributionMixin):
arg_constraints = {"logits": constraints.real_vector}
has_enumerate_support = True
support = constraints.simplex
is_discrete = True
def enumerate_support(self, expand=True):
n = self.event_shape[-1]
values = jnp.identity(n, dtype=canonicalize_dtype(self.dtype))
values = values.reshape((n,) + (1,) * len(self.batch_shape) + (n,))
if expand:
values = jnp.broadcast_to(values, (n,) + self.batch_shape + (n,))
return values
[docs]class OrderedLogistic(tfd.OrderedLogistic, TFPDistributionMixin):
arg_constraints = {"cutpoints": constraints.ordered_vector, "loc": constraints.real}
[docs]class Pareto(tfd.Pareto, TFPDistributionMixin):
arg_constraints = {"concentration": constraints.positive, "scale": constraints.positive}
__all__ = ['BijectorConstraint', 'BijectorTransform', 'TFPDistributionMixin']
_len_all = len(__all__)
for _name, _Dist in tfd.__dict__.items():
if not isinstance(_Dist, type):
continue
if not issubclass(_Dist, tfd.Distribution):
continue
if _Dist is tfd.Distribution:
continue
try:
_PyroDist = locals()[_name]
except KeyError:
_PyroDist = type(_name, (_Dist, TFPDistributionMixin), {})
_PyroDist.__module__ = __name__
if hasattr(numpyro_dist, _name):
numpyro_dist_class = getattr(numpyro_dist, _name)
# resolve FooProbs/FooLogits namespaces
if type(numpyro_dist_class).__name__ == "function":
if not hasattr(numpyro_dist, _name + "Logits"):
continue
numpyro_dist_class = getattr(numpyro_dist, _name + "Logits")
_PyroDist.arg_constraints = numpyro_dist_class.arg_constraints
_PyroDist.has_enumerate_support = numpyro_dist_class.has_enumerate_support
_PyroDist.enumerate_support = numpyro_dist_class.enumerate_support
locals()[_name] = _PyroDist
_PyroDist.__doc__ = '''
Wraps `{}.{} <https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/distributions/{}>`_
with :class:`~numpyro.contrib.tfp.distributions.TFPDistributionMixin`.
'''.format(_Dist.__module__, _Dist.__name__, _Dist.__name__)
__all__.append(_name)
# Create sphinx documentation.
__doc__ = '\n\n'.join([
'''
{0}
----------------------------------------------------------------
.. autoclass:: numpyro.contrib.tfp.distributions.{0}
'''.format(_name)
for _name in __all__[:_len_all] + sorted(__all__[_len_all:])
])