Bayesian Imputation
Real-world datasets often contain many missing values. In those situations, we have to either remove those missing data (also known as “complete case”) or replace them by some values. Though using complete case is pretty straightforward, it is only applicable when the number of missing entries is so small that throwing away those entries would not affect much the power of the analysis we are conducting on the data. The second strategy, also known as imputation, is more applicable and will be our focus in this tutorial.
Probably the most popular way to perform imputation is to fill a missing value with the mean, median, or mode of its corresponding feature. In that case, we implicitly assume that the feature containing missing values has no correlation with the remaining features of our dataset. This is a pretty strong assumption and might not be true in general. In addition, it does not encode any uncertainty that we might put on those values. Below, we will construct a Bayesian setting to resolve those issues. In particular, given a model on the dataset, we will
create a generative model for the feature with missing value
and consider missing values as unobserved latent variables.
[ ]:
!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro
[1]:
# first, we need some imports
import os
from IPython.display import set_matplotlib_formats
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
from jax import numpy as jnp, random
import numpyro
from numpyro import distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
plt.style.use("seaborn")
if "NUMPYRO_SPHINXBUILD" in os.environ:
set_matplotlib_formats("svg")
assert numpyro.__version__.startswith("0.16.0")
Dataset
The data is taken from the competition Titanic: Machine Learning from Disaster hosted on kaggle. It contains information of passengers in the Titanic accident such as name, age, gender,… And our target is to predict if a person is more likely to survive.
[2]:
train_df = pd.read_csv(
"https://raw.githubusercontent.com/agconti/kaggle-titanic/master/data/train.csv"
)
train_df.info()
train_df.head()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 PassengerId 891 non-null int64
1 Survived 891 non-null int64
2 Pclass 891 non-null int64
3 Name 891 non-null object
4 Sex 891 non-null object
5 Age 714 non-null float64
6 SibSp 891 non-null int64
7 Parch 891 non-null int64
8 Ticket 891 non-null object
9 Fare 891 non-null float64
10 Cabin 204 non-null object
11 Embarked 889 non-null object
dtypes: float64(2), int64(5), object(5)
memory usage: 83.7+ KB
[2]:
PassengerId | Survived | Pclass | Name | Sex | Age | SibSp | Parch | Ticket | Fare | Cabin | Embarked | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 0 | 3 | Braund, Mr. Owen Harris | male | 22.0 | 1 | 0 | A/5 21171 | 7.2500 | NaN | S |
1 | 2 | 1 | 1 | Cumings, Mrs. John Bradley (Florence Briggs Th... | female | 38.0 | 1 | 0 | PC 17599 | 71.2833 | C85 | C |
2 | 3 | 1 | 3 | Heikkinen, Miss. Laina | female | 26.0 | 0 | 0 | STON/O2. 3101282 | 7.9250 | NaN | S |
3 | 4 | 1 | 1 | Futrelle, Mrs. Jacques Heath (Lily May Peel) | female | 35.0 | 1 | 0 | 113803 | 53.1000 | C123 | S |
4 | 5 | 0 | 3 | Allen, Mr. William Henry | male | 35.0 | 0 | 0 | 373450 | 8.0500 | NaN | S |
Look at the data info, we know that there are missing data at Age
, Cabin
, and Embarked
columns. Although Cabin
is an important feature (because the position of a cabin in the ship can affect the chance of people in that cabin to survive), we will skip it in this tutorial for simplicity. In the dataset, there are many categorical columns and two numerical columns Age
and Fare
. Let’s first look at the distribution of those categorical columns:
[3]:
for col in ["Survived", "Pclass", "Sex", "SibSp", "Parch", "Embarked"]:
print(train_df[col].value_counts(), end="\n\n")
0 549
1 342
Name: Survived, dtype: int64
3 491
1 216
2 184
Name: Pclass, dtype: int64
male 577
female 314
Name: Sex, dtype: int64
0 608
1 209
2 28
4 18
3 16
8 7
5 5
Name: SibSp, dtype: int64
0 678
1 118
2 80
3 5
5 5
4 4
6 1
Name: Parch, dtype: int64
S 644
C 168
Q 77
Name: Embarked, dtype: int64
Prepare data
First, we will merge rare groups in SibSp
and Parch
columns together. In addition, we’ll fill 2 missing entries in Embarked
by the mode S
. Note that we can make a generative model for those missing entries in Embarked
but let’s skip doing so for simplicity.
[4]:
train_df.SibSp.clip(0, 1, inplace=True)
train_df.Parch.clip(0, 2, inplace=True)
train_df.Embarked.fillna("S", inplace=True)
Looking closer at the data, we can observe that each name contains a title. We know that age is correlated with the title of the name: e.g. those with Mrs. would be older than those with Miss.
(on average) so it might be good to create that feature. The distribution of titles is:
[5]:
train_df.Name.str.split(", ").str.get(1).str.split(" ").str.get(0).value_counts()
[5]:
Mr. 517
Miss. 182
Mrs. 125
Master. 40
Dr. 7
Rev. 6
Mlle. 2
Col. 2
Major. 2
Lady. 1
Sir. 1
the 1
Ms. 1
Capt. 1
Mme. 1
Jonkheer. 1
Don. 1
Name: Name, dtype: int64
We will make a new column Title
, where rare titles are merged into one group Misc.
.
[6]:
train_df["Title"] = (
train_df.Name.str.split(", ")
.str.get(1)
.str.split(" ")
.str.get(0)
.apply(lambda x: x if x in ["Mr.", "Miss.", "Mrs.", "Master."] else "Misc.")
)
Now, it is ready to turn the dataframe, which includes categorical values, into numpy arrays. We also perform standardization (a good practice for regression models) for Age
column.
[7]:
title_cat = pd.CategoricalDtype(
categories=["Mr.", "Miss.", "Mrs.", "Master.", "Misc."], ordered=True
)
embarked_cat = pd.CategoricalDtype(categories=["S", "C", "Q"], ordered=True)
age_mean, age_std = train_df.Age.mean(), train_df.Age.std()
data = dict(
age=train_df.Age.pipe(lambda x: (x - age_mean) / age_std).values,
pclass=train_df.Pclass.values - 1,
title=train_df.Title.astype(title_cat).cat.codes.values,
sex=(train_df.Sex == "male").astype(int).values,
sibsp=train_df.SibSp.values,
parch=train_df.Parch.values,
embarked=train_df.Embarked.astype(embarked_cat).cat.codes.values,
)
survived = train_df.Survived.values
# compute the age mean for each title
age_notnan = data["age"][jnp.isfinite(data["age"])]
title_notnan = data["title"][jnp.isfinite(data["age"])]
age_mean_by_title = jnp.stack([age_notnan[title_notnan == i].mean() for i in range(5)])
Modelling
First, we want to note that in NumPyro, the following models
def model1a():
x = numpyro.sample("x", dist.Normal(0, 1).expand([10]))
and
def model1b():
x = numpyro.sample("x", dist.Normal(0, 1).expand([10]).mask(False))
numpyro.sample("x_obs", dist.Normal(0, 1).expand([10]), obs=x)
are equivalent in the sense that both of them have
the same latent sites
x
drawn fromdist.Normal(0, 1)
prior,and the same log densities
dist.Normal(0, 1).log_prob(x)
.
Now, assume that we observed the last 6 values of x
(non-observed entries take value NaN
), the typical model will be
def model2a(x):
x_impute = numpyro.sample("x_impute", dist.Normal(0, 1).expand([4]))
x_obs = numpyro.sample("x_obs", dist.Normal(0, 1).expand([6]), obs=x[4:])
x_imputed = jnp.concatenate([x_impute, x_obs])
or with the usage of mask
,
def model2b(x):
x_impute = numpyro.sample("x_impute", dist.Normal(0, 1).expand([4]).mask(False))
x_imputed = jnp.concatenate([x_impute, x[4:]])
numpyro.sample("x", dist.Normal(0, 1).expand([10]), obs=x_imputed)
Both approaches to model the partial observed data x
are equivalent. For the model below, we will use the latter method.
[8]:
def model(
age, pclass, title, sex, sibsp, parch, embarked, survived=None, bayesian_impute=True
):
b_pclass = numpyro.sample("b_Pclass", dist.Normal(0, 1).expand([3]))
b_title = numpyro.sample("b_Title", dist.Normal(0, 1).expand([5]))
b_sex = numpyro.sample("b_Sex", dist.Normal(0, 1).expand([2]))
b_sibsp = numpyro.sample("b_SibSp", dist.Normal(0, 1).expand([2]))
b_parch = numpyro.sample("b_Parch", dist.Normal(0, 1).expand([3]))
b_embarked = numpyro.sample("b_Embarked", dist.Normal(0, 1).expand([3]))
# impute age by Title
isnan = np.isnan(age)
age_nanidx = np.nonzero(isnan)[0]
if bayesian_impute:
age_mu = numpyro.sample("age_mu", dist.Normal(0, 1).expand([5]))
age_mu = age_mu[title]
age_sigma = numpyro.sample("age_sigma", dist.Normal(0, 1).expand([5]))
age_sigma = age_sigma[title]
age_impute = numpyro.sample(
"age_impute",
dist.Normal(age_mu[age_nanidx], age_sigma[age_nanidx]).mask(False),
)
age = jnp.asarray(age).at[age_nanidx].set(age_impute)
numpyro.sample("age", dist.Normal(age_mu, age_sigma), obs=age)
else:
# fill missing data by the mean of ages for each title
age_impute = age_mean_by_title[title][age_nanidx]
age = jnp.asarray(age).at[age_nanidx].set(age_impute)
a = numpyro.sample("a", dist.Normal(0, 1))
b_age = numpyro.sample("b_Age", dist.Normal(0, 1))
logits = a + b_age * age
logits = logits + b_title[title] + b_pclass[pclass] + b_sex[sex]
logits = logits + b_sibsp[sibsp] + b_parch[parch] + b_embarked[embarked]
numpyro.sample("survived", dist.Bernoulli(logits=logits), obs=survived)
Note that in the model, the prior for age
is dist.Normal(age_mu, age_sigma)
, where the values of age_mu
and age_sigma
depend on title
. Because there are missing values in age
, we will encode those missing values in the latent parameter age_impute
. Then we can replace NaN
entries in age
with the vector age_impute
.
Sampling
We will use MCMC with NUTS kernel to sample both regression coefficients and imputed values.
[9]:
mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=1000)
mcmc.run(random.PRNGKey(0), **data, survived=survived)
mcmc.print_summary()
sample: 100%|██████████| 2000/2000 [00:15<00:00, 132.15it/s, 63 steps of size 5.68e-02. acc. prob=0.95]
mean std median 5.0% 95.0% n_eff r_hat
a 0.12 0.82 0.11 -1.21 1.49 887.50 1.00
age_impute[0] 0.20 0.84 0.18 -1.22 1.53 1346.09 1.00
age_impute[1] -0.06 0.86 -0.08 -1.41 1.26 1057.70 1.00
age_impute[2] 0.38 0.73 0.39 -0.80 1.58 1570.36 1.00
age_impute[3] 0.25 0.84 0.23 -0.99 1.86 1027.43 1.00
age_impute[4] -0.63 0.91 -0.59 -1.99 0.87 1183.66 1.00
age_impute[5] 0.21 0.89 0.19 -1.02 1.97 1456.79 1.00
age_impute[6] 0.45 0.82 0.46 -0.90 1.73 1239.22 1.00
age_impute[7] -0.62 0.86 -0.62 -2.13 0.72 1406.09 1.00
age_impute[8] -0.13 0.90 -0.14 -1.64 1.38 1905.07 1.00
age_impute[9] 0.24 0.84 0.26 -1.06 1.77 1471.12 1.00
age_impute[10] 0.20 0.89 0.21 -1.26 1.65 1588.79 1.00
age_impute[11] 0.17 0.91 0.19 -1.59 1.48 1446.52 1.00
age_impute[12] -0.65 0.89 -0.68 -2.12 0.77 1457.47 1.00
age_impute[13] 0.21 0.85 0.18 -1.24 1.53 1057.77 1.00
age_impute[14] 0.05 0.92 0.05 -1.40 1.65 1207.08 1.00
age_impute[15] 0.37 0.94 0.37 -1.02 1.98 1326.55 1.00
age_impute[16] -1.74 0.26 -1.74 -2.13 -1.32 1320.08 1.00
age_impute[17] 0.21 0.89 0.22 -1.30 1.60 1545.73 1.00
age_impute[18] 0.18 0.90 0.18 -1.26 1.58 2013.12 1.00
age_impute[19] -0.67 0.86 -0.66 -1.97 0.85 1499.50 1.00
age_impute[20] 0.23 0.89 0.27 -1.19 1.71 1712.24 1.00
age_impute[21] 0.21 0.87 0.20 -1.11 1.68 1400.55 1.00
age_impute[22] 0.19 0.90 0.18 -1.26 1.63 1400.37 1.00
age_impute[23] -0.15 0.85 -0.15 -1.57 1.24 1205.10 1.00
age_impute[24] -0.71 0.89 -0.73 -2.05 0.82 1085.52 1.00
age_impute[25] 0.20 0.85 0.19 -1.20 1.62 1708.01 1.00
age_impute[26] 0.21 0.88 0.21 -1.20 1.68 1363.75 1.00
age_impute[27] -0.69 0.91 -0.73 -2.20 0.77 1224.06 1.00
age_impute[28] 0.60 0.77 0.60 -0.61 1.95 1312.44 1.00
age_impute[29] 0.20 0.89 0.17 -1.23 1.71 938.19 1.00
age_impute[30] 0.24 0.87 0.23 -1.14 1.60 1324.50 1.00
age_impute[31] -1.72 0.26 -1.72 -2.11 -1.28 1425.46 1.00
age_impute[32] 0.44 0.77 0.43 -0.83 1.58 1587.41 1.00
age_impute[33] 0.34 0.89 0.32 -1.14 1.73 1375.14 1.00
age_impute[34] -1.72 0.26 -1.71 -2.11 -1.26 1007.71 1.00
age_impute[35] -0.45 0.90 -0.47 -2.06 0.92 1329.44 1.00
age_impute[36] 0.30 0.84 0.30 -1.03 1.73 1080.80 1.00
age_impute[37] 0.33 0.88 0.32 -1.10 1.81 1033.30 1.00
age_impute[38] 0.33 0.76 0.35 -0.94 1.56 1550.68 1.00
age_impute[39] 0.19 0.93 0.21 -1.32 1.82 1203.79 1.00
age_impute[40] -0.67 0.88 -0.69 -1.94 0.88 1382.98 1.00
age_impute[41] 0.17 0.89 0.14 -1.30 1.43 1438.18 1.00
age_impute[42] 0.23 0.82 0.25 -1.12 1.48 1499.59 1.00
age_impute[43] 0.22 0.82 0.21 -1.19 1.45 1236.67 1.00
age_impute[44] -0.41 0.85 -0.42 -1.96 0.78 812.53 1.00
age_impute[45] -0.36 0.89 -0.35 -2.01 0.94 1488.83 1.00
age_impute[46] -0.33 0.91 -0.32 -1.76 1.27 1628.61 1.00
age_impute[47] -0.71 0.85 -0.69 -2.12 0.64 1363.89 1.00
age_impute[48] 0.21 0.85 0.24 -1.21 1.64 1552.65 1.00
age_impute[49] 0.42 0.82 0.41 -0.83 1.77 754.08 1.00
age_impute[50] 0.26 0.86 0.24 -1.18 1.63 1155.49 1.00
age_impute[51] -0.29 0.91 -0.30 -1.83 1.15 1212.08 1.00
age_impute[52] 0.36 0.85 0.34 -1.12 1.68 1190.99 1.00
age_impute[53] -0.68 0.89 -0.65 -2.09 0.75 1104.75 1.00
age_impute[54] 0.27 0.90 0.25 -1.24 1.68 1331.19 1.00
age_impute[55] 0.36 0.89 0.36 -0.96 1.86 1917.52 1.00
age_impute[56] 0.38 0.86 0.40 -1.00 1.75 1862.00 1.00
age_impute[57] 0.01 0.91 0.03 -1.33 1.56 1285.43 1.00
age_impute[58] -0.69 0.91 -0.66 -2.13 0.78 1438.41 1.00
age_impute[59] -0.14 0.85 -0.16 -1.44 1.37 1135.79 1.00
age_impute[60] -0.59 0.94 -0.61 -2.19 0.93 1222.88 1.00
age_impute[61] 0.24 0.92 0.25 -1.35 1.65 1341.95 1.00
age_impute[62] -0.55 0.91 -0.57 -2.01 0.96 753.85 1.00
age_impute[63] 0.21 0.90 0.19 -1.42 1.60 1238.50 1.00
age_impute[64] -0.66 0.88 -0.68 -2.04 0.73 1214.85 1.00
age_impute[65] 0.44 0.78 0.48 -0.93 1.57 1174.41 1.00
age_impute[66] 0.22 0.94 0.20 -1.35 1.69 1910.00 1.00
age_impute[67] 0.33 0.76 0.34 -0.85 1.63 1210.24 1.00
age_impute[68] 0.31 0.84 0.33 -1.08 1.60 1756.60 1.00
age_impute[69] 0.26 0.91 0.25 -1.29 1.75 1155.87 1.00
age_impute[70] -0.67 0.86 -0.70 -2.02 0.70 1186.22 1.00
age_impute[71] -0.70 0.90 -0.69 -2.21 0.75 1469.35 1.00
age_impute[72] 0.24 0.86 0.24 -1.07 1.66 1604.16 1.00
age_impute[73] 0.34 0.72 0.35 -0.77 1.55 1144.55 1.00
age_impute[74] -0.64 0.85 -0.64 -2.10 0.77 1513.79 1.00
age_impute[75] 0.41 0.78 0.42 -0.96 1.60 796.47 1.00
age_impute[76] 0.18 0.89 0.21 -1.19 1.74 755.44 1.00
age_impute[77] 0.21 0.84 0.22 -1.22 1.63 1371.73 1.00
age_impute[78] -0.36 0.87 -0.33 -1.81 1.01 1017.23 1.00
age_impute[79] 0.20 0.84 0.19 -1.35 1.37 1677.57 1.00
age_impute[80] 0.23 0.84 0.24 -1.09 1.61 1545.61 1.00
age_impute[81] 0.28 0.90 0.32 -1.08 1.83 1735.91 1.00
age_impute[82] 0.61 0.80 0.60 -0.61 2.03 1353.67 1.00
age_impute[83] 0.24 0.89 0.26 -1.22 1.66 1165.03 1.00
age_impute[84] 0.21 0.91 0.21 -1.35 1.65 1584.00 1.00
age_impute[85] 0.24 0.92 0.21 -1.33 1.63 1271.37 1.00
age_impute[86] 0.31 0.81 0.30 -0.86 1.76 1198.70 1.00
age_impute[87] -0.11 0.84 -0.10 -1.42 1.23 1248.38 1.00
age_impute[88] 0.21 0.94 0.22 -1.31 1.77 1082.82 1.00
age_impute[89] 0.24 0.86 0.23 -1.08 1.67 2141.98 1.00
age_impute[90] 0.41 0.84 0.45 -0.88 1.90 1518.73 1.00
age_impute[91] 0.21 0.86 0.20 -1.21 1.58 1723.50 1.00
age_impute[92] 0.21 0.84 0.20 -1.21 1.57 1742.44 1.00
age_impute[93] 0.22 0.87 0.23 -1.29 1.50 1359.74 1.00
age_impute[94] 0.22 0.87 0.18 -1.09 1.70 906.55 1.00
age_impute[95] 0.22 0.87 0.23 -1.16 1.65 1112.58 1.00
age_impute[96] 0.30 0.84 0.26 -1.18 1.57 1680.70 1.00
age_impute[97] 0.23 0.87 0.25 -1.22 1.63 1408.40 1.00
age_impute[98] -0.36 0.91 -0.37 -1.96 1.03 1083.67 1.00
age_impute[99] 0.15 0.87 0.14 -1.22 1.61 1644.46 1.00
age_impute[100] 0.27 0.85 0.30 -1.27 1.45 1266.96 1.00
age_impute[101] 0.25 0.87 0.25 -1.19 1.57 1220.96 1.00
age_impute[102] -0.29 0.85 -0.28 -1.70 1.10 1392.91 1.00
age_impute[103] 0.01 0.89 0.01 -1.46 1.39 1137.34 1.00
age_impute[104] 0.21 0.86 0.24 -1.16 1.64 1018.70 1.00
age_impute[105] 0.24 0.93 0.21 -1.14 1.90 1479.67 1.00
age_impute[106] 0.21 0.83 0.21 -1.09 1.55 1471.11 1.00
age_impute[107] 0.22 0.85 0.22 -1.09 1.64 1941.83 1.00
age_impute[108] 0.31 0.88 0.30 -1.10 1.76 1342.10 1.00
age_impute[109] 0.22 0.86 0.23 -1.25 1.56 1198.01 1.00
age_impute[110] 0.33 0.78 0.35 -0.95 1.62 1267.01 1.00
age_impute[111] 0.22 0.88 0.21 -1.11 1.71 1404.51 1.00
age_impute[112] -0.03 0.90 -0.02 -1.38 1.55 1625.35 1.00
age_impute[113] 0.24 0.85 0.23 -1.17 1.62 1361.84 1.00
age_impute[114] 0.36 0.86 0.37 -0.99 1.76 1155.67 1.00
age_impute[115] 0.26 0.96 0.28 -1.37 1.81 1245.97 1.00
age_impute[116] 0.21 0.86 0.24 -1.18 1.69 1565.59 1.00
age_impute[117] -0.31 0.94 -0.33 -1.91 1.19 1593.65 1.00
age_impute[118] 0.21 0.87 0.22 -1.20 1.64 1315.42 1.00
age_impute[119] -0.69 0.88 -0.74 -2.00 0.90 1536.44 1.00
age_impute[120] 0.63 0.81 0.66 -0.65 1.89 899.61 1.00
age_impute[121] 0.27 0.90 0.26 -1.16 1.74 1744.32 1.00
age_impute[122] 0.18 0.87 0.18 -1.23 1.60 1625.58 1.00
age_impute[123] -0.39 0.88 -0.38 -1.71 1.12 1266.58 1.00
age_impute[124] -0.62 0.95 -0.63 -2.03 1.01 1600.28 1.00
age_impute[125] 0.23 0.88 0.23 -1.15 1.71 1604.27 1.00
age_impute[126] 0.18 0.91 0.18 -1.24 1.63 1527.38 1.00
age_impute[127] 0.32 0.85 0.36 -1.08 1.73 1074.98 1.00
age_impute[128] 0.25 0.88 0.25 -1.10 1.69 1486.79 1.00
age_impute[129] -0.70 0.87 -0.68 -2.20 0.56 1506.55 1.00
age_impute[130] 0.21 0.88 0.20 -1.16 1.68 1451.63 1.00
age_impute[131] 0.22 0.87 0.23 -1.22 1.61 905.86 1.00
age_impute[132] 0.33 0.83 0.33 -1.01 1.66 1517.67 1.00
age_impute[133] 0.18 0.86 0.18 -1.19 1.59 1050.00 1.00
age_impute[134] -0.14 0.92 -0.15 -1.77 1.24 1386.20 1.00
age_impute[135] 0.19 0.85 0.18 -1.22 1.53 1290.94 1.00
age_impute[136] 0.16 0.92 0.16 -1.35 1.74 1767.36 1.00
age_impute[137] -0.71 0.90 -0.68 -2.24 0.82 1154.14 1.00
age_impute[138] 0.18 0.91 0.16 -1.30 1.67 1160.90 1.00
age_impute[139] 0.24 0.90 0.24 -1.15 1.76 1289.37 1.00
age_impute[140] 0.41 0.80 0.39 -1.05 1.53 1532.92 1.00
age_impute[141] 0.27 0.83 0.29 -1.04 1.60 1310.29 1.00
age_impute[142] -0.28 0.89 -0.29 -1.68 1.22 1088.65 1.00
age_impute[143] -0.12 0.91 -0.11 -1.56 1.40 1324.74 1.00
age_impute[144] -0.65 0.87 -0.63 -1.91 0.93 1672.31 1.00
age_impute[145] -1.73 0.26 -1.74 -2.11 -1.26 1502.96 1.00
age_impute[146] 0.40 0.85 0.40 -0.85 1.84 1443.81 1.00
age_impute[147] 0.23 0.87 0.20 -1.37 1.49 1220.62 1.00
age_impute[148] -0.70 0.88 -0.70 -2.08 0.87 1846.67 1.00
age_impute[149] 0.27 0.87 0.29 -1.11 1.76 1451.79 1.00
age_impute[150] 0.21 0.90 0.20 -1.10 1.78 1409.94 1.00
age_impute[151] 0.25 0.87 0.26 -1.21 1.63 1224.08 1.00
age_impute[152] 0.05 0.85 0.05 -1.42 1.39 1164.23 1.00
age_impute[153] 0.18 0.90 0.15 -1.19 1.72 1697.92 1.00
age_impute[154] 1.05 0.93 1.04 -0.24 2.84 1212.82 1.00
age_impute[155] 0.20 0.84 0.18 -1.18 1.54 1398.45 1.00
age_impute[156] 0.23 0.95 0.19 -1.19 1.87 1773.79 1.00
age_impute[157] 0.19 0.85 0.22 -1.13 1.64 1123.21 1.00
age_impute[158] 0.22 0.86 0.22 -1.18 1.60 1307.64 1.00
age_impute[159] 0.18 0.84 0.18 -1.09 1.59 1499.97 1.00
age_impute[160] 0.24 0.89 0.28 -1.23 1.65 1100.08 1.00
age_impute[161] -0.45 0.88 -0.45 -1.86 1.05 1414.97 1.00
age_impute[162] 0.39 0.89 0.40 -1.00 1.87 1525.80 1.00
age_impute[163] 0.34 0.89 0.35 -1.14 1.75 1600.03 1.00
age_impute[164] 0.21 0.94 0.19 -1.13 1.91 1090.05 1.00
age_impute[165] 0.22 0.85 0.20 -1.11 1.60 1330.87 1.00
age_impute[166] -0.13 0.91 -0.15 -1.69 1.28 1284.90 1.00
age_impute[167] 0.22 0.89 0.24 -1.15 1.76 1261.93 1.00
age_impute[168] 0.20 0.90 0.18 -1.18 1.83 1217.16 1.00
age_impute[169] 0.07 0.89 0.05 -1.29 1.60 2007.16 1.00
age_impute[170] 0.23 0.90 0.24 -1.25 1.67 937.57 1.00
age_impute[171] 0.41 0.80 0.42 -0.82 1.82 1404.02 1.00
age_impute[172] 0.23 0.87 0.20 -1.33 1.51 2032.72 1.00
age_impute[173] -0.44 0.88 -0.44 -1.81 1.08 1006.62 1.00
age_impute[174] 0.19 0.84 0.19 -1.11 1.63 1495.21 1.00
age_impute[175] 0.20 0.85 0.20 -1.17 1.63 1551.22 1.00
age_impute[176] -0.43 0.92 -0.44 -1.83 1.21 1477.58 1.00
age_mu[0] 0.19 0.04 0.19 0.12 0.26 749.16 1.00
age_mu[1] -0.54 0.07 -0.54 -0.66 -0.42 786.30 1.00
age_mu[2] 0.43 0.08 0.42 0.31 0.55 1134.72 1.00
age_mu[3] -1.73 0.04 -1.73 -1.79 -1.65 1194.53 1.00
age_mu[4] 0.85 0.17 0.85 0.58 1.13 1111.96 1.00
age_sigma[0] 0.88 0.03 0.88 0.82 0.93 766.67 1.00
age_sigma[1] 0.90 0.06 0.90 0.81 0.99 992.72 1.00
age_sigma[2] 0.79 0.05 0.78 0.71 0.87 708.34 1.00
age_sigma[3] 0.26 0.03 0.25 0.20 0.31 959.62 1.00
age_sigma[4] 0.93 0.13 0.93 0.74 1.15 1092.88 1.00
b_Age -0.45 0.14 -0.44 -0.66 -0.22 744.95 1.00
b_Embarked[0] -0.28 0.58 -0.30 -1.28 0.64 496.51 1.00
b_Embarked[1] 0.30 0.60 0.29 -0.74 1.20 495.25 1.00
b_Embarked[2] 0.04 0.61 0.03 -0.93 1.02 482.67 1.00
b_Parch[0] 0.45 0.57 0.47 -0.45 1.42 336.02 1.02
b_Parch[1] 0.12 0.58 0.14 -0.91 1.00 377.61 1.02
b_Parch[2] -0.49 0.58 -0.45 -1.48 0.41 358.61 1.01
b_Pclass[0] 1.22 0.57 1.24 0.33 2.17 371.15 1.00
b_Pclass[1] 0.06 0.57 0.07 -0.84 1.03 369.58 1.00
b_Pclass[2] -1.18 0.57 -1.16 -2.18 -0.31 373.55 1.00
b_Sex[0] 1.15 0.74 1.18 -0.03 2.31 568.65 1.00
b_Sex[1] -1.05 0.74 -1.02 -2.18 0.21 709.29 1.00
b_SibSp[0] 0.28 0.66 0.26 -0.86 1.25 585.03 1.00
b_SibSp[1] -0.17 0.67 -0.18 -1.28 0.87 596.44 1.00
b_Title[0] -0.94 0.54 -0.96 -1.86 -0.11 437.32 1.00
b_Title[1] -0.33 0.61 -0.33 -1.32 0.60 570.32 1.00
b_Title[2] 0.53 0.62 0.53 -0.52 1.46 452.87 1.00
b_Title[3] 1.48 0.59 1.48 0.60 2.48 562.71 1.00
b_Title[4] -0.68 0.58 -0.66 -1.71 0.15 472.57 1.00
Number of divergences: 0
To double check that the assumption “age is correlated with title” is reasonable, let’s look at the infered age by title. Recall that we performed standarization on age
, so here we need to scale back to original domain.
[10]:
age_by_title = age_mean + age_std * mcmc.get_samples()["age_mu"].mean(axis=0)
dict(zip(title_cat.categories, age_by_title))
[10]:
{'Mr.': 32.434227,
'Miss.': 21.763992,
'Mrs.': 35.852997,
'Master.': 4.6297398,
'Misc.': 42.081936}
The infered result confirms our assumption that Age
is correlated with Title
:
those with
Master.
title has pretty small age (in other words, they are children in the ship) comparing to the other groups,those with
Mrs.
title have larger age than those withMiss.
title (in average).
We can also see that the result is similar to the actual statistical mean of Age
given Title
in our training dataset:
[11]:
train_df.groupby("Title")["Age"].mean()
[11]:
Title
Master. 4.574167
Misc. 42.384615
Miss. 21.773973
Mr. 32.368090
Mrs. 35.898148
Name: Age, dtype: float64
So far so good, we have many information about the regression coefficients together with imputed values and their uncertainties. Let’s inspect those results a bit:
The mean value
-0.44
ofb_Age
implies that those with smaller ages have better chance to survive.The mean value
(1.11, -1.07)
ofb_Sex
implies that female passengers have higher chance to survive than male passengers.
Prediction
In NumPyro, we can use Predictive utility for making predictions from posterior samples. Let’s check how well the model performs on the training dataset. For simplicity, we will get a survived
prediction for each posterior sample and perform the majority rule on the predictions.
[12]:
posterior = mcmc.get_samples()
survived_pred = Predictive(model, posterior)(random.PRNGKey(1), **data)["survived"]
survived_pred = (survived_pred.mean(axis=0) >= 0.5).astype(jnp.uint8)
print("Accuracy:", (survived_pred == survived).sum() / survived.shape[0])
confusion_matrix = pd.crosstab(
pd.Series(survived, name="actual"), pd.Series(survived_pred, name="predict")
)
confusion_matrix / confusion_matrix.sum(axis=1)
Accuracy: 0.8271605
[12]:
predict | 0 | 1 |
---|---|---|
actual | ||
0 | 0.876138 | 0.198830 |
1 | 0.156648 | 0.748538 |
This is a pretty good result using a simple logistic regression model. Let’s see how the model performs if we don’t use Bayesian imputation here.
[13]:
mcmc.run(random.PRNGKey(2), **data, survived=survived, bayesian_impute=False)
posterior_1 = mcmc.get_samples()
survived_pred_1 = Predictive(model, posterior_1)(random.PRNGKey(2), **data)["survived"]
survived_pred_1 = (survived_pred_1.mean(axis=0) >= 0.5).astype(jnp.uint8)
print("Accuracy:", (survived_pred_1 == survived).sum() / survived.shape[0])
confusion_matrix = pd.crosstab(
pd.Series(survived, name="actual"), pd.Series(survived_pred_1, name="predict")
)
confusion_matrix / confusion_matrix.sum(axis=1)
confusion_matrix = pd.crosstab(
pd.Series(survived, name="actual"), pd.Series(survived_pred_1, name="predict")
)
confusion_matrix / confusion_matrix.sum(axis=1)
sample: 100%|██████████| 2000/2000 [00:11<00:00, 166.79it/s, 63 steps of size 7.18e-02. acc. prob=0.93]
Accuracy: 0.82042646
[13]:
predict | 0 | 1 |
---|---|---|
actual | ||
0 | 0.872495 | 0.204678 |
1 | 0.163934 | 0.736842 |
We can see that Bayesian imputation does a little bit better here.
Remark. When using posterior
samples to perform prediction on the new data, we need to marginalize out age_impute
because those imputing values are specific to the training data:
posterior.pop("age_impute")
survived_pred = Predictive(model, posterior)(random.PRNGKey(3), **new_data)
References
McElreath, R. (2016). Statistical Rethinking: A Bayesian Course with Examples in R and Stan.
Kaggle competition: Titanic: Machine Learning from Disaster