Source code for numpyro.diagnostics

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

"""
This provides a small set of utilities in NumPyro that are used to diagnose posterior samples.
"""

from collections import OrderedDict
from itertools import product

import numpy as np

from jax import device_get
from jax.tree_util import tree_flatten, tree_map

__all__ = [
    "autocorrelation",
    "autocovariance",
    "effective_sample_size",
    "gelman_rubin",
    "hpdi",
    "split_gelman_rubin",
    "print_summary",
]


def _compute_chain_variance_stats(x):
    # compute within-chain variance and variance estimator
    # input has shape C x N x sample_shape
    C, N = x.shape[:2]
    chain_var = x.var(axis=1, ddof=1)
    var_within = chain_var.mean(axis=0)
    var_estimator = var_within * (N - 1) / N
    if x.shape[0] > 1:
        chain_mean = x.mean(axis=1)
        var_between = chain_mean.var(axis=0, ddof=1)
        var_estimator = var_estimator + var_between
    else:
        var_within = var_estimator
    return var_within, var_estimator


[docs] def gelman_rubin(x): """ Computes R-hat over chains of samples ``x``, where the first dimension of ``x`` is chain dimension and the second dimension of ``x`` is draw dimension. It is required that ``x.shape[0] >= 2`` and ``x.shape[1] >= 2``. :param numpy.ndarray x: the input array. :return: R-hat of ``x``. :rtype: numpy.ndarray """ assert x.ndim >= 2 assert x.shape[0] >= 2 assert x.shape[1] >= 2 var_within, var_estimator = _compute_chain_variance_stats(x) with np.errstate(invalid="ignore", divide="ignore"): rhat = np.sqrt(var_estimator / var_within) return rhat
[docs] def split_gelman_rubin(x): """ Computes split R-hat over chains of samples ``x``, where the first dimension of ``x`` is chain dimension and the second dimension of ``x`` is draw dimension. It is required that ``x.shape[1] >= 4``. :param numpy.ndarray x: the input array. :return: split R-hat of ``x``. :rtype: numpy.ndarray """ assert x.ndim >= 2 assert x.shape[1] >= 4 N_half = x.shape[1] // 2 new_input = np.concatenate([x[:, :N_half], x[:, -N_half:]], axis=0) split_rhat = gelman_rubin(new_input) return split_rhat
def _fft_next_fast_len(target): # find the smallest number >= N such that the only divisors are 2, 3, 5 # works just like scipy.fftpack.next_fast_len if target <= 2: return target while True: m = target while m % 2 == 0: m //= 2 while m % 3 == 0: m //= 3 while m % 5 == 0: m //= 5 if m == 1: return target target += 1
[docs] def autocorrelation(x, axis=0): """ Computes the autocorrelation of samples at dimension ``axis``. :param numpy.ndarray x: the input array. :param int axis: the dimension to calculate autocorrelation. :return: autocorrelation of ``x``. :rtype: numpy.ndarray """ # Ref: https://en.wikipedia.org/wiki/Autocorrelation#Efficient_computation # Adapted from Stan implementation # https://github.com/stan-dev/math/blob/develop/stan/math/prim/mat/fun/autocorrelation.hpp N = x.shape[axis] M = _fft_next_fast_len(N) M2 = 2 * M # transpose axis with -1 for Fourier transform x = np.swapaxes(x, axis, -1) # centering x centered_signal = x - x.mean(axis=-1, keepdims=True) # Fourier transform freqvec = np.fft.rfft(centered_signal, n=M2, axis=-1) # take square of magnitude of freqvec (or freqvec x freqvec*) freqvec_gram = freqvec * np.conjugate(freqvec) # inverse Fourier transform autocorr = np.fft.irfft(freqvec_gram, n=M2, axis=-1) # truncate and normalize the result, then transpose back to original shape autocorr = autocorr[..., :N] autocorr = autocorr / np.arange(N, 0.0, -1) with np.errstate(invalid="ignore", divide="ignore"): autocorr = autocorr / autocorr[..., :1] return np.swapaxes(autocorr, axis, -1)
[docs] def autocovariance(x, axis=0): """ Computes the autocovariance of samples at dimension ``axis``. :param numpy.ndarray x: the input array. :param int axis: the dimension to calculate autocovariance. :return: autocovariance of ``x``. :rtype: numpy.ndarray """ return autocorrelation(x, axis) * x.var(axis=axis, keepdims=True)
[docs] def effective_sample_size(x): """ Computes effective sample size of input ``x``, where the first dimension of ``x`` is chain dimension and the second dimension of ``x`` is draw dimension. **References:** 1. *Introduction to Markov Chain Monte Carlo*, Charles J. Geyer 2. *Stan Reference Manual version 2.18*, Stan Development Team :param numpy.ndarray x: the input array. :return: effective sample size of ``x``. :rtype: numpy.ndarray """ x = device_get(x) assert x.ndim >= 2 assert x.shape[1] >= 2 # find autocovariance for each chain at lag k gamma_k_c = autocovariance(x, axis=1) # find autocorrelation at lag k (from Stan reference) var_within, var_estimator = _compute_chain_variance_stats(x) rho_k = 1.0 - (var_within - gamma_k_c.mean(axis=0)) / var_estimator # correlation at lag 0 is always 1 rho_k[0] = 1.0 # initial positive sequence (formula 1.18 in [1]) applied for autocorrelation Rho_k = rho_k[:-1:2, ...] + rho_k[1::2, ...] # initial monotone (decreasing) sequence Rho_init = Rho_k[:1] Rho_k = np.concatenate( [ Rho_init, np.minimum.accumulate(np.clip(Rho_k[1:, ...], a_min=0, a_max=None), axis=0), ], axis=0, ) tau = -1.0 + 2.0 * Rho_k.sum(axis=0) n_eff = np.prod(x.shape[:2]) / tau return n_eff
[docs] def hpdi(x, prob=0.90, axis=0): """ Computes "highest posterior density interval" (HPDI) which is the narrowest interval with probability mass ``prob``. :param numpy.ndarray x: the input array. :param float prob: the probability mass of samples within the interval. :param int axis: the dimension to calculate hpdi. :return: quantiles of ``x`` at ``(1 - prob) / 2`` and ``(1 + prob) / 2``. :rtype: numpy.ndarray """ x = np.swapaxes(x, axis, 0) sorted_x = np.sort(x, axis=0) mass = x.shape[0] index_length = int(prob * mass) intervals_left = sorted_x[: (mass - index_length)] intervals_right = sorted_x[index_length:] intervals_length = intervals_right - intervals_left index_start = intervals_length.argmin(axis=0) index_end = index_start + index_length hpd_left = np.take_along_axis(sorted_x, index_start[None, ...], axis=0) hpd_left = np.swapaxes(hpd_left, axis, 0) hpd_right = np.take_along_axis(sorted_x, index_end[None, ...], axis=0) hpd_right = np.swapaxes(hpd_right, axis, 0) return np.concatenate([hpd_left, hpd_right], axis=axis)
[docs] def summary(samples, prob=0.90, group_by_chain=True): """ Returns a summary table displaying diagnostics of ``samples`` from the posterior. The diagnostics displayed are mean, standard deviation, median, the 90% Credibility Interval :func:`~numpyro.diagnostics.hpdi`, :func:`~numpyro.diagnostics.effective_sample_size`, and :func:`~numpyro.diagnostics.split_gelman_rubin`. :param samples: a collection of input samples with left most dimension is chain dimension and second to left most dimension is draw dimension. :type samples: dict or numpy.ndarray :param float prob: the probability mass of samples within the HPDI interval. :param bool group_by_chain: If True, each variable in `samples` will be treated as having shape `num_chains x num_samples x sample_shape`. Otherwise, the corresponding shape will be `num_samples x sample_shape` (i.e. without chain dimension). """ if not group_by_chain: samples = tree_map(lambda x: x[None, ...], samples) if not isinstance(samples, dict): samples = { "Param:{}".format(i): v for i, v in enumerate(tree_flatten(samples)[0]) } summary_dict = {} for name, value in samples.items(): value = device_get(value) value_flat = np.reshape(value, (-1,) + value.shape[2:]) mean = value_flat.mean(axis=0) std = value_flat.std(axis=0, ddof=1) median = np.median(value_flat, axis=0) hpd = hpdi(value_flat, prob=prob) n_eff = effective_sample_size(value) r_hat = split_gelman_rubin(value) hpd_lower = "{:.1f}%".format(50 * (1 - prob)) hpd_upper = "{:.1f}%".format(50 * (1 + prob)) summary_dict[name] = OrderedDict( [ ("mean", mean), ("std", std), ("median", median), (hpd_lower, hpd[0]), (hpd_upper, hpd[1]), ("n_eff", n_eff), ("r_hat", r_hat), ] ) return summary_dict