Source code for numpyro.distributions.copula

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from jax import lax, numpy as jnp

import numpyro.distributions.constraints as constraints
from numpyro.distributions.continuous import Beta, MultivariateNormal, Normal
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import (
    clamp_probs,
    is_prng_key,
    lazy_property,
    validate_sample,
)


[docs]class GaussianCopula(Distribution): """ A distribution that links the `batch_shape[:-1]` of marginal distribution `marginal_dist` with a multivariate Gaussian copula modelling the correlation between the axes. :param Distribution marginal_dist: Distribution whose last batch axis is to be coupled. :param array_like correlation_matrix: Correlation matrix of coupling multivariate normal distribution. :param array_like correlation_cholesky: Correlation Cholesky factor of coupling multivariate normal distribution. """ arg_constraints = { "correlation_matrix": constraints.corr_matrix, "correlation_cholesky": constraints.corr_cholesky, } reparametrized_params = [ "correlation_matrix", "correlation_cholesky", ] def __init__( self, marginal_dist, correlation_matrix=None, correlation_cholesky=None, *, validate_args=None, ): if len(marginal_dist.event_shape) > 0: raise ValueError("`marginal_dist` needs to be a univariate distribution.") self.marginal_dist = marginal_dist self.base_dist = MultivariateNormal( covariance_matrix=correlation_matrix, scale_tril=correlation_cholesky, ) event_shape = self.base_dist.event_shape batch_shape = lax.broadcast_shapes( self.marginal_dist.batch_shape[:-1], self.base_dist.batch_shape, ) super(GaussianCopula, self).__init__( batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args, )
[docs] def sample(self, key, sample_shape=()): assert is_prng_key(key) shape = sample_shape + self.batch_shape normal_samples = self.base_dist.expand(shape).sample(key) cdf = Normal().cdf(normal_samples) return self.marginal_dist.icdf(cdf)
@validate_sample def log_prob(self, value): # Ref: https://en.wikipedia.org/wiki/Copula_(probability_theory)#Gaussian_copula # see also https://github.com/pyro-ppl/numpyro/pull/1506#discussion_r1037525015 marginal_lps = self.marginal_dist.log_prob(value) probs = self.marginal_dist.cdf(value) quantiles = Normal().icdf(clamp_probs(probs)) copula_lp = ( self.base_dist.log_prob(quantiles) + 0.5 * (quantiles**2).sum(-1) + 0.5 * jnp.log(2 * jnp.pi) * quantiles.shape[-1] ) return copula_lp + marginal_lps.sum(axis=-1) @property def mean(self): return jnp.broadcast_to(self.marginal_dist.mean, self.shape()) @property def variance(self): return jnp.broadcast_to(self.marginal_dist.variance, self.shape()) @constraints.dependent_property(is_discrete=False, event_dim=1) def support(self): return constraints.independent(self.marginal_dist.support, 1)
[docs] @lazy_property def correlation_matrix(self): return self.base_dist.covariance_matrix
[docs] @lazy_property def correlation_cholesky(self): return self.base_dist.scale_tril
[docs] def tree_flatten(self): marginal_flatten, marginal_aux = self.marginal_dist.tree_flatten() return (marginal_flatten, self.base_dist.scale_tril), ( type(self.marginal_dist), marginal_aux, )
[docs] @classmethod def tree_unflatten(cls, aux_data, params): marginal_flatten, correlation_cholesky = params marginal_cls, marginal_aux = aux_data marginal_dist = marginal_cls.tree_unflatten(marginal_aux, marginal_flatten) return cls(marginal_dist, correlation_cholesky=correlation_cholesky)
[docs]class GaussianCopulaBeta(GaussianCopula): arg_constraints = { "concentration1": constraints.positive, "concentration0": constraints.positive, "correlation_matrix": constraints.corr_matrix, "correlation_cholesky": constraints.corr_cholesky, } support = constraints.independent(constraints.unit_interval, 1) def __init__( self, concentration1, concentration0, correlation_matrix=None, correlation_cholesky=None, *, validate_args=False, ): # set initially to allow argument validation self.concentration1, self.concentration0 = concentration1, concentration0 super().__init__( Beta(concentration1, concentration0), correlation_matrix, correlation_cholesky, validate_args=validate_args, )
[docs] def tree_flatten(self): return ( (self.concentration1, self.concentration0), self.base_dist.scale_tril, ), None
[docs] @classmethod def tree_unflatten(cls, aux_data, params): (concentration1, concentration0), correlation_cholesky = params return cls( concentration1, concentration0, correlation_cholesky=correlation_cholesky )