Interactive online version: Open In Colab

Bayesian Hierarchical Linear Regression

Author: Carlos Souza

Probabilistic Machine Learning models can not only make predictions about future data, but also model uncertainty. In areas such as personalized medicine, there might be a large amount of data, but there is still a relatively small amount of data for each patient. To customize predictions for each person it becomes necessary to build a model for each person — with its inherent uncertainties — and to couple these models together in a hierarchy so that information can be borrowed from other similar people [1].

The purpose of this tutorial is to demonstrate how to implement a Bayesian Hierarchical Linear Regression model using NumPyro. To motivate the tutorial, I will use OSIC Pulmonary Fibrosis Progression competition, hosted at Kaggle.

1. Understanding the task

Pulmonary fibrosis is a disorder with no known cause and no known cure, created by scarring of the lungs. In this competition, we were asked to predict a patient’s severity of decline in lung function. Lung function is assessed based on output from a spirometer, which measures the forced vital capacity (FVC), i.e. the volume of air exhaled.

In medical applications, it is useful to evaluate a model’s confidence in its decisions. Accordingly, the metric used to rank the teams was designed to reflect both the accuracy and certainty of each prediction. It’s a modified version of the Laplace Log Likelihood (more details on that later).

Let’s explore the data and see what’s that all about:

[ ]:
!pip install -q numpyro@git+ arviz
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
train = pd.read_csv(''
Patient Weeks FVC Percent Age Sex SmokingStatus
0 ID00007637202177411956430 -4 2315 58.253649 79 Male Ex-smoker
1 ID00007637202177411956430 5 2214 55.712129 79 Male Ex-smoker
2 ID00007637202177411956430 7 2061 51.862104 79 Male Ex-smoker
3 ID00007637202177411956430 9 2144 53.950679 79 Male Ex-smoker
4 ID00007637202177411956430 11 2069 52.063412 79 Male Ex-smoker

In the dataset, we were provided with a baseline chest CT scan and associated clinical information for a set of patients. A patient has an image acquired at time Week = 0 and has numerous follow up visits over the course of approximately 1-2 years, at which time their FVC is measured. For this tutorial, I will use only the Patient ID, the weeks and the FVC measurements, discarding all the rest. Using only these columns enabled our team to achieve a competitive score, which shows the power of Bayesian hierarchical linear regression models especially when gauging uncertainty is an important part of the problem.

Since this is real medical data, the relative timing of FVC measurements varies widely, as shown in the 3 sample patients below:

def chart(patient_id, ax):
    data = train[train['Patient'] == patient_id]
    x = data['Weeks']
    y = data['FVC']
    ax = sns.regplot(x, y, ax=ax, ci=None, line_kws={'color':'red'})

f, axes = plt.subplots(1, 3, figsize=(15, 5))
chart('ID00007637202177411956430', axes[0])
chart('ID00009637202177434476278', axes[1])
chart('ID00010637202177584971671', axes[2])

On average, each of the 176 provided patients made 9 visits, when FVC was measured. The visits happened in specific weeks in the [-12, 133] interval. The decline in lung capacity is very clear. We see, though, they are very different from patient to patient.

We were are asked to predict every patient’s FVC measurement for every possible week in the [-12, 133] interval, and the confidence for each prediction. In other words: we were asked fill a matrix like the one below, and provide a confidence score for each prediction:


The task was perfect to apply Bayesian inference. However, the vast majority of solutions shared by Kaggle community used discriminative machine learning models, disconsidering the fact that most discriminative methods are very poor at providing realistic uncertainty estimates. Because they are typically trained in a manner that optimizes the parameters to minimize some loss criterion (e.g. the predictive error), they do not, in general, encode any uncertainty in either their parameters or the subsequent predictions. Though many methods can produce uncertainty estimates either as a by-product or from a post-processing step, these are typically heuristic based, rather than stemming naturally from a statistically principled estimate of the target uncertainty distribution [2].

2. Modelling: Bayesian Hierarchical Linear Regression with Partial Pooling

The simplest possible linear regression, not hierarchical, would assume all FVC decline curves have the same \(\alpha\) and \(\beta\). That’s the pooled model. In the other extreme, we could assume a model where each patient has a personalized FVC decline curve, and these curves are completely unrelated. That’s the unpooled model, where each patient has completely separate regressions.

Here, I’ll use the middle ground: Partial pooling. Specifically, I’ll assume that while \(\alpha\)’s and \(\beta\)’s are different for each patient as in the unpooled case, the coefficients all share similarity. We can model this by assuming that each individual coefficient comes from a common group distribution. The image below represents this model graphically:


Mathematically, the model is described by the following equations:

\begin{align} \mu_{\alpha} &\sim \mathcal{N}(0, 100) \\ \sigma_{\alpha} &\sim |\mathcal{N}(0, 100)| \\ \mu_{\beta} &\sim \mathcal{N}(0, 100) \\ \sigma_{\beta} &\sim |\mathcal{N}(0, 100)| \\ \alpha_i &\sim \mathcal{N}(\mu_{\alpha}, \sigma_{\alpha}) \\ \beta_i &\sim \mathcal{N}(\mu_{\beta}, \sigma_{\beta}) \\ \sigma &\sim \mathcal{N}(0, 100) \\ FVC_{ij} &\sim \mathcal{N}(\alpha_i + t \beta_i, \sigma) \end{align}

where t is the time in weeks. Those are very uninformative priors, but that’s ok: our model will converge!

Implementing this model in NumPyro is pretty straightforward:

import numpyro
from numpyro.infer import MCMC, NUTS, Predictive
import numpyro.distributions as dist
from jax import random

assert numpyro.__version__.startswith('0.7.1')
def model(PatientID, Weeks, FVC_obs=None):
    μ_α = numpyro.sample("μ_α", dist.Normal(0., 100.))
    σ_α = numpyro.sample("σ_α", dist.HalfNormal(100.))
    μ_β = numpyro.sample("μ_β", dist.Normal(0., 100.))
    σ_β = numpyro.sample("σ_β", dist.HalfNormal(100.))

    unique_patient_IDs = np.unique(PatientID)
    n_patients = len(unique_patient_IDs)

    with numpyro.plate("plate_i", n_patients):
        α = numpyro.sample("α", dist.Normal(μ_α, σ_α))
        β = numpyro.sample("β", dist.Normal(μ_β, σ_β))

    σ = numpyro.sample("σ", dist.HalfNormal(100.))
    FVC_est = α[PatientID] + β[PatientID] * Weeks

    with numpyro.plate("data", len(PatientID)):
        numpyro.sample("obs", dist.Normal(FVC_est, σ), obs=FVC_obs)

That’s all for modelling!

3. Fitting the model

A great achievement of Probabilistic Programming Languages such as NumPyro is to decouple model specification and inference. After specifying my generative model, with priors, condition statements and data likelihood, I can leave the hard work to NumPyro’s inference engine.

Calling it requires just a few lines. Before we do it, let’s add a numerical Patient ID for each patient code. That can be easily done with scikit-learn’s LabelEncoder:

from sklearn.preprocessing import LabelEncoder

le = LabelEncoder()
train['PatientID'] = le.fit_transform(train['Patient'].values)

FVC_obs = train['FVC'].values
Weeks = train['Weeks'].values
PatientID = train['PatientID'].values

Now, calling NumPyro’s inference engine:

nuts_kernel = NUTS(model)

mcmc = MCMC(nuts_kernel, num_samples=2000, num_warmup=2000)
rng_key = random.PRNGKey(0), PatientID, Weeks, FVC_obs=FVC_obs)

posterior_samples = mcmc.get_samples()
sample: 100%|██████████| 4000/4000 [00:20<00:00, 195.69it/s, 63 steps of size 1.06e-01. acc. prob=0.89]

4. Checking the model

4.1. Inspecting the learned parameters

First, let’s inspect the parameters learned. To do that, I will use ArviZ, which perfectly integrates with NumPyro:

import arviz as az

data = az.from_numpyro(mcmc)
az.plot_trace(data, compact=True);

Looks like our model learned personalized alphas and betas for each patient!

4.2. Visualizing FVC decline curves for some patients

Now, let’s visually inspect FVC decline curves predicted by our model. We will completely fill in the FVC table, predicting all missing values. The first step is to create a table to fill:

pred_template = []
for i in range(train['Patient'].nunique()):
    df = pd.DataFrame(columns=['PatientID', 'Weeks'])
    df['Weeks'] = np.arange(-12, 134)
    df['PatientID'] = i
pred_template = pd.concat(pred_template, ignore_index=True)

Predicting the missing values in the FVC table and confidence (sigma) for each value becomes really easy:

PatientID = pred_template['PatientID'].values
Weeks = pred_template['Weeks'].values
predictive = Predictive(model, posterior_samples,
                        return_sites=['σ', 'obs'])
samples_predictive = predictive(random.PRNGKey(0),
                                PatientID, Weeks, None)

Let’s now put the predictions together with the true values, to visualize them:

df = pd.DataFrame(columns=['Patient', 'Weeks', 'FVC_pred', 'sigma'])
df['Patient'] = le.inverse_transform(pred_template['PatientID'])
df['Weeks'] = pred_template['Weeks']
df['FVC_pred'] = samples_predictive['obs'].T.mean(axis=1)
df['sigma'] = samples_predictive['obs'].T.std(axis=1)
df['FVC_inf'] = df['FVC_pred'] - df['sigma']
df['FVC_sup'] = df['FVC_pred'] + df['sigma']
df = pd.merge(df, train[['Patient', 'Weeks', 'FVC']],
              how='left', on=['Patient', 'Weeks'])
df = df.rename(columns={'FVC': 'FVC_true'})
Patient Weeks FVC_pred sigma FVC_inf FVC_sup FVC_true
0 ID00007637202177411956430 -12 2219.361084 159.272430 2060.088623 2378.633545 NaN
1 ID00007637202177411956430 -11 2209.278076 157.698868 2051.579102 2366.977051 NaN
2 ID00007637202177411956430 -10 2212.443115 154.503906 2057.939209 2366.947021 NaN
3 ID00007637202177411956430 -9 2208.173096 153.068268 2055.104736 2361.241455 NaN
4 ID00007637202177411956430 -8 2202.373047 157.185608 2045.187500 2359.558594 NaN

Finally, let’s see our predictions for 3 patients:

def chart(patient_id, ax):
    data = df[df['Patient'] == patient_id]
    x = data['Weeks']
    ax.plot(x, data['FVC_true'], 'o')
    ax.plot(x, data['FVC_pred'])
    ax = sns.regplot(x, data['FVC_true'], ax=ax, ci=None,
    ax.fill_between(x, data["FVC_inf"], data["FVC_sup"],
                    alpha=0.5, color='#ffcd3c')

f, axes = plt.subplots(1, 3, figsize=(15, 5))
chart('ID00007637202177411956430', axes[0])
chart('ID00009637202177434476278', axes[1])
chart('ID00011637202177653955184', axes[2])

The results are exactly what we expected to see! Highlight observations:

  • The model adequately learned Bayesian Linear Regressions! The orange line (learned predicted FVC mean) is very inline with the red line (deterministic linear regression). But most important: it learned to predict uncertainty, showed in the light orange region (one sigma above and below the mean FVC line)
  • The model predicts a higher uncertainty where the data points are more disperse (1st and 3rd patients). Conversely, where the points are closely grouped together (2nd patient), the model predicts a higher confidence (narrower light orange region)
  • Finally, in all patients, we can see that the uncertainty grows as the look more into the future: the light orange region widens as the # of weeks grow!

4.3. Computing the modified Laplace Log Likelihood and RMSE

As mentioned earlier, the competition was evaluated on a modified version of the Laplace Log Likelihood. In medical applications, it is useful to evaluate a model’s confidence in its decisions. Accordingly, the metric is designed to reflect both the accuracy and certainty of each prediction.

For each true FVC measurement, we predicted both an FVC and a confidence measure (standard deviation \(\sigma\)). The metric was computed as:

\begin{align} \sigma_{clipped} &= max(\sigma, 70) \\ \delta &= min(|FVC_{true} - FVC_{pred}|, 1000) \\ metric &= -\dfrac{\sqrt{2}\delta}{\sigma_{clipped}} - \ln(\sqrt{2} \sigma_{clipped}) \end{align}

The error was thresholded at 1000 ml to avoid large errors adversely penalizing results, while the confidence values were clipped at 70 ml to reflect the approximate measurement uncertainty in FVC. The final score was calculated by averaging the metric across all (Patient, Week) pairs. Note that metric values will be negative and higher is better.

Next, we calculate the metric and RMSE:

y = df.dropna()
rmse = ((y['FVC_pred'] - y['FVC_true']) ** 2).mean() ** (1/2)
print(f'RMSE: {rmse:.1f} ml')

sigma_c = y['sigma'].values
sigma_c[sigma_c < 70] = 70
delta = (y['FVC_pred'] - y['FVC_true']).abs()
delta[delta > 1000] = 1000
lll = - np.sqrt(2) * delta / sigma_c - np.log(np.sqrt(2) * sigma_c)
print(f'Laplace Log Likelihood: {lll.mean():.4f}')
RMSE: 122.1 ml
Laplace Log Likelihood: -6.1376

What do these numbers mean? It means if you adopted this approach, you would outperform most of the public solutions in the competition. Curiously, the vast majority of public solutions adopt a standard deterministic Neural Network, modelling uncertainty through a quantile loss. Most of the people still adopt a frequentist approach.

Uncertainty for single predictions becomes more and more important in machine learning and is often a requirement. Especially when the consequenses of a wrong prediction are high, we need to know what the probability distribution of an individual prediction is. For perspective, Kaggle just launched a new competition sponsored by Lyft, to build motion prediction models for self-driving vehicles. “We ask that you predict a few trajectories for every agent and provide a confidence score for each of them.”

Finally, I hope the great work done by Pyro/NumPyro developers help democratize Bayesian methods, empowering an ever growing community of researchers and practitioners to create models that can not only generate predictions, but also assess uncertainty in their predictions.


  1. Ghahramani, Z. Probabilistic machine learning and artificial intelligence. Nature 521, 452–459 (2015).
  2. Rainforth, Thomas William Gamlen. Automating Inference, Learning, and Design Using Probabilistic Programming. University of Oxford, 2017.