Interactive online version: Open In Colab

Variationally Inferred Parameterization

Author: Madhav Kanda

Occasionally, the Hamiltonian Monte Carlo (HMC) sampler encounters challenges in effectively sampling from the posterior distribution. One illustrative case is Neal’s funnel. In these situations, the conventional centered parameterization may prove inadequate, leading us to employ non-centered parameterization. However, there are instances where even non-centered parameterization may not suffice, necessitating the utilization of Variationally Inferred Parameterization to attain the desired centeredness within the range of 0 to 1.

The purpose of this tutorial is to implement Variationally Inferred Parameterization based on Automatic Reparameterization of Probabilistic Programs using LocScaleReparam in Numpyro.

[ ]:
%pip -qq install numpyro
%pip -qq install ucimlrepo
[ ]:
import arviz as az
import numpy as np
from ucimlrepo import fetch_ucirepo

import jax
import jax.numpy as jnp

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoDiagonalNormal
from numpyro.infer.reparam import LocScaleReparam

rng_key = jax.random.PRNGKey(0)

1. Dataset

We will be using the German Credit Dataset for this illustration. The dataset consists of 1000 entries with 20 categorial symbolic attributes prepared by Prof. Hofmann. In this dataset, each entry represents a person who takes a credit by a bank. Each person is classified as good or bad credit risks according to the set of attributes.

[ ]:
def load_german_credit():
    statlog_german_credit_data = fetch_ucirepo(id=144)
    X = statlog_german_credit_data.data.features
    y = statlog_german_credit_data.data.targets
    return X, y
[ ]:
X, y = load_german_credit()
X
Attribute1 Attribute2 Attribute3 Attribute4 Attribute5 Attribute6 Attribute7 Attribute8 Attribute9 Attribute10 Attribute11 Attribute12 Attribute13 Attribute14 Attribute15 Attribute16 Attribute17 Attribute18 Attribute19 Attribute20
0 A11 6 A34 A43 1169 A65 A75 4 A93 A101 4 A121 67 A143 A152 2 A173 1 A192 A201
1 A12 48 A32 A43 5951 A61 A73 2 A92 A101 2 A121 22 A143 A152 1 A173 1 A191 A201
2 A14 12 A34 A46 2096 A61 A74 2 A93 A101 3 A121 49 A143 A152 1 A172 2 A191 A201
3 A11 42 A32 A42 7882 A61 A74 2 A93 A103 4 A122 45 A143 A153 1 A173 2 A191 A201
4 A11 24 A33 A40 4870 A61 A73 3 A93 A101 4 A124 53 A143 A153 2 A173 2 A191 A201
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
995 A14 12 A32 A42 1736 A61 A74 3 A92 A101 4 A121 31 A143 A152 1 A172 1 A191 A201
996 A11 30 A32 A41 3857 A61 A73 4 A91 A101 4 A122 40 A143 A152 1 A174 1 A192 A201
997 A14 12 A32 A43 804 A61 A75 4 A93 A101 4 A123 38 A143 A152 1 A173 1 A191 A201
998 A11 45 A32 A43 1845 A61 A73 4 A93 A101 4 A124 23 A143 A153 1 A173 1 A192 A201
999 A12 45 A34 A41 4576 A62 A71 3 A93 A101 4 A123 27 A143 A152 1 A173 1 A191 A201

1000 rows × 20 columns

Here, X depicts 20 attributes and the values corresponding to these attributes for each person represented in the data entry and y is the output variable corresponding to these attributes

[ ]:
def data_transform(X, y):
    def categorical_to_int(x):
        d = {u: i for i, u in enumerate(np.unique(x))}
        return np.array([d[i] for i in x])

    categoricals = []
    numericals = []
    numericals.append(np.ones([len(y)]))
    for column in X:
        column = X[column]
        if column.dtype == "O":
            categoricals.append(categorical_to_int(column))
        else:
            numericals.append((column - column.mean()) / column.std())
    numericals = np.array(numericals).T
    status = np.array(y == 1, dtype=np.int32)
    status = np.squeeze(status)

    return jnp.array(numericals), jnp.array(categoricals), jnp.array(status)

Data transformation for feeding it into the Numpyro model

[ ]:
numericals, categoricals, status = data_transform(X, y)
[ ]:
x_numeric = numericals.astype(jnp.float32)
x_categorical = [jnp.eye(c.max() + 1)[c] for c in categoricals]
all_x = jnp.concatenate([x_numeric] + x_categorical, axis=1)
num_features = all_x.shape[1]
y = status[jnp.newaxis, Ellipsis]

2. Model

We will be using a logistic regression model with hierarchical prior on coefficient scales

\begin{align} \log \tau_0 & \sim \mathcal{N}(0,10) & \log \tau_i & \sim \mathcal{N}\left(\log \tau_0, 1\right) \\ \beta_i & \sim \mathcal{N}\left(0, \tau_i\right) & y & \sim \operatorname{Bernoulli}\left(\sigma\left(\beta X^T\right)\right) \end{align}

[ ]:
def german_credit():
    log_tau_zero = numpyro.sample("log_tau_zero", dist.Normal(0, 10))
    log_tau_i = numpyro.sample(
        "log_tau_i", dist.Normal(log_tau_zero, jnp.ones(num_features))
    )
    beta = numpyro.sample(
        "beta", dist.Normal(jnp.zeros(num_features), jnp.exp(log_tau_i))
    )
    numpyro.sample(
        "obs",
        dist.Bernoulli(logits=jnp.einsum("nd,md->mn", all_x, beta[jnp.newaxis, :])),
        obs=y,
    )
[ ]:
nuts_kernel = NUTS(german_credit)
mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
mcmc.run(rng_key, extra_fields=("num_steps",))
sample: 100%|██████████| 2000/2000 [00:21<00:00, 94.07it/s, 63 steps of size 6.31e-02. acc. prob=0.87]
[ ]:
mcmc.print_summary()

                    mean       std    median      5.0%     95.0%     n_eff     r_hat
       beta[0]      0.13      0.38      0.05     -0.36      0.74    284.06      1.00
       beta[1]     -0.34      0.12     -0.34     -0.52     -0.15    621.55      1.00
       beta[2]     -0.27      0.13     -0.27     -0.45     -0.03    542.13      1.00
       beta[3]     -0.30      0.10     -0.30     -0.44     -0.11    566.55      1.00
       beta[4]     -0.00      0.07     -0.00     -0.12      0.11    782.35      1.00
       beta[5]      0.12      0.09      0.11     -0.02      0.27    728.28      1.01
       beta[6]     -0.08      0.08     -0.07     -0.22      0.05    822.89      1.00
       beta[7]     -0.05      0.07     -0.04     -0.19      0.05    752.66      1.00
       beta[8]     -0.42      0.32     -0.39     -0.87      0.05    198.00      1.00
       beta[9]     -0.07      0.26     -0.02     -0.50      0.31    220.27      1.00
      beta[10]      0.26      0.31      0.18     -0.15      0.78    404.97      1.00
      beta[11]      1.23      0.34      1.25      0.68      1.79    227.34      1.01
      beta[12]     -0.26      0.34     -0.17     -0.81      0.22    349.10      1.00
      beta[13]     -0.30      0.34     -0.21     -0.86      0.13    387.72      1.00
      beta[14]      0.07      0.20      0.04     -0.26      0.38    240.45      1.03
      beta[15]      0.10      0.22      0.05     -0.18      0.50    287.41      1.02
      beta[16]      0.76      0.30      0.76      0.22      1.24    364.73      1.03
      beta[17]     -0.53      0.28     -0.55     -0.94     -0.05    269.95      1.00
      beta[18]      0.70      0.42      0.70     -0.02      1.29    367.28      1.00
      beta[19]      0.17      0.40      0.06     -0.43      0.77    333.54      1.00
      beta[20]      0.03      0.19      0.01     -0.23      0.39    381.57      1.00
      beta[21]      0.18      0.22      0.13     -0.14      0.53    335.48      1.00
      beta[22]     -0.05      0.32     -0.01     -0.56      0.46    439.54      1.00
      beta[23]     -0.10      0.30     -0.04     -0.63      0.30    508.20      1.00
      beta[24]     -0.34      0.36     -0.25     -0.94      0.12    283.15      1.00
      beta[25]      0.14      0.40      0.04     -0.46      0.71    433.69      1.00
      beta[26]     -0.01      0.19     -0.00     -0.34      0.28    438.64      1.00
      beta[27]     -0.36      0.27     -0.33     -0.78      0.04    377.33      1.01
      beta[28]     -0.07      0.22     -0.03     -0.43      0.26    493.09      1.00
      beta[29]      0.01      0.22      0.00     -0.32      0.34    448.21      1.00
      beta[30]      0.35      0.43      0.22     -0.18      1.08    314.69      1.00
      beta[31]      0.41      0.33      0.40     -0.10      0.90    402.62      1.00
      beta[32]     -0.03      0.21     -0.01     -0.39      0.30    525.23      1.00
      beta[33]     -0.12      0.18     -0.09     -0.41      0.16    334.94      1.00
      beta[34]     -0.02      0.16     -0.01     -0.24      0.26    318.25      1.00
      beta[35]      0.42      0.27      0.42     -0.04      0.81    455.99      1.00
      beta[36]      0.05      0.17      0.03     -0.18      0.35    506.34      1.00
      beta[37]     -0.12      0.25     -0.06     -0.57      0.21    470.11      1.00
      beta[38]     -0.07      0.20     -0.04     -0.39      0.24    410.71      1.00
      beta[39]      0.36      0.24      0.35     -0.04      0.71    359.55      1.00
      beta[40]      0.05      0.20      0.02     -0.29      0.35    441.70      1.00
      beta[41]     -0.00      0.21      0.00     -0.34      0.37    513.67      1.00
      beta[42]     -0.13      0.27     -0.08     -0.59      0.23    402.64      1.00
      beta[43]      0.55      0.46      0.49     -0.11      1.28    570.74      1.00
      beta[44]      0.19      0.21      0.15     -0.14      0.50    379.76      1.00
      beta[45]     -0.00      0.16      0.00     -0.25      0.26    352.19      1.00
      beta[46]      0.01      0.16      0.01     -0.25      0.25    411.05      1.00
      beta[47]     -0.16      0.24     -0.11     -0.55      0.18    455.59      1.00
      beta[48]     -0.12      0.24     -0.07     -0.55      0.21    322.67      1.04
      beta[49]     -0.04      0.23     -0.02     -0.45      0.30    437.47      1.02
      beta[50]      0.38      0.28      0.37     -0.03      0.82    266.19      1.04
      beta[51]     -0.14      0.22     -0.09     -0.52      0.16    406.31      1.00
      beta[52]      0.19      0.23      0.14     -0.14      0.55    338.97      1.00
      beta[53]      0.04      0.22      0.02     -0.23      0.43    438.03      1.00
      beta[54]      0.05      0.24      0.02     -0.32      0.41    522.43      1.00
      beta[55]      0.02      0.14      0.01     -0.22      0.23    562.00      1.00
      beta[56]     -0.01      0.13     -0.01     -0.24      0.21    638.20      1.00
      beta[57]      0.01      0.17      0.00     -0.25      0.34    590.99      1.00
      beta[58]     -0.07      0.18     -0.04     -0.34      0.23    481.37      1.00
      beta[59]      0.13      0.19      0.09     -0.12      0.47    507.56      1.00
      beta[60]     -0.14      0.33     -0.06     -0.64      0.37    303.00      1.00
      beta[61]      0.48      0.56      0.32     -0.18      1.41    438.86      1.00
  log_tau_i[0]     -1.51      0.95     -1.52     -3.03      0.11    290.78      1.00
  log_tau_i[1]     -1.07      0.67     -1.11     -2.12      0.03    641.04      1.00
  log_tau_i[2]     -1.24      0.76     -1.26     -2.47      0.03    666.31      1.00
  log_tau_i[3]     -1.16      0.65     -1.19     -2.20     -0.10    821.60      1.00
  log_tau_i[4]     -2.11      0.88     -2.13     -3.50     -0.61    806.15      1.00
  log_tau_i[5]     -1.71      0.86     -1.68     -3.28     -0.44    697.00      1.00
  log_tau_i[6]     -1.88      0.84     -1.91     -3.30     -0.58    623.56      1.00
  log_tau_i[7]     -1.99      0.90     -1.98     -3.51     -0.65    710.21      1.00
  log_tau_i[8]     -1.00      0.86     -0.96     -2.23      0.52    445.30      1.00
  log_tau_i[9]     -1.69      0.93     -1.63     -3.17     -0.14    326.33      1.00
 log_tau_i[10]     -1.41      0.95     -1.35     -2.93      0.19    441.60      1.01
 log_tau_i[11]     -0.11      0.57     -0.12     -0.97      0.80    539.60      1.00
 log_tau_i[12]     -1.36      0.96     -1.31     -3.16      0.01    336.11      1.00
 log_tau_i[13]     -1.30      0.95     -1.26     -2.85      0.28    335.04      1.00
 log_tau_i[14]     -1.72      0.89     -1.70     -3.05     -0.25    584.38      1.00
 log_tau_i[15]     -1.65      0.92     -1.63     -3.07     -0.10    345.77      1.03
 log_tau_i[16]     -0.51      0.65     -0.49     -1.42      0.59    676.64      1.00
 log_tau_i[17]     -0.84      0.76     -0.76     -2.09      0.34    303.14      1.00
 log_tau_i[18]     -0.69      0.82     -0.59     -2.03      0.61    359.35      1.00
 log_tau_i[19]     -1.45      0.99     -1.42     -2.97      0.25    397.18      1.00
 log_tau_i[20]     -1.75      0.94     -1.73     -3.39     -0.40    617.54      1.00
 log_tau_i[21]     -1.51      0.88     -1.49     -3.16     -0.27    488.52      1.00
 log_tau_i[22]     -1.56      0.93     -1.56     -3.06     -0.10    348.20      1.00
 log_tau_i[23]     -1.58      0.94     -1.57     -3.05      0.02    278.69      1.00
 log_tau_i[24]     -1.26      1.00     -1.12     -2.91      0.29    205.38      1.00
 log_tau_i[25]     -1.53      0.95     -1.56     -3.11      0.02    351.09      1.00
 log_tau_i[26]     -1.73      0.91     -1.74     -3.17     -0.22    492.18      1.00
 log_tau_i[27]     -1.15      0.89     -1.08     -2.66      0.17    485.34      1.00
 log_tau_i[28]     -1.69      0.92     -1.65     -3.15     -0.19    425.75      1.00
 log_tau_i[29]     -1.71      0.99     -1.71     -3.19      0.01    374.58      1.00
 log_tau_i[30]     -1.24      0.99     -1.20     -2.74      0.50    327.47      1.00
 log_tau_i[31]     -1.02      0.89     -0.91     -2.40      0.51    587.85      1.00
 log_tau_i[32]     -1.71      0.94     -1.70     -3.22     -0.11    511.74      1.00
 log_tau_i[33]     -1.69      0.90     -1.68     -3.13     -0.28    538.65      1.00
 log_tau_i[34]     -1.82      0.92     -1.81     -3.35     -0.35    423.01      1.00
 log_tau_i[35]     -1.06      0.82     -1.00     -2.30      0.34    470.50      1.00
 log_tau_i[36]     -1.79      0.87     -1.76     -3.15     -0.34    527.47      1.00
 log_tau_i[37]     -1.58      0.95     -1.54     -3.11      0.04    485.52      1.00
 log_tau_i[38]     -1.71      0.87     -1.65     -3.18     -0.34    482.67      1.00
 log_tau_i[39]     -1.12      0.85     -1.01     -2.44      0.33    337.59      1.00
 log_tau_i[40]     -1.76      0.96     -1.73     -3.58     -0.36    533.15      1.00
 log_tau_i[41]     -1.74      0.94     -1.70     -3.26     -0.22    500.91      1.00
 log_tau_i[42]     -1.57      0.95     -1.54     -3.04      0.01    499.44      1.00
 log_tau_i[43]     -0.87      0.93     -0.74     -2.28      0.58    445.98      1.00
 log_tau_i[44]     -1.52      0.89     -1.45     -2.95     -0.13    442.63      1.00
 log_tau_i[45]     -1.84      0.94     -1.79     -3.21     -0.09    673.31      1.00
 log_tau_i[46]     -1.82      0.85     -1.83     -3.26     -0.56    579.66      1.00
 log_tau_i[47]     -1.54      0.90     -1.51     -3.35     -0.30    428.50      1.00
 log_tau_i[48]     -1.62      0.89     -1.60     -3.00     -0.15    413.30      1.01
 log_tau_i[49]     -1.71      0.95     -1.68     -3.23     -0.13    514.04      1.00
 log_tau_i[50]     -1.12      0.92     -0.99     -2.67      0.38    206.76      1.03
 log_tau_i[51]     -1.61      0.92     -1.58     -3.07     -0.03    477.41      1.00
 log_tau_i[52]     -1.54      0.90     -1.49     -2.96     -0.09    459.83      1.00
 log_tau_i[53]     -1.74      0.92     -1.69     -3.13     -0.14    509.51      1.00
 log_tau_i[54]     -1.68      0.95     -1.67     -3.07      0.10    477.21      1.00
 log_tau_i[55]     -1.87      0.97     -1.88     -3.49     -0.35    514.38      1.00
 log_tau_i[56]     -1.87      0.96     -1.84     -3.23     -0.12    574.80      1.00
 log_tau_i[57]     -1.77      0.86     -1.72     -3.26     -0.36    646.10      1.00
 log_tau_i[58]     -1.78      0.92     -1.77     -3.18     -0.15    617.59      1.00
 log_tau_i[59]     -1.67      0.93     -1.61     -3.19     -0.21    510.74      1.00
 log_tau_i[60]     -1.50      0.99     -1.44     -3.09      0.08    386.86      1.00
 log_tau_i[61]     -1.09      1.06     -1.00     -2.79      0.52    421.27      1.00
  log_tau_zero     -1.49      0.26     -1.49     -1.90     -1.05    169.88      1.00

Number of divergences: 37

From mcmc.print_summary it is evident that there are 37 divergences. Thus, we will use Variationally Inferred Parameterization (VIP) to reduce these divergences

[ ]:
data = az.from_numpyro(mcmc)
az.plot_trace(data, compact=True);
../_images/tutorials_variationally_inferred_parameterization_16_0.png

3. Reparameterization

We introduce a parameterization parameters \(\lambda \in [0,1]\) for any variable \(z\), and transform:

=> \(z\) ~ \(N (z | μ, σ)\)

=> by defining \(z\) ~ \(N(λμ, σ^λ)\)

=> \(z\) = \(μ + σ^{1-λ}(z - λμ)\).

Thus, using the above transformation the joint density can be transformed as follows: \begin{align} p(\theta, \hat{\mu}, \mathbf{y}) & =\mathcal{N}(\theta \mid 0,1) \times \mathcal{N}\left(\mu \mid \theta, \sigma_\mu\right) \times \mathcal{N}(\mathbf{y} \mid \mu, \sigma) \end{align}

\begin{align} p(\theta, \hat{\mu}, \mathbf{y}) & =\mathcal{N}(\theta \mid 0,1) \times \mathcal{N}\left(\hat{\mu} \mid \lambda \theta, \sigma_\mu^\lambda\right) \times \mathcal{N}\left(\mathbf{y} \mid \theta+\sigma_\mu^{1-\lambda}(\hat{\mu}-\lambda \theta), \sigma\right) \end{align}

[ ]:
def german_credit_reparam(beta_centeredness=None):
    def model():
        log_tau_zero = numpyro.sample("log_tau_zero", dist.Normal(0, 10))
        log_tau_i = numpyro.sample(
            "log_tau_i", dist.Normal(log_tau_zero, jnp.ones(num_features))
        )
        with numpyro.handlers.reparam(
            config={"beta": LocScaleReparam(beta_centeredness)}
        ):
            beta = numpyro.sample(
                "beta", dist.Normal(jnp.zeros(num_features), jnp.exp(log_tau_i))
            )
        numpyro.sample(
            "obs",
            dist.Bernoulli(logits=jnp.einsum("nd,md->mn", all_x, beta[jnp.newaxis, :])),
            obs=y,
        )

    return model

Now, using SVI we optimize \(\lambda\).

[ ]:
model = german_credit_reparam()
guide = AutoDiagonalNormal(model)
svi = SVI(model, guide, numpyro.optim.Adam(3e-4), Trace_ELBO(10))
svi_results = svi.run(rng_key, 10000)
100%|██████████| 10000/10000 [00:16<00:00, 588.87it/s, init loss: 2165.2424, avg. loss [9501-10000]: 576.7846]
[ ]:
reparam_model = german_credit_reparam(
    beta_centeredness=svi_results.params["beta_centered"]
)
[ ]:
nuts_kernel = NUTS(reparam_model)
mcmc_reparam = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
mcmc_reparam.run(rng_key, extra_fields=("num_steps",))
sample: 100%|██████████| 2000/2000 [00:07<00:00, 285.41it/s, 31 steps of size 1.28e-01. acc. prob=0.89]
[ ]:
mcmc_reparam.print_summary()

                         mean       std    median      5.0%     95.0%     n_eff     r_hat
 beta_decentered[0]      0.12      0.40      0.06     -0.48      0.80    338.70      1.00
 beta_decentered[1]     -0.45      0.15     -0.45     -0.70     -0.21    791.23      1.00
 beta_decentered[2]     -0.38      0.17     -0.38     -0.65     -0.09    691.79      1.00
 beta_decentered[3]     -0.41      0.13     -0.41     -0.61     -0.19   1022.79      1.00
 beta_decentered[4]     -0.01      0.11     -0.01     -0.18      0.20   1176.84      1.00
 beta_decentered[5]      0.19      0.14      0.19     -0.04      0.41   1194.41      1.00
 beta_decentered[6]     -0.13      0.14     -0.13     -0.36      0.09   1227.24      1.00
 beta_decentered[7]     -0.07      0.12     -0.06     -0.24      0.14   1096.31      1.00
 beta_decentered[8]     -0.46      0.34     -0.46     -0.99      0.08    330.30      1.00
 beta_decentered[9]     -0.03      0.32     -0.02     -0.57      0.49    310.35      1.00
beta_decentered[10]      0.35      0.39      0.30     -0.26      1.00    426.11      1.00
beta_decentered[11]      1.29      0.31      1.30      0.81      1.82    433.16      1.00
beta_decentered[12]     -0.32      0.39     -0.25     -0.96      0.24    521.05      1.00
beta_decentered[13]     -0.38      0.40     -0.32     -1.00      0.24    410.05      1.00
beta_decentered[14]      0.08      0.28      0.06     -0.37      0.57    457.72      1.00
beta_decentered[15]      0.14      0.30      0.10     -0.28      0.66    612.31      1.00
beta_decentered[16]      0.85      0.31      0.86      0.41      1.45    432.14      1.00
beta_decentered[17]     -0.64      0.28     -0.65     -1.05     -0.14    523.15      1.00
beta_decentered[18]      0.78      0.42      0.78      0.07      1.46    545.52      1.00
beta_decentered[19]      0.15      0.39      0.08     -0.50      0.80    662.60      1.00
beta_decentered[20]      0.04      0.25      0.03     -0.39      0.40    445.85      1.00
beta_decentered[21]      0.24      0.27      0.21     -0.20      0.65    477.68      1.00
beta_decentered[22]     -0.03      0.38     -0.01     -0.64      0.60    984.59      1.00
beta_decentered[23]     -0.13      0.34     -0.08     -0.72      0.35    702.87      1.00
beta_decentered[24]     -0.41      0.39     -0.37     -1.08      0.13    603.13      1.00
beta_decentered[25]      0.19      0.47      0.09     -0.48      0.92    529.68      1.00
beta_decentered[26]      0.00      0.25      0.01     -0.47      0.35    690.54      1.00
beta_decentered[27]     -0.46      0.31     -0.46     -0.95      0.04    464.44      1.00
beta_decentered[28]     -0.09      0.30     -0.06     -0.56      0.41    464.65      1.00
beta_decentered[29]      0.02      0.30      0.01     -0.47      0.52    747.44      1.00
beta_decentered[30]      0.38      0.44      0.31     -0.30      1.05    717.12      1.00
beta_decentered[31]      0.47      0.36      0.47     -0.09      1.03    564.18      1.00
beta_decentered[32]     -0.03      0.26     -0.02     -0.44      0.44    572.03      1.00
beta_decentered[33]     -0.17      0.25     -0.15     -0.63      0.19    713.40      1.00
beta_decentered[34]     -0.02      0.21     -0.01     -0.39      0.32    620.45      1.01
beta_decentered[35]      0.53      0.31      0.55     -0.03      1.00    681.60      1.00
beta_decentered[36]      0.09      0.24      0.06     -0.27      0.49    610.06      1.00
beta_decentered[37]     -0.14      0.31     -0.10     -0.74      0.28    826.87      1.00
beta_decentered[38]     -0.12      0.25     -0.11     -0.53      0.30    493.49      1.00
beta_decentered[39]      0.44      0.28      0.44     -0.01      0.89    542.71      1.00
beta_decentered[40]      0.05      0.26      0.03     -0.39      0.45    709.78      1.00
beta_decentered[41]      0.02      0.30      0.01     -0.52      0.46    389.41      1.00
beta_decentered[42]     -0.15      0.33     -0.10     -0.74      0.32    607.05      1.00
beta_decentered[43]      0.66      0.47      0.65     -0.11      1.38    539.31      1.01
beta_decentered[44]      0.25      0.27      0.23     -0.19      0.66    686.63      1.00
beta_decentered[45]     -0.01      0.22     -0.00     -0.36      0.35    909.70      1.00
beta_decentered[46]      0.02      0.22      0.01     -0.33      0.39    741.33      1.00
beta_decentered[47]     -0.25      0.31     -0.21     -0.72      0.25    487.26      1.00
beta_decentered[48]     -0.18      0.30     -0.15     -0.67      0.30    400.57      1.00
beta_decentered[49]     -0.07      0.30     -0.05     -0.56      0.41    594.38      1.00
beta_decentered[50]      0.46      0.31      0.46     -0.09      0.88    384.33      1.00
beta_decentered[51]     -0.23      0.27     -0.19     -0.69      0.16    561.35      1.00
beta_decentered[52]      0.22      0.25      0.21     -0.19      0.60    456.41      1.00
beta_decentered[53]      0.07      0.29      0.04     -0.41      0.53    500.99      1.00
beta_decentered[54]      0.05      0.30      0.02     -0.49      0.51    896.08      1.00
beta_decentered[55]      0.01      0.23      0.01     -0.34      0.43   1033.13      1.00
beta_decentered[56]     -0.02      0.19     -0.01     -0.33      0.28    717.87      1.00
beta_decentered[57]      0.01      0.23      0.01     -0.33      0.41    684.61      1.00
beta_decentered[58]     -0.09      0.26     -0.08     -0.55      0.30    455.66      1.01
beta_decentered[59]      0.20      0.27      0.18     -0.20      0.63    413.57      1.01
beta_decentered[60]     -0.14      0.37     -0.09     -0.76      0.46    411.83      1.00
beta_decentered[61]      0.58      0.57      0.50     -0.23      1.50    495.86      1.00
       log_tau_i[0]     -1.56      0.91     -1.53     -2.94     -0.01    480.69      1.00
       log_tau_i[1]     -1.07      0.64     -1.08     -1.99      0.06    950.02      1.00
       log_tau_i[2]     -1.25      0.74     -1.23     -2.36     -0.06    796.37      1.00
       log_tau_i[3]     -1.15      0.65     -1.18     -2.25     -0.14    838.40      1.00
       log_tau_i[4]     -2.09      0.90     -2.12     -3.51     -0.49   1120.31      1.00
       log_tau_i[5]     -1.70      0.84     -1.69     -3.10     -0.35   1291.74      1.00
       log_tau_i[6]     -1.87      0.92     -1.84     -3.67     -0.58   1066.38      1.00
       log_tau_i[7]     -2.06      0.89     -2.07     -3.42     -0.42    641.48      1.00
       log_tau_i[8]     -1.11      0.91     -1.03     -2.79      0.19    482.69      1.00
       log_tau_i[9]     -1.65      0.91     -1.62     -3.05     -0.04    772.73      1.00
      log_tau_i[10]     -1.31      0.94     -1.27     -2.77      0.24    578.34      1.00
      log_tau_i[11]     -0.04      0.49     -0.08     -0.81      0.77    736.20      1.00
      log_tau_i[12]     -1.37      0.98     -1.32     -2.86      0.31    623.14      1.00
      log_tau_i[13]     -1.28      0.97     -1.21     -2.78      0.30    653.51      1.00
      log_tau_i[14]     -1.70      0.95     -1.71     -3.20     -0.15    831.88      1.00
      log_tau_i[15]     -1.67      0.97     -1.65     -3.15      0.04    726.23      1.00
      log_tau_i[16]     -0.53      0.65     -0.49     -1.55      0.54    558.56      1.00
      log_tau_i[17]     -0.81      0.68     -0.80     -1.91      0.27    655.36      1.00
      log_tau_i[18]     -0.71      0.83     -0.60     -1.89      0.75    536.55      1.00
      log_tau_i[19]     -1.55      0.98     -1.53     -3.17     -0.02    675.27      1.00
      log_tau_i[20]     -1.81      0.90     -1.79     -3.36     -0.41    823.93      1.00
      log_tau_i[21]     -1.53      0.93     -1.51     -3.03     -0.04    767.64      1.00
      log_tau_i[22]     -1.60      0.99     -1.58     -3.25     -0.01    840.71      1.00
      log_tau_i[23]     -1.64      0.97     -1.60     -3.15      0.10    615.62      1.00
      log_tau_i[24]     -1.21      0.96     -1.11     -2.61      0.47    674.58      1.00
      log_tau_i[25]     -1.54      1.04     -1.49     -3.39     -0.01    509.69      1.00
      log_tau_i[26]     -1.77      0.95     -1.74     -3.27     -0.13    915.91      1.00
      log_tau_i[27]     -1.14      0.89     -1.07     -2.53      0.36    569.10      1.01
      log_tau_i[28]     -1.72      0.93     -1.66     -3.25     -0.20    911.70      1.01
      log_tau_i[29]     -1.70      0.91     -1.71     -3.16     -0.14    648.29      1.00
      log_tau_i[30]     -1.28      1.02     -1.28     -2.99      0.34    575.90      1.00
      log_tau_i[31]     -1.10      0.86     -1.04     -2.39      0.45    730.21      1.00
      log_tau_i[32]     -1.79      0.95     -1.82     -3.32     -0.19    667.95      1.00
      log_tau_i[33]     -1.64      0.86     -1.62     -3.04     -0.32    948.64      1.00
      log_tau_i[34]     -1.87      0.92     -1.88     -3.29     -0.25    909.49      1.00
      log_tau_i[35]     -1.00      0.85     -0.95     -2.44      0.25    672.42      1.00
      log_tau_i[36]     -1.76      0.92     -1.73     -3.19     -0.22    889.56      1.00
      log_tau_i[37]     -1.63      1.00     -1.58     -3.11      0.12    973.29      1.00
      log_tau_i[38]     -1.72      0.85     -1.70     -3.13     -0.30    837.73      1.00
      log_tau_i[39]     -1.15      0.80     -1.13     -2.32      0.18    627.15      1.00
      log_tau_i[40]     -1.76      0.93     -1.70     -3.36     -0.34    686.88      1.00
      log_tau_i[41]     -1.75      0.96     -1.74     -3.20     -0.09    612.83      1.00
      log_tau_i[42]     -1.62      0.91     -1.62     -3.21     -0.25    596.18      1.00
      log_tau_i[43]     -0.79      0.93     -0.74     -2.20      0.83    560.52      1.01
      log_tau_i[44]     -1.52      0.94     -1.48     -3.08      0.01    888.27      1.00
      log_tau_i[45]     -1.83      0.90     -1.84     -3.27     -0.44   1122.53      1.00
      log_tau_i[46]     -1.83      0.89     -1.78     -3.21     -0.31    990.03      1.00
      log_tau_i[47]     -1.56      0.93     -1.48     -3.20     -0.21    736.01      1.00
      log_tau_i[48]     -1.59      0.94     -1.54     -3.20     -0.01    588.77      1.00
      log_tau_i[49]     -1.71      0.93     -1.68     -3.17     -0.17    813.94      1.00
      log_tau_i[50]     -1.15      0.83     -1.11     -2.51      0.14    514.68      1.00
      log_tau_i[51]     -1.54      0.89     -1.51     -2.97     -0.13    780.67      1.00
      log_tau_i[52]     -1.59      0.93     -1.56     -3.21     -0.28    807.47      1.00
      log_tau_i[53]     -1.74      0.92     -1.72     -3.10     -0.15    657.89      1.00
      log_tau_i[54]     -1.74      0.93     -1.74     -3.09     -0.10    867.29      1.00
      log_tau_i[55]     -1.89      0.94     -1.87     -3.34     -0.27   1132.90      1.00
      log_tau_i[56]     -1.89      0.92     -1.89     -3.27     -0.27    980.72      1.00
      log_tau_i[57]     -1.84      0.91     -1.84     -3.33     -0.38    813.32      1.00
      log_tau_i[58]     -1.75      0.93     -1.74     -3.25     -0.34    583.52      1.00
      log_tau_i[59]     -1.59      0.96     -1.53     -3.11     -0.06    754.39      1.01
      log_tau_i[60]     -1.58      0.98     -1.52     -3.23     -0.08    561.05      1.00
      log_tau_i[61]     -0.98      1.03     -0.85     -2.55      0.79    502.63      1.00
       log_tau_zero     -1.49      0.26     -1.49     -1.91     -1.10    220.32      1.00

Number of divergences: 1

The number of divergences have significantly reduced from 37 to 1.

[ ]:
data = az.from_numpyro(mcmc_reparam)
az.plot_trace(data, compact=True, figsize=(15, 25));
../_images/tutorials_variationally_inferred_parameterization_26_0.png

4. References:

  1. https://arxiv.org/abs/1906.03028

  2. https://github.com/mgorinova/autoreparam/tree/master