Interactive online version: Open In Colab

Bayesian Imputation

Real-world datasets often contain many missing values. In those situations, we have to either remove those missing data (also known as “complete case”) or replace them by some values. Though using complete case is pretty straightforward, it is only applicable when the number of missing entries is so small that throwing away those entries would not affect much the power of the analysis we are conducting on the data. The second strategy, also known as imputation, is more applicable and will be our focus in this tutorial.

Probably the most popular way to perform imputation is to fill a missing value with the mean, median, or mode of its corresponding feature. In that case, we implicitly assume that the feature containing missing values has no correlation with the remaining features of our dataset. This is a pretty strong assumption and might not be true in general. In addition, it does not encode any uncertainty that we might put on those values. Below, we will construct a Bayesian setting to resolve those issues. In particular, given a model on the dataset, we will

  • create a generative model for the feature with missing value

  • and consider missing values as unobserved latent variables.

[ ]:
!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro
[1]:
# first, we need some imports
import os

from IPython.display import set_matplotlib_formats
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd

from jax import numpy as jnp, random

import numpyro
from numpyro import distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive

plt.style.use("seaborn")
if "NUMPYRO_SPHINXBUILD" in os.environ:
    set_matplotlib_formats("svg")

assert numpyro.__version__.startswith("0.15.3")

Dataset

The data is taken from the competition Titanic: Machine Learning from Disaster hosted on kaggle. It contains information of passengers in the Titanic accident such as name, age, gender,… And our target is to predict if a person is more likely to survive.

[2]:
train_df = pd.read_csv(
    "https://raw.githubusercontent.com/agconti/kaggle-titanic/master/data/train.csv"
)
train_df.info()
train_df.head()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):
 #   Column       Non-Null Count  Dtype
---  ------       --------------  -----
 0   PassengerId  891 non-null    int64
 1   Survived     891 non-null    int64
 2   Pclass       891 non-null    int64
 3   Name         891 non-null    object
 4   Sex          891 non-null    object
 5   Age          714 non-null    float64
 6   SibSp        891 non-null    int64
 7   Parch        891 non-null    int64
 8   Ticket       891 non-null    object
 9   Fare         891 non-null    float64
 10  Cabin        204 non-null    object
 11  Embarked     889 non-null    object
dtypes: float64(2), int64(5), object(5)
memory usage: 83.7+ KB
[2]:
PassengerId Survived Pclass Name Sex Age SibSp Parch Ticket Fare Cabin Embarked
0 1 0 3 Braund, Mr. Owen Harris male 22.0 1 0 A/5 21171 7.2500 NaN S
1 2 1 1 Cumings, Mrs. John Bradley (Florence Briggs Th... female 38.0 1 0 PC 17599 71.2833 C85 C
2 3 1 3 Heikkinen, Miss. Laina female 26.0 0 0 STON/O2. 3101282 7.9250 NaN S
3 4 1 1 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35.0 1 0 113803 53.1000 C123 S
4 5 0 3 Allen, Mr. William Henry male 35.0 0 0 373450 8.0500 NaN S

Look at the data info, we know that there are missing data at Age, Cabin, and Embarked columns. Although Cabin is an important feature (because the position of a cabin in the ship can affect the chance of people in that cabin to survive), we will skip it in this tutorial for simplicity. In the dataset, there are many categorical columns and two numerical columns Age and Fare. Let’s first look at the distribution of those categorical columns:

[3]:
for col in ["Survived", "Pclass", "Sex", "SibSp", "Parch", "Embarked"]:
    print(train_df[col].value_counts(), end="\n\n")
0    549
1    342
Name: Survived, dtype: int64

3    491
1    216
2    184
Name: Pclass, dtype: int64

male      577
female    314
Name: Sex, dtype: int64

0    608
1    209
2     28
4     18
3     16
8      7
5      5
Name: SibSp, dtype: int64

0    678
1    118
2     80
3      5
5      5
4      4
6      1
Name: Parch, dtype: int64

S    644
C    168
Q     77
Name: Embarked, dtype: int64

Prepare data

First, we will merge rare groups in SibSp and Parch columns together. In addition, we’ll fill 2 missing entries in Embarked by the mode S. Note that we can make a generative model for those missing entries in Embarked but let’s skip doing so for simplicity.

[4]:
train_df.SibSp.clip(0, 1, inplace=True)
train_df.Parch.clip(0, 2, inplace=True)
train_df.Embarked.fillna("S", inplace=True)

Looking closer at the data, we can observe that each name contains a title. We know that age is correlated with the title of the name: e.g. those with Mrs. would be older than those with Miss. (on average) so it might be good to create that feature. The distribution of titles is:

[5]:
train_df.Name.str.split(", ").str.get(1).str.split(" ").str.get(0).value_counts()
[5]:
Mr.          517
Miss.        182
Mrs.         125
Master.       40
Dr.            7
Rev.           6
Mlle.          2
Col.           2
Major.         2
Lady.          1
Sir.           1
the            1
Ms.            1
Capt.          1
Mme.           1
Jonkheer.      1
Don.           1
Name: Name, dtype: int64

We will make a new column Title, where rare titles are merged into one group Misc..

[6]:
train_df["Title"] = (
    train_df.Name.str.split(", ")
    .str.get(1)
    .str.split(" ")
    .str.get(0)
    .apply(lambda x: x if x in ["Mr.", "Miss.", "Mrs.", "Master."] else "Misc.")
)

Now, it is ready to turn the dataframe, which includes categorical values, into numpy arrays. We also perform standardization (a good practice for regression models) for Age column.

[7]:
title_cat = pd.CategoricalDtype(
    categories=["Mr.", "Miss.", "Mrs.", "Master.", "Misc."], ordered=True
)
embarked_cat = pd.CategoricalDtype(categories=["S", "C", "Q"], ordered=True)
age_mean, age_std = train_df.Age.mean(), train_df.Age.std()
data = dict(
    age=train_df.Age.pipe(lambda x: (x - age_mean) / age_std).values,
    pclass=train_df.Pclass.values - 1,
    title=train_df.Title.astype(title_cat).cat.codes.values,
    sex=(train_df.Sex == "male").astype(int).values,
    sibsp=train_df.SibSp.values,
    parch=train_df.Parch.values,
    embarked=train_df.Embarked.astype(embarked_cat).cat.codes.values,
)
survived = train_df.Survived.values
# compute the age mean for each title
age_notnan = data["age"][jnp.isfinite(data["age"])]
title_notnan = data["title"][jnp.isfinite(data["age"])]
age_mean_by_title = jnp.stack([age_notnan[title_notnan == i].mean() for i in range(5)])

Modelling

First, we want to note that in NumPyro, the following models

def model1a():
    x = numpyro.sample("x", dist.Normal(0, 1).expand([10]))

and

def model1b():
    x = numpyro.sample("x", dist.Normal(0, 1).expand([10]).mask(False))
    numpyro.sample("x_obs", dist.Normal(0, 1).expand([10]), obs=x)

are equivalent in the sense that both of them have

  • the same latent sites x drawn from dist.Normal(0, 1) prior,

  • and the same log densities dist.Normal(0, 1).log_prob(x).

Now, assume that we observed the last 6 values of x (non-observed entries take value NaN), the typical model will be

def model2a(x):
    x_impute = numpyro.sample("x_impute", dist.Normal(0, 1).expand([4]))
    x_obs = numpyro.sample("x_obs", dist.Normal(0, 1).expand([6]), obs=x[4:])
    x_imputed = jnp.concatenate([x_impute, x_obs])

or with the usage of mask,

def model2b(x):
    x_impute = numpyro.sample("x_impute", dist.Normal(0, 1).expand([4]).mask(False))
    x_imputed = jnp.concatenate([x_impute, x[4:]])
    numpyro.sample("x", dist.Normal(0, 1).expand([10]), obs=x_imputed)

Both approaches to model the partial observed data x are equivalent. For the model below, we will use the latter method.

[8]:
def model(
    age, pclass, title, sex, sibsp, parch, embarked, survived=None, bayesian_impute=True
):
    b_pclass = numpyro.sample("b_Pclass", dist.Normal(0, 1).expand([3]))
    b_title = numpyro.sample("b_Title", dist.Normal(0, 1).expand([5]))
    b_sex = numpyro.sample("b_Sex", dist.Normal(0, 1).expand([2]))
    b_sibsp = numpyro.sample("b_SibSp", dist.Normal(0, 1).expand([2]))
    b_parch = numpyro.sample("b_Parch", dist.Normal(0, 1).expand([3]))
    b_embarked = numpyro.sample("b_Embarked", dist.Normal(0, 1).expand([3]))

    # impute age by Title
    isnan = np.isnan(age)
    age_nanidx = np.nonzero(isnan)[0]
    if bayesian_impute:
        age_mu = numpyro.sample("age_mu", dist.Normal(0, 1).expand([5]))
        age_mu = age_mu[title]
        age_sigma = numpyro.sample("age_sigma", dist.Normal(0, 1).expand([5]))
        age_sigma = age_sigma[title]
        age_impute = numpyro.sample(
            "age_impute",
            dist.Normal(age_mu[age_nanidx], age_sigma[age_nanidx]).mask(False),
        )
        age = jnp.asarray(age).at[age_nanidx].set(age_impute)
        numpyro.sample("age", dist.Normal(age_mu, age_sigma), obs=age)
    else:
        # fill missing data by the mean of ages for each title
        age_impute = age_mean_by_title[title][age_nanidx]
        age = jnp.asarray(age).at[age_nanidx].set(age_impute)

    a = numpyro.sample("a", dist.Normal(0, 1))
    b_age = numpyro.sample("b_Age", dist.Normal(0, 1))
    logits = a + b_age * age
    logits = logits + b_title[title] + b_pclass[pclass] + b_sex[sex]
    logits = logits + b_sibsp[sibsp] + b_parch[parch] + b_embarked[embarked]
    numpyro.sample("survived", dist.Bernoulli(logits=logits), obs=survived)

Note that in the model, the prior for age is dist.Normal(age_mu, age_sigma), where the values of age_mu and age_sigma depend on title. Because there are missing values in age, we will encode those missing values in the latent parameter age_impute. Then we can replace NaN entries in age with the vector age_impute.

Sampling

We will use MCMC with NUTS kernel to sample both regression coefficients and imputed values.

[9]:
mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=1000)
mcmc.run(random.PRNGKey(0), **data, survived=survived)
mcmc.print_summary()
sample: 100%|██████████| 2000/2000 [00:15<00:00, 132.15it/s, 63 steps of size 5.68e-02. acc. prob=0.95]

                     mean       std    median      5.0%     95.0%     n_eff     r_hat
              a      0.12      0.82      0.11     -1.21      1.49    887.50      1.00
  age_impute[0]      0.20      0.84      0.18     -1.22      1.53   1346.09      1.00
  age_impute[1]     -0.06      0.86     -0.08     -1.41      1.26   1057.70      1.00
  age_impute[2]      0.38      0.73      0.39     -0.80      1.58   1570.36      1.00
  age_impute[3]      0.25      0.84      0.23     -0.99      1.86   1027.43      1.00
  age_impute[4]     -0.63      0.91     -0.59     -1.99      0.87   1183.66      1.00
  age_impute[5]      0.21      0.89      0.19     -1.02      1.97   1456.79      1.00
  age_impute[6]      0.45      0.82      0.46     -0.90      1.73   1239.22      1.00
  age_impute[7]     -0.62      0.86     -0.62     -2.13      0.72   1406.09      1.00
  age_impute[8]     -0.13      0.90     -0.14     -1.64      1.38   1905.07      1.00
  age_impute[9]      0.24      0.84      0.26     -1.06      1.77   1471.12      1.00
 age_impute[10]      0.20      0.89      0.21     -1.26      1.65   1588.79      1.00
 age_impute[11]      0.17      0.91      0.19     -1.59      1.48   1446.52      1.00
 age_impute[12]     -0.65      0.89     -0.68     -2.12      0.77   1457.47      1.00
 age_impute[13]      0.21      0.85      0.18     -1.24      1.53   1057.77      1.00
 age_impute[14]      0.05      0.92      0.05     -1.40      1.65   1207.08      1.00
 age_impute[15]      0.37      0.94      0.37     -1.02      1.98   1326.55      1.00
 age_impute[16]     -1.74      0.26     -1.74     -2.13     -1.32   1320.08      1.00
 age_impute[17]      0.21      0.89      0.22     -1.30      1.60   1545.73      1.00
 age_impute[18]      0.18      0.90      0.18     -1.26      1.58   2013.12      1.00
 age_impute[19]     -0.67      0.86     -0.66     -1.97      0.85   1499.50      1.00
 age_impute[20]      0.23      0.89      0.27     -1.19      1.71   1712.24      1.00
 age_impute[21]      0.21      0.87      0.20     -1.11      1.68   1400.55      1.00
 age_impute[22]      0.19      0.90      0.18     -1.26      1.63   1400.37      1.00
 age_impute[23]     -0.15      0.85     -0.15     -1.57      1.24   1205.10      1.00
 age_impute[24]     -0.71      0.89     -0.73     -2.05      0.82   1085.52      1.00
 age_impute[25]      0.20      0.85      0.19     -1.20      1.62   1708.01      1.00
 age_impute[26]      0.21      0.88      0.21     -1.20      1.68   1363.75      1.00
 age_impute[27]     -0.69      0.91     -0.73     -2.20      0.77   1224.06      1.00
 age_impute[28]      0.60      0.77      0.60     -0.61      1.95   1312.44      1.00
 age_impute[29]      0.20      0.89      0.17     -1.23      1.71    938.19      1.00
 age_impute[30]      0.24      0.87      0.23     -1.14      1.60   1324.50      1.00
 age_impute[31]     -1.72      0.26     -1.72     -2.11     -1.28   1425.46      1.00
 age_impute[32]      0.44      0.77      0.43     -0.83      1.58   1587.41      1.00
 age_impute[33]      0.34      0.89      0.32     -1.14      1.73   1375.14      1.00
 age_impute[34]     -1.72      0.26     -1.71     -2.11     -1.26   1007.71      1.00
 age_impute[35]     -0.45      0.90     -0.47     -2.06      0.92   1329.44      1.00
 age_impute[36]      0.30      0.84      0.30     -1.03      1.73   1080.80      1.00
 age_impute[37]      0.33      0.88      0.32     -1.10      1.81   1033.30      1.00
 age_impute[38]      0.33      0.76      0.35     -0.94      1.56   1550.68      1.00
 age_impute[39]      0.19      0.93      0.21     -1.32      1.82   1203.79      1.00
 age_impute[40]     -0.67      0.88     -0.69     -1.94      0.88   1382.98      1.00
 age_impute[41]      0.17      0.89      0.14     -1.30      1.43   1438.18      1.00
 age_impute[42]      0.23      0.82      0.25     -1.12      1.48   1499.59      1.00
 age_impute[43]      0.22      0.82      0.21     -1.19      1.45   1236.67      1.00
 age_impute[44]     -0.41      0.85     -0.42     -1.96      0.78    812.53      1.00
 age_impute[45]     -0.36      0.89     -0.35     -2.01      0.94   1488.83      1.00
 age_impute[46]     -0.33      0.91     -0.32     -1.76      1.27   1628.61      1.00
 age_impute[47]     -0.71      0.85     -0.69     -2.12      0.64   1363.89      1.00
 age_impute[48]      0.21      0.85      0.24     -1.21      1.64   1552.65      1.00
 age_impute[49]      0.42      0.82      0.41     -0.83      1.77    754.08      1.00
 age_impute[50]      0.26      0.86      0.24     -1.18      1.63   1155.49      1.00
 age_impute[51]     -0.29      0.91     -0.30     -1.83      1.15   1212.08      1.00
 age_impute[52]      0.36      0.85      0.34     -1.12      1.68   1190.99      1.00
 age_impute[53]     -0.68      0.89     -0.65     -2.09      0.75   1104.75      1.00
 age_impute[54]      0.27      0.90      0.25     -1.24      1.68   1331.19      1.00
 age_impute[55]      0.36      0.89      0.36     -0.96      1.86   1917.52      1.00
 age_impute[56]      0.38      0.86      0.40     -1.00      1.75   1862.00      1.00
 age_impute[57]      0.01      0.91      0.03     -1.33      1.56   1285.43      1.00
 age_impute[58]     -0.69      0.91     -0.66     -2.13      0.78   1438.41      1.00
 age_impute[59]     -0.14      0.85     -0.16     -1.44      1.37   1135.79      1.00
 age_impute[60]     -0.59      0.94     -0.61     -2.19      0.93   1222.88      1.00
 age_impute[61]      0.24      0.92      0.25     -1.35      1.65   1341.95      1.00
 age_impute[62]     -0.55      0.91     -0.57     -2.01      0.96    753.85      1.00
 age_impute[63]      0.21      0.90      0.19     -1.42      1.60   1238.50      1.00
 age_impute[64]     -0.66      0.88     -0.68     -2.04      0.73   1214.85      1.00
 age_impute[65]      0.44      0.78      0.48     -0.93      1.57   1174.41      1.00
 age_impute[66]      0.22      0.94      0.20     -1.35      1.69   1910.00      1.00
 age_impute[67]      0.33      0.76      0.34     -0.85      1.63   1210.24      1.00
 age_impute[68]      0.31      0.84      0.33     -1.08      1.60   1756.60      1.00
 age_impute[69]      0.26      0.91      0.25     -1.29      1.75   1155.87      1.00
 age_impute[70]     -0.67      0.86     -0.70     -2.02      0.70   1186.22      1.00
 age_impute[71]     -0.70      0.90     -0.69     -2.21      0.75   1469.35      1.00
 age_impute[72]      0.24      0.86      0.24     -1.07      1.66   1604.16      1.00
 age_impute[73]      0.34      0.72      0.35     -0.77      1.55   1144.55      1.00
 age_impute[74]     -0.64      0.85     -0.64     -2.10      0.77   1513.79      1.00
 age_impute[75]      0.41      0.78      0.42     -0.96      1.60    796.47      1.00
 age_impute[76]      0.18      0.89      0.21     -1.19      1.74    755.44      1.00
 age_impute[77]      0.21      0.84      0.22     -1.22      1.63   1371.73      1.00
 age_impute[78]     -0.36      0.87     -0.33     -1.81      1.01   1017.23      1.00
 age_impute[79]      0.20      0.84      0.19     -1.35      1.37   1677.57      1.00
 age_impute[80]      0.23      0.84      0.24     -1.09      1.61   1545.61      1.00
 age_impute[81]      0.28      0.90      0.32     -1.08      1.83   1735.91      1.00
 age_impute[82]      0.61      0.80      0.60     -0.61      2.03   1353.67      1.00
 age_impute[83]      0.24      0.89      0.26     -1.22      1.66   1165.03      1.00
 age_impute[84]      0.21      0.91      0.21     -1.35      1.65   1584.00      1.00
 age_impute[85]      0.24      0.92      0.21     -1.33      1.63   1271.37      1.00
 age_impute[86]      0.31      0.81      0.30     -0.86      1.76   1198.70      1.00
 age_impute[87]     -0.11      0.84     -0.10     -1.42      1.23   1248.38      1.00
 age_impute[88]      0.21      0.94      0.22     -1.31      1.77   1082.82      1.00
 age_impute[89]      0.24      0.86      0.23     -1.08      1.67   2141.98      1.00
 age_impute[90]      0.41      0.84      0.45     -0.88      1.90   1518.73      1.00
 age_impute[91]      0.21      0.86      0.20     -1.21      1.58   1723.50      1.00
 age_impute[92]      0.21      0.84      0.20     -1.21      1.57   1742.44      1.00
 age_impute[93]      0.22      0.87      0.23     -1.29      1.50   1359.74      1.00
 age_impute[94]      0.22      0.87      0.18     -1.09      1.70    906.55      1.00
 age_impute[95]      0.22      0.87      0.23     -1.16      1.65   1112.58      1.00
 age_impute[96]      0.30      0.84      0.26     -1.18      1.57   1680.70      1.00
 age_impute[97]      0.23      0.87      0.25     -1.22      1.63   1408.40      1.00
 age_impute[98]     -0.36      0.91     -0.37     -1.96      1.03   1083.67      1.00
 age_impute[99]      0.15      0.87      0.14     -1.22      1.61   1644.46      1.00
age_impute[100]      0.27      0.85      0.30     -1.27      1.45   1266.96      1.00
age_impute[101]      0.25      0.87      0.25     -1.19      1.57   1220.96      1.00
age_impute[102]     -0.29      0.85     -0.28     -1.70      1.10   1392.91      1.00
age_impute[103]      0.01      0.89      0.01     -1.46      1.39   1137.34      1.00
age_impute[104]      0.21      0.86      0.24     -1.16      1.64   1018.70      1.00
age_impute[105]      0.24      0.93      0.21     -1.14      1.90   1479.67      1.00
age_impute[106]      0.21      0.83      0.21     -1.09      1.55   1471.11      1.00
age_impute[107]      0.22      0.85      0.22     -1.09      1.64   1941.83      1.00
age_impute[108]      0.31      0.88      0.30     -1.10      1.76   1342.10      1.00
age_impute[109]      0.22      0.86      0.23     -1.25      1.56   1198.01      1.00
age_impute[110]      0.33      0.78      0.35     -0.95      1.62   1267.01      1.00
age_impute[111]      0.22      0.88      0.21     -1.11      1.71   1404.51      1.00
age_impute[112]     -0.03      0.90     -0.02     -1.38      1.55   1625.35      1.00
age_impute[113]      0.24      0.85      0.23     -1.17      1.62   1361.84      1.00
age_impute[114]      0.36      0.86      0.37     -0.99      1.76   1155.67      1.00
age_impute[115]      0.26      0.96      0.28     -1.37      1.81   1245.97      1.00
age_impute[116]      0.21      0.86      0.24     -1.18      1.69   1565.59      1.00
age_impute[117]     -0.31      0.94     -0.33     -1.91      1.19   1593.65      1.00
age_impute[118]      0.21      0.87      0.22     -1.20      1.64   1315.42      1.00
age_impute[119]     -0.69      0.88     -0.74     -2.00      0.90   1536.44      1.00
age_impute[120]      0.63      0.81      0.66     -0.65      1.89    899.61      1.00
age_impute[121]      0.27      0.90      0.26     -1.16      1.74   1744.32      1.00
age_impute[122]      0.18      0.87      0.18     -1.23      1.60   1625.58      1.00
age_impute[123]     -0.39      0.88     -0.38     -1.71      1.12   1266.58      1.00
age_impute[124]     -0.62      0.95     -0.63     -2.03      1.01   1600.28      1.00
age_impute[125]      0.23      0.88      0.23     -1.15      1.71   1604.27      1.00
age_impute[126]      0.18      0.91      0.18     -1.24      1.63   1527.38      1.00
age_impute[127]      0.32      0.85      0.36     -1.08      1.73   1074.98      1.00
age_impute[128]      0.25      0.88      0.25     -1.10      1.69   1486.79      1.00
age_impute[129]     -0.70      0.87     -0.68     -2.20      0.56   1506.55      1.00
age_impute[130]      0.21      0.88      0.20     -1.16      1.68   1451.63      1.00
age_impute[131]      0.22      0.87      0.23     -1.22      1.61    905.86      1.00
age_impute[132]      0.33      0.83      0.33     -1.01      1.66   1517.67      1.00
age_impute[133]      0.18      0.86      0.18     -1.19      1.59   1050.00      1.00
age_impute[134]     -0.14      0.92     -0.15     -1.77      1.24   1386.20      1.00
age_impute[135]      0.19      0.85      0.18     -1.22      1.53   1290.94      1.00
age_impute[136]      0.16      0.92      0.16     -1.35      1.74   1767.36      1.00
age_impute[137]     -0.71      0.90     -0.68     -2.24      0.82   1154.14      1.00
age_impute[138]      0.18      0.91      0.16     -1.30      1.67   1160.90      1.00
age_impute[139]      0.24      0.90      0.24     -1.15      1.76   1289.37      1.00
age_impute[140]      0.41      0.80      0.39     -1.05      1.53   1532.92      1.00
age_impute[141]      0.27      0.83      0.29     -1.04      1.60   1310.29      1.00
age_impute[142]     -0.28      0.89     -0.29     -1.68      1.22   1088.65      1.00
age_impute[143]     -0.12      0.91     -0.11     -1.56      1.40   1324.74      1.00
age_impute[144]     -0.65      0.87     -0.63     -1.91      0.93   1672.31      1.00
age_impute[145]     -1.73      0.26     -1.74     -2.11     -1.26   1502.96      1.00
age_impute[146]      0.40      0.85      0.40     -0.85      1.84   1443.81      1.00
age_impute[147]      0.23      0.87      0.20     -1.37      1.49   1220.62      1.00
age_impute[148]     -0.70      0.88     -0.70     -2.08      0.87   1846.67      1.00
age_impute[149]      0.27      0.87      0.29     -1.11      1.76   1451.79      1.00
age_impute[150]      0.21      0.90      0.20     -1.10      1.78   1409.94      1.00
age_impute[151]      0.25      0.87      0.26     -1.21      1.63   1224.08      1.00
age_impute[152]      0.05      0.85      0.05     -1.42      1.39   1164.23      1.00
age_impute[153]      0.18      0.90      0.15     -1.19      1.72   1697.92      1.00
age_impute[154]      1.05      0.93      1.04     -0.24      2.84   1212.82      1.00
age_impute[155]      0.20      0.84      0.18     -1.18      1.54   1398.45      1.00
age_impute[156]      0.23      0.95      0.19     -1.19      1.87   1773.79      1.00
age_impute[157]      0.19      0.85      0.22     -1.13      1.64   1123.21      1.00
age_impute[158]      0.22      0.86      0.22     -1.18      1.60   1307.64      1.00
age_impute[159]      0.18      0.84      0.18     -1.09      1.59   1499.97      1.00
age_impute[160]      0.24      0.89      0.28     -1.23      1.65   1100.08      1.00
age_impute[161]     -0.45      0.88     -0.45     -1.86      1.05   1414.97      1.00
age_impute[162]      0.39      0.89      0.40     -1.00      1.87   1525.80      1.00
age_impute[163]      0.34      0.89      0.35     -1.14      1.75   1600.03      1.00
age_impute[164]      0.21      0.94      0.19     -1.13      1.91   1090.05      1.00
age_impute[165]      0.22      0.85      0.20     -1.11      1.60   1330.87      1.00
age_impute[166]     -0.13      0.91     -0.15     -1.69      1.28   1284.90      1.00
age_impute[167]      0.22      0.89      0.24     -1.15      1.76   1261.93      1.00
age_impute[168]      0.20      0.90      0.18     -1.18      1.83   1217.16      1.00
age_impute[169]      0.07      0.89      0.05     -1.29      1.60   2007.16      1.00
age_impute[170]      0.23      0.90      0.24     -1.25      1.67    937.57      1.00
age_impute[171]      0.41      0.80      0.42     -0.82      1.82   1404.02      1.00
age_impute[172]      0.23      0.87      0.20     -1.33      1.51   2032.72      1.00
age_impute[173]     -0.44      0.88     -0.44     -1.81      1.08   1006.62      1.00
age_impute[174]      0.19      0.84      0.19     -1.11      1.63   1495.21      1.00
age_impute[175]      0.20      0.85      0.20     -1.17      1.63   1551.22      1.00
age_impute[176]     -0.43      0.92     -0.44     -1.83      1.21   1477.58      1.00
      age_mu[0]      0.19      0.04      0.19      0.12      0.26    749.16      1.00
      age_mu[1]     -0.54      0.07     -0.54     -0.66     -0.42    786.30      1.00
      age_mu[2]      0.43      0.08      0.42      0.31      0.55   1134.72      1.00
      age_mu[3]     -1.73      0.04     -1.73     -1.79     -1.65   1194.53      1.00
      age_mu[4]      0.85      0.17      0.85      0.58      1.13   1111.96      1.00
   age_sigma[0]      0.88      0.03      0.88      0.82      0.93    766.67      1.00
   age_sigma[1]      0.90      0.06      0.90      0.81      0.99    992.72      1.00
   age_sigma[2]      0.79      0.05      0.78      0.71      0.87    708.34      1.00
   age_sigma[3]      0.26      0.03      0.25      0.20      0.31    959.62      1.00
   age_sigma[4]      0.93      0.13      0.93      0.74      1.15   1092.88      1.00
          b_Age     -0.45      0.14     -0.44     -0.66     -0.22    744.95      1.00
  b_Embarked[0]     -0.28      0.58     -0.30     -1.28      0.64    496.51      1.00
  b_Embarked[1]      0.30      0.60      0.29     -0.74      1.20    495.25      1.00
  b_Embarked[2]      0.04      0.61      0.03     -0.93      1.02    482.67      1.00
     b_Parch[0]      0.45      0.57      0.47     -0.45      1.42    336.02      1.02
     b_Parch[1]      0.12      0.58      0.14     -0.91      1.00    377.61      1.02
     b_Parch[2]     -0.49      0.58     -0.45     -1.48      0.41    358.61      1.01
    b_Pclass[0]      1.22      0.57      1.24      0.33      2.17    371.15      1.00
    b_Pclass[1]      0.06      0.57      0.07     -0.84      1.03    369.58      1.00
    b_Pclass[2]     -1.18      0.57     -1.16     -2.18     -0.31    373.55      1.00
       b_Sex[0]      1.15      0.74      1.18     -0.03      2.31    568.65      1.00
       b_Sex[1]     -1.05      0.74     -1.02     -2.18      0.21    709.29      1.00
     b_SibSp[0]      0.28      0.66      0.26     -0.86      1.25    585.03      1.00
     b_SibSp[1]     -0.17      0.67     -0.18     -1.28      0.87    596.44      1.00
     b_Title[0]     -0.94      0.54     -0.96     -1.86     -0.11    437.32      1.00
     b_Title[1]     -0.33      0.61     -0.33     -1.32      0.60    570.32      1.00
     b_Title[2]      0.53      0.62      0.53     -0.52      1.46    452.87      1.00
     b_Title[3]      1.48      0.59      1.48      0.60      2.48    562.71      1.00
     b_Title[4]     -0.68      0.58     -0.66     -1.71      0.15    472.57      1.00

Number of divergences: 0

To double check that the assumption “age is correlated with title” is reasonable, let’s look at the infered age by title. Recall that we performed standarization on age, so here we need to scale back to original domain.

[10]:
age_by_title = age_mean + age_std * mcmc.get_samples()["age_mu"].mean(axis=0)
dict(zip(title_cat.categories, age_by_title))
[10]:
{'Mr.': 32.434227,
 'Miss.': 21.763992,
 'Mrs.': 35.852997,
 'Master.': 4.6297398,
 'Misc.': 42.081936}

The infered result confirms our assumption that Age is correlated with Title:

  • those with Master. title has pretty small age (in other words, they are children in the ship) comparing to the other groups,

  • those with Mrs. title have larger age than those with Miss. title (in average).

We can also see that the result is similar to the actual statistical mean of Age given Title in our training dataset:

[11]:
train_df.groupby("Title")["Age"].mean()
[11]:
Title
Master.     4.574167
Misc.      42.384615
Miss.      21.773973
Mr.        32.368090
Mrs.       35.898148
Name: Age, dtype: float64

So far so good, we have many information about the regression coefficients together with imputed values and their uncertainties. Let’s inspect those results a bit:

  • The mean value -0.44 of b_Age implies that those with smaller ages have better chance to survive.

  • The mean value (1.11, -1.07) of b_Sex implies that female passengers have higher chance to survive than male passengers.

Prediction

In NumPyro, we can use Predictive utility for making predictions from posterior samples. Let’s check how well the model performs on the training dataset. For simplicity, we will get a survived prediction for each posterior sample and perform the majority rule on the predictions.

[12]:
posterior = mcmc.get_samples()
survived_pred = Predictive(model, posterior)(random.PRNGKey(1), **data)["survived"]
survived_pred = (survived_pred.mean(axis=0) >= 0.5).astype(jnp.uint8)
print("Accuracy:", (survived_pred == survived).sum() / survived.shape[0])
confusion_matrix = pd.crosstab(
    pd.Series(survived, name="actual"), pd.Series(survived_pred, name="predict")
)
confusion_matrix / confusion_matrix.sum(axis=1)
Accuracy: 0.8271605
[12]:
predict 0 1
actual
0 0.876138 0.198830
1 0.156648 0.748538

This is a pretty good result using a simple logistic regression model. Let’s see how the model performs if we don’t use Bayesian imputation here.

[13]:
mcmc.run(random.PRNGKey(2), **data, survived=survived, bayesian_impute=False)
posterior_1 = mcmc.get_samples()
survived_pred_1 = Predictive(model, posterior_1)(random.PRNGKey(2), **data)["survived"]
survived_pred_1 = (survived_pred_1.mean(axis=0) >= 0.5).astype(jnp.uint8)
print("Accuracy:", (survived_pred_1 == survived).sum() / survived.shape[0])
confusion_matrix = pd.crosstab(
    pd.Series(survived, name="actual"), pd.Series(survived_pred_1, name="predict")
)
confusion_matrix / confusion_matrix.sum(axis=1)
confusion_matrix = pd.crosstab(
    pd.Series(survived, name="actual"), pd.Series(survived_pred_1, name="predict")
)
confusion_matrix / confusion_matrix.sum(axis=1)
sample: 100%|██████████| 2000/2000 [00:11<00:00, 166.79it/s, 63 steps of size 7.18e-02. acc. prob=0.93]
Accuracy: 0.82042646
[13]:
predict 0 1
actual
0 0.872495 0.204678
1 0.163934 0.736842

We can see that Bayesian imputation does a little bit better here.

Remark. When using posterior samples to perform prediction on the new data, we need to marginalize out age_impute because those imputing values are specific to the training data:

posterior.pop("age_impute")
survived_pred = Predictive(model, posterior)(random.PRNGKey(3), **new_data)

References

  1. McElreath, R. (2016). Statistical Rethinking: A Bayesian Course with Examples in R and Stan.

  2. Kaggle competition: Titanic: Machine Learning from Disaster