Variationally Inferred Parameterization¶
Author: Madhav Kanda
Occasionally, the Hamiltonian Monte Carlo (HMC) sampler encounters challenges in effectively sampling from the posterior distribution. One illustrative case is Neal’s funnel. In these situations, the conventional centered parameterization may prove inadequate, leading us to employ non-centered parameterization. However, there are instances where even non-centered parameterization may not suffice, necessitating the utilization of Variationally Inferred Parameterization to attain the desired centeredness within the range of 0 to 1.
The purpose of this tutorial is to implement Variationally Inferred Parameterization based on Automatic Reparameterization of Probabilistic Programs using LocScaleReparam in Numpyro.
[ ]:
%pip -qq install numpyro
%pip -qq install ucimlrepo
[ ]:
import arviz as az
import numpy as np
from ucimlrepo import fetch_ucirepo
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoDiagonalNormal
from numpyro.infer.reparam import LocScaleReparam
rng_key = jax.random.PRNGKey(0)
1. Dataset¶
We will be using the German Credit Dataset for this illustration. The dataset consists of 1000 entries with 20 categorial symbolic attributes prepared by Prof. Hofmann. In this dataset, each entry represents a person who takes a credit by a bank. Each person is classified as good or bad credit risks according to the set of attributes.
[ ]:
def load_german_credit():
statlog_german_credit_data = fetch_ucirepo(id=144)
X = statlog_german_credit_data.data.features
y = statlog_german_credit_data.data.targets
return X, y
[ ]:
X, y = load_german_credit()
X
Attribute1 | Attribute2 | Attribute3 | Attribute4 | Attribute5 | Attribute6 | Attribute7 | Attribute8 | Attribute9 | Attribute10 | Attribute11 | Attribute12 | Attribute13 | Attribute14 | Attribute15 | Attribute16 | Attribute17 | Attribute18 | Attribute19 | Attribute20 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | A11 | 6 | A34 | A43 | 1169 | A65 | A75 | 4 | A93 | A101 | 4 | A121 | 67 | A143 | A152 | 2 | A173 | 1 | A192 | A201 |
1 | A12 | 48 | A32 | A43 | 5951 | A61 | A73 | 2 | A92 | A101 | 2 | A121 | 22 | A143 | A152 | 1 | A173 | 1 | A191 | A201 |
2 | A14 | 12 | A34 | A46 | 2096 | A61 | A74 | 2 | A93 | A101 | 3 | A121 | 49 | A143 | A152 | 1 | A172 | 2 | A191 | A201 |
3 | A11 | 42 | A32 | A42 | 7882 | A61 | A74 | 2 | A93 | A103 | 4 | A122 | 45 | A143 | A153 | 1 | A173 | 2 | A191 | A201 |
4 | A11 | 24 | A33 | A40 | 4870 | A61 | A73 | 3 | A93 | A101 | 4 | A124 | 53 | A143 | A153 | 2 | A173 | 2 | A191 | A201 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
995 | A14 | 12 | A32 | A42 | 1736 | A61 | A74 | 3 | A92 | A101 | 4 | A121 | 31 | A143 | A152 | 1 | A172 | 1 | A191 | A201 |
996 | A11 | 30 | A32 | A41 | 3857 | A61 | A73 | 4 | A91 | A101 | 4 | A122 | 40 | A143 | A152 | 1 | A174 | 1 | A192 | A201 |
997 | A14 | 12 | A32 | A43 | 804 | A61 | A75 | 4 | A93 | A101 | 4 | A123 | 38 | A143 | A152 | 1 | A173 | 1 | A191 | A201 |
998 | A11 | 45 | A32 | A43 | 1845 | A61 | A73 | 4 | A93 | A101 | 4 | A124 | 23 | A143 | A153 | 1 | A173 | 1 | A192 | A201 |
999 | A12 | 45 | A34 | A41 | 4576 | A62 | A71 | 3 | A93 | A101 | 4 | A123 | 27 | A143 | A152 | 1 | A173 | 1 | A191 | A201 |
1000 rows × 20 columns
Here, X depicts 20 attributes and the values corresponding to these attributes for each person represented in the data entry and y is the output variable corresponding to these attributes
[ ]:
def data_transform(X, y):
def categorical_to_int(x):
d = {u: i for i, u in enumerate(np.unique(x))}
return np.array([d[i] for i in x])
categoricals = []
numericals = []
numericals.append(np.ones([len(y)]))
for column in X:
column = X[column]
if column.dtype == "O":
categoricals.append(categorical_to_int(column))
else:
numericals.append((column - column.mean()) / column.std())
numericals = np.array(numericals).T
status = np.array(y == 1, dtype=np.int32)
status = np.squeeze(status)
return jnp.array(numericals), jnp.array(categoricals), jnp.array(status)
Data transformation for feeding it into the Numpyro model
[ ]:
numericals, categoricals, status = data_transform(X, y)
[ ]:
x_numeric = numericals.astype(jnp.float32)
x_categorical = [jnp.eye(c.max() + 1)[c] for c in categoricals]
all_x = jnp.concatenate([x_numeric] + x_categorical, axis=1)
num_features = all_x.shape[1]
y = status[jnp.newaxis, Ellipsis]
2. Model¶
We will be using a logistic regression model with hierarchical prior on coefficient scales
\begin{align} \log \tau_0 & \sim \mathcal{N}(0,10) & \log \tau_i & \sim \mathcal{N}\left(\log \tau_0, 1\right) \\ \beta_i & \sim \mathcal{N}\left(0, \tau_i\right) & y & \sim \operatorname{Bernoulli}\left(\sigma\left(\beta X^T\right)\right) \end{align}
[ ]:
def german_credit():
log_tau_zero = numpyro.sample("log_tau_zero", dist.Normal(0, 10))
log_tau_i = numpyro.sample(
"log_tau_i", dist.Normal(log_tau_zero, jnp.ones(num_features))
)
beta = numpyro.sample(
"beta", dist.Normal(jnp.zeros(num_features), jnp.exp(log_tau_i))
)
numpyro.sample(
"obs",
dist.Bernoulli(logits=jnp.einsum("nd,md->mn", all_x, beta[jnp.newaxis, :])),
obs=y,
)
[ ]:
nuts_kernel = NUTS(german_credit)
mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
mcmc.run(rng_key, extra_fields=("num_steps",))
sample: 100%|██████████| 2000/2000 [00:21<00:00, 94.07it/s, 63 steps of size 6.31e-02. acc. prob=0.87]
[ ]:
mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat
beta[0] 0.13 0.38 0.05 -0.36 0.74 284.06 1.00
beta[1] -0.34 0.12 -0.34 -0.52 -0.15 621.55 1.00
beta[2] -0.27 0.13 -0.27 -0.45 -0.03 542.13 1.00
beta[3] -0.30 0.10 -0.30 -0.44 -0.11 566.55 1.00
beta[4] -0.00 0.07 -0.00 -0.12 0.11 782.35 1.00
beta[5] 0.12 0.09 0.11 -0.02 0.27 728.28 1.01
beta[6] -0.08 0.08 -0.07 -0.22 0.05 822.89 1.00
beta[7] -0.05 0.07 -0.04 -0.19 0.05 752.66 1.00
beta[8] -0.42 0.32 -0.39 -0.87 0.05 198.00 1.00
beta[9] -0.07 0.26 -0.02 -0.50 0.31 220.27 1.00
beta[10] 0.26 0.31 0.18 -0.15 0.78 404.97 1.00
beta[11] 1.23 0.34 1.25 0.68 1.79 227.34 1.01
beta[12] -0.26 0.34 -0.17 -0.81 0.22 349.10 1.00
beta[13] -0.30 0.34 -0.21 -0.86 0.13 387.72 1.00
beta[14] 0.07 0.20 0.04 -0.26 0.38 240.45 1.03
beta[15] 0.10 0.22 0.05 -0.18 0.50 287.41 1.02
beta[16] 0.76 0.30 0.76 0.22 1.24 364.73 1.03
beta[17] -0.53 0.28 -0.55 -0.94 -0.05 269.95 1.00
beta[18] 0.70 0.42 0.70 -0.02 1.29 367.28 1.00
beta[19] 0.17 0.40 0.06 -0.43 0.77 333.54 1.00
beta[20] 0.03 0.19 0.01 -0.23 0.39 381.57 1.00
beta[21] 0.18 0.22 0.13 -0.14 0.53 335.48 1.00
beta[22] -0.05 0.32 -0.01 -0.56 0.46 439.54 1.00
beta[23] -0.10 0.30 -0.04 -0.63 0.30 508.20 1.00
beta[24] -0.34 0.36 -0.25 -0.94 0.12 283.15 1.00
beta[25] 0.14 0.40 0.04 -0.46 0.71 433.69 1.00
beta[26] -0.01 0.19 -0.00 -0.34 0.28 438.64 1.00
beta[27] -0.36 0.27 -0.33 -0.78 0.04 377.33 1.01
beta[28] -0.07 0.22 -0.03 -0.43 0.26 493.09 1.00
beta[29] 0.01 0.22 0.00 -0.32 0.34 448.21 1.00
beta[30] 0.35 0.43 0.22 -0.18 1.08 314.69 1.00
beta[31] 0.41 0.33 0.40 -0.10 0.90 402.62 1.00
beta[32] -0.03 0.21 -0.01 -0.39 0.30 525.23 1.00
beta[33] -0.12 0.18 -0.09 -0.41 0.16 334.94 1.00
beta[34] -0.02 0.16 -0.01 -0.24 0.26 318.25 1.00
beta[35] 0.42 0.27 0.42 -0.04 0.81 455.99 1.00
beta[36] 0.05 0.17 0.03 -0.18 0.35 506.34 1.00
beta[37] -0.12 0.25 -0.06 -0.57 0.21 470.11 1.00
beta[38] -0.07 0.20 -0.04 -0.39 0.24 410.71 1.00
beta[39] 0.36 0.24 0.35 -0.04 0.71 359.55 1.00
beta[40] 0.05 0.20 0.02 -0.29 0.35 441.70 1.00
beta[41] -0.00 0.21 0.00 -0.34 0.37 513.67 1.00
beta[42] -0.13 0.27 -0.08 -0.59 0.23 402.64 1.00
beta[43] 0.55 0.46 0.49 -0.11 1.28 570.74 1.00
beta[44] 0.19 0.21 0.15 -0.14 0.50 379.76 1.00
beta[45] -0.00 0.16 0.00 -0.25 0.26 352.19 1.00
beta[46] 0.01 0.16 0.01 -0.25 0.25 411.05 1.00
beta[47] -0.16 0.24 -0.11 -0.55 0.18 455.59 1.00
beta[48] -0.12 0.24 -0.07 -0.55 0.21 322.67 1.04
beta[49] -0.04 0.23 -0.02 -0.45 0.30 437.47 1.02
beta[50] 0.38 0.28 0.37 -0.03 0.82 266.19 1.04
beta[51] -0.14 0.22 -0.09 -0.52 0.16 406.31 1.00
beta[52] 0.19 0.23 0.14 -0.14 0.55 338.97 1.00
beta[53] 0.04 0.22 0.02 -0.23 0.43 438.03 1.00
beta[54] 0.05 0.24 0.02 -0.32 0.41 522.43 1.00
beta[55] 0.02 0.14 0.01 -0.22 0.23 562.00 1.00
beta[56] -0.01 0.13 -0.01 -0.24 0.21 638.20 1.00
beta[57] 0.01 0.17 0.00 -0.25 0.34 590.99 1.00
beta[58] -0.07 0.18 -0.04 -0.34 0.23 481.37 1.00
beta[59] 0.13 0.19 0.09 -0.12 0.47 507.56 1.00
beta[60] -0.14 0.33 -0.06 -0.64 0.37 303.00 1.00
beta[61] 0.48 0.56 0.32 -0.18 1.41 438.86 1.00
log_tau_i[0] -1.51 0.95 -1.52 -3.03 0.11 290.78 1.00
log_tau_i[1] -1.07 0.67 -1.11 -2.12 0.03 641.04 1.00
log_tau_i[2] -1.24 0.76 -1.26 -2.47 0.03 666.31 1.00
log_tau_i[3] -1.16 0.65 -1.19 -2.20 -0.10 821.60 1.00
log_tau_i[4] -2.11 0.88 -2.13 -3.50 -0.61 806.15 1.00
log_tau_i[5] -1.71 0.86 -1.68 -3.28 -0.44 697.00 1.00
log_tau_i[6] -1.88 0.84 -1.91 -3.30 -0.58 623.56 1.00
log_tau_i[7] -1.99 0.90 -1.98 -3.51 -0.65 710.21 1.00
log_tau_i[8] -1.00 0.86 -0.96 -2.23 0.52 445.30 1.00
log_tau_i[9] -1.69 0.93 -1.63 -3.17 -0.14 326.33 1.00
log_tau_i[10] -1.41 0.95 -1.35 -2.93 0.19 441.60 1.01
log_tau_i[11] -0.11 0.57 -0.12 -0.97 0.80 539.60 1.00
log_tau_i[12] -1.36 0.96 -1.31 -3.16 0.01 336.11 1.00
log_tau_i[13] -1.30 0.95 -1.26 -2.85 0.28 335.04 1.00
log_tau_i[14] -1.72 0.89 -1.70 -3.05 -0.25 584.38 1.00
log_tau_i[15] -1.65 0.92 -1.63 -3.07 -0.10 345.77 1.03
log_tau_i[16] -0.51 0.65 -0.49 -1.42 0.59 676.64 1.00
log_tau_i[17] -0.84 0.76 -0.76 -2.09 0.34 303.14 1.00
log_tau_i[18] -0.69 0.82 -0.59 -2.03 0.61 359.35 1.00
log_tau_i[19] -1.45 0.99 -1.42 -2.97 0.25 397.18 1.00
log_tau_i[20] -1.75 0.94 -1.73 -3.39 -0.40 617.54 1.00
log_tau_i[21] -1.51 0.88 -1.49 -3.16 -0.27 488.52 1.00
log_tau_i[22] -1.56 0.93 -1.56 -3.06 -0.10 348.20 1.00
log_tau_i[23] -1.58 0.94 -1.57 -3.05 0.02 278.69 1.00
log_tau_i[24] -1.26 1.00 -1.12 -2.91 0.29 205.38 1.00
log_tau_i[25] -1.53 0.95 -1.56 -3.11 0.02 351.09 1.00
log_tau_i[26] -1.73 0.91 -1.74 -3.17 -0.22 492.18 1.00
log_tau_i[27] -1.15 0.89 -1.08 -2.66 0.17 485.34 1.00
log_tau_i[28] -1.69 0.92 -1.65 -3.15 -0.19 425.75 1.00
log_tau_i[29] -1.71 0.99 -1.71 -3.19 0.01 374.58 1.00
log_tau_i[30] -1.24 0.99 -1.20 -2.74 0.50 327.47 1.00
log_tau_i[31] -1.02 0.89 -0.91 -2.40 0.51 587.85 1.00
log_tau_i[32] -1.71 0.94 -1.70 -3.22 -0.11 511.74 1.00
log_tau_i[33] -1.69 0.90 -1.68 -3.13 -0.28 538.65 1.00
log_tau_i[34] -1.82 0.92 -1.81 -3.35 -0.35 423.01 1.00
log_tau_i[35] -1.06 0.82 -1.00 -2.30 0.34 470.50 1.00
log_tau_i[36] -1.79 0.87 -1.76 -3.15 -0.34 527.47 1.00
log_tau_i[37] -1.58 0.95 -1.54 -3.11 0.04 485.52 1.00
log_tau_i[38] -1.71 0.87 -1.65 -3.18 -0.34 482.67 1.00
log_tau_i[39] -1.12 0.85 -1.01 -2.44 0.33 337.59 1.00
log_tau_i[40] -1.76 0.96 -1.73 -3.58 -0.36 533.15 1.00
log_tau_i[41] -1.74 0.94 -1.70 -3.26 -0.22 500.91 1.00
log_tau_i[42] -1.57 0.95 -1.54 -3.04 0.01 499.44 1.00
log_tau_i[43] -0.87 0.93 -0.74 -2.28 0.58 445.98 1.00
log_tau_i[44] -1.52 0.89 -1.45 -2.95 -0.13 442.63 1.00
log_tau_i[45] -1.84 0.94 -1.79 -3.21 -0.09 673.31 1.00
log_tau_i[46] -1.82 0.85 -1.83 -3.26 -0.56 579.66 1.00
log_tau_i[47] -1.54 0.90 -1.51 -3.35 -0.30 428.50 1.00
log_tau_i[48] -1.62 0.89 -1.60 -3.00 -0.15 413.30 1.01
log_tau_i[49] -1.71 0.95 -1.68 -3.23 -0.13 514.04 1.00
log_tau_i[50] -1.12 0.92 -0.99 -2.67 0.38 206.76 1.03
log_tau_i[51] -1.61 0.92 -1.58 -3.07 -0.03 477.41 1.00
log_tau_i[52] -1.54 0.90 -1.49 -2.96 -0.09 459.83 1.00
log_tau_i[53] -1.74 0.92 -1.69 -3.13 -0.14 509.51 1.00
log_tau_i[54] -1.68 0.95 -1.67 -3.07 0.10 477.21 1.00
log_tau_i[55] -1.87 0.97 -1.88 -3.49 -0.35 514.38 1.00
log_tau_i[56] -1.87 0.96 -1.84 -3.23 -0.12 574.80 1.00
log_tau_i[57] -1.77 0.86 -1.72 -3.26 -0.36 646.10 1.00
log_tau_i[58] -1.78 0.92 -1.77 -3.18 -0.15 617.59 1.00
log_tau_i[59] -1.67 0.93 -1.61 -3.19 -0.21 510.74 1.00
log_tau_i[60] -1.50 0.99 -1.44 -3.09 0.08 386.86 1.00
log_tau_i[61] -1.09 1.06 -1.00 -2.79 0.52 421.27 1.00
log_tau_zero -1.49 0.26 -1.49 -1.90 -1.05 169.88 1.00
Number of divergences: 37
From mcmc.print_summary it is evident that there are 37 divergences. Thus, we will use Variationally Inferred Parameterization (VIP) to reduce these divergences
[ ]:
data = az.from_numpyro(mcmc)
az.plot_trace(data, compact=True);
3. Reparameterization¶
We introduce a parameterization parameters \(\lambda \in [0,1]\) for any variable \(z\), and transform:
=> \(z\) ~ \(N (z | μ, σ)\)
=> by defining \(z\) ~ \(N(λμ, σ^λ)\)
=> \(z\) = \(μ + σ^{1-λ}(z - λμ)\).
Thus, using the above transformation the joint density can be transformed as follows: \begin{align} p(\theta, \hat{\mu}, \mathbf{y}) & =\mathcal{N}(\theta \mid 0,1) \times \mathcal{N}\left(\mu \mid \theta, \sigma_\mu\right) \times \mathcal{N}(\mathbf{y} \mid \mu, \sigma) \end{align}
\begin{align} p(\theta, \hat{\mu}, \mathbf{y}) & =\mathcal{N}(\theta \mid 0,1) \times \mathcal{N}\left(\hat{\mu} \mid \lambda \theta, \sigma_\mu^\lambda\right) \times \mathcal{N}\left(\mathbf{y} \mid \theta+\sigma_\mu^{1-\lambda}(\hat{\mu}-\lambda \theta), \sigma\right) \end{align}
[ ]:
def german_credit_reparam(beta_centeredness=None):
def model():
log_tau_zero = numpyro.sample("log_tau_zero", dist.Normal(0, 10))
log_tau_i = numpyro.sample(
"log_tau_i", dist.Normal(log_tau_zero, jnp.ones(num_features))
)
with numpyro.handlers.reparam(
config={"beta": LocScaleReparam(beta_centeredness)}
):
beta = numpyro.sample(
"beta", dist.Normal(jnp.zeros(num_features), jnp.exp(log_tau_i))
)
numpyro.sample(
"obs",
dist.Bernoulli(logits=jnp.einsum("nd,md->mn", all_x, beta[jnp.newaxis, :])),
obs=y,
)
return model
Now, using SVI we optimize \(\lambda\).
[ ]:
model = german_credit_reparam()
guide = AutoDiagonalNormal(model)
svi = SVI(model, guide, numpyro.optim.Adam(3e-4), Trace_ELBO(10))
svi_results = svi.run(rng_key, 10000)
100%|██████████| 10000/10000 [00:16<00:00, 588.87it/s, init loss: 2165.2424, avg. loss [9501-10000]: 576.7846]
[ ]:
reparam_model = german_credit_reparam(
beta_centeredness=svi_results.params["beta_centered"]
)
[ ]:
nuts_kernel = NUTS(reparam_model)
mcmc_reparam = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
mcmc_reparam.run(rng_key, extra_fields=("num_steps",))
sample: 100%|██████████| 2000/2000 [00:07<00:00, 285.41it/s, 31 steps of size 1.28e-01. acc. prob=0.89]
[ ]:
mcmc_reparam.print_summary()
mean std median 5.0% 95.0% n_eff r_hat
beta_decentered[0] 0.12 0.40 0.06 -0.48 0.80 338.70 1.00
beta_decentered[1] -0.45 0.15 -0.45 -0.70 -0.21 791.23 1.00
beta_decentered[2] -0.38 0.17 -0.38 -0.65 -0.09 691.79 1.00
beta_decentered[3] -0.41 0.13 -0.41 -0.61 -0.19 1022.79 1.00
beta_decentered[4] -0.01 0.11 -0.01 -0.18 0.20 1176.84 1.00
beta_decentered[5] 0.19 0.14 0.19 -0.04 0.41 1194.41 1.00
beta_decentered[6] -0.13 0.14 -0.13 -0.36 0.09 1227.24 1.00
beta_decentered[7] -0.07 0.12 -0.06 -0.24 0.14 1096.31 1.00
beta_decentered[8] -0.46 0.34 -0.46 -0.99 0.08 330.30 1.00
beta_decentered[9] -0.03 0.32 -0.02 -0.57 0.49 310.35 1.00
beta_decentered[10] 0.35 0.39 0.30 -0.26 1.00 426.11 1.00
beta_decentered[11] 1.29 0.31 1.30 0.81 1.82 433.16 1.00
beta_decentered[12] -0.32 0.39 -0.25 -0.96 0.24 521.05 1.00
beta_decentered[13] -0.38 0.40 -0.32 -1.00 0.24 410.05 1.00
beta_decentered[14] 0.08 0.28 0.06 -0.37 0.57 457.72 1.00
beta_decentered[15] 0.14 0.30 0.10 -0.28 0.66 612.31 1.00
beta_decentered[16] 0.85 0.31 0.86 0.41 1.45 432.14 1.00
beta_decentered[17] -0.64 0.28 -0.65 -1.05 -0.14 523.15 1.00
beta_decentered[18] 0.78 0.42 0.78 0.07 1.46 545.52 1.00
beta_decentered[19] 0.15 0.39 0.08 -0.50 0.80 662.60 1.00
beta_decentered[20] 0.04 0.25 0.03 -0.39 0.40 445.85 1.00
beta_decentered[21] 0.24 0.27 0.21 -0.20 0.65 477.68 1.00
beta_decentered[22] -0.03 0.38 -0.01 -0.64 0.60 984.59 1.00
beta_decentered[23] -0.13 0.34 -0.08 -0.72 0.35 702.87 1.00
beta_decentered[24] -0.41 0.39 -0.37 -1.08 0.13 603.13 1.00
beta_decentered[25] 0.19 0.47 0.09 -0.48 0.92 529.68 1.00
beta_decentered[26] 0.00 0.25 0.01 -0.47 0.35 690.54 1.00
beta_decentered[27] -0.46 0.31 -0.46 -0.95 0.04 464.44 1.00
beta_decentered[28] -0.09 0.30 -0.06 -0.56 0.41 464.65 1.00
beta_decentered[29] 0.02 0.30 0.01 -0.47 0.52 747.44 1.00
beta_decentered[30] 0.38 0.44 0.31 -0.30 1.05 717.12 1.00
beta_decentered[31] 0.47 0.36 0.47 -0.09 1.03 564.18 1.00
beta_decentered[32] -0.03 0.26 -0.02 -0.44 0.44 572.03 1.00
beta_decentered[33] -0.17 0.25 -0.15 -0.63 0.19 713.40 1.00
beta_decentered[34] -0.02 0.21 -0.01 -0.39 0.32 620.45 1.01
beta_decentered[35] 0.53 0.31 0.55 -0.03 1.00 681.60 1.00
beta_decentered[36] 0.09 0.24 0.06 -0.27 0.49 610.06 1.00
beta_decentered[37] -0.14 0.31 -0.10 -0.74 0.28 826.87 1.00
beta_decentered[38] -0.12 0.25 -0.11 -0.53 0.30 493.49 1.00
beta_decentered[39] 0.44 0.28 0.44 -0.01 0.89 542.71 1.00
beta_decentered[40] 0.05 0.26 0.03 -0.39 0.45 709.78 1.00
beta_decentered[41] 0.02 0.30 0.01 -0.52 0.46 389.41 1.00
beta_decentered[42] -0.15 0.33 -0.10 -0.74 0.32 607.05 1.00
beta_decentered[43] 0.66 0.47 0.65 -0.11 1.38 539.31 1.01
beta_decentered[44] 0.25 0.27 0.23 -0.19 0.66 686.63 1.00
beta_decentered[45] -0.01 0.22 -0.00 -0.36 0.35 909.70 1.00
beta_decentered[46] 0.02 0.22 0.01 -0.33 0.39 741.33 1.00
beta_decentered[47] -0.25 0.31 -0.21 -0.72 0.25 487.26 1.00
beta_decentered[48] -0.18 0.30 -0.15 -0.67 0.30 400.57 1.00
beta_decentered[49] -0.07 0.30 -0.05 -0.56 0.41 594.38 1.00
beta_decentered[50] 0.46 0.31 0.46 -0.09 0.88 384.33 1.00
beta_decentered[51] -0.23 0.27 -0.19 -0.69 0.16 561.35 1.00
beta_decentered[52] 0.22 0.25 0.21 -0.19 0.60 456.41 1.00
beta_decentered[53] 0.07 0.29 0.04 -0.41 0.53 500.99 1.00
beta_decentered[54] 0.05 0.30 0.02 -0.49 0.51 896.08 1.00
beta_decentered[55] 0.01 0.23 0.01 -0.34 0.43 1033.13 1.00
beta_decentered[56] -0.02 0.19 -0.01 -0.33 0.28 717.87 1.00
beta_decentered[57] 0.01 0.23 0.01 -0.33 0.41 684.61 1.00
beta_decentered[58] -0.09 0.26 -0.08 -0.55 0.30 455.66 1.01
beta_decentered[59] 0.20 0.27 0.18 -0.20 0.63 413.57 1.01
beta_decentered[60] -0.14 0.37 -0.09 -0.76 0.46 411.83 1.00
beta_decentered[61] 0.58 0.57 0.50 -0.23 1.50 495.86 1.00
log_tau_i[0] -1.56 0.91 -1.53 -2.94 -0.01 480.69 1.00
log_tau_i[1] -1.07 0.64 -1.08 -1.99 0.06 950.02 1.00
log_tau_i[2] -1.25 0.74 -1.23 -2.36 -0.06 796.37 1.00
log_tau_i[3] -1.15 0.65 -1.18 -2.25 -0.14 838.40 1.00
log_tau_i[4] -2.09 0.90 -2.12 -3.51 -0.49 1120.31 1.00
log_tau_i[5] -1.70 0.84 -1.69 -3.10 -0.35 1291.74 1.00
log_tau_i[6] -1.87 0.92 -1.84 -3.67 -0.58 1066.38 1.00
log_tau_i[7] -2.06 0.89 -2.07 -3.42 -0.42 641.48 1.00
log_tau_i[8] -1.11 0.91 -1.03 -2.79 0.19 482.69 1.00
log_tau_i[9] -1.65 0.91 -1.62 -3.05 -0.04 772.73 1.00
log_tau_i[10] -1.31 0.94 -1.27 -2.77 0.24 578.34 1.00
log_tau_i[11] -0.04 0.49 -0.08 -0.81 0.77 736.20 1.00
log_tau_i[12] -1.37 0.98 -1.32 -2.86 0.31 623.14 1.00
log_tau_i[13] -1.28 0.97 -1.21 -2.78 0.30 653.51 1.00
log_tau_i[14] -1.70 0.95 -1.71 -3.20 -0.15 831.88 1.00
log_tau_i[15] -1.67 0.97 -1.65 -3.15 0.04 726.23 1.00
log_tau_i[16] -0.53 0.65 -0.49 -1.55 0.54 558.56 1.00
log_tau_i[17] -0.81 0.68 -0.80 -1.91 0.27 655.36 1.00
log_tau_i[18] -0.71 0.83 -0.60 -1.89 0.75 536.55 1.00
log_tau_i[19] -1.55 0.98 -1.53 -3.17 -0.02 675.27 1.00
log_tau_i[20] -1.81 0.90 -1.79 -3.36 -0.41 823.93 1.00
log_tau_i[21] -1.53 0.93 -1.51 -3.03 -0.04 767.64 1.00
log_tau_i[22] -1.60 0.99 -1.58 -3.25 -0.01 840.71 1.00
log_tau_i[23] -1.64 0.97 -1.60 -3.15 0.10 615.62 1.00
log_tau_i[24] -1.21 0.96 -1.11 -2.61 0.47 674.58 1.00
log_tau_i[25] -1.54 1.04 -1.49 -3.39 -0.01 509.69 1.00
log_tau_i[26] -1.77 0.95 -1.74 -3.27 -0.13 915.91 1.00
log_tau_i[27] -1.14 0.89 -1.07 -2.53 0.36 569.10 1.01
log_tau_i[28] -1.72 0.93 -1.66 -3.25 -0.20 911.70 1.01
log_tau_i[29] -1.70 0.91 -1.71 -3.16 -0.14 648.29 1.00
log_tau_i[30] -1.28 1.02 -1.28 -2.99 0.34 575.90 1.00
log_tau_i[31] -1.10 0.86 -1.04 -2.39 0.45 730.21 1.00
log_tau_i[32] -1.79 0.95 -1.82 -3.32 -0.19 667.95 1.00
log_tau_i[33] -1.64 0.86 -1.62 -3.04 -0.32 948.64 1.00
log_tau_i[34] -1.87 0.92 -1.88 -3.29 -0.25 909.49 1.00
log_tau_i[35] -1.00 0.85 -0.95 -2.44 0.25 672.42 1.00
log_tau_i[36] -1.76 0.92 -1.73 -3.19 -0.22 889.56 1.00
log_tau_i[37] -1.63 1.00 -1.58 -3.11 0.12 973.29 1.00
log_tau_i[38] -1.72 0.85 -1.70 -3.13 -0.30 837.73 1.00
log_tau_i[39] -1.15 0.80 -1.13 -2.32 0.18 627.15 1.00
log_tau_i[40] -1.76 0.93 -1.70 -3.36 -0.34 686.88 1.00
log_tau_i[41] -1.75 0.96 -1.74 -3.20 -0.09 612.83 1.00
log_tau_i[42] -1.62 0.91 -1.62 -3.21 -0.25 596.18 1.00
log_tau_i[43] -0.79 0.93 -0.74 -2.20 0.83 560.52 1.01
log_tau_i[44] -1.52 0.94 -1.48 -3.08 0.01 888.27 1.00
log_tau_i[45] -1.83 0.90 -1.84 -3.27 -0.44 1122.53 1.00
log_tau_i[46] -1.83 0.89 -1.78 -3.21 -0.31 990.03 1.00
log_tau_i[47] -1.56 0.93 -1.48 -3.20 -0.21 736.01 1.00
log_tau_i[48] -1.59 0.94 -1.54 -3.20 -0.01 588.77 1.00
log_tau_i[49] -1.71 0.93 -1.68 -3.17 -0.17 813.94 1.00
log_tau_i[50] -1.15 0.83 -1.11 -2.51 0.14 514.68 1.00
log_tau_i[51] -1.54 0.89 -1.51 -2.97 -0.13 780.67 1.00
log_tau_i[52] -1.59 0.93 -1.56 -3.21 -0.28 807.47 1.00
log_tau_i[53] -1.74 0.92 -1.72 -3.10 -0.15 657.89 1.00
log_tau_i[54] -1.74 0.93 -1.74 -3.09 -0.10 867.29 1.00
log_tau_i[55] -1.89 0.94 -1.87 -3.34 -0.27 1132.90 1.00
log_tau_i[56] -1.89 0.92 -1.89 -3.27 -0.27 980.72 1.00
log_tau_i[57] -1.84 0.91 -1.84 -3.33 -0.38 813.32 1.00
log_tau_i[58] -1.75 0.93 -1.74 -3.25 -0.34 583.52 1.00
log_tau_i[59] -1.59 0.96 -1.53 -3.11 -0.06 754.39 1.01
log_tau_i[60] -1.58 0.98 -1.52 -3.23 -0.08 561.05 1.00
log_tau_i[61] -0.98 1.03 -0.85 -2.55 0.79 502.63 1.00
log_tau_zero -1.49 0.26 -1.49 -1.91 -1.10 220.32 1.00
Number of divergences: 1
The number of divergences have significantly reduced from 37 to 1.
[ ]:
data = az.from_numpyro(mcmc_reparam)
az.plot_trace(data, compact=True, figsize=(15, 25));