Source code for numpyro.distributions.copula

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0


from typing import Optional

import jax
from jax import Array, lax, numpy as jnp
from jax.typing import ArrayLike

import numpyro.distributions.constraints as constraints
from numpyro.distributions.constraints import Constraint
from numpyro.distributions.continuous import Beta, MultivariateNormal, Normal
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import clamp_probs, lazy_property, validate_sample
from numpyro.util import is_prng_key


[docs] class GaussianCopula(Distribution): """ A distribution that links the `batch_shape[:-1]` of marginal distribution `marginal_dist` with a multivariate Gaussian copula modelling the correlation between the axes. :param Distribution marginal_dist: Distribution whose last batch axis is to be coupled. :param array_like correlation_matrix: Correlation matrix of coupling multivariate normal distribution. :param array_like correlation_cholesky: Correlation Cholesky factor of coupling multivariate normal distribution. """ arg_constraints = { "correlation_matrix": constraints.corr_matrix, "correlation_cholesky": constraints.corr_cholesky, } reparametrized_params = [ "correlation_matrix", "correlation_cholesky", ] pytree_data_fields = ("marginal_dist", "base_dist") def __init__( self, marginal_dist: Distribution, correlation_matrix: Optional[Array] = None, correlation_cholesky: Optional[Array] = None, *, validate_args: Optional[bool] = None, ): if len(marginal_dist.event_shape) > 0: raise ValueError("`marginal_dist` needs to be a univariate distribution.") self.marginal_dist = marginal_dist self.base_dist = MultivariateNormal( covariance_matrix=correlation_matrix, scale_tril=correlation_cholesky, ) event_shape = self.base_dist.event_shape batch_shape = lax.broadcast_shapes( self.marginal_dist.batch_shape[:-1], self.base_dist.batch_shape, ) super(GaussianCopula, self).__init__( batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args, )
[docs] def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike: assert is_prng_key(key) shape = sample_shape + self.batch_shape normal_samples = self.base_dist.expand(shape).sample(key) cdf = Normal().cdf(normal_samples) return self.marginal_dist.icdf(cdf)
[docs] @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: # Ref: https://en.wikipedia.org/wiki/Copula_(probability_theory)#Gaussian_copula # see also https://github.com/pyro-ppl/numpyro/pull/1506#discussion_r1037525015 marginal_lps = self.marginal_dist.log_prob(value) probs = self.marginal_dist.cdf(value) quantiles = Normal().icdf(clamp_probs(probs)) copula_lp = ( self.base_dist.log_prob(quantiles) + 0.5 * (quantiles**2).sum(-1) + 0.5 * jnp.log(2 * jnp.pi) * quantiles.shape[-1] ) return copula_lp + marginal_lps.sum(axis=-1)
@property def mean(self) -> ArrayLike: return jnp.broadcast_to(self.marginal_dist.mean, self.shape()) @property def variance(self) -> ArrayLike: return jnp.broadcast_to(self.marginal_dist.variance, self.shape()) @constraints.dependent_property(is_discrete=False, event_dim=1) def support(self) -> Constraint: return constraints.independent(self.marginal_dist.support, 1)
[docs] @lazy_property def correlation_matrix(self) -> Array: return self.base_dist.covariance_matrix
[docs] @lazy_property def correlation_cholesky(self) -> Array: return self.base_dist.scale_tril
[docs] class GaussianCopulaBeta(GaussianCopula): arg_constraints = { "concentration1": constraints.positive, "concentration0": constraints.positive, "correlation_matrix": constraints.corr_matrix, "correlation_cholesky": constraints.corr_cholesky, } support = constraints.independent(constraints.unit_interval, 1) pytree_data_fields = ("concentration1", "concentration0") def __init__( self, concentration1: ArrayLike, concentration0: ArrayLike, correlation_matrix: Optional[Array] = None, correlation_cholesky: Optional[Array] = None, *, validate_args: bool = False, ): # set initially to allow argument validation self.concentration1, self.concentration0 = concentration1, concentration0 super().__init__( Beta(concentration1, concentration0), correlation_matrix, correlation_cholesky, validate_args=validate_args, )