Interactive online version: Open In Colab

Bayesian Hierarchical Stacking: Well Switching Case Study


Photo by Belinda Fewings,

Table of Contents


Suppose you have just fit 6 models to a dataset, and need to choose which one to use to make predictions on your test set. How do you choose which one to use? A couple of common tactics are: - choose the best model based on cross-validation; - average the models, using weights based on cross-validation scores.

In the paper Bayesian hierarchical stacking: Some models are (somewhere) useful, a new technique is introduced: average models based on weights which are allowed to vary across according to the input data, based on a hierarchical structure.

Here, we’ll implement the first case study from that paper - readers are nonetheless encouraged to look at the original paper to find other cases studies, as well as theoretical results. Code from the article (in R / Stan) can be found here.

!pip install -q numpyro@git+
import os

import arviz as az
from IPython.display import set_matplotlib_formats
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.interpolate import BSpline
import seaborn as sns

import jax
import jax.numpy as jnp

import numpyro
import numpyro.distributions as dist"seaborn")
if "NUMPYRO_SPHINXBUILD" in os.environ:

assert numpyro.__version__.startswith("0.15.0")
%matplotlib inline

1. Exploratory Data Analysis

The data we have to work with looks at households in Bangladesh, some of which were affected by high levels of arsenic in their water. Would affected households want to switch to a neighbour’s well?

We’ll split the data into a train and test set, and then we’ll train six different models to try to predict whether households would switch wells. Then, we’ll see how we can stack them when predicting on the test set!

But first, let’s load it in and visualise it! Each row represents a household, and the features we have available to us are:

  • switch: whether a household switched to another well;

  • arsenic: level of arsenic in drinking water;

  • educ: level of education of “head of household”;

  • dist100: distance to nearest safe-drinking well;

  • assoc: whether the household participates in any community activities.

wells = pd.read_csv(
    "", sep=" "
switch arsenic dist assoc educ
1 1 2.36 16.826000 0 0
2 1 0.71 47.321999 0 0
3 0 2.07 20.966999 0 10
4 1 1.15 21.486000 0 12
5 1 1.10 40.874001 1 14
fig, ax = plt.subplots(2, 2, figsize=(12, 6))
fig.suptitle("Target variable plotted against various predictors")
sns.scatterplot(data=wells, x="arsenic", y="switch", ax=ax[0][0])
sns.scatterplot(data=wells, x="dist", y="switch", ax=ax[0][1])
ax[1][0].set_ylabel("Proportion switch")
ax[1][1].set_ylabel("Proportion switch");

Next, we’ll choose 200 observations to be part of our train set, and 1500 to be part of our test set.

train_id = wells.sample(n=200).index
test_id = wells.loc[~wells.index.isin(train_id)].sample(n=1500).index
y_train = wells.loc[train_id, "switch"].to_numpy()
y_test = wells.loc[test_id, "switch"].to_numpy()

2. Prepare 6 different candidate models

2.1 Feature Engineering

First, let’s add a few new columns: - edu0: whether educ is 0, - edu1: whether educ is between 1 and 5, - edu2: whether educ is between 6 and 11, - edu3: whether educ is between 12 and 17, - logarsenic: natural logarithm of arsenic, - assoc_half: half of assoc, - as_square: natural logarithm of arsenic, squared, - as_third: natural logarithm of arsenic, cubed, - dist100: dist divided by 100, - intercept: just a columns of 1s.

We’re going to start by fitting 6 different models to our train set:

  • logistic regression using intercept, arsenic, assoc, edu1, edu2, and edu3;

  • same as above, but with logarsenic instead of arsenic;

  • same as the first one, but with square and cubic features as well;

  • same as the first one, but with spline features derived from logarsenic as well;

  • same as the first one, but with spline features derived from dist100 as well;

  • same as the first one, but with educ instead of the binary edu variables.

wells["edu0"] = wells["educ"].isin(np.arange(0, 1)).astype(int)
wells["edu1"] = wells["educ"].isin(np.arange(1, 6)).astype(int)
wells["edu2"] = wells["educ"].isin(np.arange(6, 12)).astype(int)
wells["edu3"] = wells["educ"].isin(np.arange(12, 18)).astype(int)
wells["logarsenic"] = np.log(wells["arsenic"])
wells["assoc_half"] = wells["assoc"] / 2.0
wells["as_square"] = wells["logarsenic"] ** 2
wells["as_third"] = wells["logarsenic"] ** 3
wells["dist100"] = wells["dist"] / 100.0
wells["intercept"] = 1
def bs(x, knots, degree):
    Generate the B-spline basis matrix for a polynomial spline.

        predictor variable.
        locations of internal breakpoints (not padded).
        degree of the piecewise polynomial.

        Spline basis matrix.

    This mirrors ``bs`` from splines package in R.
    padded_knots = np.hstack(
        [[x.min()] * (degree + 1), knots, [x.max()] * (degree + 1)]
    return pd.DataFrame(
        BSpline(padded_knots, np.eye(len(padded_knots) - degree - 1), degree)(x)[:, 1:],

knots = np.quantile(wells.loc[train_id, "logarsenic"], np.linspace(0.1, 0.9, num=10))
spline_arsenic = bs(wells["logarsenic"], knots=knots, degree=3)
knots = np.quantile(wells.loc[train_id, "dist100"], np.linspace(0.1, 0.9, num=10))
spline_dist = bs(wells["dist100"], knots=knots, degree=3)
features_0 = ["intercept", "dist100", "arsenic", "assoc", "edu1", "edu2", "edu3"]
features_1 = ["intercept", "dist100", "logarsenic", "assoc", "edu1", "edu2", "edu3"]
features_2 = [
features_3 = ["intercept", "dist100", "assoc", "edu1", "edu2", "edu3"]
features_4 = ["intercept", "logarsenic", "assoc", "edu1", "edu2", "edu3"]
features_5 = ["intercept", "dist100", "logarsenic", "assoc", "educ"]

X0 = wells.loc[train_id, features_0].to_numpy()
X1 = wells.loc[train_id, features_1].to_numpy()
X2 = wells.loc[train_id, features_2].to_numpy()
X3 = (
    pd.concat([wells.loc[:, features_3], spline_arsenic], axis=1)
X4 = pd.concat([wells.loc[:, features_4], spline_dist], axis=1).loc[train_id].to_numpy()
X5 = wells.loc[train_id, features_5].to_numpy()

X0_test = wells.loc[test_id, features_0].to_numpy()
X1_test = wells.loc[test_id, features_1].to_numpy()
X2_test = wells.loc[test_id, features_2].to_numpy()
X3_test = (
    pd.concat([wells.loc[:, features_3], spline_arsenic], axis=1)
X4_test = (
    pd.concat([wells.loc[:, features_4], spline_dist], axis=1).loc[test_id].to_numpy()
X5_test = wells.loc[test_id, features_5].to_numpy()
train_x_list = [X0, X1, X2, X3, X4, X5]
test_x_list = [X0_test, X1_test, X2_test, X3_test, X4_test, X5_test]
K = len(train_x_list)

2.2 Training

Each model will be trained in the same way - with a Bernoulli likelihood and a logit link function.

def logistic(x, y=None):
    beta = numpyro.sample("beta", dist.Normal(0, 3).expand([x.shape[1]]))
    logits = numpyro.deterministic("logits", jnp.matmul(x, beta))

fit_list = []
for k in range(K):
    sampler = numpyro.infer.NUTS(logistic)
    mcmc = numpyro.infer.MCMC(
        sampler, num_chains=4, num_samples=1000, num_warmup=1000, progress_bar=False
    rng_key = jax.random.fold_in(jax.random.PRNGKey(13), k), x=train_x_list[k], y=y_train)

2.3 Estimate leave-one-out cross-validated score for each training point

Rather than refitting each model 100 times, we will estimate the leave-one-out cross-validated score using LOO.

def find_point_wise_loo_score(fit):
    return az.loo(az.from_numpyro(fit), pointwise=True, scale="log").loo_i.values

lpd_point = np.vstack([find_point_wise_loo_score(fit) for fit in fit_list]).T
exp_lpd_point = np.exp(lpd_point)

3. Bayesian Hierarchical Stacking

3.1 Prepare stacking datasets

To determine how the stacking weights should vary across training and test sets, we will need to create “stacking datasets” which include all the features which we want the stacking weights to depend on. How should such features be included? For discrete features, this is easy, we just one-hot-encode them. But for continuous features, we need a trick. In Equation (16), the authors recommend the following: if you have a continuous feature f, then replace it with the following two features:

  • f_l: f minus the median of f, clipped above at 0;

  • f_r: f minus the median of f, clipped below at 0;

dist100_median = wells.loc[wells.index[train_id], "dist100"].median()
logarsenic_median = wells.loc[wells.index[train_id], "logarsenic"].median()
wells["dist100_l"] = (wells["dist100"] - dist100_median).clip(upper=0)
wells["dist100_r"] = (wells["dist100"] - dist100_median).clip(lower=0)
wells["logarsenic_l"] = (wells["logarsenic"] - logarsenic_median).clip(upper=0)
wells["logarsenic_r"] = (wells["logarsenic"] - logarsenic_median).clip(lower=0)

stacking_features = [
X_stacking_train = wells.loc[train_id, stacking_features].to_numpy()
X_stacking_test = wells.loc[test_id, stacking_features].to_numpy()

3.2 Define stacking model

What we seek to find is a matrix of weights \(W\) with which to multiply the models’ predictions. Let’s define a matrix \(Pred\) such that \(Pred_{i,k}\) represents the prediction made for point \(i\) by model \(k\). Then the final prediction for point \(i\) will then be:

\[\sum_k W_{i, k}Pred_{i,k}\]

Such a matrix \(W\) would be required to have each column sum to \(1\). Hence, we calculate each row \(W_i\) of \(W\) as:

\[W_i = \text{softmax}(X\_\text{stacking}_i \cdot \beta),\]

where \(\beta\) is a matrix whose values we seek to determine. For the discrete features, \(\beta\) is given a hierarchical structure over the possible inputs. Continuous features, on the other hand, get no hierarchical structure in this case study and just vary according to the input values.

Notice how, for the discrete features, a non-centered parametrisation is used. Also note that we only need to estimate K-1 columns of \(\beta\), because the weights W_{i, k} will have to sum to 1 for each i.

def stacking(
    Get weights with which to stack candidate models' predictions.

        Training stacking matrix: features on which stacking weights should depend, for the
        training set.
        Number of discrete features in `X` and `X_test`. The first `d_discrete` features
        from these matrices should be the discrete ones, with the continuous ones coming
        after them.
        Test stacking matrix: features on which stacking weights should depend, for the
        testing set.
        LOO score evaluated at each point in the training set, for each candidate model.
        Hyperprior for mean of `beta`, for discrete features.
        Hyperprior for standard deviation of `beta`, for continuous features.
        Whether to calculate stacking weights for test set.

    Naming of variables mirrors what's used in the original paper.
    N = X.shape[0]
    d = X.shape[1]
    N_test = X_test.shape[0]
    K = lpd_point.shape[1]  # number of candidate models

    with numpyro.plate("Candidate models", K - 1, dim=-2):
        # mean effect of discrete features on stacking weights
        mu = numpyro.sample("mu", dist.Normal(0, tau_mu))
        # standard deviation effect of discrete features on stacking weights
        sigma = numpyro.sample("sigma", dist.HalfNormal(scale=tau_sigma))
        with numpyro.plate("Discrete features", d_discrete, dim=-1):
            # effect of discrete features on stacking weights
            tau = numpyro.sample("tau", dist.Normal(0, 1))
        with numpyro.plate("Continuous features", d - d_discrete, dim=-1):
            # effect of continuous features on stacking weights
            beta_con = numpyro.sample("beta_con", dist.Normal(0, 1))

    # effects of features on stacking weights
    beta = numpyro.deterministic(
        "beta", jnp.hstack([(sigma.squeeze() * tau.T + mu.squeeze()).T, beta_con])
    assert beta.shape == (K - 1, d)

    # stacking weights (in unconstrained space)
    f = jnp.hstack([X @ beta.T, jnp.zeros((N, 1))])
    assert f.shape == (N, K)

    # log probability of LOO training scores weighted by stacking weights.
    log_w = jax.nn.log_softmax(f, axis=1)

    # stacking weights (constrained to sum to 1)
    numpyro.deterministic("w", jnp.exp(log_w))
    logp = jax.nn.logsumexp(lpd_point + log_w, axis=1)
    numpyro.factor("logp", jnp.sum(logp))

    if test:
        # test set stacking weights (in unconstrained space)
        f_test = jnp.hstack([X_test @ beta.T, jnp.zeros((N_test, 1))])
        # test set stacking weights (constrained to sum to 1)
        numpyro.deterministic("w_test", jax.nn.softmax(f_test, axis=1))
sampler = numpyro.infer.NUTS(stacking)
mcmc = numpyro.infer.MCMC(
    sampler, num_chains=4, num_samples=1000, num_warmup=1000, progress_bar=False
trace = mcmc.get_samples()

We can now extract the weights with which to weight the different models from the posterior, and then visualise how they vary across the training set.

Let’s compare them with what the weights would’ve been if we’d just used fixed stacking weights (computed using ArviZ - see their docs for details).

fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(16, 6), sharey=True)
training_stacking_weights = trace["w"].mean(axis=0)
sns.scatterplot(data=pd.DataFrame(training_stacking_weights), ax=ax[0])
fixed_weights = ({idx: fit for idx, fit in enumerate(fit_list)}, method="stacking")
fixed_weights_df = pd.DataFrame(
        fixed_weights[jnp.newaxis, :],
sns.scatterplot(data=fixed_weights_df, ax=ax[1])
ax[0].set_title("Training weights from Bayesian Hierarchical stacking")
ax[1].set_title("Fixed weights stacking")
ax[0].legend(ncol=6, loc="upper center")
    "Bayesian Hierarchical Stacking weights can vary according to the input",

4. Evaluate on test set

4.1 Stack predictions

Now, for each model, let’s evaluate the log predictive density for each point in the test set. Once we have predictions for each model, we need to think about how to combine them, such that for each test point, we get a single prediction.

We decided we’d do this in three ways: - Bayesian Hierarchical Stacking (bhs_pred); - choosing the model with the best training set LOO score (model_selection_preds); - fixed-weights stacking (fixed_weights_preds).

# for each candidate model, extract the posterior predictive logits
train_preds = []
for k in range(K):
    predictive = numpyro.infer.Predictive(logistic, fit_list[k].get_samples())
    rng_key = jax.random.fold_in(jax.random.PRNGKey(19), k)
    train_pred = predictive(rng_key, x=train_x_list[k])["logits"]
# reshape, so we have (N, K)
train_preds = np.vstack(train_preds).T
# same as previous cell, but for test set
test_preds = []
for k in range(K):
    predictive = numpyro.infer.Predictive(logistic, fit_list[k].get_samples())
    rng_key = jax.random.fold_in(jax.random.PRNGKey(20), k)
    test_pred = predictive(rng_key, x=test_x_list[k])["logits"]
test_preds = np.vstack(test_preds).T
# get the stacking weights for the test set
test_stacking_weights = trace["w_test"].mean(axis=0)
# get predictions using the stacking weights
bhs_predictions = (test_stacking_weights * test_preds).sum(axis=1)
# get predictions using only the model with the best LOO score
model_selection_preds = test_preds[:, lpd_point.sum(axis=0).argmax()]
# get predictions using fixed stacking weights, dependent on the LOO score
fixed_weights_preds = (fixed_weights * test_preds).sum(axis=1)

4.2 Compare methods

Let’s compare the negative log predictive density scores on the test set (note - lower is better):

fig, ax = plt.subplots(figsize=(12, 6))
neg_log_pred_densities = np.vstack(
neg_log_pred_density = pd.DataFrame(
        "Bayesian Hierarchical Stacking",
        "Model selection",
        "Fixed stacking weights",
    "Bayesian Hierarchical Stacking performs best here", fontdict={"fontsize": 18}
ax.set_xlabel("Negative mean log predictive density (lower is better)");

So, in this dataset, with this particular train-test split, Bayesian Hierarchical Stacking does indeed bring a small gain compared with model selection and compared with fixed-weight stacking.

4.3 Does this prove that Bayesian Hierarchical Stacking works?

No, a single train-test split doesn’t prove anything. Check the original paper for results with varying training set sizes, repeated with different train-test splits, in which they show that Bayesian Hierarchical Stacking consistently outperforms model selection and fixed-weight stacking.

The goal of this notebook was just to show how to implement this technique in NumPyro.


We’ve seen how Bayesian Hierarchical Stacking can help us average models with input-dependent weights, in a manner which doesn’t overfit. We only implemented the first case study from the paper, but readers are encouraged to check out the other two as well. Also check the paper for theoretical results and results from more experiments.


  1. Yuling Yao, Gregor Pirš, Aki Vehtari, Andrew Gelman (2021). Bayesian hierarchical stacking: Some models are (somewhere) useful

  2. Måns Magnusson, Michael Riis Andersen, Johan Jonasson, Aki Vehtari (2020). Leave-One-Out Cross-Validation for Bayesian Model Comparison in Large Data


  4. Thomas Wiecki (2017). Why hierarchical models are awesome, tricky, and Bayesian