Source code for numpyro.infer.importance

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

"""
Pareto Smoothed Importance Sampling (PSIS) diagnostics for variational inference.

Implements the k-hat diagnostic from:
    Yao, Y., Vehtari, A., Simpson, D., and Gelman, A. (2018).
    Yes, but Did It Work?: Evaluating Variational Inference.
    International Conference on Machine Learning.

    Vehtari, A., Simpson, D., Gelman, A., Yao, Y., and Gabry, J. (2024).
    Pareto smoothed importance sampling.
    Journal of Machine Learning Research, 25(72):1-58.
"""

from __future__ import annotations

from collections.abc import Callable
import math
import warnings

import numpy as np

import jax
from jax import device_get, random

from numpyro.handlers import seed
from numpyro.infer.elbo import get_importance_log_probs

__all__ = ["psis_diagnostic"]


def _fit_generalized_pareto(x: np.ndarray) -> tuple[float, float]:
    """Estimate parameters of the Generalized Pareto Distribution (GPD).

    Returns empirical Bayes estimates for the shape (k) and scale (sigma)
    parameters of the two-parameter GPD, using the method of Zhang and
    Stephens (2009) with the prior regularization from Vehtari et al. (2024).

    References:
        Zhang, J. and Stephens, M.A. (2009). A new and efficient estimation
        method for the generalized Pareto distribution. Technometrics,
        51(3):316-325.

        Vehtari, A., Simpson, D., Gelman, A., Yao, Y., and Gabry, J. (2024).
        Pareto smoothed importance sampling. Journal of Machine Learning
        Research, 25(72):1-58.

    :param numpy.ndarray x: one-dimensional array of positive exceedances (tail samples).
    :return: tuple of (k, sigma) where k is the shape parameter and sigma
        is the scale parameter.
    """
    if x.ndim != 1 or len(x) <= 1:
        raise ValueError(
            f"Expected 1-D array with at least 2 elements, got shape {x.shape}."
        )

    # Broad errstate is needed because degenerate inputs (e.g. zeros or
    # identical values) cause cascading numerical issues at multiple points:
    #   divide: 1/x[quartile], 1/x[-1], -k/b when tail values are zero
    #   over:   exp(L - L') when profile log-likelihood differences are large
    #   invalid: downstream ops on nan/inf from earlier divide-by-zero
    # The resulting nan/inf propagate correctly through the algorithm,
    # matching the reference implementation behavior.
    with np.errstate(divide="ignore", over="ignore", invalid="ignore"):
        return _fit_generalized_pareto_impl(x)


def _fit_generalized_pareto_impl(x: np.ndarray) -> tuple[float, float]:
    x = np.sort(x)
    n = len(x)
    PRIOR = 3
    m = 30 + int(np.sqrt(n))

    # Candidate shape parameters (Zhang & Stephens grid)
    bs = np.arange(1, m + 1, dtype=float)
    bs -= 0.5
    np.divide(m, bs, out=bs)
    np.sqrt(bs, out=bs)
    np.subtract(1, bs, out=bs)
    bs /= PRIOR * x[int(n / 4 + 0.5) - 1]
    bs += 1 / x[-1]

    # Profile log-likelihood for each candidate
    ks = np.negative(bs)
    temp = ks[:, None] * x
    np.log1p(temp, out=temp)
    np.mean(temp, axis=1, out=ks)

    L = bs / ks
    np.negative(L, out=L)
    np.log(L, out=L)
    L -= ks
    L -= 1
    L *= n

    # Posterior weights (overflow in exp is expected and harmless;
    # overflowed values get negligible weight after normalization)
    temp = L - L[:, None]
    np.exp(temp, out=temp)
    w = np.sum(temp, axis=1)
    np.divide(1, w, out=w)

    # Remove negligible weights
    dii = w >= 10 * np.finfo(float).eps
    if not np.all(dii):
        w = w[dii]
        bs = bs[dii]
    w /= w.sum()

    # Posterior mean for b
    b = np.sum(bs * w)

    # Estimate for k (note: negated relative to Zhang & Stephens)
    temp = (-b) * x
    np.log1p(temp, out=temp)
    k = np.mean(temp)

    # Estimate for sigma
    sigma = -k / b

    # Weakly informative prior for k (Vehtari et al. 2024, Appendix G)
    # Prior: mean=0.5, effective sample size a=10
    a = 10
    k = k * n / (n + a) + a * 0.5 / (n + a)

    return float(k), float(sigma)


def _compute_log_weights(
    rng_key: jax.Array,
    param_map: dict[str, jax.Array],
    model: Callable,
    guide: Callable,
    args: tuple,
    kwargs: dict,
) -> jax.Array:
    """Compute log importance weight log p(x,z) - log q(z) for a single particle."""
    # Separate seeds: guide needs its own randomness for sampling latent sites;
    # model gets an independent seed in case it has stochastic structure beyond
    # the latent sites replayed from the guide (e.g. stochastic control flow).
    model_seed, guide_seed = random.split(rng_key)
    seeded_model = seed(model, model_seed)
    seeded_guide = seed(guide, guide_seed)
    model_log_probs, guide_log_probs = get_importance_log_probs(
        seeded_model, seeded_guide, args, kwargs, param_map
    )
    log_model = sum(v.sum() for v in model_log_probs.values())
    log_guide = sum(v.sum() for v in guide_log_probs.values())
    return log_model - log_guide


def _psis_khat(log_weights: np.ndarray) -> float:
    """Compute PSIS k-hat from an array of raw log importance weights."""
    log_weights = log_weights.copy()
    log_weights -= log_weights.max()
    log_weights = np.sort(log_weights)

    # S matches notation in Vehtari et al. (2024), Algorithm 1
    S = len(log_weights)

    # Tail extraction (Vehtari et al. 2024, Algorithm 1)
    M = math.ceil(min(0.2 * S, 3 * math.sqrt(S)))
    cutoff_ind = -(M + 1)
    lw_cutoff = max(np.log(np.finfo(float).tiny), log_weights[cutoff_ind])

    lw_tail = log_weights[log_weights > lw_cutoff]

    if len(lw_tail) < 5:
        warnings.warn(
            "Not enough tail samples for reliable PSIS diagnostic.",
            stacklevel=3,
        )
        return float("inf")

    # Shift to exceedances
    tail = np.exp(lw_tail) - np.exp(lw_cutoff)

    # Fit GPD to the tail
    k, sigma = _fit_generalized_pareto(tail)

    return float(k)


[docs] def psis_diagnostic( rng_key: jax.Array, param_map: dict[str, jax.Array], model: Callable, guide: Callable, *args, num_particles: int = 1000, chunk_size: int | None = None, **kwargs, ) -> float: r"""Compute the PSIS k-hat diagnostic for a model/guide pair. The k-hat statistic measures the reliability of importance sampling estimates. It is the shape parameter of a Generalized Pareto Distribution (GPD) fitted to the upper tail of the importance weights. Interpretation (Vehtari et al. 2024): - k < 0.5: finite variance, classical CLT applies - 0.5 <= k < 0.7: finite mean, generalized CLT may apply - k >= 0.7: unreliable importance sampling estimates **Example usage**:: >>> from jax import random >>> from numpyro.infer import SVI, Trace_ELBO, psis_diagnostic >>> svi = SVI(model, guide, optimizer, Trace_ELBO()) >>> svi_result = svi.run(random.PRNGKey(0), num_steps, *args) >>> khat = psis_diagnostic( ... random.PRNGKey(1), svi_result.params, model, guide, *args ... ) .. note:: For reliable results, use at least several hundred particles (the default of 1000 is usually sufficient). Very few particles may not provide enough tail samples for GPD fitting. :param jax.random.PRNGKey rng_key: random number generator key. :param dict param_map: dictionary of current parameter values (e.g. ``svi_result.params``). :param Callable model: NumPyro model. :param Callable guide: NumPyro guide. :param args: positional arguments to model and guide. :param int num_particles: number of importance weight samples to draw. :param int chunk_size: maximum particles to evaluate at once (for memory control). If None, all particles are evaluated together. :param kwargs: keyword arguments to model and guide. :return: the estimated k-hat statistic. :rtype: float """ if num_particles < 2: raise ValueError("num_particles must be at least 2.") if chunk_size is None: chunk_size = num_particles rng_keys = random.split(rng_key, num_particles) # Compute log weights in batches def compute_fn(key): return _compute_log_weights(key, param_map, model, guide, args, kwargs) log_weights_list = [] for batch_start in range(0, num_particles, chunk_size): batch_keys = rng_keys[batch_start : batch_start + chunk_size] batch_lw = jax.vmap(compute_fn)(batch_keys) log_weights_list.append(batch_lw) log_weights = np.concatenate( [np.asarray(device_get(lw)) for lw in log_weights_list] ) return _psis_khat(log_weights)