# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import warnings
import numpy as np
import jax
from jax import lax
import jax.numpy as jnp
from jax.typing import ArrayLike
from numpyro.distributions import constraints
from numpyro.distributions.constraints import Constraint
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import log1mexp, promote_shapes, validate_sample
from numpyro.util import find_stack_level, not_jax_tracer
[docs]
class LeftCensoredDistribution(Distribution):
"""
Distribution wrapper for left-censored outcomes.
This distribution augments a base distribution with left-censoring,
so that the likelihood contribution depends on the censoring indicator.
:param base_dist: Parametric distribution for the *uncensored* values
(e.g., Exponential, Weibull, LogNormal, Normal, etc.).
This distribution must implement a ``cdf`` method.
:type base_dist: numpyro.distributions.Distribution
:param censored: Censoring indicator per observation:
0 → value is observed exactly
1 → observation is left-censored at the reported value
(true value occurred *on or before* the reported value)
:type censored: array-like of {0,1}
.. note::
The ``log_prob(value)`` method expects ``value`` to be the observed upper bound
for each observation. The contribution to the log-likelihood is:
log f(value) if censored == 0
log F(value) if censored == 1
where f is the density and F the cumulative distribution function of ``base_dist``.
This is commonly used in survival analysis, where event times are positive,
but the approach is more general and can be applied to any distribution
with a cumulative distribution function, regardless of support.
In R's ``survival`` package notation, this corresponds to
``Surv(time, event, type = 'left')``.
Example:
Surv(time = c(2, 4, 6), event = c(0, 1, 0), type='left')
means:
subject 1 had an event exactly at t=2
subject 2 had an event before or at t=4 (left-censored)
subject 3 had an event exactly at t=6
**Example:**
.. doctest::
>>> from jax import numpy as jnp
>>> from numpyro import distributions as dist
>>> base = dist.LogNormal(0., 1.)
>>> surv_dist = dist.LeftCensoredDistribution(base, censored=jnp.array([0, 1, 1]))
>>> loglik = surv_dist.log_prob(jnp.array([2., 4., 6.]))
"""
arg_constraints = {"censored": constraints.boolean}
pytree_data_fields = ("base_dist", "censored", "_support")
def __init__(
self,
base_dist: Distribution,
censored: ArrayLike = False,
*,
validate_args: bool = False,
):
# test if base_dist has an implemented cdf method
if not hasattr(base_dist, "cdf"):
raise TypeError(
f"{type(base_dist).__name__} does not have a 'cdf' method. "
"Censored distributions require a base distribution with an "
"implemented cumulative distribution function."
)
# Optionally test that cdf actually works (in validate_args mode)
if validate_args:
try:
test_val = base_dist.support.feasible_like(jnp.array(0.0))
_ = base_dist.cdf(test_val)
except (NotImplementedError, AttributeError) as e:
raise TypeError(
f"{type(base_dist).__name__}.cdf() is not properly implemented."
) from e
batch_shape = lax.broadcast_shapes(base_dist.batch_shape, jnp.shape(censored))
self.base_dist = jax.tree.map(
lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist
)
self.censored = jnp.array(
promote_shapes(censored, shape=batch_shape)[0], dtype=jnp.bool
)
self._support = base_dist.support
super().__init__(batch_shape, validate_args=validate_args)
[docs]
def sample(
self, key: Optional[jax.Array], sample_shape: tuple[int, ...] = ()
) -> ArrayLike:
return self.base_dist.expand(self.batch_shape).sample(key, sample_shape)
@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self) -> Constraint:
return self._support
[docs]
@validate_sample
def log_prob(self, value: ArrayLike) -> ArrayLike:
dtype = jnp.result_type(value, float)
minval = 100.0 * jnp.finfo(dtype).tiny
def log_cdf_censored(x):
# log(F(x)) with stability
return jnp.log(jnp.clip(self.base_dist.cdf(x), minval, 1.0))
return jnp.where(
self.censored,
log_cdf_censored(value), # left-censored observations: log F(t)
self.base_dist.log_prob(value), # observed values: log f(t)
)
[docs]
class RightCensoredDistribution(Distribution):
"""
Distribution wrapper for right-censored outcomes.
This distribution augments a base distribution with right-censoring,
so that the likelihood contribution depends on the censoring indicator.
:param base_dist: Parametric distribution for the *uncensored* values
(e.g., Exponential, Weibull, LogNormal, Normal, etc.).
This distribution must implement a ``cdf`` method.
:type base_dist: numpyro.distributions.Distribution
:param censored: Censoring indicator per observation:
0 → value is observed exactly
1 → observation is right-censored at the reported value
(true value occurred *on or after* the reported value)
:type censored: array-like of {0,1}
.. note::
The ``log_prob(value)`` method expects ``value`` to be the observed lower bound
for each observation. The contribution to the log-likelihood is:
log f(value) if censored == 0
log (1 - F(value)) if censored == 1
where f is the density and F the cumulative distribution function of ``base_dist``.
This is commonly used in survival analysis, where event times are positive,
but the approach is more general and can be applied to any distribution
with a cumulative distribution function, regardless of support.
In R's ``survival`` package notation, this corresponds to
``Surv(time, event, type = 'right')``.
Example:
Surv(time = c(5, 8, 10), event = c(1, 0, 1))
means:
subject 1 had an event at t=5
subject 2 was censored at t=8
subject 3 had an event at t=10
**Example:**
.. doctest::
>>> from jax import numpy as jnp
>>> from numpyro import distributions as dist
>>> base = dist.Exponential(rate=0.1)
>>> surv_dist = dist.RightCensoredDistribution(base, censored=jnp.array([0, 1, 0]))
>>> loglik = surv_dist.log_prob(jnp.array([5., 8., 10.]))
"""
arg_constraints = {"censored": constraints.boolean}
pytree_data_fields = ("base_dist", "censored", "_support")
def __init__(
self,
base_dist: Distribution,
censored: ArrayLike = False,
*,
validate_args: bool = False,
):
# test if base_dist has an implemented cdf method
if not hasattr(base_dist, "cdf"):
raise TypeError(
f"{type(base_dist).__name__} does not have a 'cdf' method. "
"Censored distributions require a base distribution with an "
"implemented cumulative distribution function."
)
# Optionally test that cdf actually works (in validate_args mode)
if validate_args:
try:
test_val = base_dist.support.feasible_like(jnp.array(0.0))
_ = base_dist.cdf(test_val)
except (NotImplementedError, AttributeError) as e:
raise TypeError(
f"{type(base_dist).__name__}.cdf() is not properly implemented."
) from e
batch_shape = lax.broadcast_shapes(base_dist.batch_shape, jnp.shape(censored))
self.base_dist = jax.tree.map(
lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist
)
self.censored = jnp.array(
promote_shapes(censored, shape=batch_shape)[0], dtype=jnp.bool
)
self._support = base_dist.support
super().__init__(batch_shape, validate_args=validate_args)
[docs]
def sample(
self, key: Optional[jax.Array], sample_shape: tuple[int, ...] = ()
) -> ArrayLike:
return self.base_dist.expand(self.batch_shape).sample(key, sample_shape)
@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self) -> Constraint:
return self._support
[docs]
@validate_sample
def log_prob(self, value: ArrayLike) -> ArrayLike:
dtype = jnp.result_type(value, float)
eps = jnp.finfo(dtype).eps
def log_survival_censored(x):
# log(1 - F(x)) with stability
Fx = jnp.clip(self.base_dist.cdf(x), 0.0, 1 - eps)
return jnp.log1p(-Fx)
return jnp.where(
self.censored,
log_survival_censored(value), # censored observations: log S(t)
self.base_dist.log_prob(value), # observed values: log f(t)
)
[docs]
class IntervalCensoredDistribution(Distribution):
r"""
Distribution wrapper for interval-censored outcomes.
This distribution augments a base distribution with interval censoring,
so that the likelihood contribution depends on whether the observation is
exactly observed,
left-censored, right-censored, interval-censored, or doubly-censored
(i.e., known to lie outside the observed interval).
:param base_dist: Parametric distribution for the *uncensored* values
(e.g., Exponential, Weibull, LogNormal, Normal, etc.).
This distribution must implement a ``cdf`` method.
:type base_dist: numpyro.distributions.Distribution
:param left_censored: Indicator per observation:
1 → observation is left-censored at the reported upper bound
0 → not left-censored
:type left_censored: array-like of {0,1}
:param right_censored: Indicator per observation:
1 → observation is right-censored at the reported lower bound
0 → not right-censored
:type right_censored: array-like of {0,1}
.. note::
The ``log_prob(value)`` method expects ``value`` to be a two-dimensional array
of shape ``(batch_size, 2)``, where each row is ``(lower, upper)``.
The contribution to the log-likelihood is determined as follows:
log F(upper) if left_censored == 1 and right_censored == 0
log (1 - F(lower)) if right_censored == 1 and left_censored == 0
log (F(upper) - F(lower)) if both == 0 (interval-censored)
log (1 - (F(upper) - F(lower))) if both == 1 (doubly-censored)
log f(value) if lower ≈ upper (point interval)
where f is the density and F the cumulative distribution function of ``base_dist``.
This is commonly used in survival analysis, where event times are positive,
but the approach is general and can be applied to any distribution
with a cumulative distribution function, regardless of support.
In R's ``survival`` package notation, this corresponds to
``Surv(l, r, type = 'interval2')``.
Example:
Surv(l = c(2, 4, 6), r = c(5, Inf, 9), type = 'interval2')
means:
subject 1 had an event in (2, 5]
subject 2 was right-censored at 4
subject 3 had an event in (6, 9]
**Example:**
.. doctest::
>>> from jax import numpy as jnp
>>> from numpyro import distributions as dist
>>> base = dist.Weibull(concentration=2.0, scale=3.0)
>>> left_censored = jnp.array([0, 0, 0])
>>> right_censored = jnp.array([0, 1, 0])
>>> surv_dist = dist.IntervalCensoredDistribution(base, left_censored, right_censored)
>>> values = jnp.array([
... [2.0, 5.0],
... [4.0, jnp.inf],
... [6.0, 9.0],
... ])
>>> loglik = surv_dist.log_prob(values)
"""
arg_constraints = {
"left_censored": constraints.boolean,
"right_censored": constraints.boolean,
}
pytree_data_fields = ("base_dist", "left_censored", "right_censored", "_support")
def __init__(
self,
base_dist: Distribution,
left_censored: ArrayLike,
right_censored: ArrayLike,
*,
validate_args: bool = False,
):
# Optionally test that cdf actually works (in validate_args mode)
if validate_args:
try:
test_val = base_dist.support.feasible_like(jnp.array(0.0))
_ = base_dist.cdf(test_val)
except (NotImplementedError, AttributeError) as e:
raise TypeError(
f"{type(base_dist).__name__}.cdf() is not properly implemented."
) from e
batch_shape = lax.broadcast_shapes(
base_dist.batch_shape, jnp.shape(left_censored), jnp.shape(right_censored)
)
self.base_dist = jax.tree.map(
lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist
)
self.left_censored = jnp.array(
promote_shapes(left_censored, shape=batch_shape)[0], dtype=jnp.bool
)
self.right_censored = jnp.array(
promote_shapes(right_censored, shape=batch_shape)[0], dtype=jnp.bool
)
self._support = base_dist.support
super().__init__(batch_shape, event_shape=(2,), validate_args=validate_args)
[docs]
def sample(
self, key: Optional[jax.Array], sample_shape: tuple[int, ...] = ()
) -> ArrayLike:
return self.base_dist.expand(self.batch_shape).sample(key, sample_shape)
@constraints.dependent_property(is_discrete=False, event_dim=1)
def support(self) -> Constraint:
return self._support
def _get_censoring_masks(self, value):
"""Helper to get censoring masks."""
x1 = jnp.take(value, 0, axis=-1) # left bound
x2 = jnp.take(value, 1, axis=-1) # right bound
m_left = self.left_censored & (~self.right_censored) # left-censored only
m_right = self.right_censored & (~self.left_censored) # right-censored only
m_int = (~self.left_censored) & (~self.right_censored) # interval censored
m_double = self.left_censored & self.right_censored # doubly censored
m_point = jnp.isclose(x1, x2) & m_int # point observation
m_int = m_int & (~m_point) # update interval mask to exclude point obs
return m_left, m_right, m_int, m_double, m_point
[docs]
@validate_sample
def log_prob(self, value):
dtype = jnp.result_type(value, float)
minval = 100.0 * jnp.finfo(dtype).tiny # for values close to 0
eps = jnp.finfo(dtype).eps # otherwise
x1 = jnp.take(value, 0, axis=-1) # left bound
x2 = jnp.take(value, 1, axis=-1) # right bound
# make masks based on censoring indicators
m_left, m_right, m_int, m_double, m_point = self._get_censoring_masks(value)
# Replace potential out-of-support values with finite placeholder BEFORE cdf
# (value doesn't matter; it will be overwritten)
feasible_value = self.support.feasible_like(x1)
x1_finite = jnp.where(m_left, feasible_value, x1)
x2_finite = jnp.where(m_right, feasible_value, x2)
# Calculate CDF on safe values
F1_tmp = self.base_dist.cdf(x1_finite)
F2_tmp = self.base_dist.cdf(x2_finite)
# Overwrite with correct limit values on censored rows
# Left-censored: F1 := 0
F1 = jnp.where(m_left, 0.0, F1_tmp)
# Right-censored: F2 := 1
F2 = jnp.where(m_right, 1.0, F2_tmp)
# Stabilize against log(0) and tiny intervals
F1 = jnp.clip(F1, minval, 1.0 - eps)
F2 = jnp.clip(F2, minval, 1.0 - eps)
# Use a stable log-diff for intervals (also covers left/right cases)
# log(F2 - F1) = logF2 + log1p(-exp(logF1 - logF2))
logF1 = jnp.log(F1)
logF2 = jnp.log(F2)
lp_interval = logF2 + jnp.log1p(-jnp.exp(jnp.clip(logF1 - logF2, max=-minval)))
# handle point intervals (x1 == x2) by returning log density instead of log prob
lp_interval = jnp.where(m_point, self.base_dist.log_prob(x1), lp_interval)
# for doubly censored data, the value is not in the interval, so computation is 1 - exp(lp_interval)
lp_double = log1mexp(lp_interval)
# Select the right expression per row
# left: log F(x2)
lp_left = logF2
# right: log (1 - F(x1)) = log1p(-F1)
lp_right = jnp.log1p(-F1)
logp = jnp.zeros_like(logF1)
logp = jnp.where(m_left, lp_left, logp)
logp = jnp.where(m_right, lp_right, logp)
logp = jnp.where(m_int, lp_interval, logp)
logp = jnp.where(m_double, lp_double, logp)
return logp
def _validate_sample(self, value: ArrayLike) -> None:
if value.shape[-1] != 2:
raise ValueError(
f"Expected last dimension of `value` to be 2 (lower, upper), but got shape {value.shape}"
)
x1 = jnp.take(value, 0, axis=-1) # left bound
x2 = jnp.take(value, 1, axis=-1) # right bound
m_left, m_right, m_int, m_double, m_point = self._get_censoring_masks(value)
# check validity under base_dist of x1 and x2
with warnings.catch_warnings():
warnings.simplefilter("ignore")
x1_mask = self.base_dist._validate_sample(x1)
x2_mask = self.base_dist._validate_sample(x2)
mask = jnp.ones_like(x1, dtype=jnp.bool)
# for left-censored, the upper bound must be in the support of base_dist
mask = jnp.where(m_left, x2_mask, mask)
if not_jax_tracer(mask):
if not np.all(mask):
warnings.warn(
"For left-censored observations, upper bound should be within the support of base_dist. ",
stacklevel=find_stack_level(),
)
# for right-censored, the lower bound must be in the support of base_dist
mask = jnp.where(m_right, x1_mask, mask)
if not_jax_tracer(mask):
if not np.all(mask):
warnings.warn(
"For right-censored observations, lower bound should be within the support of base_dist. ",
stacklevel=find_stack_level(),
)
# for interval-censored, doubly censored and point, both bounds must be in the support of base_dist
mask = jnp.where(m_int | m_double | m_point, x1_mask & x2_mask, mask)
if not_jax_tracer(mask):
if not np.all(mask):
warnings.warn(
"For interval-censored, doubly-censored, or exact observations,"
"lower bound should be within the support of base_dist. ",
stacklevel=find_stack_level(),
)
# for interval-censored and doubly-censored, upper bound must be > lower bound
mask = jnp.where(m_int | m_double, mask & (x2 > x1), mask)
if not_jax_tracer(mask):
if not np.all(mask):
warnings.warn(
"For interval-censored and doubly-censored observations,"
"upper bound should greater than lower bound. ",
stacklevel=find_stack_level(),
)
return mask