Interactive online version: Open In Colab
[1]:
!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro
/Users/thomaspinder/Library/Application Support/hatch/pythons/3.12/python/bin/pip: line 2: /Users/thomaspinder/Library
Support/hatch/pythons/3.12/python/bin/python3.12: No such file or directory
/Users/thomaspinder/Library/Application Support/hatch/pythons/3.12/python/bin/pip: line 2: exec: /Users/thomaspinder/Library
Support/hatch/pythons/3.12/python/bin/python3.12: cannot execute: No such file or directory

Composing Gaussian Processes with NumPyro using GPJax

GPJax is a Gaussian process (GP) library built on JAX and Equinox. In this notebook, we use GPJax’s GP components inside a NumPyro model and run MCMC over everything jointly — regression coefficients and GP hyperparameters alike.

The model is a semiparametric spatial regression:

\[y(\mathbf{x}) = \underbrace{\mathbf{w}^\top \mathbf{x} + b}_{\text{linear trend}} + \underbrace{f(\mathbf{x})}_{\text{GP residual}} + \epsilon\]

We’ll use NumPyro to handle the linear part, GPJax for the GP components, and NUTS to draw samples from the joint posterior.

[2]:
import gpjax as gpx
import matplotlib as mpl
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
import jax.random as jr

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive

jax.config.update("jax_enable_x64", True)
assert numpyro.__version__.startswith("0.21.0")

Data simulation

We generate 200 points on a \([0, 5] \times [0, 5]\) grid. The true signal has two parts: a linear trend \(y_{\text{lin}} = 2x_1 - x_2 + 1.5\) and a nonlinear residual \(y_{\text{res}} = \sin(x_1)\cos(x_2)\), plus Gaussian noise \(\epsilon \sim \mathcal{N}(0, 0.1^2)\).

[3]:
N = 200
key = jr.key(123)
keys = jr.split(key, 8)

X = jr.uniform(keys[0], shape=(N, 2), minval=0.0, maxval=5.0)

# True generative components
true_slope = jnp.array([2.0, -1.0])
true_intercept = 1.5
y_lin = X @ true_slope + true_intercept
y_res = jnp.sin(X[:, 0]) * jnp.cos(X[:, 1])

latent_signal = y_lin + y_res
noise_stddev = 0.1
y = latent_signal + noise_stddev * jr.normal(keys[1], shape=latent_signal.shape)

Building the GP inside the NumPyro model

GPJax parameters may be initialised with JAX arrays and so any value drawn from numpyro.sample can be passed in as a kernel or likelihood hyperparameter. This is in contrast with GPJax versions <0.14 where in a registration step was required. Since v0.14, the integration is much simpler.

In the model below we draw the lengthscale, variance, and observation standard deviation as ordinary NumPyro sample sites and then construct the GPJax Prior, Likelihood, and Posterior from those draws on each model evaluation. The GP’s marginal log-likelihood on the residuals is added to the joint density via numpyro.factor. For predictions, gp_posterior.predict(X_new, train_data=D_resid) gives the GP’s predictive distribution at new locations and the total prediction is the linear trend plus the GP residual.

[4]:
def joint_model(X, Y, X_new=None):
    # Linear trend parameters
    slope = numpyro.sample("slope", dist.Normal(0.0, 5.0).expand([2]))
    intercept = numpyro.sample("intercept", dist.Normal(0.0, 5.0))

    # GP hyperparameters
    lengthscale = numpyro.sample("lengthscale", dist.LogNormal(0.0, 1.0))
    variance = numpyro.sample("variance", dist.LogNormal(0.0, 1.0))
    obs_stddev = numpyro.sample("obs_stddev", dist.LogNormal(0.0, 1.0))

    # GP construction with raw JAX scalars from numpyro.sample
    kernel = gpx.kernels.Matern32(
        active_dims=[0, 1], lengthscale=lengthscale, variance=variance
    )
    meanf = gpx.mean_functions.Constant()
    gp_prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
    likelihood = gpx.likelihoods.Gaussian(
        num_datapoints=X.shape[0], obs_stddev=obs_stddev
    )
    gp_posterior = gp_prior * likelihood

    trend = X @ slope + intercept

    if Y is not None:
        residuals = (Y - trend).reshape(-1, 1)
        D_resid = gpx.Dataset(X=X, y=residuals)
        mll = gpx.objectives.conjugate_mll(gp_posterior, D_resid)
        numpyro.factor("gp_log_lik", mll)

    if X_new is not None and Y is not None:
        residuals = (Y - trend).reshape(-1, 1)
        D_resid = gpx.Dataset(X=X, y=residuals)

        latent_dist = gp_posterior.predict(X_new, train_data=D_resid)
        f_new = numpyro.sample("f_new", latent_dist).reshape((-1, 1))

        total_prediction = (X_new @ slope + intercept).reshape(-1, 1) + f_new
        numpyro.deterministic("y_pred", total_prediction)

MCMC inference

NUTS samples from the joint posterior over the linear parameters and the GP hyperparameters.

[5]:
nuts_kernel = NUTS(joint_model)
mcmc = MCMC(nuts_kernel, num_warmup=1500, num_samples=2000, num_chains=1)
mcmc.run(keys[2], X, y)
mcmc.print_summary()
sample: 100%|██████████| 3500/3500 [00:45<00:00, 76.45it/s, 15 steps of size 2.02e-01. acc. prob=0.94]

                   mean       std    median      5.0%     95.0%     n_eff     r_hat
    intercept      1.71      1.19      1.73     -0.14      3.76    881.30      1.00
  lengthscale      4.29      1.15      4.12      2.59      6.11    651.93      1.00
   obs_stddev      0.11      0.01      0.11      0.10      0.12   1723.87      1.00
     slope[0]      1.83      0.21      1.84      1.48      2.16   1142.20      1.00
     slope[1]     -0.98      0.20     -0.98     -1.29     -0.63   1461.80      1.00
     variance      1.81      1.45      1.38      0.27      3.48    709.84      1.00

Number of divergences: 0

Inspecting the samples

Let’s take a look at the posterior samples for the linear coefficients and the GP hyperparameters. We can use arviz to visualize the trace and summary statistics.

[6]:
samples = mcmc.get_samples()

param_info = [
    ("slope[0]", samples["slope"][:, 0], true_slope[0]),
    ("slope[1]", samples["slope"][:, 1], true_slope[1]),
    ("intercept", samples["intercept"], true_intercept),
    ("obs_stddev", samples["obs_stddev"], noise_stddev),
]

n_params = len(param_info)
fig, axes = plt.subplots(n_params, 2, figsize=(12, 2 * n_params))

for i, (name, chain, true_val) in enumerate(param_info):
    axes[i, 0].hist(
        chain, bins=40, density=True, alpha=0.7, color="C0", edgecolor="none"
    )
    if true_val is not None:
        axes[i, 0].axvline(
            true_val, color="C3", ls="--", lw=1.5, label=f"true = {true_val}"
        )
        axes[i, 0].legend(fontsize=8)
    axes[i, 0].set_ylabel(name, fontsize=10)
    axes[i, 0].set_yticks([])

    axes[i, 1].plot(chain, alpha=0.4, color="C0", lw=0.3)
    if true_val is not None:
        axes[i, 1].axhline(true_val, color="C3", ls="--", lw=1.5)

axes[0, 0].set_title("Posterior density")
axes[0, 1].set_title("Trace")
fig.suptitle("Posterior traces (dashed lines = true values)", fontweight="bold", y=1.01)
fig.tight_layout()
plt.show()
../_images/tutorials_gpjax_example_10_0.png

Posterior predictions

Predictive generates predictions for each posterior sample, combining the linear trend with the GP residual.

[7]:
samples = mcmc.get_samples()
predictive = Predictive(joint_model, samples, return_sites=["y_pred"])
preds = predictive(keys[3], X=X, Y=y, X_new=X)["y_pred"]
mean_pred = jnp.mean(preds, axis=0)

rmse = jnp.sqrt(jnp.mean((mean_pred.flatten() - latent_signal.flatten()) ** 2))
print(f"Joint Model RMSE (vs true signal): {rmse:.4f}")
Joint Model RMSE (vs true signal): 0.0450

Visualisation

True signal vs. the joint model’s posterior mean on a 2D grid.

[8]:
n_grid = 30
x1 = jnp.linspace(0, 5, n_grid)
x2 = jnp.linspace(0, 5, n_grid)
X1, X2 = jnp.meshgrid(x1, x2)
X_grid = jnp.column_stack([X1.ravel(), X2.ravel()])

y_grid_true = (X_grid @ true_slope + true_intercept) + (
    jnp.sin(X_grid[:, 0]) * jnp.cos(X_grid[:, 1])
)

preds_grid = predictive(keys[4], X=X, Y=y, X_new=X_grid)["y_pred"]
mean_pred_grid = jnp.mean(preds_grid, axis=0)

fig, axes = plt.subplots(1, 2, figsize=(12, 3.5), sharey=True)

vmin = min(y_grid_true.min(), mean_pred_grid.min())
vmax = max(y_grid_true.max(), mean_pred_grid.max())
levels = jnp.linspace(vmin, vmax, 20)

axes[0].tricontourf(
    X_grid[:, 0], X_grid[:, 1], y_grid_true, levels=levels, cmap="magma"
)
axes[0].set_title("True Signal")

axes[1].tricontourf(
    X_grid[:, 0],
    X_grid[:, 1],
    mean_pred_grid.flatten(),
    levels=levels,
    cmap="magma",
)
axes[1].set_title(f"Joint Model (RMSE: {rmse:.2f})")

fig.colorbar(
    axes[0].collections[0],
    ax=axes.tolist(),
    format=mpl.ticker.FormatStrFormatter("%d"),
)

for ax in axes:
    ax.set_xlabel("$x_1$")
    ax.scatter(X[:, 0], X[:, 1], c="white", s=5, alpha=0.3, edgecolors="none")
../_images/tutorials_gpjax_example_14_0.png