Source code for numpyro.distributions.transforms

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import math
import warnings
import weakref

import numpy as np

from jax import lax, vmap
from jax.flatten_util import ravel_pytree
from jax.nn import log_sigmoid, softplus
import jax.numpy as jnp
from jax.scipy.linalg import solve_triangular
from jax.scipy.special import expit, logit
from jax.tree_util import register_pytree_node, tree_flatten, tree_map

from numpyro.distributions import constraints
from numpyro.distributions.util import (
    add_diag,
    matrix_to_tril_vec,
    signed_stick_breaking_tril,
    sum_rightmost,
    vec_to_tril_matrix,
)
from numpyro.util import find_stack_level, not_jax_tracer

__all__ = [
    "biject_to",
    "AbsTransform",
    "AffineTransform",
    "CholeskyTransform",
    "ComposeTransform",
    "CorrCholeskyTransform",
    "CorrMatrixCholeskyTransform",
    "ExpTransform",
    "IdentityTransform",
    "L1BallTransform",
    "LowerCholeskyTransform",
    "ScaledUnitLowerCholeskyTransform",
    "LowerCholeskyAffine",
    "PermuteTransform",
    "PowerTransform",
    "ReshapeTransform",
    "SigmoidTransform",
    "SimplexToOrderedTransform",
    "SoftplusTransform",
    "SoftplusLowerCholeskyTransform",
    "StickBreakingTransform",
    "Transform",
    "UnpackTransform",
]


def _clipped_expit(x):
    finfo = jnp.finfo(jnp.result_type(x))
    return jnp.clip(expit(x), a_min=finfo.tiny, a_max=1.0 - finfo.eps)


[docs] class Transform(object): domain = constraints.real codomain = constraints.real _inv = None def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) register_pytree_node(cls, cls.tree_flatten, cls.tree_unflatten) @property def inv(self): inv = None if self._inv is not None: inv = self._inv() if inv is None: inv = _InverseTransform(self) self._inv = weakref.ref(inv) return inv def __call__(self, x): raise NotImplementedError def _inverse(self, y): raise NotImplementedError
[docs] def log_abs_det_jacobian(self, x, y, intermediates=None): raise NotImplementedError
[docs] def call_with_intermediates(self, x): return self(x), None
[docs] def forward_shape(self, shape): """ Infers the shape of the forward computation, given the input shape. Defaults to preserving shape. """ return shape
[docs] def inverse_shape(self, shape): """ Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape. """ return shape
# Allow for pickle serialization of transforms. def __getstate__(self): attrs = {} for k, v in self.__dict__.items(): if isinstance(v, weakref.ref): attrs[k] = None else: attrs[k] = v return attrs
[docs] @classmethod def tree_unflatten(cls, aux_data, params): params_keys, aux_data = aux_data self = cls.__new__(cls) for k, v in zip(params_keys, params): setattr(self, k, v) for k, v in aux_data.items(): setattr(self, k, v) return self
class ParameterFreeTransform(Transform): def tree_flatten(self): return (), ((), dict()) def __eq__(self, other): return isinstance(other, type(self)) class _InverseTransform(Transform): def __init__(self, transform): super().__init__() self._inv = transform @property def domain(self): return self._inv.codomain @property def codomain(self): return self._inv.domain @property def inv(self): return self._inv def __call__(self, x): return self._inv._inverse(x) def log_abs_det_jacobian(self, x, y, intermediates=None): # NB: we don't use intermediates for inverse transform return -self._inv.log_abs_det_jacobian(y, x, None) def forward_shape(self, shape): return self._inv.inverse_shape(shape) def inverse_shape(self, shape): return self._inv.forward_shape(shape) def tree_flatten(self): return (self._inv,), (("_inv",), dict()) def __eq__(self, other): if not isinstance(other, _InverseTransform): return False return self._inv == other._inv
[docs] class AbsTransform(ParameterFreeTransform): domain = constraints.real codomain = constraints.positive def __eq__(self, other): return isinstance(other, AbsTransform) def __call__(self, x): return jnp.abs(x) def _inverse(self, y): warnings.warn( "AbsTransform is not a bijective transform." " The inverse of `y` will be `y`.", stacklevel=find_stack_level(), ) return y
[docs] class AffineTransform(Transform): """ .. note:: When `scale` is a JAX tracer, we always assume that `scale > 0` when calculating `codomain`. """ def __init__(self, loc, scale, domain=constraints.real): self.loc = loc self.scale = scale self.domain = domain @property def codomain(self): if self.domain is constraints.real: return constraints.real elif isinstance(self.domain, constraints.greater_than): if not_jax_tracer(self.scale) and np.all(np.less(self.scale, 0)): return constraints.less_than(self(self.domain.lower_bound)) # we suppose scale > 0 for any tracer else: return constraints.greater_than(self(self.domain.lower_bound)) elif isinstance(self.domain, constraints.less_than): if not_jax_tracer(self.scale) and np.all(np.less(self.scale, 0)): return constraints.greater_than(self(self.domain.upper_bound)) # we suppose scale > 0 for any tracer else: return constraints.less_than(self(self.domain.upper_bound)) elif isinstance(self.domain, constraints.interval): if not_jax_tracer(self.scale) and np.all(np.less(self.scale, 0)): return constraints.interval( self(self.domain.upper_bound), self(self.domain.lower_bound) ) else: return constraints.interval( self(self.domain.lower_bound), self(self.domain.upper_bound) ) else: raise NotImplementedError def __call__(self, x): return self.loc + self.scale * x def _inverse(self, y): return (y - self.loc) / self.scale
[docs] def log_abs_det_jacobian(self, x, y, intermediates=None): return jnp.broadcast_to(jnp.log(jnp.abs(self.scale)), jnp.shape(x))
[docs] def forward_shape(self, shape): return lax.broadcast_shapes( shape, getattr(self.loc, "shape", ()), getattr(self.scale, "shape", ()) )
[docs] def inverse_shape(self, shape): return lax.broadcast_shapes( shape, getattr(self.loc, "shape", ()), getattr(self.scale, "shape", ()) )
[docs] def tree_flatten(self): return (self.loc, self.scale, self.domain), (("loc", "scale", "domain"), dict())
def __eq__(self, other): if not isinstance(other, AffineTransform): return False return ( jnp.array_equal(self.loc, other.loc) & jnp.array_equal(self.scale, other.scale) & (self.domain == other.domain) )
def _get_compose_transform_input_event_dim(parts): input_event_dim = parts[-1].domain.event_dim for part in parts[len(parts) - 1 :: -1]: input_event_dim = part.domain.event_dim + max( input_event_dim - part.codomain.event_dim, 0 ) return input_event_dim def _get_compose_transform_output_event_dim(parts): output_event_dim = parts[0].codomain.event_dim for part in parts[1:]: output_event_dim = part.codomain.event_dim + max( output_event_dim - part.domain.event_dim, 0 ) return output_event_dim
[docs] class ComposeTransform(Transform): def __init__(self, parts): self.parts = parts @property def domain(self): input_event_dim = _get_compose_transform_input_event_dim(self.parts) first_input_event_dim = self.parts[0].domain.event_dim assert input_event_dim >= first_input_event_dim if input_event_dim == first_input_event_dim: return self.parts[0].domain else: return constraints.independent( self.parts[0].domain, input_event_dim - first_input_event_dim ) @property def codomain(self): output_event_dim = _get_compose_transform_output_event_dim(self.parts) last_output_event_dim = self.parts[-1].codomain.event_dim assert output_event_dim >= last_output_event_dim if output_event_dim == last_output_event_dim: return self.parts[-1].codomain else: return constraints.independent( self.parts[-1].codomain, output_event_dim - last_output_event_dim ) def __call__(self, x): for part in self.parts: x = part(x) return x def _inverse(self, y): for part in self.parts[::-1]: y = part.inv(y) return y
[docs] def log_abs_det_jacobian(self, x, y, intermediates=None): if intermediates is not None: if len(intermediates) != len(self.parts): raise ValueError( "Intermediates array has length = {}. Expected = {}.".format( len(intermediates), len(self.parts) ) ) result = 0.0 input_event_dim = self.domain.event_dim for i, part in enumerate(self.parts[:-1]): y_tmp = part(x) if intermediates is None else intermediates[i][0] inter = None if intermediates is None else intermediates[i][1] logdet = part.log_abs_det_jacobian(x, y_tmp, intermediates=inter) batch_ndim = input_event_dim - part.domain.event_dim result = result + sum_rightmost(logdet, batch_ndim) input_event_dim = part.codomain.event_dim + batch_ndim x = y_tmp # account the the last transform, where y is available inter = None if intermediates is None else intermediates[-1] part = self.parts[-1] logdet = part.log_abs_det_jacobian(x, y, intermediates=inter) result = result + sum_rightmost(logdet, input_event_dim - part.domain.event_dim) return result
[docs] def call_with_intermediates(self, x): intermediates = [] for part in self.parts[:-1]: x, inter = part.call_with_intermediates(x) intermediates.append([x, inter]) # NB: we don't need to hold the last output value in `intermediates` x, inter = self.parts[-1].call_with_intermediates(x) intermediates.append(inter) return x, intermediates
[docs] def forward_shape(self, shape): for part in self.parts: shape = part.forward_shape(shape) return shape
[docs] def inverse_shape(self, shape): for part in reversed(self.parts): shape = part.inverse_shape(shape) return shape
[docs] def tree_flatten(self): return (self.parts,), (("parts",), {})
def __eq__(self, other): if not isinstance(other, ComposeTransform): return False return jnp.logical_and(*(p1 == p2 for p1, p2 in zip(self.parts, other.parts)))
def _matrix_forward_shape(shape, offset=0): # Reshape from (..., N) to (..., D, D). if len(shape) < 1: raise ValueError("Too few dimensions in input") N = shape[-1] D = round((0.25 + 2 * N) ** 0.5 - 0.5) if D * (D + 1) // 2 != N: raise ValueError("Input is not a flattend lower-diagonal number") D = D - offset return shape[:-1] + (D, D) def _matrix_inverse_shape(shape, offset=0): # Reshape from (..., D, D) to (..., N). if len(shape) < 2: raise ValueError("Too few dimensions on input") if shape[-2] != shape[-1]: raise ValueError("Input is not square") D = shape[-1] + offset N = D * (D + 1) // 2 return shape[:-2] + (N,)
[docs] class CholeskyTransform(ParameterFreeTransform): r""" Transform via the mapping :math:`y = cholesky(x)`, where `x` is a positive definite matrix. """ domain = constraints.positive_definite codomain = constraints.lower_cholesky def __call__(self, x): return jnp.linalg.cholesky(x) def _inverse(self, y): return jnp.matmul(y, jnp.swapaxes(y, -2, -1))
[docs] def log_abs_det_jacobian(self, x, y, intermediates=None): # Ref: http://web.mit.edu/18.325/www/handouts/handout2.pdf page 13 n = jnp.shape(x)[-1] order = -jnp.arange(n, 0, -1) return -n * jnp.log(2) + jnp.sum( order * jnp.log(jnp.diagonal(y, axis1=-2, axis2=-1)), axis=-1 )
[docs] class CorrCholeskyTransform(ParameterFreeTransform): r""" Transforms a uncontrained real vector :math:`x` with length :math:`D*(D-1)/2` into the Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower triangular matrix with positive diagonals and unit Euclidean norm for each row. The transform is processed as follows: 1. First we convert :math:`x` into a lower triangular matrix with the following order: .. math:: \begin{bmatrix} 1 & 0 & 0 & 0 \\ x_0 & 1 & 0 & 0 \\ x_1 & x_2 & 1 & 0 \\ x_3 & x_4 & x_5 & 1 \end{bmatrix} 2. For each row :math:`X_i` of the lower triangular part, we apply a *signed* version of class :class:`StickBreakingTransform` to transform :math:`X_i` into a unit Euclidean length vector using the following steps: a. Scales into the interval :math:`(-1, 1)` domain: :math:`r_i = \tanh(X_i)`. b. Transforms into an unsigned domain: :math:`z_i = r_i^2`. c. Applies :math:`s_i = StickBreakingTransform(z_i)`. d. Transforms back into signed domain: :math:`y_i = (sign(r_i), 1) * \sqrt{s_i}`. """ domain = constraints.real_vector codomain = constraints.corr_cholesky def __call__(self, x): # we interchange step 1 and step 2.a for a better performance t = jnp.tanh(x) return signed_stick_breaking_tril(t) def _inverse(self, y): # inverse stick-breaking z1m_cumprod = 1 - jnp.cumsum(y * y, axis=-1) pad_width = [(0, 0)] * y.ndim pad_width[-1] = (1, 0) z1m_cumprod_shifted = jnp.pad( z1m_cumprod[..., :-1], pad_width, mode="constant", constant_values=1.0 ) t = matrix_to_tril_vec(y, diagonal=-1) / jnp.sqrt( matrix_to_tril_vec(z1m_cumprod_shifted, diagonal=-1) ) # inverse of tanh return jnp.arctanh(t)
[docs] def log_abs_det_jacobian(self, x, y, intermediates=None): # NB: because domain and codomain are two spaces with different dimensions, determinant of # Jacobian is not well-defined. Here we return `log_abs_det_jacobian` of `x` and the # flatten lower triangular part of `y`. # stick_breaking_logdet = log(y / r) = log(z_cumprod) (modulo right shifted) z1m_cumprod = 1 - jnp.cumsum(y * y, axis=-1) # by taking diagonal=-2, we don't need to shift z_cumprod to the right # NB: diagonal=-2 works fine for (2 x 2) matrix, where we get an empty array z1m_cumprod_tril = matrix_to_tril_vec(z1m_cumprod, diagonal=-2) stick_breaking_logdet = 0.5 * jnp.sum(jnp.log(z1m_cumprod_tril), axis=-1) tanh_logdet = -2 * jnp.sum(x + softplus(-2 * x) - jnp.log(2.0), axis=-1) return stick_breaking_logdet + tanh_logdet
[docs] def forward_shape(self, shape): return _matrix_forward_shape(shape, offset=-1)
[docs] def inverse_shape(self, shape): return _matrix_inverse_shape(shape, offset=-1)
[docs] class CorrMatrixCholeskyTransform(CholeskyTransform): r""" Transform via the mapping :math:`y = cholesky(x)`, where `x` is a correlation matrix. """ domain = constraints.corr_matrix codomain = constraints.corr_cholesky
[docs] def log_abs_det_jacobian(self, x, y, intermediates=None): # NB: see derivation in LKJCholesky implementation n = jnp.shape(x)[-1] order = -jnp.arange(n - 1, -1, -1) return jnp.sum(order * jnp.log(jnp.diagonal(y, axis1=-2, axis2=-1)), axis=-1)
[docs] class ExpTransform(Transform): # TODO: refine domain/codomain logic through setters, especially when # transforms for inverses are supported def __init__(self, domain=constraints.real): self.domain = domain @property def codomain(self): if self.domain is constraints.ordered_vector: return constraints.positive_ordered_vector elif self.domain is constraints.real: return constraints.positive elif isinstance(self.domain, constraints.greater_than): return constraints.greater_than(self.__call__(self.domain.lower_bound)) elif isinstance(self.domain, constraints.interval): return constraints.interval( self.__call__(self.domain.lower_bound), self.__call__(self.domain.upper_bound), ) else: raise NotImplementedError def __call__(self, x): # XXX consider to clamp from below for stability if necessary return jnp.exp(x) def _inverse(self, y): return jnp.log(y)
[docs] def log_abs_det_jacobian(self, x, y, intermediates=None): return x
[docs] def tree_flatten(self): return (self.domain,), (("domain",), dict())
def __eq__(self, other): if not isinstance(other, ExpTransform): return False return self.domain == other.domain
[docs] class IdentityTransform(ParameterFreeTransform): def __call__(self, x): return x def _inverse(self, y): return y
[docs] def log_abs_det_jacobian(self, x, y, intermediates=None): return jnp.zeros_like(x)
class IndependentTransform(Transform): """ Wraps a transform by aggregating over ``reinterpreted_batch_ndims``-many dims in :meth:`check`, so that an event is valid only if all its independent entries are valid. """ def __init__(self, base_transform, reinterpreted_batch_ndims): assert isinstance(base_transform, Transform) assert isinstance(reinterpreted_batch_ndims, int) assert reinterpreted_batch_ndims >= 0 self.base_transform = base_transform self.reinterpreted_batch_ndims = reinterpreted_batch_ndims super().__init__() @property def domain(self): return constraints.independent( self.base_transform.domain, self.reinterpreted_batch_ndims ) @property def codomain(self): return constraints.independent( self.base_transform.codomain, self.reinterpreted_batch_ndims ) def __call__(self, x): return self.base_transform(x) def _inverse(self, y): return self.base_transform._inverse(y) def log_abs_det_jacobian(self, x, y, intermediates=None): result = self.base_transform.log_abs_det_jacobian( x, y, intermediates=intermediates ) if jnp.ndim(result) < self.reinterpreted_batch_ndims: expected = self.domain.event_dim raise ValueError(f"Expected x.dim() >= {expected} but got {jnp.ndim(x)}") return sum_rightmost(result, self.reinterpreted_batch_ndims) def call_with_intermediates(self, x): return self.base_transform.call_with_intermediates(x) def forward_shape(self, shape): return self.base_transform.forward_shape(shape) def inverse_shape(self, shape): return self.base_transform.inverse_shape(shape) def tree_flatten(self): return (self.base_transform, self.reinterpreted_batch_ndims), ( ("base_transform", "reinterpreted_batch_ndims"), dict(), ) def __eq__(self, other): if not isinstance(other, IndependentTransform): return False return (self.base_transform == other.base_transform) & ( self.reinterpreted_batch_ndims == other.reinterpreted_batch_ndims )
[docs] class L1BallTransform(ParameterFreeTransform): r""" Transforms a uncontrained real vector :math:`x` into the unit L1 ball. """ domain = constraints.real_vector codomain = constraints.l1_ball def __call__(self, x): # transform to (-1, 1) interval t = jnp.tanh(x) # apply stick-breaking transform remainder = jnp.cumprod(1 - jnp.abs(t[..., :-1]), axis=-1) pad_width = [(0, 0)] * (t.ndim - 1) + [(1, 0)] remainder = jnp.pad(remainder, pad_width, mode="constant", constant_values=1.0) return t * remainder def _inverse(self, y): # inverse stick-breaking remainder = 1 - jnp.cumsum(jnp.abs(y[..., :-1]), axis=-1) pad_width = [(0, 0)] * (y.ndim - 1) + [(1, 0)] remainder = jnp.pad(remainder, pad_width, mode="constant", constant_values=1.0) finfo = jnp.finfo(y.dtype) remainder = jnp.clip(remainder, a_min=finfo.tiny) t = y / remainder # inverse of tanh t = jnp.clip(t, a_min=-1 + finfo.eps, a_max=1 - finfo.eps) return jnp.arctanh(t)
[docs] def log_abs_det_jacobian(self, x, y, intermediates=None): # compute stick-breaking logdet # t1 -> t1 # t2 -> t2 * (1 - abs(t1)) # t3 -> t3 * (1 - abs(t1)) * (1 - abs(t2)) # hence jacobian is triangular and logdet is the sum of the log # of the diagonal part of the jacobian one_minus_remainder = jnp.cumsum(jnp.abs(y[..., :-1]), axis=-1) eps = jnp.finfo(y.dtype).eps one_minus_remainder = jnp.clip(one_minus_remainder, a_max=1 - eps) # log(remainder) = log1p(remainder - 1) stick_breaking_logdet = jnp.sum(jnp.log1p(-one_minus_remainder), axis=-1) tanh_logdet = -2 * jnp.sum(x + softplus(-2 * x) - jnp.log(2.0), axis=-1) return stick_breaking_logdet + tanh_logdet
[docs] class LowerCholeskyAffine(Transform): r""" Transform via the mapping :math:`y = loc + scale\_tril\ @\ x`. :param loc: a real vector. :param scale_tril: a lower triangular matrix with positive diagonal. **Example** .. doctest:: >>> import jax.numpy as jnp >>> from numpyro.distributions.transforms import LowerCholeskyAffine >>> base = jnp.ones(2) >>> loc = jnp.zeros(2) >>> scale_tril = jnp.array([[0.3, 0.0], [1.0, 0.5]]) >>> affine = LowerCholeskyAffine(loc=loc, scale_tril=scale_tril) >>> affine(base) Array([0.3, 1.5], dtype=float32) """ domain = constraints.real_vector codomain = constraints.real_vector def __init__(self, loc, scale_tril): if jnp.ndim(scale_tril) != 2: raise ValueError( "Only support 2-dimensional scale_tril matrix. " "Please make a feature request if you need to " "use this transform with batched scale_tril." ) self.loc = loc self.scale_tril = scale_tril def __call__(self, x): return self.loc + jnp.squeeze( jnp.matmul(self.scale_tril, x[..., jnp.newaxis]), axis=-1 ) def _inverse(self, y): y = y - self.loc original_shape = jnp.shape(y) yt = jnp.reshape(y, (-1, original_shape[-1])).T xt = solve_triangular(self.scale_tril, yt, lower=True) return jnp.reshape(xt.T, original_shape)
[docs] def log_abs_det_jacobian(self, x, y, intermediates=None): return jnp.broadcast_to( jnp.log(jnp.diagonal(self.scale_tril, axis1=-2, axis2=-1)).sum(-1), jnp.shape(x)[:-1], )
[docs] def forward_shape(self, shape): if len(shape) < 1: raise ValueError("Too few dimensions on input") return lax.broadcast_shapes(shape, self.loc.shape, self.scale_tril.shape[:-1])
[docs] def inverse_shape(self, shape): if len(shape) < 1: raise ValueError("Too few dimensions on input") return lax.broadcast_shapes(shape, self.loc.shape, self.scale_tril.shape[:-1])
[docs] def tree_flatten(self): return (self.loc, self.scale_tril), (("loc", "scale_tril"), dict())
def __eq__(self, other): if not isinstance(other, LowerCholeskyAffine): return False return jnp.array_equal(self.loc, other.loc) & jnp.array_equal( self.scale_tril, other.scale_tril )
[docs] class LowerCholeskyTransform(ParameterFreeTransform): """ Transform a real vector to a lower triangular cholesky factor, where the strictly lower triangular submatrix is unconstrained and the diagonal is parameterized with an exponential transform. """ domain = constraints.real_vector codomain = constraints.lower_cholesky def __call__(self, x): n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) z = vec_to_tril_matrix(x[..., :-n], diagonal=-1) diag = jnp.exp(x[..., -n:]) return add_diag(z, diag) def _inverse(self, y): z = matrix_to_tril_vec(y, diagonal=-1) return jnp.concatenate( [z, jnp.log(jnp.diagonal(y, axis1=-2, axis2=-1))], axis=-1 )
[docs] def log_abs_det_jacobian(self, x, y, intermediates=None): # the jacobian is diagonal, so logdet is the sum of diagonal `exp` transform n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) return x[..., -n:].sum(-1)
[docs] def forward_shape(self, shape): return _matrix_forward_shape(shape)
[docs] def inverse_shape(self, shape): return _matrix_inverse_shape(shape)
[docs] class ScaledUnitLowerCholeskyTransform(LowerCholeskyTransform): r""" Like `LowerCholeskyTransform` this `Transform` transforms a real vector to a lower triangular cholesky factor. However it does so via a decomposition :math:`y = loc + unit\_scale\_tril\ @\ scale\_diag\ @\ x`. where :math:`unit\_scale\_tril` has ones along the diagonal and :math:`scale\_diag` is a diagonal matrix with all positive entries that is parameterized with a softplus transform. """ domain = constraints.real_vector codomain = constraints.scaled_unit_lower_cholesky def __call__(self, x): n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) z = vec_to_tril_matrix(x[..., :-n], diagonal=-1) diag = softplus(x[..., -n:]) return add_diag(z, 1) * diag[..., None] def _inverse(self, y): diag = jnp.diagonal(y, axis1=-2, axis2=-1) z = matrix_to_tril_vec(y / diag[..., None], diagonal=-1) return jnp.concatenate([z, _softplus_inv(diag)], axis=-1)
[docs] def log_abs_det_jacobian(self, x, y, intermediates=None): n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) diag = x[..., -n:] diag_softplus = jnp.diagonal(y, axis1=-2, axis2=-1) return (jnp.log(diag_softplus) * jnp.arange(n) - softplus(-diag)).sum(-1)
[docs] class OrderedTransform(ParameterFreeTransform): """ Transform a real vector to an ordered vector. **References:** 1. *Stan Reference Manual v2.20, section 10.6*, Stan Development Team **Example** .. doctest:: >>> import jax.numpy as jnp >>> from numpyro.distributions.transforms import OrderedTransform >>> base = jnp.ones(3) >>> transform = OrderedTransform() >>> assert jnp.allclose(transform(base), jnp.array([1., 3.7182817, 6.4365635]), rtol=1e-3, atol=1e-3) """ domain = constraints.real_vector codomain = constraints.ordered_vector def __call__(self, x): z = jnp.concatenate([x[..., :1], jnp.exp(x[..., 1:])], axis=-1) return jnp.cumsum(z, axis=-1) def _inverse(self, y): x = jnp.log(y[..., 1:] - y[..., :-1]) return jnp.concatenate([y[..., :1], x], axis=-1)
[docs] def log_abs_det_jacobian(self, x, y, intermediates=None): return jnp.sum(x[..., 1:], -1)
[docs] class PermuteTransform(Transform): domain = constraints.real_vector codomain = constraints.real_vector def __init__(self, permutation): self.permutation = permutation def __call__(self, x): return x[..., self.permutation] def _inverse(self, y): size = self.permutation.size permutation_inv = ( jnp.zeros(size, dtype=jnp.result_type(int)) .at[self.permutation] .set(jnp.arange(size)) ) return y[..., permutation_inv]
[docs] def log_abs_det_jacobian(self, x, y, intermediates=None): return jnp.full(jnp.shape(x)[:-1], 0.0)
[docs] def tree_flatten(self): return (self.permutation,), (("permutation",), dict())
def __eq__(self, other): if not isinstance(other, PermuteTransform): return False return jnp.array_equal(self.permutation, other.permutation)
[docs] class PowerTransform(Transform): domain = constraints.positive codomain = constraints.positive def __init__(self, exponent): self.exponent = exponent def __call__(self, x): return jnp.power(x, self.exponent) def _inverse(self, y): return jnp.power(y, 1 / self.exponent)
[docs] def log_abs_det_jacobian(self, x, y, intermediates=None): return jnp.log(jnp.abs(self.exponent * y / x))
[docs] def forward_shape(self, shape): return lax.broadcast_shapes(shape, getattr(self.exponent, "shape", ()))
[docs] def inverse_shape(self, shape): return lax.broadcast_shapes(shape, getattr(self.exponent, "shape", ()))
[docs] def tree_flatten(self): return (self.exponent,), (("exponent",), dict())
def __eq__(self, other): if not isinstance(other, PowerTransform): return False return jnp.array_equal(self.exponent, other.exponent)
[docs] class SigmoidTransform(ParameterFreeTransform): codomain = constraints.unit_interval def __call__(self, x): return _clipped_expit(x) def _inverse(self, y): return logit(y)
[docs] def log_abs_det_jacobian(self, x, y, intermediates=None): return -softplus(x) - softplus(-x)
[docs] class SimplexToOrderedTransform(Transform): """ Transform a simplex into an ordered vector (via difference in Logistic CDF between cutpoints) Used in [1] to induce a prior on latent cutpoints via transforming ordered category probabilities. :param anchor_point: Anchor point is a nuisance parameter to improve the identifiability of the transform. For simplicity, we assume it is a scalar value, but it is broadcastable x.shape[:-1]. For more details please refer to Section 2.2 in [1] **References:** 1. *Ordinal Regression Case Study, section 2.2*, M. Betancourt, https://betanalpha.github.io/assets/case_studies/ordinal_regression.html **Example** .. doctest:: >>> import jax.numpy as jnp >>> from numpyro.distributions.transforms import SimplexToOrderedTransform >>> base = jnp.array([0.3, 0.1, 0.4, 0.2]) >>> transform = SimplexToOrderedTransform() >>> assert jnp.allclose(transform(base), jnp.array([-0.8472978, -0.40546507, 1.3862944]), rtol=1e-3, atol=1e-3) """ domain = constraints.simplex codomain = constraints.ordered_vector def __init__(self, anchor_point=0.0): self.anchor_point = anchor_point def __call__(self, x): s = jnp.cumsum(x[..., :-1], axis=-1) y = logit(s) + jnp.expand_dims(self.anchor_point, -1) return y def _inverse(self, y): y = y - jnp.expand_dims(self.anchor_point, -1) s = expit(y) # x0 = s0, x1 = s1 - s0, x2 = s2 - s1,..., xn = 1 - s[n-1] # add two boundary points 0 and 1 pad_width = [(0, 0)] * (jnp.ndim(s) - 1) + [(1, 1)] s = jnp.pad(s, pad_width, constant_values=(0, 1)) x = s[..., 1:] - s[..., :-1] return x
[docs] def log_abs_det_jacobian(self, x, y, intermediates=None): # |dp/dc| = |dx/dy| = prod(ds/dy) = prod(expit'(y)) # we know log derivative of expit(y) is `-softplus(y) - softplus(-y)` J_logdet = (softplus(y) + softplus(-y)).sum(-1) return J_logdet
[docs] def tree_flatten(self): return (self.anchor_point,), (("anchor_point",), dict())
def __eq__(self, other): if not isinstance(other, SimplexToOrderedTransform): return False return jnp.array_equal(self.anchor_point, other.anchor_point)
[docs] def forward_shape(self, shape): return shape[:-1] + (shape[-1] - 1,)
[docs] def inverse_shape(self, shape): return shape[:-1] + (shape[-1] + 1,)
def _softplus_inv(y): return jnp.log(-jnp.expm1(-y)) + y
[docs] class SoftplusTransform(ParameterFreeTransform): r""" Transform from unconstrained space to positive domain via softplus :math:`y = \log(1 + \exp(x))`. The inverse is computed as :math:`x = \log(\exp(y) - 1)`. """ domain = constraints.real codomain = constraints.softplus_positive def __call__(self, x): return softplus(x) def _inverse(self, y): return _softplus_inv(y)
[docs] def log_abs_det_jacobian(self, x, y, intermediates=None): return -softplus(-x)
[docs] class SoftplusLowerCholeskyTransform(ParameterFreeTransform): """ Transform from unconstrained vector to lower-triangular matrices with nonnegative diagonal entries. This is useful for parameterizing positive definite matrices in terms of their Cholesky factorization. """ domain = constraints.real_vector codomain = constraints.softplus_lower_cholesky def __call__(self, x): n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) z = vec_to_tril_matrix(x[..., :-n], diagonal=-1) diag = softplus(x[..., -n:]) return z + jnp.expand_dims(diag, axis=-1) * jnp.identity(n) def _inverse(self, y): z = matrix_to_tril_vec(y, diagonal=-1) diag = _softplus_inv(jnp.diagonal(y, axis1=-2, axis2=-1)) return jnp.concatenate([z, diag], axis=-1)
[docs] def log_abs_det_jacobian(self, x, y, intermediates=None): # the jacobian is diagonal, so logdet is the sum of diagonal # `softplus` transform n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) return -softplus(-x[..., -n:]).sum(-1)
[docs] def forward_shape(self, shape): return _matrix_forward_shape(shape)
[docs] def inverse_shape(self, shape): return _matrix_inverse_shape(shape)
[docs] class StickBreakingTransform(ParameterFreeTransform): domain = constraints.real_vector codomain = constraints.simplex def __call__(self, x): # we shift x to obtain a balanced mapping (0, 0, ..., 0) -> (1/K, 1/K, ..., 1/K) x = x - jnp.log(x.shape[-1] - jnp.arange(x.shape[-1])) # convert to probabilities (relative to the remaining) of each fraction of the stick z = _clipped_expit(x) z1m_cumprod = jnp.cumprod(1 - z, axis=-1) pad_width = [(0, 0)] * x.ndim pad_width[-1] = (0, 1) z_padded = jnp.pad(z, pad_width, mode="constant", constant_values=1.0) pad_width = [(0, 0)] * x.ndim pad_width[-1] = (1, 0) z1m_cumprod_shifted = jnp.pad( z1m_cumprod, pad_width, mode="constant", constant_values=1.0 ) return z_padded * z1m_cumprod_shifted def _inverse(self, y): y_crop = y[..., :-1] z1m_cumprod = jnp.clip( 1 - jnp.cumsum(y_crop, axis=-1), a_min=jnp.finfo(y.dtype).tiny ) # hence x = logit(z) = log(z / (1 - z)) = y[::-1] / z1m_cumprod x = jnp.log(y_crop / z1m_cumprod) return x + jnp.log(x.shape[-1] - jnp.arange(x.shape[-1]))
[docs] def log_abs_det_jacobian(self, x, y, intermediates=None): # Ref: https://mc-stan.org/docs/2_19/reference-manual/simplex-transform-section.html # |det|(J) = Product(y * (1 - sigmoid(x))) # = Product(y * sigmoid(x) * exp(-x)) x = x - jnp.log(x.shape[-1] - jnp.arange(x.shape[-1])) return jnp.sum(jnp.log(y[..., :-1]) + (log_sigmoid(x) - x), axis=-1)
[docs] def forward_shape(self, shape): if len(shape) < 1: raise ValueError("Too few dimensions on input") return shape[:-1] + (shape[-1] + 1,)
[docs] def inverse_shape(self, shape): if len(shape) < 1: raise ValueError("Too few dimensions on input") return shape[:-1] + (shape[-1] - 1,)
class UnpackTransform(Transform): """ Transforms a contiguous array to a pytree of subarrays. :param unpack_fn: callable used to unpack a contiguous array. """ domain = constraints.real_vector codomain = constraints.dependent def __init__(self, unpack_fn): self.unpack_fn = unpack_fn def __call__(self, x): batch_shape = x.shape[:-1] if batch_shape: unpacked = vmap(self.unpack_fn)(x.reshape((-1,) + x.shape[-1:])) return tree_map( lambda z: jnp.reshape(z, batch_shape + z.shape[1:]), unpacked ) else: return self.unpack_fn(x) def _inverse(self, y): leading_dims = [ v.shape[0] if jnp.ndim(v) > 0 else 0 for v in tree_flatten(y)[0] ] d0 = leading_dims[0] not_scalar = d0 > 0 or len(leading_dims) > 1 if not_scalar and all(d == d0 for d in leading_dims[1:]): warnings.warn( "UnpackTransform.inv might lead to an unexpected behavior because it" " cannot transform a batch of unpacked arrays.", stacklevel=find_stack_level(), ) return ravel_pytree(y)[0] def log_abs_det_jacobian(self, x, y, intermediates=None): return jnp.zeros(jnp.shape(x)[:-1]) def forward_shape(self, shape): raise NotImplementedError def inverse_shape(self, shape): raise NotImplementedError def tree_flatten(self): # XXX: what if unpack_fn is a parametrized callable pytree? return (), ((), {"unpack_fn": self.unpack_fn}) def __eq__(self, other): return isinstance(other, UnpackTransform) and self.unpack_fn is other.unpack_fn def _get_target_shape(shape, forward_shape, inverse_shape): batch_ndims = len(shape) - len(inverse_shape) return shape[:batch_ndims] + forward_shape class ReshapeTransform(Transform): """ Reshape a sample, leaving batch dimensions unchanged. :param forward_shape: Shape to transform the sample to. :param inverse_shape: Shape of the sample for the inverse transform. """ domain = constraints.real codomain = constraints.real def __init__(self, forward_shape, inverse_shape) -> None: forward_size = math.prod(forward_shape) inverse_size = math.prod(inverse_shape) if forward_size != inverse_size: raise ValueError( f"forward shape {forward_shape} (size {forward_size}) and inverse " f"shape {inverse_shape} (size {inverse_size}) are not compatible" ) self._forward_shape = forward_shape self._inverse_shape = inverse_shape def forward_shape(self, shape): return _get_target_shape(shape, self._forward_shape, self._inverse_shape) def inverse_shape(self, shape): return _get_target_shape(shape, self._inverse_shape, self._forward_shape) def __call__(self, x): return jnp.reshape(x, self.forward_shape(jnp.shape(x))) def _inverse(self, y): return jnp.reshape(y, self.inverse_shape(jnp.shape(y))) def log_abs_det_jacobian(self, x, y, intermediates=None): return 0.0 def tree_flatten(self): aux_data = { "_forward_shape": self._forward_shape, "_inverse_shape": self._inverse_shape, } return (), ((), aux_data) def __eq__(self, other): return ( isinstance(other, ReshapeTransform) and self._forward_shape == other._forward_shape and self._inverse_shape == other._inverse_shape ) ########################################################## # CONSTRAINT_REGISTRY ########################################################## class ConstraintRegistry(object): def __init__(self): self._registry = {} def register(self, constraint, factory=None): if factory is None: return lambda factory: self.register(constraint, factory) if isinstance(constraint, constraints.Constraint): constraint = type(constraint) self._registry[constraint] = factory return factory def __call__(self, constraint): try: factory = self._registry[type(constraint)] except KeyError as e: raise NotImplementedError from e return factory(constraint) biject_to = ConstraintRegistry() @biject_to.register(constraints.corr_cholesky) def _transform_to_corr_cholesky(constraint): return CorrCholeskyTransform() @biject_to.register(constraints.corr_matrix) def _transform_to_corr_matrix(constraint): return ComposeTransform( [CorrCholeskyTransform(), CorrMatrixCholeskyTransform().inv] ) @biject_to.register(type(constraints.positive)) def _transform_to_positive(constraint): return ExpTransform() @biject_to.register(constraints.greater_than) def _transform_to_greater_than(constraint): return ComposeTransform( [ ExpTransform(), AffineTransform(constraint.lower_bound, 1, domain=constraints.positive), ] ) @biject_to.register(constraints.less_than) def _transform_to_less_than(constraint): return ComposeTransform( [ ExpTransform(), AffineTransform(constraint.upper_bound, -1, domain=constraints.positive), ] ) @biject_to.register(type(constraints.real_matrix)) @biject_to.register(type(constraints.real_vector)) @biject_to.register(constraints.independent) def _biject_to_independent(constraint): return IndependentTransform( biject_to(constraint.base_constraint), constraint.reinterpreted_batch_ndims ) @biject_to.register(type(constraints.unit_interval)) def _transform_to_unit_interval(constraint): return SigmoidTransform() @biject_to.register(type(constraints.circular)) @biject_to.register(constraints.open_interval) @biject_to.register(constraints.interval) def _transform_to_interval(constraint): scale = constraint.upper_bound - constraint.lower_bound return ComposeTransform( [ SigmoidTransform(), AffineTransform( constraint.lower_bound, scale, domain=constraints.unit_interval ), ] ) @biject_to.register(constraints.l1_ball) def _transform_to_l1_ball(constraint): return L1BallTransform() @biject_to.register(constraints.lower_cholesky) def _transform_to_lower_cholesky(constraint): return LowerCholeskyTransform() @biject_to.register(constraints.scaled_unit_lower_cholesky) def _transform_to_scaled_unit_lower_cholesky(constraint): return ScaledUnitLowerCholeskyTransform() @biject_to.register(constraints.ordered_vector) def _transform_to_ordered_vector(constraint): return OrderedTransform() @biject_to.register(constraints.positive_definite) def _transform_to_positive_definite(constraint): return ComposeTransform([LowerCholeskyTransform(), CholeskyTransform().inv]) @biject_to.register(constraints.positive_ordered_vector) def _transform_to_positive_ordered_vector(constraint): return ComposeTransform([OrderedTransform(), ExpTransform()]) @biject_to.register(constraints.real) def _transform_to_real(constraint): return IdentityTransform() @biject_to.register(constraints.softplus_positive) def _transform_to_softplus_positive(constraint): return SoftplusTransform() @biject_to.register(constraints.softplus_lower_cholesky) def _transform_to_softplus_lower_cholesky(constraint): return SoftplusLowerCholeskyTransform() @biject_to.register(constraints.simplex) def _transform_to_simplex(constraint): return StickBreakingTransform()