Bayesian Hierarchical Stacking: Well Switching Case Study
Table of Contents
Intro
Suppose you have just fit 6 models to a dataset, and need to choose which one to use to make predictions on your test set. How do you choose which one to use? A couple of common tactics are: - choose the best model based on cross-validation; - average the models, using weights based on cross-validation scores.
In the paper Bayesian hierarchical stacking: Some models are (somewhere) useful, a new technique is introduced: average models based on weights which are allowed to vary across according to the input data, based on a hierarchical structure.
Here, we’ll implement the first case study from that paper - readers are nonetheless encouraged to look at the original paper to find other cases studies, as well as theoretical results. Code from the article (in R / Stan) can be found here.
[1]:
!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro
[2]:
import os
import arviz as az
from IPython.display import set_matplotlib_formats
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.interpolate import BSpline
import seaborn as sns
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
plt.style.use("seaborn")
if "NUMPYRO_SPHINXBUILD" in os.environ:
set_matplotlib_formats("svg")
numpyro.set_host_device_count(4)
assert numpyro.__version__.startswith("0.16.0")
[3]:
%matplotlib inline
1. Exploratory Data Analysis
The data we have to work with looks at households in Bangladesh, some of which were affected by high levels of arsenic in their water. Would affected households want to switch to a neighbour’s well?
We’ll split the data into a train and test set, and then we’ll train six different models to try to predict whether households would switch wells. Then, we’ll see how we can stack them when predicting on the test set!
But first, let’s load it in and visualise it! Each row represents a household, and the features we have available to us are:
switch: whether a household switched to another well;
arsenic: level of arsenic in drinking water;
educ: level of education of “head of household”;
dist100: distance to nearest safe-drinking well;
assoc: whether the household participates in any community activities.
[4]:
wells = pd.read_csv(
"http://stat.columbia.edu/~gelman/arm/examples/arsenic/wells.dat", sep=" "
)
[5]:
wells.head()
[5]:
switch | arsenic | dist | assoc | educ | |
---|---|---|---|---|---|
1 | 1 | 2.36 | 16.826000 | 0 | 0 |
2 | 1 | 0.71 | 47.321999 | 0 | 0 |
3 | 0 | 2.07 | 20.966999 | 0 | 10 |
4 | 1 | 1.15 | 21.486000 | 0 | 12 |
5 | 1 | 1.10 | 40.874001 | 1 | 14 |
[6]:
fig, ax = plt.subplots(2, 2, figsize=(12, 6))
fig.suptitle("Target variable plotted against various predictors")
sns.scatterplot(data=wells, x="arsenic", y="switch", ax=ax[0][0])
sns.scatterplot(data=wells, x="dist", y="switch", ax=ax[0][1])
sns.barplot(
data=wells.groupby("assoc")["switch"].mean().reset_index(),
x="assoc",
y="switch",
ax=ax[1][0],
)
ax[1][0].set_ylabel("Proportion switch")
sns.barplot(
data=wells.groupby("educ")["switch"].mean().reset_index(),
x="educ",
y="switch",
ax=ax[1][1],
)
ax[1][1].set_ylabel("Proportion switch");
Next, we’ll choose 200 observations to be part of our train set, and 1500 to be part of our test set.
[7]:
np.random.seed(1)
train_id = wells.sample(n=200).index
test_id = wells.loc[~wells.index.isin(train_id)].sample(n=1500).index
y_train = wells.loc[train_id, "switch"].to_numpy()
y_test = wells.loc[test_id, "switch"].to_numpy()
2. Prepare 6 different candidate models
2.1 Feature Engineering
First, let’s add a few new columns: - edu0
: whether educ
is 0
, - edu1
: whether educ
is between 1
and 5
, - edu2
: whether educ
is between 6
and 11
, - edu3
: whether educ
is between 12
and 17
, - logarsenic
: natural logarithm of arsenic
, - assoc_half
: half of assoc
, - as_square
: natural logarithm of arsenic
, squared, - as_third
: natural logarithm of arsenic
, cubed, - dist100
: dist
divided by 100
, -
intercept
: just a columns of 1
s.
We’re going to start by fitting 6 different models to our train set:
logistic regression using
intercept
,arsenic
,assoc
,edu1
,edu2
, andedu3
;same as above, but with
logarsenic
instead ofarsenic
;same as the first one, but with square and cubic features as well;
same as the first one, but with spline features derived from
logarsenic
as well;same as the first one, but with spline features derived from
dist100
as well;same as the first one, but with
educ
instead of the binaryedu
variables.
[8]:
wells["edu0"] = wells["educ"].isin(np.arange(0, 1)).astype(int)
wells["edu1"] = wells["educ"].isin(np.arange(1, 6)).astype(int)
wells["edu2"] = wells["educ"].isin(np.arange(6, 12)).astype(int)
wells["edu3"] = wells["educ"].isin(np.arange(12, 18)).astype(int)
wells["logarsenic"] = np.log(wells["arsenic"])
wells["assoc_half"] = wells["assoc"] / 2.0
wells["as_square"] = wells["logarsenic"] ** 2
wells["as_third"] = wells["logarsenic"] ** 3
wells["dist100"] = wells["dist"] / 100.0
wells["intercept"] = 1
[9]:
def bs(x, knots, degree):
"""
Generate the B-spline basis matrix for a polynomial spline.
Parameters
----------
x
predictor variable.
knots
locations of internal breakpoints (not padded).
degree
degree of the piecewise polynomial.
Returns
-------
pd.DataFrame
Spline basis matrix.
Notes
-----
This mirrors ``bs`` from splines package in R.
"""
padded_knots = np.hstack(
[[x.min()] * (degree + 1), knots, [x.max()] * (degree + 1)]
)
return pd.DataFrame(
BSpline(padded_knots, np.eye(len(padded_knots) - degree - 1), degree)(x)[:, 1:],
index=x.index,
)
knots = np.quantile(wells.loc[train_id, "logarsenic"], np.linspace(0.1, 0.9, num=10))
spline_arsenic = bs(wells["logarsenic"], knots=knots, degree=3)
knots = np.quantile(wells.loc[train_id, "dist100"], np.linspace(0.1, 0.9, num=10))
spline_dist = bs(wells["dist100"], knots=knots, degree=3)
[10]:
features_0 = ["intercept", "dist100", "arsenic", "assoc", "edu1", "edu2", "edu3"]
features_1 = ["intercept", "dist100", "logarsenic", "assoc", "edu1", "edu2", "edu3"]
features_2 = [
"intercept",
"dist100",
"arsenic",
"as_third",
"as_square",
"assoc",
"edu1",
"edu2",
"edu3",
]
features_3 = ["intercept", "dist100", "assoc", "edu1", "edu2", "edu3"]
features_4 = ["intercept", "logarsenic", "assoc", "edu1", "edu2", "edu3"]
features_5 = ["intercept", "dist100", "logarsenic", "assoc", "educ"]
X0 = wells.loc[train_id, features_0].to_numpy()
X1 = wells.loc[train_id, features_1].to_numpy()
X2 = wells.loc[train_id, features_2].to_numpy()
X3 = (
pd.concat([wells.loc[:, features_3], spline_arsenic], axis=1)
.loc[train_id]
.to_numpy()
)
X4 = pd.concat([wells.loc[:, features_4], spline_dist], axis=1).loc[train_id].to_numpy()
X5 = wells.loc[train_id, features_5].to_numpy()
X0_test = wells.loc[test_id, features_0].to_numpy()
X1_test = wells.loc[test_id, features_1].to_numpy()
X2_test = wells.loc[test_id, features_2].to_numpy()
X3_test = (
pd.concat([wells.loc[:, features_3], spline_arsenic], axis=1)
.loc[test_id]
.to_numpy()
)
X4_test = (
pd.concat([wells.loc[:, features_4], spline_dist], axis=1).loc[test_id].to_numpy()
)
X5_test = wells.loc[test_id, features_5].to_numpy()
[11]:
train_x_list = [X0, X1, X2, X3, X4, X5]
test_x_list = [X0_test, X1_test, X2_test, X3_test, X4_test, X5_test]
K = len(train_x_list)
2.2 Training
Each model will be trained in the same way - with a Bernoulli likelihood and a logit link function.
[12]:
def logistic(x, y=None):
beta = numpyro.sample("beta", dist.Normal(0, 3).expand([x.shape[1]]))
logits = numpyro.deterministic("logits", jnp.matmul(x, beta))
numpyro.sample(
"obs",
dist.Bernoulli(logits=logits),
obs=y,
)
[13]:
fit_list = []
for k in range(K):
sampler = numpyro.infer.NUTS(logistic)
mcmc = numpyro.infer.MCMC(
sampler, num_chains=4, num_samples=1000, num_warmup=1000, progress_bar=False
)
rng_key = jax.random.fold_in(jax.random.PRNGKey(13), k)
mcmc.run(rng_key, x=train_x_list[k], y=y_train)
fit_list.append(mcmc)
2.3 Estimate leave-one-out cross-validated score for each training point
Rather than refitting each model 100 times, we will estimate the leave-one-out cross-validated score using LOO.
[14]:
def find_point_wise_loo_score(fit):
return az.loo(az.from_numpyro(fit), pointwise=True, scale="log").loo_i.values
lpd_point = np.vstack([find_point_wise_loo_score(fit) for fit in fit_list]).T
exp_lpd_point = np.exp(lpd_point)
3. Bayesian Hierarchical Stacking
3.1 Prepare stacking datasets
To determine how the stacking weights should vary across training and test sets, we will need to create “stacking datasets” which include all the features which we want the stacking weights to depend on. How should such features be included? For discrete features, this is easy, we just one-hot-encode them. But for continuous features, we need a trick. In Equation (16), the authors recommend the following: if you have a continuous feature f
, then replace it with the following two features:
f_l
:f
minus the median off
, clipped above at 0;f_r
:f
minus the median off
, clipped below at 0;
[15]:
dist100_median = wells.loc[wells.index[train_id], "dist100"].median()
logarsenic_median = wells.loc[wells.index[train_id], "logarsenic"].median()
wells["dist100_l"] = (wells["dist100"] - dist100_median).clip(upper=0)
wells["dist100_r"] = (wells["dist100"] - dist100_median).clip(lower=0)
wells["logarsenic_l"] = (wells["logarsenic"] - logarsenic_median).clip(upper=0)
wells["logarsenic_r"] = (wells["logarsenic"] - logarsenic_median).clip(lower=0)
stacking_features = [
"edu0",
"edu1",
"edu2",
"edu3",
"assoc_half",
"dist100_l",
"dist100_r",
"logarsenic_l",
"logarsenic_r",
]
X_stacking_train = wells.loc[train_id, stacking_features].to_numpy()
X_stacking_test = wells.loc[test_id, stacking_features].to_numpy()
3.2 Define stacking model
What we seek to find is a matrix of weights \(W\) with which to multiply the models’ predictions. Let’s define a matrix \(Pred\) such that \(Pred_{i,k}\) represents the prediction made for point \(i\) by model \(k\). Then the final prediction for point \(i\) will then be:
Such a matrix \(W\) would be required to have each column sum to \(1\). Hence, we calculate each row \(W_i\) of \(W\) as:
where \(\beta\) is a matrix whose values we seek to determine. For the discrete features, \(\beta\) is given a hierarchical structure over the possible inputs. Continuous features, on the other hand, get no hierarchical structure in this case study and just vary according to the input values.
Notice how, for the discrete features, a non-centered parametrisation is used. Also note that we only need to estimate K-1
columns of \(\beta\), because the weights W_{i, k}
will have to sum to 1
for each i
.
[16]:
def stacking(
X,
d_discrete,
X_test,
lpd_point,
tau_mu,
tau_sigma,
*,
test,
):
"""
Get weights with which to stack candidate models' predictions.
Parameters
----------
X
Training stacking matrix: features on which stacking weights should depend, for the
training set.
d_discrete
Number of discrete features in `X` and `X_test`. The first `d_discrete` features
from these matrices should be the discrete ones, with the continuous ones coming
after them.
X_test
Test stacking matrix: features on which stacking weights should depend, for the
testing set.
lpd_point
LOO score evaluated at each point in the training set, for each candidate model.
tau_mu
Hyperprior for mean of `beta`, for discrete features.
tau_sigma
Hyperprior for standard deviation of `beta`, for continuous features.
test
Whether to calculate stacking weights for test set.
Notes
-----
Naming of variables mirrors what's used in the original paper.
"""
N = X.shape[0]
d = X.shape[1]
N_test = X_test.shape[0]
K = lpd_point.shape[1] # number of candidate models
with numpyro.plate("Candidate models", K - 1, dim=-2):
# mean effect of discrete features on stacking weights
mu = numpyro.sample("mu", dist.Normal(0, tau_mu))
# standard deviation effect of discrete features on stacking weights
sigma = numpyro.sample("sigma", dist.HalfNormal(scale=tau_sigma))
with numpyro.plate("Discrete features", d_discrete, dim=-1):
# effect of discrete features on stacking weights
tau = numpyro.sample("tau", dist.Normal(0, 1))
with numpyro.plate("Continuous features", d - d_discrete, dim=-1):
# effect of continuous features on stacking weights
beta_con = numpyro.sample("beta_con", dist.Normal(0, 1))
# effects of features on stacking weights
beta = numpyro.deterministic(
"beta", jnp.hstack([(sigma.squeeze() * tau.T + mu.squeeze()).T, beta_con])
)
assert beta.shape == (K - 1, d)
# stacking weights (in unconstrained space)
f = jnp.hstack([X @ beta.T, jnp.zeros((N, 1))])
assert f.shape == (N, K)
# log probability of LOO training scores weighted by stacking weights.
log_w = jax.nn.log_softmax(f, axis=1)
# stacking weights (constrained to sum to 1)
numpyro.deterministic("w", jnp.exp(log_w))
logp = jax.nn.logsumexp(lpd_point + log_w, axis=1)
numpyro.factor("logp", jnp.sum(logp))
if test:
# test set stacking weights (in unconstrained space)
f_test = jnp.hstack([X_test @ beta.T, jnp.zeros((N_test, 1))])
# test set stacking weights (constrained to sum to 1)
numpyro.deterministic("w_test", jax.nn.softmax(f_test, axis=1))
[17]:
sampler = numpyro.infer.NUTS(stacking)
mcmc = numpyro.infer.MCMC(
sampler, num_chains=4, num_samples=1000, num_warmup=1000, progress_bar=False
)
mcmc.run(
jax.random.PRNGKey(17),
X=X_stacking_train,
d_discrete=4,
X_test=X_stacking_test,
lpd_point=lpd_point,
tau_mu=1.0,
tau_sigma=0.5,
test=True,
)
trace = mcmc.get_samples()
We can now extract the weights with which to weight the different models from the posterior, and then visualise how they vary across the training set.
Let’s compare them with what the weights would’ve been if we’d just used fixed stacking weights (computed using ArviZ - see their docs for details).
[18]:
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(16, 6), sharey=True)
training_stacking_weights = trace["w"].mean(axis=0)
sns.scatterplot(data=pd.DataFrame(training_stacking_weights), ax=ax[0])
fixed_weights = (
az.compare({idx: fit for idx, fit in enumerate(fit_list)}, method="stacking")
.sort_index()["weight"]
.to_numpy()
)
fixed_weights_df = pd.DataFrame(
np.repeat(
fixed_weights[jnp.newaxis, :],
len(X_stacking_train),
axis=0,
)
)
sns.scatterplot(data=fixed_weights_df, ax=ax[1])
ax[0].set_title("Training weights from Bayesian Hierarchical stacking")
ax[1].set_title("Fixed weights stacking")
ax[0].set_xlabel("Index")
ax[1].set_xlabel("Index")
ax[0].legend(ncol=6, loc="upper center")
fig.suptitle(
"Bayesian Hierarchical Stacking weights can vary according to the input",
fontsize=18,
)
fig.tight_layout();
4. Evaluate on test set
4.1 Stack predictions
Now, for each model, let’s evaluate the log predictive density for each point in the test set. Once we have predictions for each model, we need to think about how to combine them, such that for each test point, we get a single prediction.
We decided we’d do this in three ways: - Bayesian Hierarchical Stacking (bhs_pred
); - choosing the model with the best training set LOO score (model_selection_preds
); - fixed-weights stacking (fixed_weights_preds
).
[19]:
# for each candidate model, extract the posterior predictive logits
train_preds = []
for k in range(K):
predictive = numpyro.infer.Predictive(logistic, fit_list[k].get_samples())
rng_key = jax.random.fold_in(jax.random.PRNGKey(19), k)
train_pred = predictive(rng_key, x=train_x_list[k])["logits"]
train_preds.append(train_pred.mean(axis=0))
# reshape, so we have (N, K)
train_preds = np.vstack(train_preds).T
[20]:
# same as previous cell, but for test set
test_preds = []
for k in range(K):
predictive = numpyro.infer.Predictive(logistic, fit_list[k].get_samples())
rng_key = jax.random.fold_in(jax.random.PRNGKey(20), k)
test_pred = predictive(rng_key, x=test_x_list[k])["logits"]
test_preds.append(test_pred.mean(axis=0))
test_preds = np.vstack(test_preds).T
[21]:
# get the stacking weights for the test set
test_stacking_weights = trace["w_test"].mean(axis=0)
# get predictions using the stacking weights
bhs_predictions = (test_stacking_weights * test_preds).sum(axis=1)
# get predictions using only the model with the best LOO score
model_selection_preds = test_preds[:, lpd_point.sum(axis=0).argmax()]
# get predictions using fixed stacking weights, dependent on the LOO score
fixed_weights_preds = (fixed_weights * test_preds).sum(axis=1)
4.2 Compare methods
Let’s compare the negative log predictive density scores on the test set (note - lower is better):
[22]:
fig, ax = plt.subplots(figsize=(12, 6))
neg_log_pred_densities = np.vstack(
[
-dist.Bernoulli(logits=bhs_predictions).log_prob(y_test),
-dist.Bernoulli(logits=model_selection_preds).log_prob(y_test),
-dist.Bernoulli(logits=fixed_weights_preds).log_prob(y_test),
]
).T
neg_log_pred_density = pd.DataFrame(
neg_log_pred_densities,
columns=[
"Bayesian Hierarchical Stacking",
"Model selection",
"Fixed stacking weights",
],
)
sns.barplot(
data=neg_log_pred_density.reindex(
columns=neg_log_pred_density.mean(axis=0).sort_values(ascending=False).index
),
orient="h",
ax=ax,
)
ax.set_title(
"Bayesian Hierarchical Stacking performs best here", fontdict={"fontsize": 18}
)
ax.set_xlabel("Negative mean log predictive density (lower is better)");
So, in this dataset, with this particular train-test split, Bayesian Hierarchical Stacking does indeed bring a small gain compared with model selection and compared with fixed-weight stacking.
4.3 Does this prove that Bayesian Hierarchical Stacking works?
No, a single train-test split doesn’t prove anything. Check the original paper for results with varying training set sizes, repeated with different train-test splits, in which they show that Bayesian Hierarchical Stacking consistently outperforms model selection and fixed-weight stacking.
The goal of this notebook was just to show how to implement this technique in NumPyro.
Conclusion
We’ve seen how Bayesian Hierarchical Stacking can help us average models with input-dependent weights, in a manner which doesn’t overfit. We only implemented the first case study from the paper, but readers are encouraged to check out the other two as well. Also check the paper for theoretical results and results from more experiments.
References
Yuling Yao, Gregor Pirš, Aki Vehtari, Andrew Gelman (2021). Bayesian hierarchical stacking: Some models are (somewhere) useful
Måns Magnusson, Michael Riis Andersen, Johan Jonasson, Aki Vehtari (2020). Leave-One-Out Cross-Validation for Bayesian Model Comparison in Large Data
Thomas Wiecki (2017). Why hierarchical models are awesome, tricky, and Bayesian