Source code for numpyro.contrib.hsgp.laplacian

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

"""
This module contains functions for computing eigenvalues and eigenfunctions of the laplace operator.
"""

from __future__ import annotations

import numpy as np

from jax import Array
import jax.numpy as jnp
from jax.typing import ArrayLike


[docs] def eigenindices(m: list[int] | int, dim: int) -> Array: """Returns the indices of the first :math:`D \\times m^\\star` eigenvalues of the laplacian operator. .. math:: m^\\star = \\prod_{i=1}^D m_i For more details see Eq. (10) in [1]. **References:** 1. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023). :param list[int] | int m: The number of desired eigenvalue indices in each dimension. If an integer, the same number of eigenvalues is computed in each dimension. :param int dim: The dimension of the space. :returns: An array of the indices of the first :math:`D \\times m^\\star` eigenvalues. :rtype: Array **Examples:** .. code-block:: python >>> import jax.numpy as jnp >>> from numpyro.contrib.hsgp.laplacian import eigenindices >>> m = 10 >>> S = eigenindices(m, 1) >>> assert S.shape == (1, m) >>> S Array([[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], dtype=int32) >>> m = 10 >>> S = eigenindices(m, 2) >>> assert S.shape == (2, 100) >>> m = [2, 2, 3] # Riutort-Mayol et al eq (10) >>> S = eigenindices(m, 3) >>> assert S.shape == (3, 12) >>> S Array([[1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2], [1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2], [1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3]], dtype=int32) """ if isinstance(m, int): m = [m] * dim elif len(m) != dim: raise ValueError("The length of m must be equal to the dimension of the space.") return ( jnp.stack( jnp.meshgrid(*[jnp.arange(1, m_ + 1) for m_ in m], indexing="ij"), axis=-1 ) .reshape(-1, dim) .T )
[docs] def sqrt_eigenvalues( ell: ArrayLike | list[int | float], m: list[int] | int, dim: int ) -> Array: """ The first :math:`m^\\star \\times D` square root of eigenvalues of the laplacian operator in :math:`[-L_1, L_1] \\times ... \\times [-L_D, L_D]`. See Eq. (56) in [1]. **References:** 1. Solin, A., Särkkä, S. Hilbert space methods for reduced-rank Gaussian process regression. Stat Comput 30, 419-446 (2020) :param int | float | list[int | float] ell: The length of the interval in each dimension divided by 2. If a float, the same length is used in each dimension. :param list[int] | int m: The number of eigenvalues to compute in each dimension. If an integer, the same number of eigenvalues is computed in each dimension. :param int dim: The dimension of the space. :returns: An array of the first :math:`m^\\star \\times D` square root of eigenvalues. :rtype: Array """ ell_ = _convert_ell(ell, dim) S = eigenindices(m, dim) return S * jnp.pi / 2 / ell_ # dim x prod(m) array of eigenvalues
[docs] def eigenfunctions(x: ArrayLike, ell: float | list[float], m: int | list[int]) -> Array: """ The first :math:`m^\\star` eigenfunctions of the laplacian operator in :math:`[-L_1, L_1] \\times ... \\times [-L_D, L_D]` evaluated at values of `x`. See Eq. (56) in [1]. If `x` is 1D, the problem is assumed unidimensional. Otherwise, the dimension of the input space is inferred as the size of the last dimension of `x`. Other dimensions are treated as batch dimensions. **Example:** .. code-block:: python >>> import jax.numpy as jnp >>> from numpyro.contrib.hsgp.laplacian import eigenfunctions >>> n = 100 >>> m = 10 >>> x = jnp.linspace(-1, 1, n) >>> basis = eigenfunctions(x=x, ell=1.2, m=m) >>> assert basis.shape == (n, m) >>> x = jnp.ones((n, 3)) # 2d input >>> basis = eigenfunctions(x=x, ell=1.2, m=[2, 2, 3]) >>> assert basis.shape == (n, 12) **References:** 1. Solin, A., Särkkä, S. Hilbert space methods for reduced-rank Gaussian process regression. Stat Comput 30, 419-446 (2020) :param ArrayLike x: The points at which to evaluate the eigenfunctions. If `x` is 1D the problem is assumed unidimensional. Otherwise, the dimension of the input space is inferred as the last dimension of `x`. Other dimensions are treated as batch dimensions. :param float | list[float] ell: The length of the interval in each dimension divided by 2. If a float, the same length is used in each dimension. :param int | list[int] m: The number of eigenvalues to compute in each dimension. If an integer, the same number of eigenvalues is computed in each dimension. :returns: An array of the first :math:`m^\\star \\times D` eigenfunctions evaluated at `x`. :rtype: Array """ if jnp.ndim(x) == 1: x_ = jnp.expand_dims(x, axis=-1) else: x_ = jnp.array(x) dim = jnp.shape(x_)[-1] # others assumed batch dims n_batch_dims = jnp.ndim(x_) - 1 ell_ = _convert_ell(ell, dim) a = jnp.expand_dims(ell_, tuple(range(n_batch_dims))) b = jnp.expand_dims(sqrt_eigenvalues(ell_, m, dim), tuple(range(n_batch_dims))) return jnp.prod( jnp.sqrt(1 / a) * jnp.sin(b * (jnp.expand_dims(x_, axis=-1) + a)), axis=-2 )
[docs] def eigenfunctions_periodic(x: ArrayLike, w0: float, m: int) -> tuple[Array, Array]: """ Basis functions for the approximation of the periodic kernel. :param ArrayLike x: The points at which to evaluate the eigenfunctions. :param float w0: The frequency of the periodic kernel. :param int m: The number of eigenfunctions to compute. .. note:: If you want to parameterize it with respect to the period use `w0 = 2 * jnp.pi / period`. .. warning:: Multidimensional inputs are not supported. """ if jnp.ndim(x) > 1: raise ValueError( "Multidimensional inputs are not supported by the periodic kernel." ) m1 = jnp.tile(w0 * jnp.expand_dims(x, axis=-1), m) m2 = jnp.diag(jnp.arange(m, dtype=jnp.float32)) mw0x = m1 @ m2 cosines = jnp.cos(mw0x) sines = jnp.sin(mw0x) return cosines, sines
def _convert_ell(ell: float | int | list[float | int] | ArrayLike, dim: int) -> Array: """ Process the half-length of the approximation interval and return a `D \\times 1` array. If `ell` is a scalar, it is converted to a list of length dim, then transformed into an Array. :param float | int | list[float | int] | ArrayLike ell: The length of the interval in each dimension divided by 2. If a float or int, the same length is used in each dimension. :param int dim: The dimension of the space. :returns: A `D \\times 1` array of the half-lengths of the approximation interval. :rtype: Array """ ell_ = jnp.empty((dim, 1)) if isinstance(ell, float) | isinstance(ell, int): ell = jnp.array([ell] * dim)[..., None] if isinstance(ell, list): if len(ell) != dim: raise ValueError( "The length of ell must be equal to the dimension of the space." ) ell_ = jnp.array(ell)[..., None] # dim x 1 array elif isinstance(ell, Array) | isinstance(ell, np.ndarray): ell_ = jnp.array(ell) if jnp.shape(ell_) != (dim, 1): raise ValueError("ell must be a scalar or a list of length `dim`.") return ell_