# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import jax
from jax import Array, lax, numpy as jnp
from jax.typing import ArrayLike
import numpyro.distributions.constraints as constraints
from numpyro.distributions.constraints import Constraint
from numpyro.distributions.continuous import Beta, MultivariateNormal, Normal
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import clamp_probs, lazy_property, validate_sample
from numpyro.util import is_prng_key
[docs]
class GaussianCopula(Distribution):
"""
A distribution that links the `batch_shape[:-1]` of marginal distribution `marginal_dist`
with a multivariate Gaussian copula modelling the correlation between the axes.
:param Distribution marginal_dist: Distribution whose last batch axis is to be coupled.
:param array_like correlation_matrix: Correlation matrix of coupling multivariate normal distribution.
:param array_like correlation_cholesky: Correlation Cholesky factor of coupling multivariate normal distribution.
"""
arg_constraints = {
"correlation_matrix": constraints.corr_matrix,
"correlation_cholesky": constraints.corr_cholesky,
}
reparametrized_params = [
"correlation_matrix",
"correlation_cholesky",
]
pytree_data_fields = ("marginal_dist", "base_dist")
def __init__(
self,
marginal_dist: Distribution,
correlation_matrix: Optional[Array] = None,
correlation_cholesky: Optional[Array] = None,
*,
validate_args: Optional[bool] = None,
):
if len(marginal_dist.event_shape) > 0:
raise ValueError("`marginal_dist` needs to be a univariate distribution.")
self.marginal_dist = marginal_dist
self.base_dist = MultivariateNormal(
covariance_matrix=correlation_matrix,
scale_tril=correlation_cholesky,
)
event_shape = self.base_dist.event_shape
batch_shape = lax.broadcast_shapes(
self.marginal_dist.batch_shape[:-1],
self.base_dist.batch_shape,
)
super(GaussianCopula, self).__init__(
batch_shape=batch_shape,
event_shape=event_shape,
validate_args=validate_args,
)
[docs]
def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike:
assert is_prng_key(key)
shape = sample_shape + self.batch_shape
normal_samples = self.base_dist.expand(shape).sample(key)
cdf = Normal().cdf(normal_samples)
return self.marginal_dist.icdf(cdf)
[docs]
@validate_sample
def log_prob(self, value: ArrayLike) -> ArrayLike:
# Ref: https://en.wikipedia.org/wiki/Copula_(probability_theory)#Gaussian_copula
# see also https://github.com/pyro-ppl/numpyro/pull/1506#discussion_r1037525015
marginal_lps = self.marginal_dist.log_prob(value)
probs = self.marginal_dist.cdf(value)
quantiles = Normal().icdf(clamp_probs(probs))
copula_lp = (
self.base_dist.log_prob(quantiles)
+ 0.5 * (quantiles**2).sum(-1)
+ 0.5 * jnp.log(2 * jnp.pi) * quantiles.shape[-1]
)
return copula_lp + marginal_lps.sum(axis=-1)
@property
def mean(self) -> ArrayLike:
return jnp.broadcast_to(self.marginal_dist.mean, self.shape())
@property
def variance(self) -> ArrayLike:
return jnp.broadcast_to(self.marginal_dist.variance, self.shape())
@constraints.dependent_property(is_discrete=False, event_dim=1)
def support(self) -> Constraint:
return constraints.independent(self.marginal_dist.support, 1)
[docs]
@lazy_property
def correlation_matrix(self) -> Array:
return self.base_dist.covariance_matrix
[docs]
@lazy_property
def correlation_cholesky(self) -> Array:
return self.base_dist.scale_tril
[docs]
class GaussianCopulaBeta(GaussianCopula):
arg_constraints = {
"concentration1": constraints.positive,
"concentration0": constraints.positive,
"correlation_matrix": constraints.corr_matrix,
"correlation_cholesky": constraints.corr_cholesky,
}
support = constraints.independent(constraints.unit_interval, 1)
pytree_data_fields = ("concentration1", "concentration0")
def __init__(
self,
concentration1: ArrayLike,
concentration0: ArrayLike,
correlation_matrix: Optional[Array] = None,
correlation_cholesky: Optional[Array] = None,
*,
validate_args: bool = False,
):
# set initially to allow argument validation
self.concentration1, self.concentration0 = concentration1, concentration0
super().__init__(
Beta(concentration1, concentration0),
correlation_matrix,
correlation_cholesky,
validate_args=validate_args,
)