Source code for numpyro.distributions.transforms
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
import math
from typing import Generic, Optional, Sequence, Tuple, Union, cast
import warnings
import weakref
import numpy as np
import jax
from jax import Array, lax, vmap
from jax.nn import log_sigmoid, softplus
import jax.numpy as jnp
from jax.scipy.linalg import solve_triangular
from jax.scipy.special import expit, logit
from jax.tree_util import register_pytree_node
from jax.typing import ArrayLike
from numpyro._typing import (
NonScalarArray,
NumLike,
NumLikeT,
PyTree,
)
from numpyro.distributions import constraints
from numpyro.distributions.constraints import Constraint
from numpyro.distributions.util import (
add_diag,
array_equiv,
matrix_to_tril_vec,
signed_stick_breaking_tril,
sum_rightmost,
vec_to_tril_matrix,
)
from numpyro.util import find_stack_level, not_jax_tracer
__all__ = [
"biject_to",
"AbsTransform",
"AffineTransform",
"CholeskyTransform",
"ComplexTransform",
"ComposeTransform",
"CorrCholeskyTransform",
"CorrMatrixCholeskyTransform",
"ExpTransform",
"IdentityTransform",
"L1BallTransform",
"LowerCholeskyTransform",
"ScaledUnitLowerCholeskyTransform",
"LowerCholeskyAffine",
"PackRealFastFourierCoefficientsTransform",
"PermuteTransform",
"PowerTransform",
"RealFastFourierTransform",
"ReshapeTransform",
"SigmoidTransform",
"SimplexToOrderedTransform",
"SoftplusTransform",
"SoftplusLowerCholeskyTransform",
"StickBreakingTransform",
"Transform",
"UnpackTransform",
"ZeroSumTransform",
]
def _clipped_expit(x: NumLike) -> NumLike:
finfo = jnp.finfo(jnp.result_type(x))
return jnp.clip(expit(x), finfo.tiny, 1.0 - finfo.eps)
[docs]
class Transform(Generic[NumLikeT]):
_inv: Optional[Union["Transform", weakref.ref]] = None
@property
def domain(self) -> Constraint:
return constraints.real
@property
def codomain(self) -> Constraint:
return constraints.real
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
register_pytree_node(cls, cls.tree_flatten, cls.tree_unflatten)
@property
def inv(self) -> "Transform":
inv = None
if (self._inv is not None) and isinstance(self._inv, weakref.ref):
inv = self._inv()
if inv is None:
inv = _InverseTransform(self)
self._inv = weakref.ref(inv)
return cast("Transform", inv)
def __call__(self, x: NumLikeT) -> NumLike:
raise NotImplementedError()
def _inverse(self, y: NumLikeT) -> NumLike:
raise NotImplementedError()
[docs]
def log_abs_det_jacobian(
self, x: NumLikeT, y: NumLikeT, intermediates: Optional[PyTree] = None
) -> NumLike:
raise NotImplementedError()
[docs]
def call_with_intermediates(self, x: NumLikeT) -> Tuple[NumLike, Optional[PyTree]]:
return self(x), None
[docs]
def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
"""
Infers the shape of the forward computation, given the input shape.
Defaults to preserving shape.
"""
return shape
[docs]
def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
"""
Infers the shapes of the inverse computation, given the output shape.
Defaults to preserving shape.
"""
return shape
@property
def sign(self) -> NumLike:
"""
Sign of the derivative of the transform if it is bijective.
"""
raise NotImplementedError(
f"Transform `{self.__class__.__name__}` does not implement `sign`."
)
# Allow for pickle serialization of transforms.
def __getstate__(self):
attrs = {}
for k, v in self.__dict__.items():
if isinstance(v, weakref.ref):
attrs[k] = None
else:
attrs[k] = v
return attrs
[docs]
@classmethod
def tree_unflatten(cls, aux_data, params):
params_keys, aux_data = aux_data
self = cls.__new__(cls)
for k, v in zip(params_keys, params):
setattr(self, k, v)
for k, v in aux_data.items():
setattr(self, k, v)
return self
def __eq__(self, other: object) -> bool:
return bool(self.eq(other, static=True))
class ParameterFreeTransform(Transform[NumLikeT]):
def tree_flatten(self):
return (), ((), dict())
def eq(self, other: object, static: bool = False) -> ArrayLike:
return isinstance(other, type(self))
class _InverseTransform(Transform[NumLike]):
_inv: Transform
def __init__(self, transform: Transform):
super().__init__()
self._inv = transform
@property
def domain(self) -> Constraint:
return self._inv.codomain
@property
def codomain(self) -> Constraint:
return self._inv.domain
@property
def sign(self) -> NumLike:
return self._inv.sign
@property
def inv(self) -> Transform:
return self._inv
def __call__(self, x: NumLike) -> NumLike:
return self._inv._inverse(x)
def log_abs_det_jacobian(
self, x: NumLike, y: NumLike, intermediates: Optional[PyTree] = None
) -> NumLike:
# NB: we don't use intermediates for inverse transform
return -self._inv.log_abs_det_jacobian(y, x, None)
def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
return self._inv.inverse_shape(shape)
def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
return self._inv.forward_shape(shape)
def tree_flatten(self):
return (self._inv,), (("_inv",), dict())
def eq(self, other: object, static: bool = False) -> ArrayLike:
if not isinstance(other, _InverseTransform):
return False
return self._inv.eq(other._inv, static=static)
[docs]
class AbsTransform(ParameterFreeTransform[NumLike]):
domain = constraints.real
codomain = constraints.positive
def __call__(self, x: NumLike) -> NumLike:
return jnp.abs(x)
def _inverse(self, y: NumLike) -> NumLike:
warnings.warn(
"AbsTransform is not a bijective transform."
" The inverse of `y` will be `y`.",
stacklevel=find_stack_level(),
)
return y
[docs]
class AffineTransform(Transform[NumLike]):
"""
.. note:: When `scale` is a JAX tracer, we always assume that `scale > 0`
when calculating `codomain`.
"""
def __init__(
self, loc: NumLike, scale: NumLike, domain: Constraint = constraints.real
):
self.loc = loc
self.scale = scale
self._domain = domain
@property
def domain(self) -> Constraint:
return self._domain
@property
def codomain(self) -> Constraint:
if self.domain is constraints.real:
return constraints.real
elif isinstance(self.domain, constraints.greater_than):
if not_jax_tracer(self.scale) and np.all(np.less(self.scale, 0)):
return constraints.less_than(self(self.domain.lower_bound))
# we suppose scale > 0 for any tracer
else:
return constraints.greater_than(self(self.domain.lower_bound))
elif isinstance(self.domain, constraints.less_than):
if not_jax_tracer(self.scale) and np.all(np.less(self.scale, 0)):
return constraints.greater_than(self(self.domain.upper_bound))
# we suppose scale > 0 for any tracer
else:
return constraints.less_than(self(self.domain.upper_bound))
elif isinstance(self.domain, constraints.interval):
if not_jax_tracer(self.scale) and np.all(np.less(self.scale, 0)):
return constraints.interval(
self(self.domain.upper_bound), self(self.domain.lower_bound)
)
else:
return constraints.interval(
self(self.domain.lower_bound), self(self.domain.upper_bound)
)
else:
raise NotImplementedError
@property
def sign(self) -> NumLike:
return jnp.sign(self.scale)
def __call__(self, x: NumLike) -> NumLike:
return self.loc + jnp.multiply(self.scale, x)
def _inverse(self, y: NumLike) -> NumLike:
return (y - self.loc) / self.scale
[docs]
def log_abs_det_jacobian(
self, x: NumLike, y: NumLike, intermediates: Optional[PyTree] = None
) -> NumLike:
return jnp.broadcast_to(jnp.log(jnp.abs(self.scale)), jnp.shape(x))
[docs]
def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
return lax.broadcast_shapes(
shape, getattr(self.loc, "shape", ()), getattr(self.scale, "shape", ())
)
[docs]
def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
return lax.broadcast_shapes(
shape, getattr(self.loc, "shape", ()), getattr(self.scale, "shape", ())
)
[docs]
def tree_flatten(self):
return (self.loc, self.scale, self.domain), (
("loc", "scale", "_domain"),
dict(),
)
[docs]
def eq(self, other: object, static: bool = False) -> ArrayLike:
if not isinstance(other, AffineTransform):
return False
return (
array_equiv(self.loc, other.loc, static=static)
& array_equiv(self.scale, other.scale, static=static)
& self.domain.eq(other.domain, static=static)
)
def _get_compose_transform_input_event_dim(parts):
input_event_dim = parts[-1].domain.event_dim
for part in parts[:-1][::-1]:
input_event_dim = part.domain.event_dim + max(
input_event_dim - part.codomain.event_dim, 0
)
return input_event_dim
def _get_compose_transform_output_event_dim(parts):
output_event_dim = parts[0].codomain.event_dim
for part in parts[1:]:
output_event_dim = part.codomain.event_dim + max(
output_event_dim - part.domain.event_dim, 0
)
return output_event_dim
[docs]
class ComposeTransform(Transform[NumLike]):
def __init__(self, parts: Sequence[Transform]) -> None:
self.parts = parts
@property
def domain(self) -> Constraint:
input_event_dim = _get_compose_transform_input_event_dim(self.parts)
first_input_event_dim = self.parts[0].domain.event_dim
assert input_event_dim >= first_input_event_dim
if input_event_dim == first_input_event_dim:
return self.parts[0].domain
else:
return constraints.independent(
self.parts[0].domain, input_event_dim - first_input_event_dim
)
@property
def codomain(self) -> Constraint:
output_event_dim = _get_compose_transform_output_event_dim(self.parts)
last_output_event_dim = self.parts[-1].codomain.event_dim
assert output_event_dim >= last_output_event_dim
if output_event_dim == last_output_event_dim:
return self.parts[-1].codomain
else:
return constraints.independent(
self.parts[-1].codomain, output_event_dim - last_output_event_dim
)
@property
def sign(self) -> NumLike:
sign: NumLike = 1
for transform in self.parts:
sign *= transform.sign
return sign
def __call__(self, x: NumLike) -> NumLike:
for part in self.parts:
x = part(x)
return x
def _inverse(self, y: NumLike) -> NumLike:
for part in self.parts[::-1]:
y = part.inv(y)
return y
[docs]
def log_abs_det_jacobian(
self, x: NumLike, y: NumLike, intermediates: Optional[PyTree] = None
) -> NumLike:
if intermediates is not None:
if len(intermediates) != len(self.parts):
raise ValueError(
"Intermediates array has length = {}. Expected = {}.".format(
len(intermediates), len(self.parts)
)
)
result = 0.0
input_event_dim = self.domain.event_dim
for i, part in enumerate(self.parts[:-1]):
y_tmp = part(x) if intermediates is None else intermediates[i][0]
inter = None if intermediates is None else intermediates[i][1]
logdet = part.log_abs_det_jacobian(x, y_tmp, intermediates=inter)
batch_ndim = input_event_dim - part.domain.event_dim
result = result + sum_rightmost(logdet, batch_ndim)
input_event_dim = part.codomain.event_dim + batch_ndim
x = y_tmp
# account the the last transform, where y is available
inter = None if intermediates is None else intermediates[-1]
part = self.parts[-1]
logdet = part.log_abs_det_jacobian(x, y, intermediates=inter)
result = result + sum_rightmost(logdet, input_event_dim - part.domain.event_dim)
return result
[docs]
def call_with_intermediates(self, x: NumLike) -> Tuple[NumLike, Optional[PyTree]]:
intermediates: list[Optional[PyTree]] = []
for part in self.parts[:-1]:
x, inter = part.call_with_intermediates(x)
intermediates.append([x, inter])
# NB: we don't need to hold the last output value in `intermediates`
x, inter = self.parts[-1].call_with_intermediates(x)
intermediates.append(inter)
return x, intermediates
[docs]
def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
for part in self.parts:
shape = part.forward_shape(shape)
return shape
[docs]
def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
for part in reversed(self.parts):
shape = part.inverse_shape(shape)
return shape
[docs]
def eq(self, other: object, static: bool = False) -> ArrayLike:
if not isinstance(other, ComposeTransform):
return False
if len(self.parts) != len(other.parts):
return False
if static:
return all(
p1.eq(p2, static=True) for p1, p2 in zip(self.parts, other.parts)
)
result = jnp.array(True)
for p1, p2 in zip(self.parts, other.parts):
result = result & p1.eq(p2, static=False)
return result
def _matrix_forward_shape(shape: tuple[int, ...], offset: int = 0) -> tuple[int, ...]:
# Reshape from (..., N) to (..., D, D).
if len(shape) < 1:
raise ValueError("Too few dimensions in input")
N = shape[-1]
D = round((0.25 + 2 * N) ** 0.5 - 0.5)
if D * (D + 1) // 2 != N:
raise ValueError("Input is not a flattened lower-diagonal number")
D = D - offset
return shape[:-1] + (D, D)
def _matrix_inverse_shape(shape: tuple[int, ...], offset: int = 0) -> tuple[int, ...]:
# Reshape from (..., D, D) to (..., N).
if len(shape) < 2:
raise ValueError("Too few dimensions on input")
if shape[-2] != shape[-1]:
raise ValueError("Input is not square")
D = shape[-1] + offset
N = D * (D + 1) // 2
return shape[:-2] + (N,)
[docs]
class CholeskyTransform(ParameterFreeTransform[NonScalarArray]):
r"""
Transform via the mapping :math:`y = cholesky(x)`, where `x` is a
positive definite matrix.
"""
domain = constraints.positive_definite
codomain = constraints.lower_cholesky
def __call__(self, x: NonScalarArray) -> NonScalarArray:
return jnp.linalg.cholesky(x)
def _inverse(self, y: NonScalarArray) -> NonScalarArray:
return jnp.matmul(y, jnp.swapaxes(y, -2, -1))
[docs]
def log_abs_det_jacobian(
self,
x: NonScalarArray,
y: NonScalarArray,
intermediates: Optional[PyTree] = None,
) -> NumLike:
# Ref: http://web.mit.edu/18.325/www/handouts/handout2.pdf page 13
n = jnp.shape(x)[-1]
order = -jnp.arange(n, 0, -1)
return -n * jnp.log(2) + jnp.sum(
order * jnp.log(jnp.diagonal(y, axis1=-2, axis2=-1)), axis=-1
)
[docs]
class CorrCholeskyTransform(ParameterFreeTransform[NonScalarArray]):
r"""
Transforms an unconstrained real vector :math:`x` with length :math:`D*(D-1)/2` into the
Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower
triangular matrix with positive diagonals and unit Euclidean norm for each row.
The transform is processed as follows:
1. First we convert :math:`x` into a lower triangular matrix with the following order:
.. math::
\begin{bmatrix}
1 & 0 & 0 & 0 \\
x_0 & 1 & 0 & 0 \\
x_1 & x_2 & 1 & 0 \\
x_3 & x_4 & x_5 & 1
\end{bmatrix}
2. For each row :math:`X_i` of the lower triangular part, we apply a *signed* version of
class :class:`StickBreakingTransform` to transform :math:`X_i` into a
unit Euclidean length vector using the following steps:
a. Scales into the interval :math:`(-1, 1)` domain: :math:`r_i = \tanh(X_i)`.
b. Transforms into an unsigned domain: :math:`z_i = r_i^2`.
c. Applies :math:`s_i = StickBreakingTransform(z_i)`.
d. Transforms back into signed domain: :math:`y_i = (sign(r_i), 1) * \sqrt{s_i}`.
"""
domain = constraints.real_vector
codomain = constraints.corr_cholesky
def __call__(self, x: NonScalarArray) -> NonScalarArray:
# we interchange step 1 and step 2.a for a better performance
t = jnp.tanh(x)
return signed_stick_breaking_tril(t)
def _inverse(self, y: NonScalarArray) -> NonScalarArray:
# inverse stick-breaking
z1m_cumprod = 1 - jnp.cumsum(y * y, axis=-1)
pad_width = [(0, 0)] * y.ndim
pad_width[-1] = (1, 0)
z1m_cumprod_shifted = jnp.pad(
z1m_cumprod[..., :-1], pad_width, mode="constant", constant_values=1.0
)
t = matrix_to_tril_vec(y, diagonal=-1) / jnp.sqrt(
matrix_to_tril_vec(z1m_cumprod_shifted, diagonal=-1)
)
# inverse of tanh
return jnp.arctanh(t)
[docs]
def log_abs_det_jacobian(
self,
x: NonScalarArray,
y: NonScalarArray,
intermediates: Optional[PyTree] = None,
) -> NumLike:
# NB: because domain and codomain are two spaces with different dimensions, determinant of
# Jacobian is not well-defined. Here we return `log_abs_det_jacobian` of `x` and the
# flatten lower triangular part of `y`.
# stick_breaking_logdet = log(y / r) = log(z_cumprod) (modulo right shifted)
z1m_cumprod = 1 - jnp.cumsum(y * y, axis=-1)
# by taking diagonal=-2, we don't need to shift z_cumprod to the right
# NB: diagonal=-2 works fine for (2 x 2) matrix, where we get an empty array
z1m_cumprod_tril = matrix_to_tril_vec(z1m_cumprod, diagonal=-2)
stick_breaking_logdet = 0.5 * jnp.sum(jnp.log(z1m_cumprod_tril), axis=-1)
tanh_logdet = -2 * jnp.sum(x + softplus(-2 * x) - jnp.log(2.0), axis=-1)
return stick_breaking_logdet + tanh_logdet
[docs]
def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
return _matrix_forward_shape(shape, offset=-1)
[docs]
def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
return _matrix_inverse_shape(shape, offset=-1)
[docs]
class CorrMatrixCholeskyTransform(CholeskyTransform):
r"""
Transform via the mapping :math:`y = cholesky(x)`, where `x` is a
correlation matrix.
"""
domain = constraints.corr_matrix # type: ignore[assignment]
codomain = constraints.corr_cholesky # type: ignore[assignment]
[docs]
def log_abs_det_jacobian(
self,
x: NonScalarArray,
y: NonScalarArray,
intermediates: Optional[PyTree] = None,
) -> NumLike:
# NB: see derivation in LKJCholesky implementation
n = jnp.shape(x)[-1]
order = -jnp.arange(n - 1, -1, -1)
return jnp.sum(order * jnp.log(jnp.diagonal(y, axis1=-2, axis2=-1)), axis=-1)
[docs]
class ExpTransform(Transform[NumLike]):
sign = 1
# TODO: refine domain/codomain logic through setters, especially when
# transforms for inverses are supported
def __init__(self, domain: Constraint = constraints.real):
self._domain = domain
@property
def domain(self) -> Constraint:
return self._domain
@property
def codomain(self) -> Constraint:
if self.domain is constraints.ordered_vector:
return constraints.positive_ordered_vector
elif self.domain is constraints.real:
return constraints.positive
elif isinstance(self.domain, constraints.greater_than):
return constraints.greater_than(self.__call__(self.domain.lower_bound))
elif isinstance(self.domain, constraints.interval):
return constraints.interval(
self.__call__(self.domain.lower_bound),
self.__call__(self.domain.upper_bound),
)
else:
raise NotImplementedError
def __call__(self, x: NumLike) -> NumLike:
# Note: consider to clamp from below for stability if necessary
return jnp.exp(x)
def _inverse(self, y: NumLike) -> NumLike:
return jnp.log(y)
[docs]
def log_abs_det_jacobian(
self, x: NumLike, y: NumLike, intermediates: Optional[PyTree] = None
) -> NumLike:
return x
[docs]
def eq(self, other: object, static: bool = False) -> ArrayLike:
if not isinstance(other, ExpTransform):
return False
return self.domain.eq(other.domain, static=static)
[docs]
class IdentityTransform(ParameterFreeTransform[NumLike]):
sign = 1
def __call__(self, x: NumLike) -> NumLike:
return x
def _inverse(self, y: NumLike) -> NumLike:
return y
[docs]
def log_abs_det_jacobian(
self, x: NumLike, y: NumLike, intermediates: Optional[PyTree] = None
) -> NumLike:
return jnp.zeros_like(x)
class IndependentTransform(Transform[NumLike]):
"""
Wraps a transform by aggregating over ``reinterpreted_batch_ndims``-many
dims in :meth:`check`, so that an event is valid only if all its
independent entries are valid.
"""
def __init__(
self, base_transform: Transform, reinterpreted_batch_ndims: int
) -> None:
assert isinstance(base_transform, Transform)
assert isinstance(reinterpreted_batch_ndims, int)
assert reinterpreted_batch_ndims >= 0
self.base_transform = base_transform
self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
super().__init__()
@property
def domain(self) -> Constraint:
return constraints.independent(
self.base_transform.domain, self.reinterpreted_batch_ndims
)
@property
def codomain(self) -> Constraint:
return constraints.independent(
self.base_transform.codomain, self.reinterpreted_batch_ndims
)
def __call__(self, x: NumLike) -> NumLike:
return self.base_transform(x)
def _inverse(self, y: NumLike) -> NumLike:
return self.base_transform._inverse(y)
def log_abs_det_jacobian(
self, x: NumLike, y: NumLike, intermediates: Optional[PyTree] = None
) -> NumLike:
result = self.base_transform.log_abs_det_jacobian(
x, y, intermediates=intermediates
)
if jnp.ndim(result) < self.reinterpreted_batch_ndims:
expected = self.domain.event_dim
raise ValueError(f"Expected x.dim() >= {expected} but got {jnp.ndim(x)}")
return sum_rightmost(result, self.reinterpreted_batch_ndims)
def call_with_intermediates(self, x: NumLike) -> Tuple[NumLike, Optional[PyTree]]:
return self.base_transform.call_with_intermediates(x)
def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
return self.base_transform.forward_shape(shape)
def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
return self.base_transform.inverse_shape(shape)
def tree_flatten(self):
return (self.base_transform,), (
("base_transform",),
{"reinterpreted_batch_ndims": self.reinterpreted_batch_ndims},
)
def eq(self, other: object, static: bool = False) -> ArrayLike:
if not isinstance(other, IndependentTransform):
return False
if self.reinterpreted_batch_ndims != other.reinterpreted_batch_ndims:
return False
return self.base_transform.eq(other.base_transform, static=static)
[docs]
class L1BallTransform(ParameterFreeTransform[NonScalarArray]):
r"""
Transforms an unconstrained real vector :math:`x` into the unit L1 ball.
"""
domain = constraints.real_vector
codomain = constraints.l1_ball
def __call__(self, x: NonScalarArray) -> NonScalarArray:
# transform to (-1, 1) interval
t = jnp.tanh(x)
# apply stick-breaking transform
remainder = jnp.cumprod(1 - jnp.abs(t[..., :-1]), axis=-1)
pad_width = [(0, 0)] * (t.ndim - 1) + [(1, 0)]
remainder = jnp.pad(remainder, pad_width, mode="constant", constant_values=1.0)
return t * remainder
def _inverse(self, y: NonScalarArray) -> NonScalarArray:
# inverse stick-breaking
remainder = 1 - jnp.cumsum(jnp.abs(y[..., :-1]), axis=-1)
pad_width = [(0, 0)] * (y.ndim - 1) + [(1, 0)]
remainder = jnp.pad(remainder, pad_width, mode="constant", constant_values=1.0)
finfo = jnp.finfo(y.dtype)
remainder = jnp.clip(remainder, finfo.tiny)
t = y / remainder
# inverse of tanh
t = jnp.clip(t, -1 + finfo.eps, 1 - finfo.eps)
return jnp.arctanh(t)
[docs]
def log_abs_det_jacobian(
self,
x: NonScalarArray,
y: NonScalarArray,
intermediates: Optional[PyTree] = None,
) -> NumLike:
# compute stick-breaking logdet
# t1 -> t1
# t2 -> t2 * (1 - abs(t1))
# t3 -> t3 * (1 - abs(t1)) * (1 - abs(t2))
# hence jacobian is triangular and logdet is the sum of the log
# of the diagonal part of the jacobian
one_minus_remainder = jnp.cumsum(jnp.abs(y[..., :-1]), axis=-1)
eps = jnp.finfo(y.dtype).eps
one_minus_remainder = jnp.clip(one_minus_remainder, None, 1 - eps)
# log(remainder) = log1p(remainder - 1)
stick_breaking_logdet = jnp.sum(jnp.log1p(-one_minus_remainder), axis=-1)
tanh_logdet = -2 * jnp.sum(x + softplus(-2 * x) - jnp.log(2.0), axis=-1)
return stick_breaking_logdet + tanh_logdet
[docs]
class LowerCholeskyAffine(Transform[NonScalarArray]):
r"""
Transform via the mapping :math:`y = loc + scale\_tril\ @\ x`.
:param loc: a real vector.
:param scale_tril: a lower triangular matrix with positive diagonal.
**Example**
.. doctest::
>>> import jax.numpy as jnp
>>> from numpyro.distributions.transforms import LowerCholeskyAffine
>>> base = jnp.ones(2)
>>> loc = jnp.zeros(2)
>>> scale_tril = jnp.array([[0.3, 0.0], [1.0, 0.5]])
>>> affine = LowerCholeskyAffine(loc=loc, scale_tril=scale_tril)
>>> affine(base)
Array([0.3, 1.5], dtype=float32)
"""
domain = constraints.real_vector
codomain = constraints.real_vector
def __init__(self, loc: NonScalarArray, scale_tril: NonScalarArray):
if jnp.ndim(scale_tril) != 2:
raise ValueError(
"Only support 2-dimensional scale_tril matrix. "
"Please make a feature request if you need to "
"use this transform with batched scale_tril."
)
self.loc = loc
self.scale_tril = scale_tril
def __call__(self, x: NonScalarArray) -> NonScalarArray:
return self.loc + jnp.squeeze(
jnp.matmul(self.scale_tril, x[..., jnp.newaxis]), axis=-1
)
def _inverse(self, y: NonScalarArray) -> NonScalarArray:
y = y - self.loc
original_shape = jnp.shape(y)
yt = jnp.reshape(y, (-1, original_shape[-1])).T
xt = solve_triangular(self.scale_tril, yt, lower=True)
return jnp.reshape(xt.T, original_shape)
[docs]
def log_abs_det_jacobian(
self,
x: NonScalarArray,
y: NonScalarArray,
intermediates: Optional[PyTree] = None,
) -> NumLike:
return jnp.broadcast_to(
jnp.log(jnp.diagonal(self.scale_tril, axis1=-2, axis2=-1)).sum(-1),
jnp.shape(x)[:-1],
)
[docs]
def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
if len(shape) < 1:
raise ValueError("Too few dimensions on input")
return lax.broadcast_shapes(shape, self.loc.shape, self.scale_tril.shape[:-1])
[docs]
def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
if len(shape) < 1:
raise ValueError("Too few dimensions on input")
return lax.broadcast_shapes(shape, self.loc.shape, self.scale_tril.shape[:-1])
[docs]
def eq(self, other: object, static: bool = False) -> ArrayLike:
if not isinstance(other, LowerCholeskyAffine):
return False
return array_equiv(self.loc, other.loc, static=static) & array_equiv(
self.scale_tril, other.scale_tril, static=static
)
[docs]
class LowerCholeskyTransform(ParameterFreeTransform[NonScalarArray]):
"""
Transform a real vector to a lower triangular cholesky
factor, where the strictly lower triangular submatrix is
unconstrained and the diagonal is parameterized with an
exponential transform.
"""
domain = constraints.real_vector
codomain = constraints.lower_cholesky
def __call__(self, x: NonScalarArray) -> NonScalarArray:
n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2)
z = vec_to_tril_matrix(x[..., :-n], diagonal=-1)
diag = jnp.exp(x[..., -n:])
return add_diag(z, diag)
def _inverse(self, y: NonScalarArray) -> NonScalarArray:
z = matrix_to_tril_vec(y, diagonal=-1)
return jnp.concatenate(
[z, jnp.log(jnp.diagonal(y, axis1=-2, axis2=-1))], axis=-1
)
[docs]
def log_abs_det_jacobian(
self,
x: NonScalarArray,
y: NonScalarArray,
intermediates: Optional[PyTree] = None,
) -> NumLike:
# the jacobian is diagonal, so logdet is the sum of diagonal `exp` transform
n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2)
return x[..., -n:].sum(-1)
[docs]
def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
return _matrix_forward_shape(shape)
[docs]
def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
return _matrix_inverse_shape(shape)
[docs]
class ScaledUnitLowerCholeskyTransform(LowerCholeskyTransform):
r"""
Like `LowerCholeskyTransform` this `Transform` transforms
a real vector to a lower triangular cholesky factor. However
it does so via a decomposition
:math:`y = loc + unit\_scale\_tril\ @\ scale\_diag\ @\ x`.
where :math:`unit\_scale\_tril` has ones along the diagonal
and :math:`scale\_diag` is a diagonal matrix with all positive
entries that is parameterized with a softplus transform.
"""
domain = constraints.real_vector
codomain = constraints.scaled_unit_lower_cholesky
def __call__(self, x: NonScalarArray) -> NonScalarArray:
n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2)
z = vec_to_tril_matrix(x[..., :-n], diagonal=-1)
diag = softplus(x[..., -n:])
return add_diag(z, jnp.array(1)) * diag[..., None]
def _inverse(self, y: NonScalarArray) -> NonScalarArray:
diag = jnp.diagonal(y, axis1=-2, axis2=-1)
z = matrix_to_tril_vec(y / diag[..., None], diagonal=-1)
return jnp.concatenate([z, _softplus_inv(diag)], axis=-1)
[docs]
def log_abs_det_jacobian(
self,
x: NonScalarArray,
y: NonScalarArray,
intermediates: Optional[PyTree] = None,
) -> NumLike:
n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2)
diag = x[..., -n:]
diag_softplus = jnp.diagonal(y, axis1=-2, axis2=-1)
return (jnp.log(diag_softplus) * jnp.arange(n) - softplus(-diag)).sum(-1)
[docs]
class OrderedTransform(ParameterFreeTransform[NonScalarArray]):
"""
Transform a real vector to an ordered vector.
**References:**
1. *Stan Reference Manual v2.20, section 10.6*,
Stan Development Team
**Example**
.. doctest::
>>> import jax.numpy as jnp
>>> from numpyro.distributions.transforms import OrderedTransform
>>> base = jnp.ones(3)
>>> transform = OrderedTransform()
>>> assert jnp.allclose(transform(base), jnp.array([1., 3.7182817, 6.4365635]), rtol=1e-3, atol=1e-3)
"""
domain = constraints.real_vector
codomain = constraints.ordered_vector
def __call__(self, x: NonScalarArray) -> NonScalarArray:
z = jnp.concatenate([x[..., :1], jnp.exp(x[..., 1:])], axis=-1)
return jnp.cumsum(z, axis=-1)
def _inverse(self, y: NonScalarArray) -> NonScalarArray:
x = jnp.log(y[..., 1:] - y[..., :-1])
return jnp.concatenate([y[..., :1], x], axis=-1)
[docs]
def log_abs_det_jacobian(
self,
x: NonScalarArray,
y: NonScalarArray,
intermediates: Optional[PyTree] = None,
) -> NumLike:
return jnp.sum(x[..., 1:], -1)
[docs]
class PermuteTransform(Transform[NonScalarArray]):
domain = constraints.real_vector
codomain = constraints.real_vector
def __init__(self, permutation: Array) -> None:
self.permutation = permutation
def __call__(self, x: NonScalarArray) -> NonScalarArray:
return x[..., self.permutation]
def _inverse(self, y: NonScalarArray) -> NonScalarArray:
size = self.permutation.size
permutation_inv: NonScalarArray = (
jnp.zeros(size, dtype=jnp.result_type(int))
.at[self.permutation]
.set(jnp.arange(size))
)
return y[..., permutation_inv]
[docs]
def log_abs_det_jacobian(
self,
x: NonScalarArray,
y: NonScalarArray,
intermediates: Optional[PyTree] = None,
) -> NumLike:
return jnp.full(jnp.shape(x)[:-1], 0.0)
[docs]
def eq(self, other: object, static: bool = False) -> ArrayLike:
if not isinstance(other, PermuteTransform):
return False
return array_equiv(self.permutation, other.permutation, static=static)
[docs]
class PowerTransform(Transform[NumLike]):
domain = constraints.positive
codomain = constraints.positive
def __init__(self, exponent: NumLike) -> None:
self.exponent = exponent
def __call__(self, x: NumLike) -> NumLike:
return jnp.power(x, self.exponent)
def _inverse(self, y: NumLike) -> NumLike:
return jnp.power(y, 1 / self.exponent)
[docs]
def log_abs_det_jacobian(
self, x: NumLike, y: NumLike, intermediates: Optional[PyTree] = None
) -> NumLike:
return jnp.log(jnp.abs(jnp.multiply(self.exponent, y) / x))
[docs]
def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
return lax.broadcast_shapes(shape, getattr(self.exponent, "shape", ()))
[docs]
def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
return lax.broadcast_shapes(shape, getattr(self.exponent, "shape", ()))
[docs]
def eq(self, other: object, static: bool = False) -> ArrayLike:
if not isinstance(other, PowerTransform):
return False
return array_equiv(self.exponent, other.exponent, static=static)
@property
def sign(self) -> NumLike:
return jnp.sign(self.exponent)
[docs]
class SigmoidTransform(ParameterFreeTransform[NumLike]):
codomain = constraints.unit_interval
sign = 1
def __call__(self, x: NumLike) -> NumLike:
return _clipped_expit(x)
def _inverse(self, y: NumLike) -> NumLike:
return logit(y)
[docs]
def log_abs_det_jacobian(
self, x: NumLike, y: NumLike, intermediates: Optional[PyTree] = None
) -> NumLike:
return -softplus(x) - softplus(-x)
[docs]
class SimplexToOrderedTransform(Transform[NonScalarArray]):
"""
Transform a simplex into an ordered vector (via difference in Logistic CDF between cutpoints)
Used in [1] to induce a prior on latent cutpoints via transforming ordered category probabilities.
:param anchor_point: Anchor point is a nuisance parameter to improve the identifiability of the transform.
For simplicity, we assume it is a scalar value, but it is broadcastable x.shape[:-1].
For more details please refer to Section 2.2 in [1]
**References:**
1. *Ordinal Regression Case Study, section 2.2*,
M. Betancourt, https://betanalpha.github.io/assets/case_studies/ordinal_regression.html
**Example**
.. doctest::
>>> import jax.numpy as jnp
>>> from numpyro.distributions.transforms import SimplexToOrderedTransform
>>> base = jnp.array([0.3, 0.1, 0.4, 0.2])
>>> transform = SimplexToOrderedTransform()
>>> assert jnp.allclose(transform(base), jnp.array([-0.8472978, -0.40546507, 1.3862944]), rtol=1e-3, atol=1e-3)
"""
domain = constraints.simplex
codomain = constraints.ordered_vector
def __init__(self, anchor_point: ArrayLike = 0.0) -> None:
self.anchor_point = anchor_point
def __call__(self, x: NonScalarArray) -> NonScalarArray:
s = jnp.cumsum(x[..., :-1], axis=-1)
y = logit(s) + jnp.expand_dims(self.anchor_point, -1)
return y
def _inverse(self, y: NonScalarArray) -> NonScalarArray:
y = y - jnp.expand_dims(self.anchor_point, -1)
s = expit(y)
# x0 = s0, x1 = s1 - s0, x2 = s2 - s1,..., xn = 1 - s[n-1]
# add two boundary points 0 and 1
pad_width = [(0, 0)] * (jnp.ndim(s) - 1) + [(1, 1)]
s = jnp.pad(s, pad_width, constant_values=(0, 1))
x = s[..., 1:] - s[..., :-1]
return x
[docs]
def log_abs_det_jacobian(
self,
x: NonScalarArray,
y: NonScalarArray,
intermediates: Optional[PyTree] = None,
) -> NumLike:
# |dp/dc| = |dx/dy| = prod(ds/dy) = prod(expit'(y))
# we know log derivative of expit(y) is `-softplus(y) - softplus(-y)`
J_logdet = (softplus(y) + softplus(-y)).sum(-1)
return J_logdet
[docs]
def eq(self, other: object, static: bool = False) -> ArrayLike:
if not isinstance(other, SimplexToOrderedTransform):
return False
return array_equiv(self.anchor_point, other.anchor_point, static=static)
[docs]
def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
return shape[:-1] + (shape[-1] - 1,)
[docs]
def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
return shape[:-1] + (shape[-1] + 1,)
def _softplus_inv(y: NumLike) -> NumLike:
return jnp.log(-jnp.expm1(-y)) + y
[docs]
class SoftplusTransform(ParameterFreeTransform[NumLike]):
r"""
Transform from unconstrained space to positive domain via softplus :math:`y = \log(1 + \exp(x))`.
The inverse is computed as :math:`x = \log(\exp(y) - 1)`.
"""
domain = constraints.real
codomain = constraints.softplus_positive
sign = 1
def __call__(self, x: NumLike) -> NumLike:
return softplus(x)
def _inverse(self, y: NumLike) -> NumLike:
return _softplus_inv(y)
[docs]
def log_abs_det_jacobian(
self, x: NumLike, y: NumLike, intermediates: Optional[PyTree] = None
) -> NumLike:
return -softplus(-x)
[docs]
class SoftplusLowerCholeskyTransform(ParameterFreeTransform[NonScalarArray]):
"""
Transform from unconstrained vector to lower-triangular matrices with
nonnegative diagonal entries. This is useful for parameterizing positive
definite matrices in terms of their Cholesky factorization.
"""
domain = constraints.real_vector
codomain = constraints.softplus_lower_cholesky
def __call__(self, x: NonScalarArray) -> NonScalarArray:
n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2)
z = vec_to_tril_matrix(x[..., :-n], diagonal=-1)
diag = softplus(x[..., -n:])
return z + jnp.expand_dims(diag, axis=-1) * jnp.identity(n)
def _inverse(self, y: NonScalarArray) -> NonScalarArray:
z = matrix_to_tril_vec(y, diagonal=-1)
diag = _softplus_inv(jnp.diagonal(y, axis1=-2, axis2=-1))
return jnp.concatenate([z, diag], axis=-1)
[docs]
def log_abs_det_jacobian(
self,
x: NonScalarArray,
y: NonScalarArray,
intermediates: Optional[PyTree] = None,
) -> NumLike:
# the jacobian is diagonal, so logdet is the sum of diagonal
# `softplus` transform
n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2)
return -softplus(-x[..., -n:]).sum(-1)
[docs]
def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
return _matrix_forward_shape(shape)
[docs]
def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
return _matrix_inverse_shape(shape)
[docs]
class StickBreakingTransform(ParameterFreeTransform[NonScalarArray]):
domain = constraints.real_vector
codomain = constraints.simplex
def __call__(self, x: NonScalarArray) -> NonScalarArray:
# we shift x to obtain a balanced mapping (0, 0, ..., 0) -> (1/K, 1/K, ..., 1/K)
x = x - jnp.log(x.shape[-1] - jnp.arange(x.shape[-1]))
# convert to probabilities (relative to the remaining) of each fraction of the stick
z = _clipped_expit(x)
z1m_cumprod = jnp.cumprod(1 - z, axis=-1)
pad_width = [(0, 0)] * x.ndim
pad_width[-1] = (0, 1)
z_padded = jnp.pad(z, pad_width, mode="constant", constant_values=1.0)
pad_width = [(0, 0)] * x.ndim
pad_width[-1] = (1, 0)
z1m_cumprod_shifted = jnp.pad(
z1m_cumprod, pad_width, mode="constant", constant_values=1.0
)
return z_padded * z1m_cumprod_shifted
def _inverse(self, y: NonScalarArray) -> NonScalarArray:
y_crop = y[..., :-1]
z1m_cumprod = jnp.clip(1 - jnp.cumsum(y_crop, axis=-1), jnp.finfo(y.dtype).tiny)
# hence x = logit(z) = log(z / (1 - z)) = y[::-1] / z1m_cumprod
x = jnp.log(y_crop / z1m_cumprod)
return x + jnp.log(x.shape[-1] - jnp.arange(x.shape[-1]))
[docs]
def log_abs_det_jacobian(
self,
x: NonScalarArray,
y: NonScalarArray,
intermediates: Optional[PyTree] = None,
) -> NumLike:
# Ref: https://mc-stan.org/docs/2_19/reference-manual/simplex-transform-section.html
# |det|(J) = Product(y * (1 - sigmoid(x)))
# = Product(y * sigmoid(x) * exp(-x))
x = x - jnp.log(x.shape[-1] - jnp.arange(x.shape[-1]))
return jnp.sum(jnp.log(y[..., :-1]) + (log_sigmoid(x) - x), axis=-1)
[docs]
def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
if len(shape) < 1:
raise ValueError("Too few dimensions on input")
return shape[:-1] + (shape[-1] + 1,)
[docs]
def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
if len(shape) < 1:
raise ValueError("Too few dimensions on input")
return shape[:-1] + (shape[-1] - 1,)
class UnpackTransform(Transform[NonScalarArray]):
"""
Transforms a contiguous array to a pytree of subarrays.
:param unpack_fn: callable used to unpack a contiguous array.
:param pack_fn: callable used to pack a pytree into a contiguous array.
"""
domain = constraints.real_vector
codomain = constraints.dependent
def __init__(self, unpack_fn, pack_fn=None):
self.unpack_fn = unpack_fn
self.pack_fn = pack_fn
def __call__(self, x: NonScalarArray) -> NonScalarArray:
batch_shape = x.shape[:-1]
if batch_shape:
unpacked = vmap(self.unpack_fn)(x.reshape((-1,) + x.shape[-1:]))
return jax.tree.map(
lambda z: jnp.reshape(z, batch_shape + z.shape[1:]), unpacked
)
else:
return self.unpack_fn(x)
def _inverse(self, y: NonScalarArray) -> NonScalarArray:
if self.pack_fn is None:
raise NotImplementedError(
"pack_fn needs to be provided to perform UnpackTransform.inv."
)
leading_dims = [
v.shape[0] if jnp.ndim(v) > 0 else 0 for v in jax.tree.flatten(y)[0]
]
if not leading_dims:
return jnp.array([])
d0 = leading_dims[0]
not_scalar = d0 > 0 or len(leading_dims) > 1
if not_scalar and all(d == d0 for d in leading_dims[1:]):
warnings.warn(
"UnpackTransform.inv might lead to an unexpected behavior because it"
" cannot transform a batch of unpacked arrays.",
stacklevel=find_stack_level(),
)
return self.pack_fn(y)
def log_abs_det_jacobian(
self,
x: NonScalarArray,
y: NonScalarArray,
intermediates: Optional[PyTree] = None,
) -> NumLike:
return jnp.zeros(jnp.shape(x)[:-1])
def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
raise NotImplementedError
def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
raise NotImplementedError
def tree_flatten(self):
# Note: what if unpack_fn is a parametrized callable pytree?
return (), ((), {"unpack_fn": self.unpack_fn, "pack_fn": self.pack_fn})
def eq(self, other: object, static: bool = False) -> ArrayLike:
return (
isinstance(other, UnpackTransform)
and (self.unpack_fn is other.unpack_fn)
and (self.pack_fn is other.pack_fn)
)
def _get_target_shape(
shape: tuple[int, ...],
forward_shape: tuple[int, ...],
inverse_shape: tuple[int, ...],
) -> tuple[int, ...]:
batch_ndims = len(shape) - len(inverse_shape)
return shape[:batch_ndims] + forward_shape
class ReshapeTransform(Transform[NonScalarArray]):
"""
Reshape a sample, leaving batch dimensions unchanged.
:param forward_shape: Shape to transform the sample to.
:param inverse_shape: Shape of the sample for the inverse transform.
"""
def __init__(
self,
forward_shape: tuple[int, ...],
inverse_shape: tuple[int, ...],
) -> None:
forward_size = math.prod(forward_shape)
inverse_size = math.prod(inverse_shape)
if forward_size != inverse_size:
raise ValueError(
f"forward shape {forward_shape} (size {forward_size}) and inverse "
f"shape {inverse_shape} (size {inverse_size}) are not compatible"
)
self._forward_shape = forward_shape
self._inverse_shape = inverse_shape
@property
def domain(self) -> Constraint:
return constraints.independent(constraints.real, len(self._inverse_shape))
@property
def codomain(self) -> Constraint:
return constraints.independent(constraints.real, len(self._forward_shape))
def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
return _get_target_shape(shape, self._forward_shape, self._inverse_shape)
def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
return _get_target_shape(shape, self._inverse_shape, self._forward_shape)
def __call__(self, x: NonScalarArray) -> NonScalarArray:
return jnp.reshape(x, self.forward_shape(jnp.shape(x)))
def _inverse(self, y: NonScalarArray) -> NonScalarArray:
return jnp.reshape(y, self.inverse_shape(jnp.shape(y)))
def log_abs_det_jacobian(
self,
x: NonScalarArray,
y: NonScalarArray,
intermediates: Optional[PyTree] = None,
) -> NumLike:
return jnp.zeros_like(x, shape=x.shape[: x.ndim - len(self._inverse_shape)])
def tree_flatten(self):
aux_data = {
"_forward_shape": self._forward_shape,
"_inverse_shape": self._inverse_shape,
}
return (), ((), aux_data)
def eq(self, other: object, static: bool = False) -> ArrayLike:
return (
isinstance(other, ReshapeTransform)
and self._forward_shape == other._forward_shape
and self._inverse_shape == other._inverse_shape
)
def _normalize_rfft_shape(
input_shape: tuple[int, ...],
shape: Optional[tuple[int, ...]],
) -> tuple[int, ...]:
if shape is None:
return input_shape
return input_shape[: len(input_shape) - len(shape)] + shape
[docs]
class RealFastFourierTransform(Transform[NonScalarArray]):
"""
N-dimensional discrete fast Fourier transform for real input.
:param transform_shape: Length of each transformed axis to use from the input,
defaults to the input size.
:param transform_ndims: Number of trailing dimensions to transform.
"""
def __init__(
self,
transform_shape: Optional[tuple[int, ...]] = None,
transform_ndims: int = 1,
) -> None:
if isinstance(transform_shape, int):
transform_shape = (transform_shape,)
if transform_shape is not None and len(transform_shape) != transform_ndims:
raise ValueError(
f"Length of transform shape ({transform_shape}) does not match number "
f"of dimensions to transform ({transform_ndims})."
)
self.transform_shape = transform_shape
self.transform_ndims = transform_ndims
def __call__(self, x: NonScalarArray) -> NonScalarArray:
axes = tuple(range(-self.transform_ndims, 0))
return jnp.fft.rfftn(x, self.transform_shape, axes)
def _inverse(self, y: NonScalarArray) -> NonScalarArray:
axes = tuple(range(-self.transform_ndims, 0))
return jnp.fft.irfftn(y, self.transform_shape, axes)
[docs]
def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
# Dimensions remain unchanged except the last transformed dimension.
shape = _normalize_rfft_shape(shape, self.transform_shape)
return shape[:-1] + (shape[-1] // 2 + 1,)
[docs]
def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
if self.transform_shape:
return _normalize_rfft_shape(shape, self.transform_shape)
size = 2 * (shape[-1] - 1)
return shape[:-1] + (size,)
[docs]
def log_abs_det_jacobian(
self,
x: NonScalarArray,
y: NonScalarArray,
intermediates: Optional[PyTree] = None,
) -> NumLike:
batch_shape = jnp.broadcast_shapes(
x.shape[: -self.transform_ndims], y.shape[: -self.transform_ndims]
)
event_shape = x.shape[-self.transform_ndims :]
size = math.prod(event_shape)
q = math.prod(2 - size % 2 for size in event_shape)
return jnp.broadcast_to(
(size * jnp.log(size) - jnp.log(2) * (size - q)) / 2, batch_shape
)
[docs]
def tree_flatten(self):
aux_data = {
"transform_shape": self.transform_shape,
"transform_ndims": self.transform_ndims,
}
return (), ((), aux_data)
@property
def domain(self) -> Constraint:
return constraints.independent(constraints.real, self.transform_ndims)
@property
def codomain(self) -> Constraint:
return constraints.independent(constraints.complex, self.transform_ndims)
[docs]
def eq(self, other: object, static: bool = False) -> ArrayLike:
return (
isinstance(other, RealFastFourierTransform)
and self.transform_ndims == other.transform_ndims
and self.transform_shape == other.transform_shape
)
[docs]
class PackRealFastFourierCoefficientsTransform(Transform[NonScalarArray]):
"""
Transform a real vector to complex coefficients of a real fast Fourier transform.
:param transform_shape: Shape of the real vector, defaults to the input size.
"""
domain = constraints.real_vector
codomain = constraints.independent(constraints.complex, 1)
def __init__(self, transform_shape: Optional[tuple[int, ...]] = None) -> None:
assert transform_shape is None or len(transform_shape) == 1, (
"Packing Fourier coefficients is only implemented for vectors."
)
self.shape: Optional[tuple[int, ...]] = transform_shape
[docs]
def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
*batch_shape, n = shape
assert self.shape is None or self.shape == (n,), (
f"`shape` must be `None` or `{self.shape}. Got `{shape}`."
)
n_rfft = n // 2 + 1
return (*batch_shape, n_rfft)
[docs]
def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
*batch_shape, n_rfft = shape
assert self.shape is not None, (
"Shape must be specified in `__init__` for inverse transform."
)
(n,) = self.shape
assert n_rfft == n // 2 + 1
return (*batch_shape, n)
[docs]
def log_abs_det_jacobian(
self,
x: NonScalarArray,
y: NonScalarArray,
intermediates: Optional[PyTree] = None,
) -> Array:
shape = jnp.broadcast_shapes(x.shape[:-1], y.shape[:-1])
return jnp.zeros_like(x, shape=shape)
def __call__(self, x: NonScalarArray) -> NonScalarArray:
assert self.shape is None or self.shape == x.shape[-1:]
n = x.shape[-1]
n_real = n // 2 + 1
n_imag = n - n_real
complex_dtype = jnp.result_type(x.dtype, jnp.complex64)
return (
jnp.asarray(x)[..., :n_real]
.astype(complex_dtype)
.at[..., 1 : 1 + n_imag]
.add(1j * x[..., n_real:])
)
def _inverse(self, y: NonScalarArray) -> NonScalarArray:
assert self.shape is not None, (
"Shape must be specified in `__init__` for inverse transform."
)
(n,) = self.shape
n_real = n // 2 + 1
n_imag = n - n_real
return jnp.concatenate([y.real, y.imag[..., 1 : n_imag + 1]], axis=-1)
[docs]
def eq(self, other: object, static: bool = False) -> ArrayLike:
return (
isinstance(other, PackRealFastFourierCoefficientsTransform)
and self.shape == other.shape
)
[docs]
class RecursiveLinearTransform(Transform[NonScalarArray]):
"""
Apply a linear transformation recursively such that
:math:`y_t = A y_{t - 1} + x_t` for :math:`t > 0`, where :math:`x_t` and :math:`y_t`
are vectors and :math:`A` is a square transition matrix. The series is initialized
by :math:`y_0 = 0`.
:param transition_matrix: Square transition matrix :math:`A` for successive states
or a batch of transition matrices.
**Example:**
.. doctest::
>>> from jax import random
>>> from jax import numpy as jnp
>>> import numpyro
>>> from numpyro import distributions as dist
>>>
>>> def cauchy_random_walk():
... return numpyro.sample(
... "x",
... dist.TransformedDistribution(
... dist.Cauchy(0, 1).expand([10, 1]).to_event(1),
... dist.transforms.RecursiveLinearTransform(jnp.eye(1)),
... ),
... )
>>>
>>> numpyro.handlers.seed(cauchy_random_walk, 0)().shape
(10, 1)
>>>
>>> def rocket_trajectory():
... scale = numpyro.sample(
... "scale",
... dist.HalfCauchy(1).expand([2]).to_event(1),
... )
... transition_matrix = jnp.array([[1, 1], [0, 1]])
... return numpyro.sample(
... "x",
... dist.TransformedDistribution(
... dist.Normal(0, scale).expand([10, 2]).to_event(1),
... dist.transforms.RecursiveLinearTransform(transition_matrix),
... ),
... )
>>>
>>> numpyro.handlers.seed(rocket_trajectory, 0)().shape
(10, 2)
"""
domain = constraints.real_matrix
codomain = constraints.real_matrix
def __init__(
self,
transition_matrix: NonScalarArray,
initial_value: Optional[NonScalarArray] = None,
) -> None:
event_shape = transition_matrix.shape[-1:]
if initial_value is None:
initial_value = np.zeros(event_shape)
assert event_shape == initial_value.shape[-1:], (
f"Event shape of initial value must be the same as transition matrix, got {event_shape} and"
f" {initial_value.shape[-1:]}."
)
self.initial_value = initial_value
self.transition_matrix = transition_matrix
def _get_initial_value(self, sample_shape) -> Array:
iv_batch_shape, event_shape = (
self.initial_value.shape[:-1],
self.initial_value.shape[-1:],
)
transition_batch_shape = self.transition_matrix.shape[:-2]
batch_shape = jnp.broadcast_shapes(
sample_shape, transition_batch_shape, iv_batch_shape
)
return jnp.broadcast_to(self.initial_value, batch_shape + event_shape)
def __call__(self, x: NonScalarArray) -> NonScalarArray:
# Move the time axis to the first position so we can scan over it.
sample_shape = x.shape[:-2]
x = jnp.moveaxis(x, -2, 0)
def f(y, x):
y = jnp.einsum("...ij,...j->...i", self.transition_matrix, y) + x
return y, y
initial_value = self._get_initial_value(sample_shape)
_, y = lax.scan(f, initial_value, x)
return jnp.moveaxis(y, 0, -2)
def _inverse(self, y: NonScalarArray) -> NonScalarArray:
# Move the time axis to the first position so we can scan over it in reverse.
sample_shape = y.shape[:-2]
y = jnp.moveaxis(y, -2, 0)
def f(y, prev):
x = y - jnp.einsum("...ij,...j->...i", self.transition_matrix, prev)
return prev, x
initial_value = self._get_initial_value(sample_shape)
_, x = lax.scan(
f, y[-1], jnp.roll(y, 1, axis=0).at[0].set(initial_value), reverse=True
)
return jnp.moveaxis(x, 0, -2)
[docs]
def log_abs_det_jacobian(
self,
x: NonScalarArray,
y: NonScalarArray,
intermediates: Optional[PyTree] = None,
) -> NumLike:
return jnp.zeros_like(x, shape=x.shape[:-2])
[docs]
def tree_flatten(self):
return (self.transition_matrix, self.initial_value), (
("transition_matrix", "initial_value"),
{},
)
[docs]
def eq(self, other: object, static: bool = False) -> ArrayLike:
if not isinstance(other, RecursiveLinearTransform):
return False
tm_eq = array_equiv(
self.transition_matrix, other.transition_matrix, static=static
)
if self.initial_value is None and other.initial_value is None:
iv_eq = True
elif self.initial_value is None or other.initial_value is None:
iv_eq = False
else:
iv_eq = array_equiv(self.initial_value, other.initial_value, static=static)
return tm_eq & iv_eq
[docs]
class ZeroSumTransform(Transform[NonScalarArray]):
"""A transform that constrains an array to sum to zero, adapted from PyMC [1] as described in [2,3]
:param transform_ndims: Number of trailing dimensions to transform.
**References**
[1] https://github.com/pymc-devs/pymc/blob/244fb97b01ad0f3dadf5c3837b65839e2a59a0e8/pymc/distributions/transforms.py#L266
[2] https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.ZeroSumNormal.html
[3] https://learnbayesstats.com/episode/74-optimizing-nuts-developing-zerosumnormal-distribution-adrian-seyboldt/
"""
def __init__(self, transform_ndims: int = 1) -> None:
self.transform_ndims = transform_ndims
@property
def domain(self) -> Constraint:
return constraints.independent(constraints.real, self.transform_ndims)
@property
def codomain(self) -> Constraint:
return constraints.zero_sum(self.transform_ndims)
def __call__(self, x: NonScalarArray) -> NonScalarArray:
zero_sum_axes = tuple(range(-self.transform_ndims, 0))
for axis in zero_sum_axes:
x = self.extend_axis(x, axis=axis)
return x
def _inverse(self, y: NonScalarArray) -> NonScalarArray:
zero_sum_axes = tuple(range(-self.transform_ndims, 0))
for axis in zero_sum_axes:
y = self.extend_axis_rev(y, axis=axis)
return y
[docs]
def extend_axis_rev(self, array: NonScalarArray, axis: int) -> NonScalarArray:
normalized_axis = axis if axis >= 0 else jnp.ndim(array) + axis
n = array.shape[normalized_axis]
last = jnp.take(array, jnp.array([-1]), axis=normalized_axis)
sum_vals = -last * jnp.sqrt(n)
norm = sum_vals / (jnp.sqrt(n) + n)
slice_before = (slice(None, None),) * normalized_axis
return array[(*slice_before, slice(None, -1))] + norm
[docs]
def extend_axis(self, array: NonScalarArray, axis: int) -> NonScalarArray:
n = array.shape[axis] + 1
sum_vals = array.sum(axis, keepdims=True)
norm = sum_vals / (jnp.sqrt(n) + n)
fill_val = norm - sum_vals / jnp.sqrt(n)
out = jnp.concatenate([array, fill_val], axis=axis)
return out - norm
[docs]
def log_abs_det_jacobian(
self,
x: NonScalarArray,
y: NonScalarArray,
intermediates: Optional[PyTree] = None,
) -> jnp.ndarray:
shape = jnp.broadcast_shapes(
x.shape[: -self.transform_ndims], y.shape[: -self.transform_ndims]
)
return jnp.zeros_like(x, shape=shape)
[docs]
def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
return shape[: -self.transform_ndims] + tuple(
s + 1 for s in shape[-self.transform_ndims :]
)
[docs]
def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
return shape[: -self.transform_ndims] + tuple(
s - 1 for s in shape[-self.transform_ndims :]
)
[docs]
def tree_flatten(self):
aux_data = {"transform_ndims": self.transform_ndims}
return (), ((), aux_data)
[docs]
def eq(self, other: object, static: bool = False) -> ArrayLike:
return (
isinstance(other, ZeroSumTransform)
and self.transform_ndims == other.transform_ndims
)
[docs]
class ComplexTransform(ParameterFreeTransform[NonScalarArray]):
"""
Transforms a pair of real numbers to a complex number.
"""
domain = constraints.real_vector
codomain = constraints.complex
def __call__(self, x: NonScalarArray) -> NonScalarArray:
assert x.shape[-1] == 2, "Input must have a trailing dimension of size 2."
return lax.complex(x[..., 0], x[..., 1])
def _inverse(self, y: NonScalarArray) -> NonScalarArray:
return jnp.stack([y.real, y.imag], axis=-1)
[docs]
def log_abs_det_jacobian(
self,
x: NonScalarArray,
y: NonScalarArray,
intermediates: Optional[PyTree] = None,
) -> NonScalarArray:
return jnp.zeros_like(y)
[docs]
def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
assert shape[-1] == 2, "Input must have a trailing dimension of size 2."
return shape[:-1]
##########################################################
# CONSTRAINT_REGISTRY
##########################################################
class ConstraintRegistry(object):
def __init__(self):
self._registry = {}
def register(self, constraint, factory=None):
if factory is None:
return lambda factory: self.register(constraint, factory)
if isinstance(constraint, constraints.Constraint):
constraint = type(constraint)
self._registry[constraint] = factory
return factory
def __call__(self, constraint):
try:
factory = self._registry[type(constraint)]
except KeyError as e:
raise NotImplementedError from e
return factory(constraint)
biject_to = ConstraintRegistry()
@biject_to.register(constraints.corr_cholesky)
def _transform_to_corr_cholesky(constraint):
return CorrCholeskyTransform()
@biject_to.register(constraints.corr_matrix)
def _transform_to_corr_matrix(constraint):
return ComposeTransform(
[CorrCholeskyTransform(), CorrMatrixCholeskyTransform().inv]
)
@biject_to.register(type(constraints.positive))
@biject_to.register(type(constraints.nonnegative))
def _transform_to_positive(constraint):
return ExpTransform()
@biject_to.register(constraints.greater_than)
@biject_to.register(constraints.greater_than_eq)
def _transform_to_greater_than(constraint):
return ComposeTransform(
[
ExpTransform(),
AffineTransform(constraint.lower_bound, 1, domain=constraints.positive),
]
)
@biject_to.register(constraints.less_than)
@biject_to.register(constraints.less_than_eq)
def _transform_to_less_than(constraint):
return ComposeTransform(
[
ExpTransform(),
AffineTransform(constraint.upper_bound, -1, domain=constraints.positive),
]
)
@biject_to.register(type(constraints.real_matrix))
@biject_to.register(type(constraints.real_vector))
@biject_to.register(constraints.independent)
def _biject_to_independent(constraint):
return IndependentTransform(
biject_to(constraint.base_constraint), constraint.reinterpreted_batch_ndims
)
@biject_to.register(type(constraints.unit_interval))
def _transform_to_unit_interval(constraint):
return SigmoidTransform()
@biject_to.register(type(constraints.circular))
@biject_to.register(constraints.open_interval)
@biject_to.register(constraints.interval)
def _transform_to_interval(constraint):
scale = constraint.upper_bound - constraint.lower_bound
return ComposeTransform(
[
SigmoidTransform(),
AffineTransform(
constraint.lower_bound, scale, domain=constraints.unit_interval
),
]
)
@biject_to.register(constraints.l1_ball)
def _transform_to_l1_ball(constraint):
return L1BallTransform()
@biject_to.register(constraints.lower_cholesky)
def _transform_to_lower_cholesky(constraint):
return LowerCholeskyTransform()
@biject_to.register(constraints.scaled_unit_lower_cholesky)
def _transform_to_scaled_unit_lower_cholesky(constraint):
return ScaledUnitLowerCholeskyTransform()
@biject_to.register(constraints.ordered_vector)
def _transform_to_ordered_vector(constraint):
return OrderedTransform()
@biject_to.register(constraints.positive_definite)
@biject_to.register(constraints.positive_semidefinite)
def _transform_to_positive_definite(constraint):
return ComposeTransform([LowerCholeskyTransform(), CholeskyTransform().inv])
@biject_to.register(constraints.positive_ordered_vector)
def _transform_to_positive_ordered_vector(constraint):
return ComposeTransform([OrderedTransform(), ExpTransform()])
@biject_to.register(constraints.complex)
def _transform_to_complex(constraint):
return ComplexTransform()
@biject_to.register(constraints.real)
def _transform_to_real(constraint):
return IdentityTransform()
@biject_to.register(constraints.softplus_positive)
def _transform_to_softplus_positive(constraint):
return SoftplusTransform()
@biject_to.register(constraints.softplus_lower_cholesky)
def _transform_to_softplus_lower_cholesky(constraint):
return SoftplusLowerCholeskyTransform()
@biject_to.register(constraints.simplex)
def _transform_to_simplex(constraint):
return StickBreakingTransform()
@biject_to.register(constraints.zero_sum)
def _transform_to_zero_sum(constraint):
return ZeroSumTransform(constraint.event_dim)