Source code for numpyro.distributions.directional

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import namedtuple
import functools
import math
from math import pi
import operator

from jax import lax
import jax.numpy as jnp
import jax.random as random
from jax.scipy import special
from jax.scipy.special import erf, i0e, i1e, logsumexp

from numpyro.distributions import constraints
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import (
    lazy_property,
    promote_shapes,
    safe_normalize,
    validate_sample,
    von_mises_centered,
)
from numpyro.util import is_prng_key, while_loop


def _numel(shape):
    return functools.reduce(operator.mul, shape, 1)


def log_I1(orders: int, value, terms=250):
    r"""Compute first n log modified bessel function of first kind
    .. math ::
        \log(I_v(z)) = v*\log(z/2) + \log(\sum_{k=0}^\inf \exp\left[2*k*\log(z/2) - \sum_kk^k log(kk)
        - \lgamma(v + k + 1)\right])
    :param orders: orders of the log modified bessel function.
    :param value: values to compute modified bessel function for
    :param terms: truncation of summation
    :return: 0 to orders modified bessel function
    """
    orders = orders + 1
    if value.ndim == 0:
        vshape = jnp.shape([1])
    else:
        vshape = value.shape
    value = value.reshape(-1, 1)
    flat_vshape = _numel(vshape)

    k = jnp.arange(terms)
    lgammas_all = special.gammaln(jnp.arange(1.0, terms + orders + 1))
    assert lgammas_all.shape == (orders + terms,)  # lgamma(0) = inf => start from 1

    lvalues = jnp.log(value / 2) * k.reshape(1, -1)
    assert lvalues.shape == (flat_vshape, terms)

    lfactorials = lgammas_all[:terms]
    assert lfactorials.shape == (terms,)

    lgammas = jnp.tile(lgammas_all, orders).reshape((orders, -1))
    assert lgammas.shape == (orders, terms + orders)  # lgamma(0) = inf => start from 1

    indices = k[:orders].reshape(-1, 1) + k.reshape(1, -1)
    assert indices.shape == (orders, terms)

    seqs = logsumexp(
        2 * lvalues[None, :, :]
        - lfactorials[None, None, :]
        - jnp.take_along_axis(lgammas, indices, axis=1)[:, None, :],
        -1,
    )
    assert seqs.shape == (orders, flat_vshape)

    i1s = lvalues[..., :orders].T + seqs
    assert i1s.shape == (orders, flat_vshape)
    return i1s.reshape(-1, *vshape)


[docs] class VonMises(Distribution): """ The von Mises distribution, also known as the circular normal distribution. This distribution is supported by a circular constraint from -pi to +pi. By default, the circular support behaves like ``constraints.interval(-math.pi, math.pi)``. To avoid issues at the boundaries of this interval during sampling, you should reparameterize this distribution using ``handlers.reparam`` with a :class:`~numpyro.infer.reparam.CircularReparam` reparametrizer in the model, e.g.:: @handlers.reparam(config={"direction": CircularReparam()}) def model(): direction = numpyro.sample("direction", VonMises(0.0, 4.0)) ... """ arg_constraints = {"loc": constraints.real, "concentration": constraints.positive} reparametrized_params = ["loc"] support = constraints.circular def __init__(self, loc, concentration, *, validate_args=None): """von Mises distribution for sampling directions. :param loc: center of distribution :param concentration: concentration of distribution """ self.loc, self.concentration = promote_shapes(loc, concentration) batch_shape = lax.broadcast_shapes(jnp.shape(concentration), jnp.shape(loc)) super(VonMises, self).__init__( batch_shape=batch_shape, validate_args=validate_args )
[docs] def sample(self, key, sample_shape=()): """Generate sample from von Mises distribution :param key: random number generator key :param sample_shape: shape of samples :return: samples from von Mises """ assert is_prng_key(key) samples = von_mises_centered( key, self.concentration, sample_shape + self.shape() ) samples = samples + self.loc # VM(0, concentration) -> VM(loc,concentration) samples = (samples + jnp.pi) % (2.0 * jnp.pi) - jnp.pi return samples
@validate_sample def log_prob(self, value): return -( jnp.log(2 * jnp.pi) + jnp.log(i0e(self.concentration)) ) + self.concentration * (jnp.cos((value - self.loc) % (2 * jnp.pi)) - 1) @property def mean(self): """Computes circular mean of distribution. NOTE: same as location when mapped to support [-pi, pi]""" return jnp.broadcast_to( (self.loc + jnp.pi) % (2.0 * jnp.pi) - jnp.pi, self.batch_shape ) @property def variance(self): """Computes circular variance of distribution""" return jnp.broadcast_to( 1.0 - i1e(self.concentration) / i0e(self.concentration), self.batch_shape )
PhiMarginalState = namedtuple("PhiMarginalState", ["i", "done", "phi", "key"])
[docs] class SineSkewed(Distribution): r"""Sine-skewing [1] is a procedure for producing a distribution that breaks pointwise symmetry on a torus distribution. The new distribution is called the Sine Skewed X distribution, where X is the name of the (symmetric) base distribution. Torus distributions are distributions with support on products of circles (i.e., :math:`\otimes S^1` where :math:`S^1 = [-pi,pi)`). So, a 0-torus is a point, the 1-torus is a circle, and the 2-torus is commonly associated with the donut shape. The sine skewed X distribution is parameterized by a weight parameter for each dimension of the event of X. For example with a von Mises distribution over a circle (1-torus), the sine skewed von Mises distribution has one skew parameter. The skewness parameters can be inferred using :class:`~numpyro.infer.HMC` or :class:`~numpyro.infer.NUTS`. For example, the following will produce a prior over skewness for the 2-torus,:: @numpyro.handlers.reparam(config={'phi_loc': CircularReparam(), 'psi_loc': CircularReparam()}) def model(obs): # Sine priors phi_loc = numpyro.sample('phi_loc', VonMises(pi, 2.)) psi_loc = numpyro.sample('psi_loc', VonMises(-pi / 2, 2.)) phi_conc = numpyro.sample('phi_conc', Beta(1., 1.)) psi_conc = numpyro.sample('psi_conc', Beta(1., 1.)) corr_scale = numpyro.sample('corr_scale', Beta(2., 5.)) # Skewing prior ball_trans = L1BallTransform() skewness = numpyro.sample('skew_phi', Normal(0, 0.5).expand((2,))) skewness = ball_trans(skewness) # constraint sum |skewness_i| <= 1 with numpyro.plate('obs_plate'): sine = SineBivariateVonMises(phi_loc=phi_loc, psi_loc=psi_loc, phi_concentration=70 * phi_conc, psi_concentration=70 * psi_conc, weighted_correlation=corr_scale) return numpyro.sample('phi_psi', SineSkewed(sine, skewness), obs=obs) To ensure the skewing does not alter the normalization constant of the (sine bivariate von Mises) base distribution the skewness parameters are constraint. The constraint requires the sum of the absolute values of skewness to be less than or equal to one. We can use the :class:`~numpyro.distriubtions.transforms.L1BallTransform` to achieve this. In the context of :class:`~pyro.infer.SVI`, this distribution can freely be used as a likelihood, but use as latent variables it will lead to slow inference for 2 and higher dim toruses. This is because the base_dist cannot be reparameterized. .. note:: An event in the base distribution must be on a d-torus, so the event_shape must be `(d,)`. .. note:: For the skewness parameter, it must hold that the sum of the absolute value of its weights for an event must be less than or equal to one. See eq. 2.1 in [1]. ** References: ** 1. Sine-skewed toroidal distributions and their application in protein bioinformatics Ameijeiras-Alonso, J., Ley, C. (2019) :param numpyro.distributions.Distribution base_dist: base density on a d-dimensional torus. Supported base distributions include: 1D :class:`~numpyro.distributions.VonMises`, :class:`~numnumpyro.distributions.SineBivariateVonMises`, 1D :class:`~numpyro.distributions.ProjectedNormal`, and :class:`~numpyro.distributions.Uniform` (-pi, pi). :param jax.numpy.array skewness: skewness of the distribution. """ arg_constraints = {"skewness": constraints.l1_ball} pytree_data_fields = ("base_dist", "skewness") support = constraints.independent(constraints.circular, 1) def __init__(self, base_dist: Distribution, skewness, *, validate_args=None): assert ( base_dist.event_shape == skewness.shape[-1:] ), "Sine Skewing is only valid with a skewness parameter for each dimension of `base_dist.event_shape`." batch_shape = jnp.broadcast_shapes(base_dist.batch_shape, skewness.shape[:-1]) event_shape = skewness.shape[-1:] self.skewness = jnp.broadcast_to(skewness, batch_shape + event_shape) self.base_dist = base_dist.expand(batch_shape) super().__init__(batch_shape, event_shape, validate_args=validate_args) def __repr__(self): args_string = ", ".join( [ "{}: {}".format( p, getattr(self, p) if getattr(self, p).numel() == 1 else getattr(self, p).size(), ) for p in self.arg_constraints.keys() ] ) return ( self.__class__.__name__ + "(" + f"base_density: {str(self.base_dist)}, " + args_string + ")" )
[docs] def sample(self, key, sample_shape=()): base_key, skew_key = random.split(key) bd = self.base_dist ys = bd.sample(base_key, sample_shape) u = random.uniform(skew_key, sample_shape + self.batch_shape) # Section 2.3 step 3 in [1] mask = u <= 0.5 + 0.5 * ( self.skewness * jnp.sin((ys - bd.mean) % (2 * jnp.pi)) ).sum(-1) mask = mask[..., None] samples = (jnp.where(mask, ys, -ys + 2 * bd.mean) + jnp.pi) % ( 2 * jnp.pi ) - jnp.pi return samples
[docs] def log_prob(self, value): if self._validate_args: self._validate_sample(value) if self.base_dist._validate_args: self.base_dist._validate_sample(value) # Eq. 2.1 in [1] skew_prob = jnp.log1p( (self.skewness * jnp.sin((value - self.base_dist.mean) % (2 * jnp.pi))).sum( -1 ) ) return self.base_dist.log_prob(value) + skew_prob
@property def mean(self): """Mean of the base distribution""" return self.base_dist.mean
[docs] class SineBivariateVonMises(Distribution): r"""Unimodal distribution of two dependent angles on the 2-torus (:math:`S^1 \otimes S^1`) given by .. math:: C^{-1}\exp(\kappa_1\cos(x_1-\mu_1) + \kappa_2\cos(x_2 -\mu_2) + \rho\sin(x_1 - \mu_1)\sin(x_2 - \mu_2)) and .. math:: C = (2\pi)^2 \sum_{i=0} {2i \choose i} \left(\frac{\rho^2}{4\kappa_1\kappa_2}\right)^i I_i(\kappa_1)I_i(\kappa_2), where :math:`I_i(\cdot)` is the modified bessel function of first kind, mu's are the locations of the distribution, kappa's are the concentration and rho gives the correlation between angles :math:`x_1` and :math:`x_2`. This distribution is helpful for modeling coupled angles such as torsion angles in peptide chains. To infer parameters, use :class:`~numpyro.infer.hmc.NUTS` or :class:`~numpyro.infer.hmc.HMC` with priors that avoid parameterizations where the distribution becomes bimodal; see note below. .. note:: Sample efficiency drops as .. math:: \frac{\rho}{\kappa_1\kappa_2} \rightarrow 1 because the distribution becomes increasingly bimodal. To avoid bimodality use the `weighted_correlation` parameter with a skew away from one (e.g., Beta(1,3)). The `weighted_correlation` should be in [0,1]. .. note:: The correlation and weighted_correlation params are mutually exclusive. .. note:: In the context of :class:`~numpyro.infer.svi.SVI`, this distribution can be used as a likelihood but not for latent variables. ** References: ** 1. Probabilistic model for two dependent circular variables Singh, H., Hnizdo, V., and Demchuck, E. (2002) :param np.ndarray phi_loc: location of first angle :param np.ndarray psi_loc: location of second angle :param np.ndarray phi_concentration: concentration of first angle :param np.ndarray psi_concentration: concentration of second angle :param np.ndarray correlation: correlation between the two angles :param np.ndarray weighted_correlation: set correlation to weighted_corr * sqrt(phi_conc*psi_conc) to avoid bimodality (see note). The `weighted_correlation` should be in [0,1]. """ arg_constraints = { "phi_loc": constraints.circular, "psi_loc": constraints.circular, "phi_concentration": constraints.positive, "psi_concentration": constraints.positive, "correlation": constraints.real, } support = constraints.independent(constraints.circular, 1) max_sample_iter = 1000 def __init__( self, phi_loc, psi_loc, phi_concentration, psi_concentration, correlation=None, weighted_correlation=None, validate_args=None, ): assert (correlation is None) != (weighted_correlation is None) if weighted_correlation is not None: correlation = weighted_correlation * jnp.sqrt( phi_concentration * psi_concentration ) batch_shape = lax.broadcast_shapes( jnp.shape(phi_loc), jnp.shape(psi_loc), jnp.shape(phi_concentration), jnp.shape(psi_concentration), jnp.shape(correlation), ) ( self.phi_loc, self.psi_loc, self.phi_concentration, self.psi_concentration, self.correlation, ) = promote_shapes( phi_loc, psi_loc, phi_concentration, psi_concentration, correlation, shape=batch_shape, ) super().__init__(batch_shape, (2,), validate_args=validate_args) self.phi_loc = jnp.broadcast_to(self.phi_loc, batch_shape) self.psi_loc = jnp.broadcast_to(self.psi_loc, batch_shape) self.phi_concentration = jnp.broadcast_to(self.phi_concentration, batch_shape) self.psi_concentration = jnp.broadcast_to(self.psi_concentration, batch_shape) self.correlation = jnp.broadcast_to(self.correlation, batch_shape)
[docs] @lazy_property def norm_const(self): corr = jnp.reshape(self.correlation, (1, -1)) conc = jnp.stack( (self.phi_concentration, self.psi_concentration), axis=-1 ).reshape(-1, 2) m = jnp.arange(50).reshape(-1, 1) num = special.gammaln(2 * m + 1.0) den = special.gammaln(m + 1.0) lbinoms = num - 2 * den fs = lbinoms.reshape(-1, 1) + m * ( jnp.log(jnp.clip(corr**2, a_min=jnp.finfo(jnp.result_type(float)).tiny)) - jnp.log(4 * jnp.prod(conc, axis=-1)) ) fs += log_I1(49, conc, terms=51).sum(-1) norm_const = 2 * jnp.log(jnp.array(2 * pi)) + logsumexp(fs, 0) return norm_const.reshape(jnp.shape(self.phi_loc))
@validate_sample def log_prob(self, value): indv = self.phi_concentration * jnp.cos( value[..., 0] - self.phi_loc ) + self.psi_concentration * jnp.cos(value[..., 1] - self.psi_loc) corr = ( self.correlation * jnp.sin(value[..., 0] - self.phi_loc) * jnp.sin(value[..., 1] - self.psi_loc) ) return indv + corr - self.norm_const
[docs] def sample(self, key, sample_shape=()): """ ** References: ** 1. A New Unified Approach for the Simulation of a Wide Class of Directional Distributions John T. Kent, Asaad M. Ganeiber & Kanti V. Mardia (2018) """ assert is_prng_key(key) phi_key, psi_key = random.split(key) corr = self.correlation conc = jnp.stack((self.phi_concentration, self.psi_concentration)) eig = 0.5 * (conc[0] - corr**2 / conc[1]) eig = jnp.stack((jnp.zeros_like(eig), eig)) eigmin = jnp.where(eig[1] < 0, eig[1], jnp.zeros_like(eig[1], dtype=eig.dtype)) eig = eig - eigmin b0 = self._bfind(eig) total = _numel(sample_shape) phi_den = log_I1(0, conc[1]).squeeze(0) batch_size = _numel(self.batch_shape) phi_shape = (total, 2, batch_size) phi_state = SineBivariateVonMises._phi_marginal( phi_shape, phi_key, jnp.reshape(conc, (2, batch_size)), jnp.reshape(corr, (batch_size,)), jnp.reshape(eig, (2, batch_size)), jnp.reshape(b0, (batch_size,)), jnp.reshape(eigmin, (batch_size,)), jnp.reshape(phi_den, (batch_size,)), ) phi = jnp.arctan2(phi_state.phi[:, 1], phi_state.phi[:, 0]) alpha = jnp.sqrt( conc[1].reshape(-1) ** 2 + (corr.reshape(-1) * jnp.sin(phi)) ** 2 ) beta = jnp.arctan(corr.reshape(-1) / conc[1].reshape(-1) * jnp.sin(phi)) psi = VonMises(beta, alpha).sample(psi_key) phi_psi = jnp.stack( ( (phi + jnp.reshape(self.phi_loc, -1) + pi) % (2 * pi) - pi, (psi + jnp.reshape(self.psi_loc, -1) + pi) % (2 * pi) - pi, ), axis=-1, ) return phi_psi.reshape(*sample_shape, *self.batch_shape, *self.event_shape)
@staticmethod def _phi_marginal(shape, rng_key, conc, corr, eig, b0, eigmin, phi_den): def update_fn(curr): i, done, phi, key = curr phi_key, key = random.split(key) accept_key, acg_key, phi_key = random.split(phi_key, 3) x = lax.rsqrt(1 + 2 * eig / b0) * random.normal(acg_key, shape) x /= jnp.linalg.norm( x, axis=1, keepdims=True ) # Angular Central Gaussian distribution lf = ( conc[0] * (x[:, 0] - 1) + eigmin + log_I1(0, jnp.sqrt(conc[1] ** 2 + (corr * x[:, 1]) ** 2)).squeeze(0) - phi_den ) lg_inv = 1.0 - b0 / 2 + jnp.log(b0 / 2 + (eig * x**2).sum(1)) assert lg_inv.shape == lf.shape accepted = random.uniform(accept_key, lf.shape) < jnp.exp(lf + lg_inv) accepted = accepted[:, None] phi = jnp.where(accepted, x, phi) return PhiMarginalState(i + 1, done | accepted, phi, key) def cond_fn(curr): return jnp.bitwise_and( curr.i < SineBivariateVonMises.max_sample_iter, jnp.logical_not(jnp.all(curr.done)), ) phi_state = while_loop( cond_fn, update_fn, PhiMarginalState( i=jnp.array(0), done=jnp.zeros(shape, dtype=bool), phi=jnp.empty(shape, dtype=float), key=rng_key, ), ) return PhiMarginalState( phi_state.i, phi_state.done, phi_state.phi, phi_state.key ) @property def mean(self): """Computes circular mean of distribution. Note: same as location when mapped to support [-pi, pi]""" mean = (jnp.stack((self.phi_loc, self.psi_loc), axis=-1) + jnp.pi) % ( 2.0 * jnp.pi ) - jnp.pi return jnp.broadcast_to(mean, (*self.batch_shape, 2)) def _bfind(self, eig): b = eig.shape[0] / 2 * jnp.ones(self.batch_shape, dtype=eig.dtype) g1 = jnp.sum(1 / (b + 2 * eig) ** 2, axis=0) g2 = jnp.sum(-2 / (b + 2 * eig) ** 3, axis=0) return jnp.where(jnp.linalg.norm(eig, axis=0) != 0, b - g1 / g2, b)
[docs] class ProjectedNormal(Distribution): """ Projected isotropic normal distribution of arbitrary dimension. This distribution over directional data is qualitatively similar to the von Mises and von Mises-Fisher distributions, but permits tractable variational inference via reparametrized gradients. To use this distribution with autoguides and HMC, use ``handlers.reparam`` with a :class:`~numpyro.infer.reparam.ProjectedNormalReparam` reparametrizer in the model, e.g.:: @handlers.reparam(config={"direction": ProjectedNormalReparam()}) def model(): direction = numpyro.sample("direction", ProjectedNormal(zeros(3))) ... .. note:: This implements :meth:`log_prob` only for dimensions {2,3}. [1] D. Hernandez-Stumpfhauser, F.J. Breidt, M.J. van der Woerd (2017) "The General Projected Normal Distribution of Arbitrary Dimension: Modeling and Bayesian Inference" https://projecteuclid.org/euclid.ba/1453211962 """ arg_constraints = {"concentration": constraints.real_vector} reparametrized_params = ["concentration"] support = constraints.sphere def __init__(self, concentration, *, validate_args=None): assert jnp.ndim(concentration) >= 1 self.concentration = concentration batch_shape = concentration.shape[:-1] event_shape = concentration.shape[-1:] super().__init__(batch_shape, event_shape, validate_args=validate_args) @property def mean(self): """ Note this is the mean in the sense of a centroid in the submanifold that minimizes expected squared geodesic distance. """ return safe_normalize(self.concentration) @property def mode(self): return safe_normalize(self.concentration)
[docs] def sample(self, key, sample_shape=()): shape = sample_shape + self.batch_shape + self.event_shape eps = random.normal(key, shape=shape) return safe_normalize(self.concentration + eps)
[docs] def log_prob(self, value): if self._validate_args: event_shape = value.shape[-1:] if event_shape != self.event_shape: raise ValueError( f"Expected event shape {self.event_shape}, " f"but got {event_shape}" ) self._validate_sample(value) dim = int(self.concentration.shape[-1]) if dim == 2: return _projected_normal_log_prob_2(self.concentration, value) if dim == 3: return _projected_normal_log_prob_3(self.concentration, value) raise NotImplementedError( f"ProjectedNormal.log_prob() is not implemented for dim = {dim}. " "Consider using handlers.reparam with ProjectedNormalReparam." )
[docs] @staticmethod def infer_shapes(concentration): batch_shape = concentration[:-1] event_shape = concentration[-1:] return batch_shape, event_shape
def _projected_normal_log_prob_2(concentration, value): def _dot(x, y): return (x[..., None, :] @ y[..., None])[..., 0, 0] # We integrate along a ray, factorizing the integrand as a product of: # a truncated normal distribution over coordinate t parallel to the ray, and # a univariate normal distribution over coordinate r perpendicular to the ray. t = _dot(concentration, value) t2 = t * t r2 = _dot(concentration, concentration) - t2 perp_part = (-0.5) * r2 - 0.5 * math.log(2 * math.pi) # This is the log of a definite integral, computed by mathematica: # Integrate[x/(E^((x-t)^2/2) Sqrt[2 Pi]), {x, 0, Infinity}] # = (t + Sqrt[2/Pi]/E^(t^2/2) + t Erf[t/Sqrt[2]])/2 para_part = jnp.log( (jnp.exp((-0.5) * t2) * ((2 / math.pi) ** 0.5) + t * (1 + erf(t * 0.5**0.5))) / 2 ) return para_part + perp_part def _projected_normal_log_prob_3(concentration, value): def _dot(x, y): return (x[..., None, :] @ y[..., None])[..., 0, 0] # We integrate along a ray, factorizing the integrand as a product of: # a truncated normal distribution over coordinate t parallel to the ray, and # a bivariate normal distribution over coordinate r perpendicular to the ray. t = _dot(concentration, value) t2 = t * t r2 = _dot(concentration, concentration) - t2 perp_part = (-0.5) * r2 - math.log(2 * math.pi) # This is the log of a definite integral, computed by mathematica: # Integrate[x^2/(E^((x-t)^2/2) Sqrt[2 Pi]), {x, 0, Infinity}] # = t/(E^(t^2/2) Sqrt[2 Pi]) + ((1 + t^2) (1 + Erf[t/Sqrt[2]]))/2 para_part = jnp.log( t * jnp.exp((-0.5) * t2) / (2 * math.pi) ** 0.5 + (1 + t2) * (1 + erf(t * 0.5**0.5)) / 2 ) return para_part + perp_part