Source code for numpyro.distributions.directional

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import math

from jax import lax
import jax.numpy as jnp
import jax.random as random
from jax.scipy.special import erf, i0e, i1e

from numpyro.distributions import constraints
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import (
    is_prng_key,
    promote_shapes,
    safe_normalize,
    validate_sample,
    von_mises_centered,
)


[docs]class VonMises(Distribution): arg_constraints = {"loc": constraints.real, "concentration": constraints.positive} reparametrized_params = ["loc"] support = constraints.interval(-math.pi, math.pi) def __init__(self, loc, concentration, validate_args=None): """von Mises distribution for sampling directions. :param loc: center of distribution :param concentration: concentration of distribution """ self.loc, self.concentration = promote_shapes(loc, concentration) batch_shape = lax.broadcast_shapes(jnp.shape(concentration), jnp.shape(loc)) super(VonMises, self).__init__( batch_shape=batch_shape, validate_args=validate_args )
[docs] def sample(self, key, sample_shape=()): """Generate sample from von Mises distribution :param key: random number generator key :param sample_shape: shape of samples :return: samples from von Mises """ assert is_prng_key(key) samples = von_mises_centered( key, self.concentration, sample_shape + self.shape() ) samples = samples + self.loc # VM(0, concentration) -> VM(loc,concentration) samples = (samples + jnp.pi) % (2.0 * jnp.pi) - jnp.pi return samples
@validate_sample def log_prob(self, value): return -( jnp.log(2 * jnp.pi) + jnp.log(i0e(self.concentration)) ) + self.concentration * (jnp.cos((value - self.loc) % (2 * jnp.pi)) - 1) @property def mean(self): """Computes circular mean of distribution. NOTE: same as location when mapped to support [-pi, pi]""" return jnp.broadcast_to( (self.loc + jnp.pi) % (2.0 * jnp.pi) - jnp.pi, self.batch_shape ) @property def variance(self): """Computes circular variance of distribution""" return jnp.broadcast_to( 1.0 - i1e(self.concentration) / i0e(self.concentration), self.batch_shape )
[docs]class ProjectedNormal(Distribution): """ Projected isotropic normal distribution of arbitrary dimension. This distribution over directional data is qualitatively similar to the von Mises and von Mises-Fisher distributions, but permits tractable variational inference via reparametrized gradients. To use this distribution with autoguides and HMC, use ``handlers.reparam`` with a :class:`~numpyro.infer.reparam.ProjectedNormalReparam` reparametrizer in the model, e.g.:: @handlers.reparam(config={"direction": ProjectedNormalReparam()}) def model(): direction = numpyro.sample("direction", ProjectedNormal(zeros(3))) ... .. note:: This implements :meth:`log_prob` only for dimensions {2,3}. [1] D. Hernandez-Stumpfhauser, F.J. Breidt, M.J. van der Woerd (2017) "The General Projected Normal Distribution of Arbitrary Dimension: Modeling and Bayesian Inference" https://projecteuclid.org/euclid.ba/1453211962 """ arg_constraints = {"concentration": constraints.real_vector} reparametrized_params = ["concentration"] support = constraints.sphere def __init__(self, concentration, *, validate_args=None): assert jnp.ndim(concentration) >= 1 self.concentration = concentration batch_shape = concentration.shape[:-1] event_shape = concentration.shape[-1:] super().__init__(batch_shape, event_shape, validate_args=validate_args) @property def mean(self): """ Note this is the mean in the sense of a centroid in the submanifold that minimizes expected squared geodesic distance. """ return safe_normalize(self.concentration) @property def mode(self): return safe_normalize(self.concentration)
[docs] def sample(self, key, sample_shape=()): shape = sample_shape + self.batch_shape + self.event_shape eps = random.normal(key, shape=shape) return safe_normalize(self.concentration + eps)
[docs] def log_prob(self, value): if self._validate_args: event_shape = value.shape[-1:] if event_shape != self.event_shape: raise ValueError( f"Expected event shape {self.event_shape}, " f"but got {event_shape}" ) self._validate_sample(value) dim = int(self.concentration.shape[-1]) if dim == 2: return _projected_normal_log_prob_2(self.concentration, value) if dim == 3: return _projected_normal_log_prob_3(self.concentration, value) raise NotImplementedError( f"ProjectedNormal.log_prob() is not implemented for dim = {dim}. " "Consider using handlers.reparam with ProjectedNormalReparam." )
[docs] @staticmethod def infer_shapes(concentration): batch_shape = concentration[:-1] event_shape = concentration[-1:] return batch_shape, event_shape
def _projected_normal_log_prob_2(concentration, value): def _dot(x, y): return (x[..., None, :] @ y[..., None])[..., 0, 0] # We integrate along a ray, factorizing the integrand as a product of: # a truncated normal distribution over coordinate t parallel to the ray, and # a univariate normal distribution over coordinate r perpendicular to the ray. t = _dot(concentration, value) t2 = t * t r2 = _dot(concentration, concentration) - t2 perp_part = (-0.5) * r2 - 0.5 * math.log(2 * math.pi) # This is the log of a definite integral, computed by mathematica: # Integrate[x/(E^((x-t)^2/2) Sqrt[2 Pi]), {x, 0, Infinity}] # = (t + Sqrt[2/Pi]/E^(t^2/2) + t Erf[t/Sqrt[2]])/2 para_part = jnp.log( (jnp.exp((-0.5) * t2) * ((2 / math.pi) ** 0.5) + t * (1 + erf(t * 0.5 ** 0.5))) / 2 ) return para_part + perp_part def _projected_normal_log_prob_3(concentration, value): def _dot(x, y): return (x[..., None, :] @ y[..., None])[..., 0, 0] # We integrate along a ray, factorizing the integrand as a product of: # a truncated normal distribution over coordinate t parallel to the ray, and # a bivariate normal distribution over coordinate r perpendicular to the ray. t = _dot(concentration, value) t2 = t * t r2 = _dot(concentration, concentration) - t2 perp_part = (-0.5) * r2 - math.log(2 * math.pi) # This is the log of a definite integral, computed by mathematica: # Integrate[x^2/(E^((x-t)^2/2) Sqrt[2 Pi]), {x, 0, Infinity}] # = t/(E^(t^2/2) Sqrt[2 Pi]) + ((1 + t^2) (1 + Erf[t/Sqrt[2]]))/2 para_part = jnp.log( t * jnp.exp((-0.5) * t2) / (2 * math.pi) ** 0.5 + (1 + t2) * (1 + erf(t * 0.5 ** 0.5)) / 2 ) return para_part + perp_part