Contributed Code

Nested Sampling

Nested Sampling is a non-MCMC approach that works for arbitrary probability models, and is particularly well suited to complex posteriors:

class NestedSampler(model, *, constructor_kwargs=None, termination_kwargs=None)[source]

Bases: object

(EXPERIMENTAL) A wrapper for jaxns , a nested sampling package based on JAX.

See reference [1] for details on the meaning of each parameter. Please consider citing this reference if you use the nested sampler in your research.

Note

To enumerate over a discrete latent variable, you can add the keyword infer={“enumerate”: “parallel”} to the corresponding sample statement.

Note

To improve the performance, please consider enabling x64 mode at the beginning of your NumPyro program numpyro.enable_x64().

References

  1. JAXNS: a high-performance nested sampling package based on JAX, Joshua G. Albert (https://arxiv.org/abs/2012.15286)

Parameters:
  • model (callable) – a call with NumPyro primitives

  • constructor_kwargs (dict) – additional keyword arguments to construct an upstream jaxns.NestedSampler instance.

  • termination_kwargs (dict) – keyword arguments to terminate the sampler. Please refer to the upstream jaxns.NestedSampler.__call__() method.

Example

>>> from jax import random
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.contrib.nested_sampling import NestedSampler

>>> true_coefs = jnp.array([1., 2., 3.])
>>> data = random.normal(random.PRNGKey(0), (2000, 3))
>>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(1))
>>>
>>> def model(data, labels):
...     coefs = numpyro.sample('coefs', dist.Normal(0, 1).expand([3]))
...     intercept = numpyro.sample('intercept', dist.Normal(0., 10.))
...     return numpyro.sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)),
...                           obs=labels)
>>>
>>> ns = NestedSampler(model)
>>> ns.run(random.PRNGKey(2), data, labels)
>>> samples = ns.get_samples(random.PRNGKey(3), num_samples=1000)
>>> assert jnp.mean(jnp.abs(samples['intercept'])) < 0.05
>>> print(jnp.mean(samples['coefs'], axis=0))  
[0.93661342 1.95034876 2.86123884]
run(rng_key, *args, **kwargs)[source]

Run the nested samplers and collect weighted samples.

Parameters:
  • rng_key (random.PRNGKey) – Random number generator key to be used for the sampling.

  • args – The arguments needed by the model.

  • kwargs – The keyword arguments needed by the model.

get_samples(rng_key, num_samples)[source]

Draws samples from the weighted samples collected from the run.

Parameters:
  • rng_key (random.PRNGKey) – Random number generator key to be used to draw samples.

  • num_samples (int) – The number of samples.

Returns:

a dict of posterior samples

get_weighted_samples()[source]

Gets weighted samples and their corresponding log weights.

print_summary()[source]

Print summary of the result. This is a wrapper of jaxns.utils.summary().

diagnostics()[source]

Plot diagnostics of the result. This is a wrapper of jaxns.plotting.plot_diagnostics() and jaxns.plotting.plot_cornerplot().

Stein Variational Inference

Stein Variational Inference (SteinVI) is a family of VI techniques for approximate Bayesian inference based on Stein’s method (see [1] for an overview). It is gaining popularity as it combines the scalability of traditional VI with the flexibility of non-parametric particle-based methods.

Stein variational gradient descent (SVGD) [2] is a recent SteinVI technique which uses iteratively moves a set of particles \(\{z_i\}_{i=1}^N\) to approximate a distribution p(z). SVGD is well suited for capturing correlations between latent variables as a particle-based method. The technique preserves the scalability of traditional VI approaches while offering the flexibility and modeling scope of methods such as Markov chain Monte Carlo (MCMC). SVGD is good at capturing multi-modality [3][4].

numpyro.contrib.einstein is a framework for particle-based inference using the Stein mixture algorithm. The framework works on Stein mixtures, a restricted mixture of guide programs parameterized by Stein particles. Similarly to how SVGD works, Stein mixtures can approximate model posteriors by moving the Stein particles according to the Stein forces. Because the Stein particles parameterize a guide, they capture a neighborhood rather than a single point.

numpyro.contrib.einstein mimics the interface from numpyro.infer.svi, so trying SteinVI requires minimal change to the code for existing models inferred with SVI. For primary usage, see the Bayesian neural network example.

The framework currently supports several kernels, including:

  • RBFKernel

  • LinearKernel

  • RandomFeatureKernel

  • MixtureKernel

  • GraphicalKernel

  • ProbabilityProductKernel

For example, usage see:

References

1. Stein’s Method Meets Statistics: A Review of Some Recent Developments (2021) Andreas Anastasiou, Alessandro Barp, François-Xavier Briol, Bruno Ebner, Robert E. Gaunt, Fatemeh Ghaderinezhad, Jackson Gorham, Arthur Gretton, Christophe Ley, Qiang Liu, Lester Mackey, Chris. J. Oates, Gesine Reinert, Yvik Swan. https://arxiv.org/abs/2105.03481

2. Stein variational gradient descent: A general-purpose Bayesian inference algorithm (2016) Qiang Liu, Dilin Wang. NeurIPS

3. Nonlinear Stein Variational Gradient Descent for Learning Diversified Mixture Models (2019) Dilin Wang, Qiang Liu. PMLR

SteinVI Interface

class SteinVI(model: ~collections.abc.Callable, guide: ~collections.abc.Callable, optim: ~numpyro.optim._NumPyroOptim, kernel_fn: ~numpyro.contrib.einstein.stein_kernels.SteinKernel, num_stein_particles: int = 10, num_elbo_particles: int = 10, loss_temperature: float = 1.0, repulsion_temperature: float = 1.0, non_mixture_guide_params_fn: ~collections.abc.Callable[[str], bool] = <function SteinVI.<lambda>>, enum=True, **static_kwargs)[source]

Variational inference with Stein mixtures.

Example:

>>> from jax import random
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.distributions import constraints
>>> from numpyro.contrib.einstein import MixtureGuidePredictive, SteinVI, RBFKernel

>>> def model(data):
...     f = numpyro.sample("latent_fairness", dist.Beta(10, 10))
...     with numpyro.plate("N", data.shape[0] if data is not None else 10):
...         numpyro.sample("obs", dist.Bernoulli(f), obs=data)

>>> def guide(data):
...     alpha_q = numpyro.param("alpha_q", 15., constraint=constraints.positive)
...     beta_q = numpyro.param("beta_q", lambda rng_key: random.exponential(rng_key),
...                            constraint=constraints.positive)
...     numpyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q))

>>> data = jnp.concatenate([jnp.ones(6), jnp.zeros(4)])
>>> optimizer = numpyro.optim.Adam(step_size=0.0005)
>>> stein = SteinVI(model, guide, optimizer, kernel_fn=RBFKernel())
>>> stein_result = stein.run(random.PRNGKey(0), 2000, data)
>>> params = stein_result.params
>>> # use guide to make predictive
>>> predictive = MixtureGuidePredictive(model, guide, params, num_samples=1000, guide_sites=stein.guide_sites)
>>> samples = predictive(random.PRNGKey(1), data=None)
Parameters:
  • model (Callable) – Python callable with Pyro primitives for the model.

  • guide – Python callable with Pyro primitives for the guide (recognition network).

  • optim (_NumPyroOptim) – An instance of _NumpyroOptim.

  • kernel_fn (SteinKernel) – Function that produces a logarithm of the statistical kernel to use with Stein mixture inference.

  • num_stein_particles – Number of particles (i.e., mixture components) in the Stein mixture.

  • num_elbo_particles – Number of Monte Carlo draws used to approximate the attractive force gradient. (More particles give better gradient approximations)

  • loss_temperature (Float) – Scaling factor of the attractive force.

  • repulsion_temperature (Float) – Scaling factor of the repulsive force (Non-linear Stein)

  • non_mixture_guide_param_fn (Callable) – predicate on names of parameters in guide which should be optimized classically without Stein (E.g. parameters for large normal networks or other transformation)

  • static_kwargs – Static keyword arguments for the model / guide, i.e. arguments that remain constant during inference.

SteinVI Kernels

class RBFKernel(mode='norm', matrix_mode='norm_diag', bandwidth_factor: ~collections.abc.Callable[[float], float] = <function RBFKernel.<lambda>>)[source]

Calculates the Gaussian RBF kernel function, from [1], \(k(x,y) = \exp(\frac{1}{h} \|x-y\|^2)\), where the bandwidth h is computed using the median heuristic \(h = \frac{1}{\log(n)} \text{med}(\|x-y\|)\).

References:

  1. Stein Variational Gradient Descent by Liu and Wang

Parameters:
  • mode (str) – Either ‘norm’ (default) specifying to take the norm of each particle, ‘vector’ to return a component-wise kernel or ‘matrix’ to return a matrix-valued kernel

  • matrix_mode (str) – Either ‘norm_diag’ (default) for diagonal filled with the norm kernel or ‘vector_diag’ for diagonal of vector-valued kernel

  • bandwidth_factor – A multiplier to the bandwidth based on data size n (default 1/log(n))

class LinearKernel(mode='norm')[source]

Calculates the linear kernel \(k(x,y) = x \cdot y + 1\) from [1].

References:

  1. Stein Variational Gradient Descent as Moment Matching by Liu and Wang

class RandomFeatureKernel(mode='norm', bandwidth_subset=None, bandwidth_factor: ~collections.abc.Callable[[float], float] = <function RandomFeatureKernel.<lambda>>)[source]

Calculates the random kernel \(k(x,y)= 1/m\sum_{l=1}^{m}\phi(x,w_l)\phi(y,w_l)\) from [1].

References:

  1. Stein Variational Gradient Descent as Moment Matching by Liu and Wang

Parameters:
  • bandwidth_subset – How many particles should be used to calculate the bandwidth? (default None, meaning all particles)

  • random_indices – The set of indices which to do random feature expansion on. (default None, meaning all indices)

  • bandwidth_factor – A multiplier to the bandwidth based on data size n (default 1/log(n))

class MixtureKernel(ws: list[float], kernel_fns: list[SteinKernel], mode='norm')[source]

Calculates a mixture of multiple kernels \(k(x,y) = \sum_i w_ik_i(x,y)\)

References:

  1. Stein Variational Gradient Descent as Moment Matching by Liu and Wang

Parameters:
  • ws – Weight of each kernel in the mixture

  • kernel_fns – Different kernel functions to mix together

class GraphicalKernel(mode='matrix', local_kernel_fns: dict[str, ~numpyro.contrib.einstein.stein_kernels.SteinKernel] | None = None, default_kernel_fn: ~numpyro.contrib.einstein.stein_kernels.SteinKernel = <numpyro.contrib.einstein.stein_kernels.RBFKernel object>)[source]

Calculates graphical kernel \(k(x,y) = diag({K_l(x_l,y_l)})\) for local kernels \(K_l\) from [1][2].

References:

  1. Stein Variational Message Passing for Continuous Graphical Models by Wang, Zheng, and Liu

  2. Stein Variational Gradient Descent with Matrix-Valued Kernels by Wang, Tang, Bajaj, and Liu

Parameters:
  • local_kernel_fns – A mapping between parameters and a choice of kernel function for that parameter (default to default_kernel_fn for each parameter)

  • default_kernel_fn – The default choice of kernel function when none is specified for a particular parameter

class ProbabilityProductKernel(guide, scale=1.0)[source]

Stochastic Support

class StochasticSupportInference(model, num_slp_samples, max_slps)[source]

Bases: ABC

Base class for running inference in programs with stochastic support. Each subclass decomposes the input model into so called straight-line programs (SLPs) which are the different control-flow paths in the model. Inference is then run in each SLP separately and the results are combined to produce an overall posterior.

Note

This implementation assumes that all stochastic branching is done based on the outcomes of discrete sampling sites that are annotated with infer={"branching": True}. For example,

def model():
    model1 = numpyro.sample("model1", dist.Bernoulli(0.5), infer={"branching": True})
    if model1 == 0:
        mean = numpyro.sample("a1", dist.Normal(0.0, 1.0))
    else:
        mean = numpyro.sample("a2", dist.Normal(1.0, 1.0))
    numpyro.sample("obs", dist.Normal(mean, 1.0), obs=0.2)
Parameters:
  • model – Python callable containing Pyro primitives primitives. local inference. Defaults to NUTS.

  • num_slp_samples (int) – Number of samples to draw from the prior to discover the straight-line programs (SLPs).

  • max_slps (int) – Maximum number of SLPs to discover. DCC will not run inference on more than max_slps.

run(rng_key, *args, **kwargs)[source]

Run inference on each SLP separately and combine the results.

Parameters:
  • rng_key (jax.random.PRNGKey) – Random number generator key.

  • args – Arguments to the model.

  • kwargs – Keyword arguments to the model.

class DCC(model, mcmc_kwargs, kernel_cls=<class 'numpyro.infer.hmc.NUTS'>, num_slp_samples=1000, max_slps=124, proposal_scale=1.0)[source]

Bases: StochasticSupportInference

Implements the Divide, Conquer, and Combine (DCC) algorithm for models with stochastic support from [1].

References:

  1. Divide, Conquer, and Combine: a New Inference Strategy for Probabilistic Programs with Stochastic Support, Yuan Zhou, Hongseok Yang, Yee Whye Teh, Tom Rainforth

Example:

def model():
    model1 = numpyro.sample("model1", dist.Bernoulli(0.5), infer={"branching": True})
    if model1 == 0:
        mean = numpyro.sample("a1", dist.Normal(0.0, 1.0))
    else:
        mean = numpyro.sample("a2", dist.Normal(1.0, 1.0))
    numpyro.sample("obs", dist.Normal(mean, 1.0), obs=0.2)

mcmc_kwargs = dict(
    num_warmup=500, num_samples=1000
)
dcc = DCC(model, mcmc_kwargs=mcmc_kwargs)
dcc_result = dcc.run(random.PRNGKey(0))
Parameters:
  • model – Python callable containing Pyro primitives primitives.

  • mcmc_kwargs (dict) – Dictionary of arguments passed to MCMC.

  • kernel_cls (numpyro.infer.mcmc.MCMCKernel) – MCMC kernel class that is used for local inference. Defaults to NUTS.

  • num_slp_samples (int) – Number of samples to draw from the prior to discover the straight-line programs (SLPs).

  • max_slps (int) – Maximum number of SLPs to discover. DCC will not run inference on more than max_slps.

  • proposal_scale (float) – Scale parameter for the proposal distribution for estimating the normalization constant of an SLP.

class SDVI(model, optimizer, svi_num_steps=1000, combine_elbo_particles=1000, guide_init=<class 'numpyro.infer.autoguide.AutoNormal'>, loss=<numpyro.infer.elbo.Trace_ELBO object>, svi_progress_bar=False, num_slp_samples=1000, max_slps=124)[source]

Bases: StochasticSupportInference

Implements the Support Decomposition Variational Inference (SDVI) algorithm for models with stochastic support from [1]. This implementation creates a separate guide for each SLP, trains the guides separately, and then combines the guides by weighting them proportional to their ELBO estimates.

References:

  1. Rethinking Variational Inference for Probabilistic Programs with Stochastic Support, Tim Reichelt, Luke Ong, Tom Rainforth

Example:

def model():
    model1 = numpyro.sample("model1", dist.Bernoulli(0.5), infer={"branching": True})
    if model1 == 0:
        mean = numpyro.sample("a1", dist.Normal(0.0, 1.0))
    else:
        mean = numpyro.sample("a2", dist.Normal(1.0, 1.0))
    numpyro.sample("obs", dist.Normal(mean, 1.0), obs=0.2)

sdvi = SDVI(model, numpyro.optim.Adam(step_size=0.001))
sdvi_result = sdvi.run(random.PRNGKey(0))
Parameters:
  • model – Python callable containing Pyro primitives primitives.

  • optimizer – An instance of _NumpyroOptim, a jax.example_libraries.optimizers.Optimizer or an Optax GradientTransformation. Gets passed to SVI.

  • svi_num_steps (int) – Number of steps to run SVI for each SLP.

  • combine_elbo_particles (int) – Number of particles to estimate ELBO for computing SLP weights.

  • guide_init – A constructor for the guide. This should be a callable that returns a AutoGuide instance. Defaults to AutoNormal.

  • loss – ELBO loss for SVI. Defaults to Trace_ELBO.

  • svi_progress_bar (bool) – Whether to use a progress bar for SVI.

  • num_slp_samples (int) – Number of samples to draw from the prior to discover the straight-line programs (SLPs).

  • max_slps (int) – Maximum number of SLPs to discover. DCC will not run inference on more than max_slps.

Hilbert Space Gaussian Processes Approximation

This module contains helper functions for use in the Hilbert Space Gaussian Process (HSGP) approximation method described in [1] and [2].

Warning

This module is experimental.

Why do we need an approximation?

Gaussian processes do not scale well with the number of data points. Recall we had to invert the kernel matrix! The computational complexity of the Gaussian process model is \(\mathcal{O}(n^3)\), where \(n\) is the number of data points. The HSGP approximation method is a way to reduce the computational complexity of the Gaussian process model to \(\mathcal{O}(mn + m)\), where \(m\) is the number of basis functions used in the approximation.

Approximation Strategy Steps:

We strongly recommend reading [1] and [2] for a detailed explanation of the approximation method. In [3] you can find a practical approach using NumPyro and PyMC.

Here we provide the main steps and ingredients of the approximation method:

  1. Each stationary kernel \(k\) has an associated spectral density \(S(\omega)\). There are closed formulas for the most common kernels. These formulas depend on the hyperparameters of the kernel (e.g. amplitudes and length scales).

  2. We can approximate the spectral density \(S(\omega)\) as a polynomial series in \(||\omega||\). We call \(\omega\) the frequency.

  3. We can interpret these polynomial terms as “powers” of the Laplacian operator. The key observation is that the Fourier transform of the Laplacian operator is \(||\omega||^2\).

  4. Next, we impose Dirichlet boundary conditions on the Laplacian operator which makes it self-adjoint and with discrete spectrum.

  5. We identify the expansion in (2) with the sum of powers of the Laplacian operator in the eigenbasis of (4).

Let \(m^\star = \prod_{d=1}^D m_d\) be the total number of terms of the approximation, where \(m_d\) is the number of basis functions used in the approximation for the \(d\)-th dimension. Then, the approximation formula, in the non-centered parameterization, is:

\[f(x) \approx \sum_{j = 1}^{m^\star} \overbrace{\color{red}{\left(S(\sqrt{\boldsymbol{\lambda}_j})\right)^{1/2}}}^{\text{all hyperparameters are here!}} \times \underbrace{\color{blue}{\phi_{j}(\boldsymbol{x})}}_{\text{easy to compute!}} \times \overbrace{\color{green}{\beta_{j}}}^{\sim \: \text{Normal}(0,1)}\]

where \(\boldsymbol{x}\) is a \(D\) vector of inputs, \(\boldsymbol{\lambda}_j\) are the eigenvalues of the Laplacian operator, \(\phi_{j}(\boldsymbol{x})\) are the eigenfunctions of the Laplacian operator, and \(\beta_{j}\) are the coefficients of the expansion (see Eq. (8) in [2]). We expect this to be a good approximation for a finite number of \(m^\star\) terms in the series as long as the inputs values \(x\) are not too close to the boundaries \(-L_d\) and \(L_d\).

Note

Even though the periodic kernel is not stationary, one can still adapt and find a similar approximation formula. However, these kernels are not supported for multidimensional inputs. See Appendix B in [2] for more details.

Example:

Here is an example of how to use the HSGP approximation method with NumPyro. We will use the squared exponential kernel. Other kernels can be used similarly.

>>> from jax import random
>>> import jax.numpy as jnp

>>> import numpyro
>>> from numpyro.contrib.hsgp.approximation import hsgp_squared_exponential
>>> import numpyro.distributions as dist
>>> from numpyro.infer import MCMC, NUTS


>>> def generate_synthetic_data(rng_key, start, stop: float, num, scale):
...     """Generate synthetic data."""
...     x = jnp.linspace(start=start, stop=stop, num=num)
...     y = jnp.sin(4 * jnp.pi * x) + jnp.sin(7 * jnp.pi * x)
...     y_obs = y + scale * random.normal(rng_key, shape=(num,))
...     return x, y_obs


>>> rng_key = random.PRNGKey(seed=42)
>>> rng_key, rng_subkey = random.split(rng_key)
>>> x, y_obs = generate_synthetic_data(
...     rng_key=rng_subkey, start=0, stop=1, num=80, scale=0.3
>>> )


>>> def model(x, ell, m, non_centered, y=None):
...     # --- Priors ---
...     alpha = numpyro.sample("alpha", dist.InverseGamma(concentration=12, rate=10))
...     length = numpyro.sample("length", dist.InverseGamma(concentration=6, rate=1))
...     noise = numpyro.sample("noise", dist.InverseGamma(concentration=12, rate=10))
...     # --- Parametrization ---
...     f = hsgp_squared_exponential(
...         x=x, alpha=alpha, length=length, ell=ell, m=m, non_centered=non_centered
...     )
...     # --- Likelihood ---
...     with numpyro.plate("data", x.shape[0]):
...         numpyro.sample("likelihood", dist.Normal(loc=f, scale=noise), obs=y)


>>> sampler = NUTS(model)
>>> mcmc = MCMC(sampler=sampler, num_warmup=500, num_samples=1_000, num_chains=2)

>>> rng_key, rng_subkey = random.split(rng_key)

>>> ell = 1.3
>>> m = 20
>>> non_centered = True

>>> mcmc.run(rng_subkey, x, ell, m, non_centered, y_obs)

>>> mcmc.print_summary()

              mean       std    median      5.0%     95.0%     n_eff     r_hat
   alpha      1.24      0.34      1.18      0.72      1.74   1804.01      1.00
 beta[0]     -0.10      0.66     -0.10     -1.24      0.92   1819.91      1.00
 beta[1]      0.00      0.71     -0.01     -1.09      1.26   1872.82      1.00
 beta[2]     -0.05      0.69     -0.03     -1.09      1.16   2105.88      1.00
 beta[3]      0.25      0.74      0.26     -0.98      1.42   2281.30      1.00
 beta[4]     -0.17      0.69     -0.17     -1.21      1.00   2551.39      1.00
 beta[5]      0.09      0.75      0.10     -1.13      1.30   3232.13      1.00
 beta[6]     -0.49      0.75     -0.49     -1.65      0.82   3042.31      1.00
 beta[7]      0.42      0.75      0.44     -0.78      1.65   2885.42      1.00
 beta[8]      0.69      0.71      0.71     -0.48      1.82   2811.68      1.00
 beta[9]     -1.43      0.75     -1.40     -2.63     -0.21   2858.68      1.00
beta[10]      0.33      0.71      0.33     -0.77      1.51   2198.65      1.00
beta[11]      1.09      0.73      1.11     -0.23      2.18   2765.99      1.00
beta[12]     -0.91      0.72     -0.91     -2.06      0.31   2586.53      1.00
beta[13]      0.05      0.70      0.04     -1.16      1.12   2569.59      1.00
beta[14]     -0.44      0.71     -0.44     -1.58      0.73   2626.09      1.00
beta[15]      0.69      0.73      0.70     -0.45      1.88   2626.32      1.00
beta[16]      0.98      0.74      0.98     -0.15      2.28   2282.86      1.00
beta[17]     -2.54      0.77     -2.52     -3.82     -1.29   3347.56      1.00
beta[18]      1.35      0.66      1.35      0.30      2.46   2638.17      1.00
beta[19]      1.10      0.54      1.09      0.25      2.01   2428.37      1.00
  length      0.07      0.01      0.07      0.06      0.09   2321.67      1.00
   noise      0.33      0.03      0.33      0.29      0.38   2472.83      1.00

Number of divergences: 0

Note

Additional examples with code can be found in [3], [4] and [5].

References:

  1. Solin, A., Särkkä, S. Hilbert space methods for reduced-rank Gaussian process regression. Stat Comput 30, 419-446 (2020).

  2. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).

  3. Orduz, J., A Conceptual and Practical Introduction to Hilbert Space GPs Approximation Methods.

  4. Example: Hilbert space approximation for Gaussian processes.

  5. Gelman, Vehtari, Simpson, et al., Bayesian workflow book - Birthdays.

Note

The code of this module is based on the code of the example Example: Hilbert space approximation for Gaussian processes by Omar Sosa Rodríguez.

eigenindices

eigenindices(m: list[int] | int, dim: int) ArrayImpl[source]

Returns the indices of the first \(D \times m^\star\) eigenvalues of the laplacian operator.

\[m^\star = \prod_{i=1}^D m_i\]

For more details see Eq. (10) in [1].

References:

  1. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).

Parameters:
  • m (list[int] | int) – The number of desired eigenvalue indices in each dimension. If an integer, the same number of eigenvalues is computed in each dimension.

  • dim (int) – The dimension of the space.

Returns:

An array of the indices of the first \(D \times m^\star\) eigenvalues.

Return type:

ArrayImpl

Examples:

>>> import jax.numpy as jnp

>>> from numpyro.contrib.hsgp.laplacian import eigenindices

>>> m = 10
>>> S = eigenindices(m, 1)
>>> assert S.shape == (1, m)
>>> S
Array([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10]], dtype=int32)

>>> m = 10
>>> S = eigenindices(m, 2)
>>> assert S.shape == (2, 100)

>>> m = [2, 2, 3]  # Riutort-Mayol et al eq (10)
>>> S = eigenindices(m, 3)
>>> assert S.shape == (3, 12)
>>> S
Array([[1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2],
       [1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2],
       [1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3]], dtype=int32)

sqrt_eigenvalues

sqrt_eigenvalues(ell: int | float | list[int | float], m: list[int] | int, dim: int) ArrayImpl[source]

The first \(m^\star \times D\) square root of eigenvalues of the laplacian operator in \([-L_1, L_1] \times ... \times [-L_D, L_D]\). See Eq. (56) in [1].

References:

  1. Solin, A., Särkkä, S. Hilbert space methods for reduced-rank Gaussian process regression. Stat Comput 30, 419-446 (2020)

Parameters:
  • ell (int | float | list[int | float]) – The length of the interval in each dimension divided by 2. If a float, the same length is used in each dimension.

  • m (list[int] | int) – The number of eigenvalues to compute in each dimension. If an integer, the same number of eigenvalues is computed in each dimension.

  • dim (int) – The dimension of the space.

Returns:

An array of the first \(m^\star \times D\) square root of eigenvalues.

Return type:

ArrayImpl

eigenfunctions

eigenfunctions(x: ArrayImpl, ell: float | list[float], m: int | list[int]) ArrayImpl[source]

The first \(m^\star\) eigenfunctions of the laplacian operator in \([-L_1, L_1] \times ... \times [-L_D, L_D]\) evaluated at values of x. See Eq. (56) in [1]. If x is 1D, the problem is assumed unidimensional. Otherwise, the dimension of the input space is inferred as the size of the last dimension of x. Other dimensions are treated as batch dimensions.

Example:

>>> import jax.numpy as jnp

>>> from numpyro.contrib.hsgp.laplacian import eigenfunctions

>>> n = 100
>>> m = 10

>>> x = jnp.linspace(-1, 1, n)

>>> basis = eigenfunctions(x=x, ell=1.2, m=m)

>>> assert basis.shape == (n, m)

>>> x = jnp.ones((n, 3))  # 2d input
>>> basis = eigenfunctions(x=x, ell=1.2, m=[2, 2, 3])
>>> assert basis.shape == (n, 12)

References:

  1. Solin, A., Särkkä, S. Hilbert space methods for reduced-rank Gaussian process regression. Stat Comput 30, 419-446 (2020)

Parameters:
  • x (ArrayImpl) – The points at which to evaluate the eigenfunctions. If x is 1D the problem is assumed unidimensional. Otherwise, the dimension of the input space is inferred as the last dimension of x. Other dimensions are treated as batch dimensions.

  • ell (float | list[float]) – The length of the interval in each dimension divided by 2. If a float, the same length is used in each dimension.

  • m (int | list[int]) – The number of eigenvalues to compute in each dimension. If an integer, the same number of eigenvalues is computed in each dimension.

Returns:

An array of the first \(m^\star \times D\) eigenfunctions evaluated at x.

Return type:

ArrayImpl

eigenfunctions_periodic

eigenfunctions_periodic(x: ArrayImpl, w0: float, m: int)[source]

Basis functions for the approximation of the periodic kernel.

Parameters:
  • x (ArrayImpl) – The points at which to evaluate the eigenfunctions.

  • w0 (float) – The frequency of the periodic kernel.

  • m (int) – The number of eigenfunctions to compute.

Note

If you want to parameterize it with respect to the period use w0 = 2 * jnp.pi / period.

Warning

Multidimensional inputs are not supported.

spectral_density_squared_exponential

spectral_density_squared_exponential(dim: int, w: ArrayImpl, alpha: float, length: float | ArrayImpl) float[source]

Spectral density of the squared exponential kernel.

See Section 4.2 in [1] and Section 2.1 in [2].

\[S(\boldsymbol{\omega}) = \alpha (\sqrt{2\pi})^D \ell^D \exp\left(-\frac{1}{2} \ell^2 \boldsymbol{\omega}^{T} \boldsymbol{\omega}\right)\]

References:

  1. Rasmussen, C. E., & Williams, C. K. I. (2006). Gaussian Processes for Machine Learning.

  2. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).

Parameters:
  • dim (int) – dimension

  • w (ArrayImpl) – frequency

  • alpha (float) – amplitude

  • length (float) – length scale

Returns:

spectral density value

Return type:

float

spectral_density_matern

spectral_density_matern(dim: int, nu: float, w: ArrayImpl, alpha: float, length: float | ArrayImpl) float[source]

Spectral density of the Matérn kernel.

See Eq. (4.15) in [1] and Section 2.1 in [2].

\[S(\boldsymbol{\omega}) = \alpha \frac{2^{D} \pi^{D/2} \Gamma(\nu + D/2) (2 \nu)^{\nu}}{\Gamma(\nu) \ell^{2 \nu}} \left(\frac{2 \nu}{\ell^2} + \boldsymbol{\omega}^{T} \boldsymbol{\omega}\right)^{-\nu - D/2}\]

References:

  1. Rasmussen, C. E., & Williams, C. K. I. (2006). Gaussian Processes for Machine Learning.

  2. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).

Parameters:
  • dim (int) – dimension

  • nu (float) – smoothness

  • w (ArrayImpl) – frequency

  • alpha (float) – amplitude

  • length (float) – length scale

Returns:

spectral density value

Return type:

float

diag_spectral_density_squared_exponential

diag_spectral_density_squared_exponential(alpha: float, length: float | list[float], ell: float | int | list[float | int], m: int | list[int], dim: int) ArrayImpl[source]

Evaluates the spectral density of the squared exponential kernel at the first \(D \times m^\star\) square root eigenvalues of the laplacian operator in \([-L_1, L_1] \times ... \times [-L_D, L_D]\).

Parameters:
  • alpha (float) – amplitude of the squared exponential kernel

  • length (float) – length scale of the squared exponential kernel

  • ell (float | int | list[float | int]) – The length of the interval divided by 2 in each dimension. If a float or int, the same length is used in each dimension.

  • m (int | list[int]) – The number of eigenvalues to compute for each dimension. If an integer, the same number of eigenvalues is computed in each dimension.

  • dim (int) – The dimension of the space

Returns:

spectral density vector evaluated at the first \(D \times m^\star\) square root eigenvalues

Return type:

ArrayImpl

diag_spectral_density_matern

diag_spectral_density_matern(nu: float, alpha: float, length: float, ell: float | int | list[float | int], m: int | list[int], dim: int) ArrayImpl[source]

Evaluates the spectral density of the Matérn kernel at the first \(D \times m^\star\) square root eigenvalues of the laplacian operator in \([-L_1, L_1] \times ... \times [-L_D, L_D]\).

Parameters:
  • nu (float) – smoothness parameter

  • alpha (float) – amplitude of the Matérn kernel

  • length (float) – length scale of the Matérn kernel

  • ell (float | int | list[float | int]) – The length of the interval divided by 2 in each dimension. If a float or int, the same length is used in each dimension.

  • m (int | list[int]) – The number of eigenvalues to compute for each dimension. If an integer, the same number of eigenvalues is computed in each dimension.

  • dim (int) – The dimension of the space

Returns:

spectral density vector evaluated at the first \(D \times m^\star\) square root eigenvalues

Return type:

ArrayImpl

diag_spectral_density_periodic

diag_spectral_density_periodic(alpha: float, length: float, m: int) ArrayImpl[source]

Not actually a spectral density but these are used in the same way. These are simply the first m coefficients of the low rank approximation for the periodic kernel. See Appendix B in [1].

References:

  1. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).

Parameters:
  • alpha (float) – amplitude

  • length (float) – length scale

  • m (int) – number of eigenvalues

Returns:

“spectral density” vector

Return type:

ArrayImpl

hsgp_squared_exponential

hsgp_squared_exponential(x: ArrayImpl, alpha: float, length: float, ell: float | int | list[float | int], m: int | list[int], non_centered: bool = True) ArrayImpl[source]

Hilbert space Gaussian process approximation using the squared exponential kernel.

The main idea of the approach is to combine the associated spectral density of the squared exponential kernel and the spectrum of the Dirichlet Laplacian operator to obtain a low-rank approximation of the Gram matrix. For more details see [1, 2].

References:

  1. Solin, A., Särkkä, S. Hilbert space methods for reduced-rank Gaussian process regression. Stat Comput 30, 419-446 (2020).

  2. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).

Parameters:
  • x (ArrayImpl) – input data

  • alpha (float) – amplitude of the squared exponential kernel

  • length (float) – length scale of the squared exponential kernel

  • ell (float | int | list[float | int]) – positive value that parametrizes the length of the D-dimensional box so that the input data lies in the interval \([-L_1, L_1] \times ... \times [-L_D, L_E]\). We expect the approximation to be valid within this interval

  • m (int | list[m]) – number of eigenvalues to compute and include in the approximation for each dimension (\(\left\{1, ..., D\right\}\)). If an integer, the same number of eigenvalues is computed in each dimension.

  • non_centered (bool) – whether to use a non-centered parameterization. By default, it is set to True

Returns:

the low-rank approximation linear model

Return type:

ArrayImpl

hsgp_matern

hsgp_matern(x: ArrayImpl, nu: float, alpha: float, length: float, ell: float | int | list[float | int], m: int | list[int], non_centered: bool = True)[source]

Hilbert space Gaussian process approximation using the Matérn kernel.

The main idea of the approach is to combine the associated spectral density of the Matérn kernel kernel and the spectrum of the Dirichlet Laplacian operator to obtain a low-rank approximation of the Gram matrix. For more details see [1, 2].

References:

  1. Solin, A., Särkkä, S. Hilbert space methods for reduced-rank Gaussian process regression. Stat Comput 30, 419-446 (2020).

  2. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).

Parameters:
  • x (ArrayImpl) – input data

  • nu (float) – smoothness parameter

  • alpha (float) – amplitude of the squared exponential kernel

  • length (float) – length scale of the squared exponential kernel

  • ell (float | int | list[float | int]) – positive value that parametrizes the length of the D-dimensional box so that the input data lies in the interval \([-L_1, L_1] \times ... \times [-L_D, L_D]\). We expect the approximation to be valid within this interval

  • m (int | list[m]) – number of eigenvalues to compute and include in the approximation for each dimension (\(\left\{1, ..., D\right\}\)). If an integer, the same number of eigenvalues is computed in each dimension.

  • non_centered (bool) – whether to use a non-centered parameterization. By default, it is set to True.

Returns:

the low-rank approximation linear model

Return type:

ArrayImpl

hsgp_periodic_non_centered

hsgp_periodic_non_centered(x: ArrayImpl, alpha: float, length: float, w0: float, m: int) ArrayImpl[source]

Low rank approximation for the periodic squared exponential kernel in the non-centered parametrization.

See Appendix B in [1].

References:

  1. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).

Parameters:
  • x (ArrayImpl) – input data

  • alpha (float) – amplitude

  • length (float) – length scale

  • w0 (float) – frequency of the periodic kernel

  • m (int) – number of eigenvalues to compute and include in the approximation

Returns:

the low-rank approximation linear model

Return type:

ArrayImpl