Note
Go to the end to download the full example code.
Example: Sparse Regression¶
We demonstrate how to do (fully Bayesian) sparse linear regression using the approach described in [1]. This approach is particularly suitable for situations with many feature dimensions (large P) but not too many datapoints (small N). In particular we consider a quadratic regressor of the form:
\[f(X) = \text{constant} + \sum_i \theta_i X_i + \sum_{i<j} \theta_{ij} X_i X_j + \text{observation noise}\]
References:
Raj Agrawal, Jonathan H. Huggins, Brian Trippe, Tamara Broderick (2019), “The Kernel Interaction Trick: Fast Bayesian Discovery of Pairwise Interactions in High Dimensions”, (https://arxiv.org/abs/1905.06501)
import argparse
import itertools
import os
import time
import numpy as np
from jax import vmap
import jax.numpy as jnp
import jax.random as random
from jax.scipy.linalg import cho_factor, cho_solve, solve_triangular
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
def dot(X, Z):
return jnp.dot(X, Z[..., None])[..., 0]
# The kernel that corresponds to our quadratic regressor.
def kernel(X, Z, eta1, eta2, c, jitter=1.0e-4):
eta1sq = jnp.square(eta1)
eta2sq = jnp.square(eta2)
k1 = 0.5 * eta2sq * jnp.square(1.0 + dot(X, Z))
k2 = -0.5 * eta2sq * dot(jnp.square(X), jnp.square(Z))
k3 = (eta1sq - eta2sq) * dot(X, Z)
k4 = jnp.square(c) - 0.5 * eta2sq
if X.shape == Z.shape:
k4 += jitter * jnp.eye(X.shape[0])
return k1 + k2 + k3 + k4
# Most of the model code is concerned with constructing the sparsity inducing prior.
def model(X, Y, hypers):
S, P, N = hypers["expected_sparsity"], X.shape[1], X.shape[0]
sigma = numpyro.sample("sigma", dist.HalfNormal(hypers["alpha3"]))
phi = sigma * (S / jnp.sqrt(N)) / (P - S)
eta1 = numpyro.sample("eta1", dist.HalfCauchy(phi))
msq = numpyro.sample("msq", dist.InverseGamma(hypers["alpha1"], hypers["beta1"]))
xisq = numpyro.sample("xisq", dist.InverseGamma(hypers["alpha2"], hypers["beta2"]))
eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq
lam = numpyro.sample("lambda", dist.HalfCauchy(jnp.ones(P)))
kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam))
# compute kernel
kX = kappa * X
k = kernel(kX, kX, eta1, eta2, hypers["c"]) + sigma**2 * jnp.eye(N)
assert k.shape == (N, N)
# sample Y according to the standard gaussian process formula
numpyro.sample(
"Y",
dist.MultivariateNormal(loc=jnp.zeros(X.shape[0]), covariance_matrix=k),
obs=Y,
)
# Compute the mean and variance of coefficient theta_i (where i = dimension) for a
# MCMC sample of the kernel hyperparameters (eta1, xisq, ...).
# Compare to theorem 5.1 in reference [1].
def compute_singleton_mean_variance(X, Y, dimension, msq, lam, eta1, xisq, c, sigma):
P, N = X.shape[1], X.shape[0]
probe = jnp.zeros((2, P))
probe = probe.at[:, dimension].set(jnp.array([1.0, -1.0]))
eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq
kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam))
kX = kappa * X
kprobe = kappa * probe
k_xx = kernel(kX, kX, eta1, eta2, c) + sigma**2 * jnp.eye(N)
k_xx_inv = jnp.linalg.inv(k_xx)
k_probeX = kernel(kprobe, kX, eta1, eta2, c)
k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c)
vec = jnp.array([0.50, -0.50])
mu = jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, Y))
mu = jnp.dot(mu, vec)
var = k_prbprb - jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, jnp.transpose(k_probeX)))
var = jnp.matmul(var, vec)
var = jnp.dot(var, vec)
return mu, var
# Compute the mean and variance of coefficient theta_ij for a MCMC sample of the
# kernel hyperparameters (eta1, xisq, ...). Compare to theorem 5.1 in reference [1].
def compute_pairwise_mean_variance(X, Y, dim1, dim2, msq, lam, eta1, xisq, c, sigma):
P, N = X.shape[1], X.shape[0]
probe = jnp.zeros((4, P))
probe = probe.at[:, dim1].set(jnp.array([1.0, 1.0, -1.0, -1.0]))
probe = probe.at[:, dim2].set(jnp.array([1.0, -1.0, 1.0, -1.0]))
eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq
kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam))
kX = kappa * X
kprobe = kappa * probe
k_xx = kernel(kX, kX, eta1, eta2, c) + sigma**2 * jnp.eye(N)
k_xx_inv = jnp.linalg.inv(k_xx)
k_probeX = kernel(kprobe, kX, eta1, eta2, c)
k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c)
vec = jnp.array([0.25, -0.25, -0.25, 0.25])
mu = jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, Y))
mu = jnp.dot(mu, vec)
var = k_prbprb - jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, jnp.transpose(k_probeX)))
var = jnp.matmul(var, vec)
var = jnp.dot(var, vec)
return mu, var
# Sample coefficients theta from the posterior for a given MCMC sample.
# The first P returned values are {theta_1, theta_2, ...., theta_P}, while
# the remaining values are {theta_ij} for i,j in the list `active_dims`,
# sorted so that i < j.
def sample_theta_space(X, Y, active_dims, msq, lam, eta1, xisq, c, sigma):
P, N, M = X.shape[1], X.shape[0], len(active_dims)
# the total number of coefficients we return
num_coefficients = P + M * (M - 1) // 2
probe = jnp.zeros((2 * P + 2 * M * (M - 1), P))
vec = jnp.zeros((num_coefficients, 2 * P + 2 * M * (M - 1)))
start1 = 0
start2 = 0
for dim in range(P):
probe = probe.at[start1 : start1 + 2, dim].set(jnp.array([1.0, -1.0]))
vec = vec.at[start2, start1 : start1 + 2].set(jnp.array([0.5, -0.5]))
start1 += 2
start2 += 1
for dim1 in active_dims:
for dim2 in active_dims:
if dim1 >= dim2:
continue
probe = probe.at[start1 : start1 + 4, dim1].set(
jnp.array([1.0, 1.0, -1.0, -1.0])
)
probe = probe.at[start1 : start1 + 4, dim2].set(
jnp.array([1.0, -1.0, 1.0, -1.0])
)
vec = vec.at[start2, start1 : start1 + 4].set(
jnp.array([0.25, -0.25, -0.25, 0.25])
)
start1 += 4
start2 += 1
eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq
kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam))
kX = kappa * X
kprobe = kappa * probe
k_xx = kernel(kX, kX, eta1, eta2, c) + sigma**2 * jnp.eye(N)
L = cho_factor(k_xx, lower=True)[0]
k_probeX = kernel(kprobe, kX, eta1, eta2, c)
k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c)
mu = jnp.matmul(k_probeX, cho_solve((L, True), Y))
mu = jnp.sum(mu * vec, axis=-1)
Linv_k_probeX = solve_triangular(L, jnp.transpose(k_probeX), lower=True)
covar = k_prbprb - jnp.matmul(jnp.transpose(Linv_k_probeX), Linv_k_probeX)
covar = jnp.matmul(vec, jnp.matmul(covar, jnp.transpose(vec)))
# sample from N(mu, covar)
L = jnp.linalg.cholesky(covar)
sample = mu + jnp.matmul(L, np.random.randn(num_coefficients))
return sample
# Helper function for doing HMC inference
def run_inference(model, args, rng_key, X, Y, hypers):
start = time.time()
kernel = NUTS(model)
mcmc = MCMC(
kernel,
num_warmup=args.num_warmup,
num_samples=args.num_samples,
num_chains=args.num_chains,
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)
mcmc.run(rng_key, X, Y, hypers)
mcmc.print_summary()
print("\nMCMC elapsed time:", time.time() - start)
return mcmc.get_samples()
# Get the mean and variance of a gaussian mixture
def gaussian_mixture_stats(mus, variances):
mean_mu = jnp.mean(mus)
mean_var = jnp.mean(variances) + jnp.mean(jnp.square(mus)) - jnp.square(mean_mu)
return mean_mu, mean_var
# Create artificial regression dataset where only S out of P feature
# dimensions contain signal and where there is a single pairwise interaction
# between the first and second dimensions.
def get_data(N=20, S=2, P=10, sigma_obs=0.05):
assert S < P and P > 1 and S > 0
np.random.seed(0)
X = np.random.randn(N, P)
# generate S coefficients with non-negligible magnitude
W = 0.5 + 2.5 * np.random.rand(S)
# generate data using the S coefficients and a single pairwise interaction
Y = (
np.sum(X[:, 0:S] * W, axis=-1)
+ X[:, 0] * X[:, 1]
+ sigma_obs * np.random.randn(N)
)
Y -= jnp.mean(Y)
Y_std = jnp.std(Y)
assert X.shape == (N, P)
assert Y.shape == (N,)
return X, Y / Y_std, W / Y_std, 1.0 / Y_std
# Helper function for analyzing the posterior statistics for coefficient theta_i
def analyze_dimension(samples, X, Y, dimension, hypers):
vmap_args = (
samples["msq"],
samples["lambda"],
samples["eta1"],
samples["xisq"],
samples["sigma"],
)
mus, variances = vmap(
lambda msq, lam, eta1, xisq, sigma: compute_singleton_mean_variance(
X, Y, dimension, msq, lam, eta1, xisq, hypers["c"], sigma
)
)(*vmap_args)
mean, variance = gaussian_mixture_stats(mus, variances)
std = jnp.sqrt(variance)
return mean, std
# Helper function for analyzing the posterior statistics for coefficient theta_ij
def analyze_pair_of_dimensions(samples, X, Y, dim1, dim2, hypers):
vmap_args = (
samples["msq"],
samples["lambda"],
samples["eta1"],
samples["xisq"],
samples["sigma"],
)
mus, variances = vmap(
lambda msq, lam, eta1, xisq, sigma: compute_pairwise_mean_variance(
X, Y, dim1, dim2, msq, lam, eta1, xisq, hypers["c"], sigma
)
)(*vmap_args)
mean, variance = gaussian_mixture_stats(mus, variances)
std = jnp.sqrt(variance)
return mean, std
def main(args):
X, Y, expected_thetas, expected_pairwise = get_data(
N=args.num_data, P=args.num_dimensions, S=args.active_dimensions
)
# setup hyperparameters
hypers = {
"expected_sparsity": max(1.0, args.num_dimensions / 10),
"alpha1": 3.0,
"beta1": 1.0,
"alpha2": 3.0,
"beta2": 1.0,
"alpha3": 1.0,
"c": 1.0,
}
# do inference
rng_key = random.PRNGKey(0)
samples = run_inference(model, args, rng_key, X, Y, hypers)
# compute the mean and square root variance of each coefficient theta_i
means, stds = vmap(lambda dim: analyze_dimension(samples, X, Y, dim, hypers))(
jnp.arange(args.num_dimensions)
)
print(
"Coefficients theta_1 to theta_%d used to generate the data:"
% args.active_dimensions,
expected_thetas,
)
print(
"The single quadratic coefficient theta_{1,2} used to generate the data:",
expected_pairwise,
)
active_dimensions = []
for dim, (mean, std) in enumerate(zip(means, stds)):
# we mark the dimension as inactive if the interval [mean - 3 * std, mean + 3 * std] contains zero
lower, upper = mean - 3.0 * std, mean + 3.0 * std
inactive = "inactive" if lower < 0.0 and upper > 0.0 else "active"
if inactive == "active":
active_dimensions.append(dim)
print(
"[dimension %02d/%02d] %s:\t%.2e +- %.2e"
% (dim + 1, args.num_dimensions, inactive, mean, std)
)
print(
"Identified a total of %d active dimensions; expected %d."
% (len(active_dimensions), args.active_dimensions)
)
# Compute the mean and square root variance of coefficients theta_ij for i,j active dimensions.
# Note that the resulting numbers are only meaningful for i != j.
if len(active_dimensions) > 0:
dim_pairs = jnp.array(
list(itertools.product(active_dimensions, active_dimensions))
)
means, stds = vmap(
lambda dim_pair: analyze_pair_of_dimensions(
samples, X, Y, dim_pair[0], dim_pair[1], hypers
)
)(dim_pairs)
for dim_pair, mean, std in zip(dim_pairs, means, stds):
dim1, dim2 = dim_pair
if dim1 >= dim2:
continue
lower, upper = mean - 3.0 * std, mean + 3.0 * std
if not (lower < 0.0 and upper > 0.0):
format_str = "Identified pairwise interaction between dimensions %d and %d: %.2e +- %.2e"
print(format_str % (dim1 + 1, dim2 + 1, mean, std))
# Draw a single sample of coefficients theta from the posterior, where we return all singleton
# coefficients theta_i and pairwise coefficients theta_ij for i, j active dimensions. We use the
# final MCMC sample obtained from the HMC sampler.
thetas = sample_theta_space(
X,
Y,
active_dimensions,
samples["msq"][-1],
samples["lambda"][-1],
samples["eta1"][-1],
samples["xisq"][-1],
hypers["c"],
samples["sigma"][-1],
)
print("Single posterior sample theta:\n", thetas)
if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.3")
parser = argparse.ArgumentParser(description="Gaussian Process example")
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=500, type=int)
parser.add_argument("--num-chains", nargs="?", default=1, type=int)
parser.add_argument("--num-data", nargs="?", default=100, type=int)
parser.add_argument("--num-dimensions", nargs="?", default=20, type=int)
parser.add_argument("--active-dimensions", nargs="?", default=3, type=int)
parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".')
args = parser.parse_args()
numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)
main(args)