Source code for numpyro.contrib.hsgp.approximation

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

"""
This module contains the low-rank approximation functions of the Hilbert space Gaussian process.
"""

from __future__ import annotations

from jax import Array
import jax.numpy as jnp
from jax.typing import ArrayLike

import numpyro
from numpyro.contrib.hsgp.laplacian import eigenfunctions, eigenfunctions_periodic
from numpyro.contrib.hsgp.spectral_densities import (
    diag_spectral_density_matern,
    diag_spectral_density_periodic,
    diag_spectral_density_rational_quadratic,
    diag_spectral_density_squared_exponential,
)
import numpyro.distributions as dist


def _non_centered_approximation(phi: Array, spd: Array, m: int) -> Array:
    with numpyro.plate("basis", m):
        beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=1.0))

    return phi @ (spd * beta)


def _centered_approximation(phi: Array, spd: Array, m: int) -> Array:
    with numpyro.plate("basis", m):
        beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=spd))

    return phi @ beta


def linear_approximation(
    phi: Array, spd: Array, m: int, non_centered: bool = True
) -> Array:
    """
    Linear approximation formula of the Hilbert space Gaussian process.

    See Eq. (8) in [1].

    **References:**

        1. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space
           approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).

    :param Array phi: laplacian eigenfunctions
    :param Array spd: square root of the diagonal of the spectral density evaluated at square
        root of the first `m` eigenvalues.
    :param int m: number of eigenfunctions in the approximation
    :param bool non_centered: whether to use a non-centered parameterization
    :return: The low-rank approximation linear model
    :rtype: Array
    """
    if non_centered:
        return _non_centered_approximation(phi, spd, m)
    return _centered_approximation(phi, spd, m)


[docs] def hsgp_squared_exponential( x: ArrayLike, alpha: float, length: float, ell: float | int | list[float | int], m: int | list[int], non_centered: bool = True, ) -> Array: """ Hilbert space Gaussian process approximation using the squared exponential kernel. The main idea of the approach is to combine the associated spectral density of the squared exponential kernel and the spectrum of the Dirichlet Laplacian operator to obtain a low-rank approximation of the Gram matrix. For more details see [1, 2]. **References:** 1. Solin, A., Särkkä, S. Hilbert space methods for reduced-rank Gaussian process regression. Stat Comput 30, 419-446 (2020). 2. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023). :param ArrayLike x: input data :param float alpha: amplitude of the squared exponential kernel :param float length: length scale of the squared exponential kernel :param float | int | list[float | int] ell: positive value that parametrizes the length of the D-dimensional box so that the input data lies in the interval :math:`[-L_1, L_1] \\times ... \\times [-L_D, L_E]`. We expect the approximation to be valid within this interval :param int | list[m] m: number of eigenvalues to compute and include in the approximation for each dimension (:math:`\\left\\{1, ..., D\\right\\}`). If an integer, the same number of eigenvalues is computed in each dimension. :param bool non_centered: whether to use a non-centered parameterization. By default, it is set to True :return: the low-rank approximation linear model :rtype: Array """ dim = jnp.shape(x)[-1] if jnp.ndim(x) > 1 else 1 phi = eigenfunctions(x=x, ell=ell, m=m) spd = jnp.sqrt( diag_spectral_density_squared_exponential( alpha=alpha, length=length, ell=ell, m=m, dim=dim ) ) return linear_approximation( phi=phi, spd=spd, m=phi.shape[-1], non_centered=non_centered )
[docs] def hsgp_matern( x: ArrayLike, nu: float, alpha: float, length: float, ell: float | int | list[float | int], m: int | list[int], non_centered: bool = True, ) -> Array: """ Hilbert space Gaussian process approximation using the Matérn kernel. The main idea of the approach is to combine the associated spectral density of the Matérn kernel kernel and the spectrum of the Dirichlet Laplacian operator to obtain a low-rank approximation of the Gram matrix. For more details see [1, 2]. **References:** 1. Solin, A., Särkkä, S. Hilbert space methods for reduced-rank Gaussian process regression. Stat Comput 30, 419-446 (2020). 2. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023). :param ArrayLike x: input data :param float nu: smoothness parameter :param float alpha: amplitude of the squared exponential kernel :param float length: length scale of the squared exponential kernel :param float | int | list[float | int] ell: positive value that parametrizes the length of the D-dimensional box so that the input data lies in the interval :math:`[-L_1, L_1] \\times ... \\times [-L_D, L_D]`. We expect the approximation to be valid within this interval :param int | list[m] m: number of eigenvalues to compute and include in the approximation for each dimension (:math:`\\left\\{1, ..., D\\right\\}`). If an integer, the same number of eigenvalues is computed in each dimension. :param bool non_centered: whether to use a non-centered parameterization. By default, it is set to True. :return: the low-rank approximation linear model :rtype: Array """ dim = jnp.shape(x)[-1] if jnp.ndim(x) > 1 else 1 phi = eigenfunctions(x=x, ell=ell, m=m) spd = jnp.sqrt( diag_spectral_density_matern( nu=nu, alpha=alpha, length=length, ell=ell, m=m, dim=dim ) ) return linear_approximation( phi=phi, spd=spd, m=phi.shape[-1], non_centered=non_centered )
[docs] def hsgp_rational_quadratic( x: ArrayLike, alpha: float, length: float, scale_mixture: float, ell: float | int | list[float | int], m: int | list[int], non_centered: bool = True, ) -> Array: """ Hilbert space Gaussian process approximation using the Rational Quadratic kernel. The Rational Quadratic kernel can be seen as a scale mixture (an infinite sum) of squared exponential kernels with different length scales. As the scale mixture parameter approaches infinity, the kernel converges to the squared exponential kernel. The main idea of the approach is to combine the associated spectral density of the Rational Quadratic kernel and the spectrum of the Dirichlet Laplacian operator to obtain a low-rank approximation of the Gram matrix. For more details see [1, 2]. .. note:: - Due to the heavier tails of the RQ kernel compared to the Squared Exponential kernel, the HSGP approximation may require larger ``ell`` values (e.g., 10 instead of 5) for accurate results, especially for small ``scale_mixture`` values. - The spectral density requires ``scale_mixture > dim/2`` for the approximation at :math:`\\omega = 0` to be well-defined. For example, ``scale_mixture > 0.5`` for 1D, ``scale_mixture > 1`` for 2D, and ``scale_mixture > 1.5`` for 3D. **References:** 1. Solin, A., Särkkä, S. Hilbert space methods for reduced-rank Gaussian process regression. Stat Comput 30, 419-446 (2020). 2. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023). :param ArrayLike x: input data :param float alpha: amplitude of the Rational Quadratic kernel :param float length: length scale of the Rational Quadratic kernel (scalar, isotropic only) :param float scale_mixture: scale mixture parameter (α in the RQ kernel formula). Controls the relative weighting of small-scale and large-scale variations. As scale_mixture → ∞, the kernel converges to the squared exponential kernel. :param float | int | list[float | int] ell: positive value that parametrizes the length of the D-dimensional box so that the input data lies in the interval :math:`[-L_1, L_1] \\times ... \\times [-L_D, L_D]`. We expect the approximation to be valid within this interval :param int | list[m] m: number of eigenvalues to compute and include in the approximation for each dimension (:math:`\\left\\{1, ..., D\\right\\}`). If an integer, the same number of eigenvalues is computed in each dimension. :param bool non_centered: whether to use a non-centered parameterization. By default, it is set to True. :return: the low-rank approximation linear model :rtype: Array """ dim = jnp.shape(x)[-1] if jnp.ndim(x) > 1 else 1 phi = eigenfunctions(x=x, ell=ell, m=m) spd = jnp.sqrt( diag_spectral_density_rational_quadratic( alpha=alpha, length=length, scale_mixture=scale_mixture, ell=ell, m=m, dim=dim, ) ) return linear_approximation( phi=phi, spd=spd, m=phi.shape[-1], non_centered=non_centered )
[docs] def hsgp_periodic_non_centered( x: ArrayLike, alpha: float, length: float, w0: float, m: int ) -> Array: """ Low rank approximation for the periodic squared exponential kernel in the non-centered parametrization. See Appendix B in [1]. **References:** 1. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023). :param ArrayLike x: input data :param float alpha: amplitude :param float length: length scale :param float w0: frequency of the periodic kernel :param int m: number of eigenvalues to compute and include in the approximation :return: the low-rank approximation linear model :rtype: Array """ q2 = diag_spectral_density_periodic(alpha=alpha, length=length, m=m) cosines, sines = eigenfunctions_periodic(x=x, w0=w0, m=m) with numpyro.plate("cos_basis", m): beta_cos = numpyro.sample("beta_cos", dist.Normal(0, 1)) with numpyro.plate("sin_basis", m - 1): beta_sin = numpyro.sample("beta_sin", dist.Normal(0, 1)) # The first eigenfunction for the sine component # is zero, so the first parameter wouldn't contribute to the approximation. # We set it to zero to identify the model and avoid divergences. zero = jnp.array([0.0]) beta_sin = jnp.concatenate((zero, beta_sin)) return cosines @ (q2 * beta_cos) + sines @ (q2 * beta_sin)