Source code for numpyro.distributions.util

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0


from collections import namedtuple
import functools as ft
from functools import partial, update_wrapper
import math
import warnings

import numpy as np

import jax
from jax import jit, lax, random, vmap
import jax.numpy as jnp
from jax.scipy.linalg import solve_triangular
from jax.scipy.special import digamma
from jax.typing import ArrayLike

from numpyro.util import not_jax_tracer


def array_equiv(a: ArrayLike, b: ArrayLike, static: bool = False):
    """Check equality using identity first, then jnp.array_equiv for arrays.

    :param a: First array or value.
    :param b: Second array or value.
    :param static: If True, returns a Python bool. Returns False if either input
        is a JAX array (to avoid tracer issues). If False (default), returns
        an array-like result suitable for use with JAX tracers.
    """
    if a is b:
        return True
    if static:
        if isinstance(a, jax.Array) or isinstance(b, jax.Array):
            return False
        return bool(np.array_equiv(a, b))
    xp = jnp if isinstance(a, jax.Array) or isinstance(b, jax.Array) else np
    return xp.array_equiv(a, b)


# Parameters for Transformed Rejection with Squeeze (TRS) algorithm - page 3.
_tr_params = namedtuple(
    "tr_params", ["c", "b", "a", "alpha", "u_r", "v_r", "m", "log_p", "log1_p", "log_h"]
)


def _get_tr_params(n, p):
    # See Table 1. Additionally, we pre-compute log(p), log1(-p) and the
    # constant terms, that depend only on (n, p, m) in log(f(k)) (bottom of page 5).
    mu = n * p
    spq = jnp.sqrt(mu * (1 - p))
    c = mu + 0.5
    b = 1.15 + 2.53 * spq
    a = -0.0873 + 0.0248 * b + 0.01 * p
    alpha = (2.83 + 5.1 / b) * spq
    u_r = 0.43
    v_r = 0.92 - 4.2 / b
    m = jnp.floor((n + 1) * p).astype(n.dtype)
    log_p = jnp.log(p)
    log1_p = jnp.log1p(-p)
    log_h = (m + 0.5) * (jnp.log((m + 1.0) / (n - m + 1.0)) + log1_p - log_p) + (
        stirling_approx_tail(m) + stirling_approx_tail(n - m)
    )
    return _tr_params(c, b, a, alpha, u_r, v_r, m, log_p, log1_p, log_h)


def stirling_approx_tail(k):
    precomputed = jnp.array(
        [
            0.08106146679532726,
            0.04134069595540929,
            0.02767792568499834,
            0.02079067210376509,
            0.01664469118982119,
            0.01387612882307075,
            0.01189670994589177,
            0.01041126526197209,
            0.009255462182712733,
            0.008330563433362871,
        ]
    )
    kp1 = k + 1
    kp1sq = (k + 1) ** 2
    return jnp.where(
        k < 10,
        precomputed[k],
        (1.0 / 12 - (1.0 / 360 - (1.0 / 1260) / kp1sq) / kp1sq) / kp1,
    )


_binomial_mu_thresh = 10


def _binomial_btrs(key, p, n):
    """
    Based on the transformed rejection sampling algorithm (BTRS) from the
    following reference:

    Hormann, "The Generation of Binonmial Random Variates"
    (https://core.ac.uk/download/pdf/11007254.pdf)
    """

    def _btrs_body_fn(val):
        _, key, _, _ = val
        key, key_u, key_v = random.split(key, 3)
        u = random.uniform(key_u)
        v = random.uniform(key_v)
        u = u - 0.5
        k = jnp.floor(
            (2 * tr_params.a / (0.5 - jnp.abs(u)) + tr_params.b) * u + tr_params.c
        ).astype(n.dtype)
        return k, key, u, v

    def _btrs_cond_fn(val):
        def accept_fn(k, u, v):
            # See acceptance condition in Step 3. (Page 3) of TRS algorithm
            # v <= f(k) * g_grad(u) / alpha

            m = tr_params.m
            log_p = tr_params.log_p
            log1_p = tr_params.log1_p
            # See: formula for log(f(k)) at bottom of Page 5.
            log_f = (
                (n + 1.0) * jnp.log((n - m + 1.0) / (n - k + 1.0))
                + (k + 0.5) * (jnp.log((n - k + 1.0) / (k + 1.0)) + log_p - log1_p)
                + (stirling_approx_tail(k) - stirling_approx_tail(n - k))
                + tr_params.log_h
            )
            g = (tr_params.a / (0.5 - jnp.abs(u)) ** 2) + tr_params.b
            return jnp.log((v * tr_params.alpha) / g) <= log_f

        k, key, u, v = val
        early_accept = (jnp.abs(u) <= tr_params.u_r) & (v <= tr_params.v_r)
        early_reject = (k < 0) | (k > n)
        # when vmapped _binomial_dispatch will convert the cond condition into
        # an HLO select that will execute both branches. This is a workaround
        # that avoids the resulting infinite loop when p=0. This should also
        # improve performance in less catastrophic cases.
        cond_exclude_small_mu = p * n >= _binomial_mu_thresh
        cond_main = lax.cond(
            early_accept | early_reject,
            (),
            lambda _: ~early_accept,
            (k, u, v),
            lambda x: ~accept_fn(*x),
        )
        return cond_exclude_small_mu & cond_main

    tr_params = _get_tr_params(n, p)
    ret = lax.while_loop(
        _btrs_cond_fn, _btrs_body_fn, (-1, key, 1.0, 1.0)
    )  # use k=-1 initially so that cond_fn returns True
    return ret[0]


def _binomial_inversion(key, p, n):
    def _binom_inv_body_fn(val):
        i, key, geom_acc = val
        key, key_u = random.split(key)
        u = random.uniform(key_u)
        geom = jnp.floor(jnp.log1p(-u) / log1_p) + 1
        geom_acc = geom_acc + geom
        return i + 1, key, geom_acc

    def _binom_inv_cond_fn(val):
        i, _, geom_acc = val
        # see the note on cond_exclude_small_mu in _binomial_btrs
        # this cond_exclude_large_mu is unnecessary for correctness but will
        # still improve performance.
        cond_exclude_large_mu = p * n < _binomial_mu_thresh
        return cond_exclude_large_mu & (geom_acc <= n)

    log1_p = jnp.log1p(-p)
    # Make sure p=0 is never taken into account as a fix for possible zeros in p.
    log1_p = jnp.where(log1_p == 0, -jnp.finfo(log1_p.dtype).tiny, log1_p)
    ret = lax.while_loop(_binom_inv_cond_fn, _binom_inv_body_fn, (-1, key, 0.0))
    return ret[0]


def _binomial_dispatch(key, p, n):
    def dispatch(key, p, n):
        is_le_mid = p <= 0.5
        pq = jnp.where(is_le_mid, p, 1 - p)
        mu = n * pq
        k = lax.cond(
            mu < _binomial_mu_thresh,
            (key, pq, n),
            lambda x: _binomial_inversion(*x),
            (key, pq, n),
            lambda x: _binomial_btrs(*x),
        )
        return jnp.where(is_le_mid, k, n - k)

    # Return 0 for nan `p` or negative `n`, since nan values are not allowed for integer types
    cond0 = jnp.isfinite(p) & (n > 0) & (p > 0)
    return lax.cond(
        cond0 & (p < 1),
        (key, p, n),
        lambda x: dispatch(*x),
        (),
        lambda _: jnp.where(cond0, n, 0),
    )


@partial(jit, static_argnums=(3,))
def _binomial(key, p, n, shape):
    shape = shape or lax.broadcast_shapes(jnp.shape(p), jnp.shape(n))
    # reshape to map over axis 0
    p = jnp.reshape(jnp.broadcast_to(p, shape), -1)
    n = jnp.reshape(jnp.broadcast_to(n, shape), -1)
    key = random.split(key, jnp.size(p))
    if jax.default_backend() == "cpu":
        ret = lax.map(lambda x: _binomial_dispatch(*x), (key, p, n))
    else:
        ret = vmap(lambda *x: _binomial_dispatch(*x))(key, p, n)
    return jnp.reshape(ret, shape)


def binomial(key, p, n=1, shape=()):
    return _binomial(key, p, n, shape)


@partial(jit, static_argnums=(2,))
def _categorical(key, p, shape):
    # this implementation is fast when event shape is small, and slow otherwise
    # Ref: https://stackoverflow.com/a/34190035
    shape = shape or p.shape[:-1]
    s = jnp.cumsum(p, axis=-1)
    # Normalize s to deal with numerical issues.
    s = s[..., :-1] / s[..., -1:]
    r = random.uniform(key, shape=shape + (1,))
    # FIXME: replace this computation by using binary search as suggested in the above
    # reference. A while_loop + vmap for a reshaped 2D array would be enough.
    return jnp.sum(s < r, axis=-1)


def categorical(key, p, shape=()):
    return _categorical(key, p, shape)


def _scatter_add_one(operand, indices, updates):
    return lax.scatter_add(
        operand,
        indices,
        updates,
        lax.ScatterDimensionNumbers(
            update_window_dims=(),
            inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,),
        ),
    )


@partial(jit, static_argnums=(3, 4))
def _multinomial(key, p, n, n_max, shape=()):
    if jnp.shape(n) != jnp.shape(p)[:-1]:
        broadcast_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
        n = jnp.broadcast_to(n, broadcast_shape)
        p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])
    shape = shape or p.shape[:-1]
    if n_max == 0:
        return jnp.zeros(shape + p.shape[-1:], dtype=jnp.result_type(int))
    # get indices from categorical distribution then gather the result
    indices = categorical(key, p, (n_max,) + shape)
    # mask out values when counts is heterogeneous
    if jnp.ndim(n) > 0:
        mask = promote_shapes(
            jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,)
        )[0]
        mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
        excess = jnp.concatenate(
            [
                jnp.expand_dims(n_max - n, -1),
                jnp.zeros(jnp.shape(n) + (p.shape[-1] - 1,)),
            ],
            -1,
        )
    else:
        mask = 1
        excess = 0
    # NB: we transpose to move batch shape to the front
    indices_2D = (jnp.reshape(indices * mask, (n_max, -1))).T
    samples_2D = vmap(_scatter_add_one, (0, 0, 0))(
        jnp.zeros((indices_2D.shape[0], p.shape[-1]), dtype=indices.dtype),
        jnp.expand_dims(indices_2D, axis=-1),
        jnp.ones(indices_2D.shape, dtype=indices.dtype),
    )
    return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess


def multinomial(key, p, n, shape=(), total_count_max=None):
    if total_count_max is None:
        if isinstance(n, jax.core.Tracer):
            raise ValueError(
                "Please specify total_count_max in Multinomial distribution."
            )
        n_max = int(np.max(jax.device_get(n)))
    else:
        n_max = total_count_max
    return _multinomial(key, p, n, n_max, shape)


def cholesky_of_inverse(matrix):
    # This formulation only takes the inverse of a triangular matrix
    # which is more numerically stable.
    # Refer to:
    # https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril
    tril_inv = jnp.swapaxes(
        jnp.linalg.cholesky(matrix[..., ::-1, ::-1])[..., ::-1, ::-1], -2, -1
    )
    identity = jnp.broadcast_to(jnp.identity(matrix.shape[-1]), tril_inv.shape)
    return solve_triangular(tril_inv, identity, lower=True)


# TODO: move upstream to jax.nn
def binary_cross_entropy_with_logits(x, y):
    # compute -y * log(sigmoid(x)) - (1 - y) * log(1 - sigmoid(x))
    # Ref: https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits
    return jnp.clip(x, 0) + jnp.log1p(jnp.exp(-jnp.abs(x))) - x * y


def _reshape(x, shape):
    if isinstance(x, (int, float, np.ndarray, np.generic)):
        return np.reshape(x, shape)
    else:
        return jnp.reshape(x, shape)


def promote_shapes(*args, shape=(), return_shapes: bool = False):
    """
    Promote the shapes of arrays so they have the same number of dimensions and are
    broadcastable.

    :param *args: Arrays to promote.
    :param shape: Leading shape to add to arrays.
    :param return_shapes: Return the shapes of arrays instead of the shape-promoted
        arrays.
    """
    # adapted from lax.lax_numpy
    if len(args) < 2 and not shape:
        return [jnp.shape(arg) for arg in args] if return_shapes else args
    else:
        shapes = [jnp.shape(arg) for arg in args]
        num_dims = len(lax.broadcast_shapes(shape, *shapes))
        shapes = [(1,) * (num_dims - len(shape)) + shape for shape in shapes]
        if return_shapes:
            return shapes
        return [
            _reshape(arg, shape) if jnp.ndim(arg) < len(shape) else arg
            for arg, shape in zip(args, shapes)
        ]


def sum_rightmost(x, dim):
    """
    Sum out ``dim`` many rightmost dimensions of a given tensor.
    """
    out_dim = jnp.ndim(x) - dim
    x = jnp.reshape(jnp.expand_dims(x, -1), jnp.shape(x)[:out_dim] + (-1,))
    return jnp.sum(x, axis=-1)


def matrix_to_tril_vec(x, diagonal=0):
    idxs = jnp.tril_indices(x.shape[-1], diagonal)
    return x[..., idxs[0], idxs[1]]


def vec_to_tril_matrix(t, diagonal=0):
    # NB: the following formula only works for diagonal <= 0
    n = round((math.sqrt(1 + 8 * t.shape[-1]) - 1) / 2) - diagonal
    n2 = n * n
    idx = jnp.reshape(jnp.arange(n2), (n, n))[jnp.tril_indices(n, diagonal)]
    x = lax.scatter_add(
        jnp.zeros(t.shape[:-1] + (n2,)),
        jnp.expand_dims(idx, axis=-1),
        t,
        lax.ScatterDimensionNumbers(
            update_window_dims=range(t.ndim - 1),
            inserted_window_dims=(t.ndim - 1,),
            scatter_dims_to_operand_dims=(t.ndim - 1,),
        ),
    )
    return jnp.reshape(x, x.shape[:-1] + (n, n))


def cholesky_update(L, x, coef=1):
    """
    Finds cholesky of L @ L.T + coef * x @ x.T.

    **References;**

        1. A more efficient rank-one covariance matrix update for evolution strategies,
           Oswin Krause and Christian Igel
    """
    batch_shape = lax.broadcast_shapes(L.shape[:-2], x.shape[:-1])
    L = jnp.broadcast_to(L, batch_shape + L.shape[-2:])
    x = jnp.broadcast_to(x, batch_shape + x.shape[-1:])
    diag = jnp.diagonal(L, axis1=-2, axis2=-1)
    # convert to unit diagonal triangular matrix: L @ D @ T.t
    L = L / diag[..., None, :]
    D = jnp.square(diag)

    def scan_fn(carry, val):
        b, w = carry
        j, Dj, L_j = val
        wj = w[..., j]
        gamma = b * Dj + coef * jnp.square(wj)
        Dj_new = gamma / b
        b = gamma / Dj_new

        # update vectors w and L_j
        w = w - wj[..., None] * L_j
        L_j = L_j + (coef * wj / gamma)[..., None] * w
        return (b, w), (Dj_new, L_j)

    D, L = jnp.moveaxis(D, -1, 0), jnp.moveaxis(L, -1, 0)  # move scan dim to front
    _, (D, L) = lax.scan(
        scan_fn, (jnp.ones(batch_shape), x), (jnp.arange(D.shape[0]), D, L)
    )
    D, L = jnp.moveaxis(D, 0, -1), jnp.moveaxis(L, 0, -1)  # move scan dim back
    return L * jnp.sqrt(D)[..., None, :]


def signed_stick_breaking_tril(t):
    # make sure that t in (-1, 1)
    eps = jnp.finfo(t.dtype).eps
    t = jnp.clip(t, -1 + eps, 1 - eps)
    # transform t to tril matrix with identity diagonal
    r = vec_to_tril_matrix(t, diagonal=-1)

    # apply stick-breaking on the squared values;
    # we omit the step of computing s = z * z_cumprod by using the fact:
    #     y = sign(r) * s = sign(r) * sqrt(z * z_cumprod) = r * sqrt(z_cumprod)
    z = r**2
    z1m_cumprod_sqrt = jnp.cumprod(jnp.sqrt(1 - z), axis=-1)

    pad_width = [(0, 0)] * z.ndim
    pad_width[-1] = (1, 0)
    z1m_cumprod_sqrt_shifted = jnp.pad(
        z1m_cumprod_sqrt[..., :-1], pad_width, mode="constant", constant_values=1.0
    )
    y = add_diag(r, 1) * z1m_cumprod_sqrt_shifted
    return y


def logmatmulexp(x, y):
    """
    Numerically stable version of ``(x.log() @ y.log()).exp()``.
    """
    x_shift = lax.stop_gradient(jnp.amax(x, -1, keepdims=True))
    y_shift = lax.stop_gradient(jnp.amax(y, -2, keepdims=True))
    xy = jnp.log(jnp.matmul(jnp.exp(x - x_shift), jnp.exp(y - y_shift)))
    return xy + x_shift + y_shift


[docs] @jax.custom_jvp def log1mexp(x: ArrayLike) -> ArrayLike: """ Numerically stable calculation of the quantity :math:`\\log(1 - \\exp(x))`, following the algorithm of `Mächler 2012`_. .. _Mächler 2012: https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf Returns ``-jnp.inf`` when ``x == 0`` and ``jnp.nan`` when ``x > 0``. :param x: A number or array of numbers. :return: The value of :math:`\\log(1 - \\exp(x))`. """ return jnp.where( x > -0.6931472, # approx -log(2) jnp.log(-jnp.expm1(x)), jnp.log1p(-jnp.exp(x)), )
# Custom jvp for log1mexp to handle the gradient when x is near 0. # # Inspired by the approach taken here for the function log1mexp(-x): # https://github.com/google-research/google-research/blob/14e984cdb8630a7e3d210dff8760fc06d490fc4b/diffusion_distillation/diffusion_distillation/utils.py#L364-L370 # That code is (c) 2024 The Google Research Authors and licensed under # an Apache 2.0 License. log1mexp.defjvps(lambda t, ans, x: -t / jnp.expm1(-x))
[docs] def logdiffexp(a: ArrayLike, b: ArrayLike) -> ArrayLike: """ Numerically stable calculation of the quantity :math:`\\log(\\exp(a) - \\exp(b))`, provided :math:`+\\infty > a \\ge b`, following the algorithm of `Mächler 2012`_. .. _Mächler 2012: https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf Returns ``-jnp.inf`` when ``a == b``, including when ``a == b == -jnp.inf``, since this corresponds to ``jnp.log(0)``. Returns ``jnp.nan`` when ``a < b`` or ``a == jnp.inf``. :param a: A number or array of numbers. :param b: A number or array of numbers. :return: The value of :math:`\\log(\\exp(a) - \\exp(b))`. """ return jnp.where( (a < jnp.inf) & (a > b), a + log1mexp(b - a), jnp.where((a < jnp.inf) & (a == b), -jnp.inf, jnp.nan), )
def clamp_probs(probs: ArrayLike) -> ArrayLike: finfo = jnp.finfo(jnp.result_type(probs, float)) return jnp.clip(probs, finfo.tiny, 1.0 - finfo.eps) def betainc(a, b, x): try: from tensorflow_probability.substrates.jax.math import betainc as betainc_fn except ImportError: from jax.scipy.special import betainc as betainc_fn dtype = jnp.result_type(float) return betainc_fn( jnp.array(a, dtype=dtype), jnp.array(b, dtype=dtype), jnp.array(x, dtype=dtype), ) def betaincinv(a, b, y): try: from tensorflow_probability.substrates.jax.math import special as tfp_special except ImportError as e: raise ImportError( "Please install `tensorflow_probability>=0.18` for betaincinv." ) from e dtype = jnp.result_type(float) return tfp_special.betaincinv( jnp.array(a, dtype=dtype), jnp.array(b, dtype=dtype), jnp.array(y, dtype=dtype), ) def gammaincinv(a, y): try: from tensorflow_probability.substrates.jax import math as tfp_math return tfp_math.igammainv(jnp.array(a), jnp.array(y)) except ImportError as e: raise ImportError( "Please install `tensorflow_probability>=0.18` for gammaincinv." ) from e def is_identically_zero(x): """ Check if argument is exactly the number zero. True for the number zero; false for other numbers; false for ndarrays. """ if isinstance(x, (int, float)): return x == 0 else: return False def is_identically_one(x): """ Check if argument is exactly the number one. True for the number one; false for other numbers; false for ndarrays. """ if isinstance(x, (int, float)): return x == 1 else: return False def von_mises_centered(key, concentration, shape=(), dtype=jnp.float64): """Compute centered von Mises samples using rejection sampling from [1] with wrapped Cauchy proposal. *** References *** [1] Luc Devroye "Non-Uniform Random Variate Generation", Springer-Verlag, 1986; Chapter 9, p. 473-476. http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf :param key: random number generator key :param concentration: concentration of distribution :param shape: shape of samples :param dtype: float precesions for choosing correct s cutfoff :return: centered samples from von Mises """ shape = shape or jnp.shape(concentration) dtype = jnp.result_type(dtype) concentration = lax.convert_element_type(concentration, dtype) concentration = jnp.broadcast_to(concentration, shape) return _von_mises_centered(key, concentration, shape, dtype) @partial(jit, static_argnums=(2, 3)) def _von_mises_centered(key, concentration, shape, dtype): # Cutoff from TensorFlow probability # (https://github.com/tensorflow/probability/blob/f051e03dd3cc847d31061803c2b31c564562a993/tensorflow_probability/python/distributions/von_mises.py#L567-L570) s_cutoff_map = { jnp.dtype(jnp.float16): 1.8e-1, jnp.dtype(jnp.float32): 2e-2, jnp.dtype(jnp.float64): 1.2e-4, } s_cutoff = s_cutoff_map.get(dtype) r = 1.0 + jnp.sqrt(1.0 + 4.0 * concentration**2) rho = (r - jnp.sqrt(2.0 * r)) / (2.0 * concentration) s_exact = (1.0 + rho**2) / (2.0 * rho) s_approximate = 1.0 / concentration s = jnp.where(concentration > s_cutoff, s_exact, s_approximate) def cond_fn(*args): """check if all are done or reached max number of iterations""" i, _, done, _, _ = args[0] return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done))) def body_fn(*args): i, key, done, _, w = args[0] uni_ukey, uni_vkey, key = random.split(key, 3) u = random.uniform( key=uni_ukey, shape=shape, dtype=concentration.dtype, minval=-1.0, maxval=1.0, ) z = jnp.cos(jnp.pi * u) w = jnp.where(done, w, (1.0 + s * z) / (s + z)) # Update where not done y = concentration * (s - w) v = random.uniform(key=uni_vkey, shape=shape, dtype=concentration.dtype) accept = (y * (2.0 - y) >= v) | (jnp.log(y / v) + 1.0 >= y) return i + 1, key, accept | done, u, w init_done = jnp.zeros(shape, dtype=bool) init_u = jnp.zeros(shape) init_w = jnp.zeros(shape) _, _, done, u, w = lax.while_loop( cond_fun=cond_fn, body_fun=body_fn, init_val=(jnp.array(0), key, init_done, init_u, init_w), ) return jnp.sign(u) * jnp.arccos(w) def scale_and_mask(x, scale=None, mask=None): """ Scale and mask a tensor, broadcasting and avoiding unnecessary ops. """ if is_identically_zero(x): return x if not (scale is None or is_identically_one(scale)): x = x * scale if mask is None: return x else: return jnp.where(mask, x, 0.0) # TODO: use funsor implementation def periodic_repeat(x, size, dim): """ Repeat a ``period``-sized array up to given ``size``. """ assert isinstance(size, int) and size >= 0 assert isinstance(dim, int) if dim >= 0: dim -= jnp.ndim(x) period = jnp.shape(x)[dim] repeats = (size + period - 1) // period result = jnp.repeat(x, repeats, axis=dim) result = result[(Ellipsis, slice(None, size)) + (slice(None),) * (-1 - dim)] return result def safe_normalize(x, *, p=2): """ Safely project a vector onto the sphere wrt the ``p``-norm. This avoids the singularity at zero by mapping zero to the uniform unit vector proportional to ``[1, 1, ..., 1]``. :param numpy.ndarray x: A vector :param float p: The norm exponent, defaults to 2 i.e. the Euclidean norm. :returns: A normalized version ``x / ||x||_p``. :rtype: numpy.ndarray """ assert isinstance(p, (float, int)) assert p >= 0 norm = jnp.linalg.norm(x, p, axis=-1, keepdims=True) x = x / jnp.clip(norm, jnp.finfo(x.dtype).tiny) # Avoid the singularity. mask = jnp.all(x == 0, axis=-1, keepdims=True) x = jnp.where(mask, x.shape[-1] ** (-1 / p), x) return x def is_prng_key(key): warnings.warn("Please use numpyro.util.is_prng_key.", DeprecationWarning) try: if jax.dtypes.issubdtype(key.dtype, jax.dtypes.prng_key): return key.shape == () return key.shape == (2,) and key.dtype == np.uint32 except AttributeError: return False def assert_one_of(**kwargs): """ Assert that exactly one of the keyword arguments is not None. """ specified = [key for key, value in kwargs.items() if value is not None] if len(specified) != 1: raise ValueError( f"Exactly one of {list(kwargs)} must be specified; got {specified}." ) def multidigamma(a: jnp.ndarray, d: jnp.ndarray) -> jnp.ndarray: """ Derivative of the log of multivariate gamma. """ return digamma(a[..., None] - 0.5 * jnp.arange(d)).sum(axis=-1) def tri_logabsdet(a: jnp.ndarray) -> jnp.ndarray: """ Evaluate the `logabsdet` of a triangular positive-definite matrix. """ return jnp.log(jnp.diagonal(a, axis1=-1, axis2=-2)).sum(axis=-1) # The is sourced from: torch.distributions.util.py # # Copyright (c) 2016- Facebook, Inc (Adam Paszke) # Copyright (c) 2014- Facebook, Inc (Soumith Chintala) # Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) # Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) # Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) # Copyright (c) 2011-2013 NYU (Clement Farabet) # Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) # Copyright (c) 2006 Idiap Research Institute (Samy Bengio) # Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. class lazy_property(object): r""" Used as a decorator for lazy loading of class attributes. This uses a non-data descriptor that calls the wrapped method to compute the property on first call; thereafter replacing the wrapped method into an instance attribute. """ def __init__(self, wrapped): self.wrapped = wrapped update_wrapper(self, wrapped) # This is to prevent warnings from sphinx def __call__(self, *args, **kwargs): return self.wrapped(*args, **kwargs) def __get__(self, instance, obj_type=None): if instance is None: return self value = self.wrapped(instance) if not_jax_tracer(value): setattr(instance, self.wrapped.__name__, value) return value def validate_sample(log_prob_fn): @ft.wraps(log_prob_fn) def wrapper(self, *args, **kwargs): log_prob = log_prob_fn(self, *args, **kwargs) if self._validate_args: value = kwargs["value"] if "value" in kwargs else args[0] mask = self._validate_sample(value) log_prob = jnp.where(mask, log_prob, -jnp.inf) return log_prob return wrapper def add_diag(matrix: jnp.ndarray, diag: jnp.ndarray) -> jnp.ndarray: """ Add `diag` to the trailing diagonal of `matrix`. """ idx = jnp.arange(matrix.shape[-1]) return matrix.at[..., idx, idx].add(diag)