Interactive online version: Open In Colab

Bad posterior geometry and how to deal with it

HMC and its variant NUTS use gradient information to draw (approximate) samples from a posterior distribution. These gradients are computed in a particular coordinate system, and different choices of coordinate system can make HMC more or less efficient. This is analogous to the situation in constrained optimization problems where, for example, parameterizing a positive quantity via an exponential versus softplus transformation results in distinct optimization dynamics.

Consequently it is important to pay attention to the geometry of the posterior distribution. Reparameterizing the model (i.e. changing the coordinate system) can make a big practical difference for many complex models. For the most complex models it can be absolutely essential. For the same reason it can be important to pay attention to some of the hyperparameters that control HMC/NUTS (in particular the max_tree_depth and dense_mass).

In this tutorial we explore models with bad posterior geometries—and what one can do to get achieve better performance—with a few concrete examples.

[1]:
!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro
[1]:
from functools import partial

import numpy as np

import jax.numpy as jnp
from jax import random

import numpyro
import numpyro.distributions as dist
from numpyro.diagnostics import summary

from numpyro.infer import MCMC, NUTS

assert numpyro.__version__.startswith("0.8.0")

# NB: replace cpu by gpu to run this notebook on gpu
numpyro.set_platform("cpu")

We begin by writing a helper function to do NUTS inference.

[2]:
def run_inference(
    model, num_warmup=1000, num_samples=1000, max_tree_depth=10, dense_mass=False
):

    kernel = NUTS(model, max_tree_depth=max_tree_depth, dense_mass=dense_mass)
    mcmc = MCMC(
        kernel,
        num_warmup=num_warmup,
        num_samples=num_samples,
        num_chains=1,
        progress_bar=False,
    )
    mcmc.run(random.PRNGKey(0))
    summary_dict = summary(mcmc.get_samples(), group_by_chain=False)

    # print the largest r_hat for each variable
    for k, v in summary_dict.items():
        spaces = " " * max(12 - len(k), 0)
        print("[{}] {} \t max r_hat: {:.4f}".format(k, spaces, np.max(v["r_hat"])))

Evaluating HMC/NUTS

In general it is difficult to assess whether the samples returned from HMC or NUTS represent accurate (approximate) samples from the posterior. Two general rules of thumb, however, are to look at the effective sample size (ESS) and r_hat diagnostics returned by mcmc.print_summary(). If we see values of r_hat in the range (1.0, 1.05) and effective sample sizes that are comparable to the total number of samples num_samples (assuming thinning=1) then we have good reason to believe that HMC is doing a good job. If, however, we see low effective sample sizes or large r_hats for some of the variables (e.g. r_hat = 1.15) then HMC is likely struggling with the posterior geometry. In the following we will use r_hat as our primary diagnostic metric.

Model reparameterization

Example #1

We begin with an example (horseshoe regression; see examples/horseshoe_regression.py for a complete example script) where reparameterization helps a lot. This particular example demonstrates a general reparameterization strategy that is useful in many models with hierarchical/multi-level structure. For more discussion of some of the issues that can arise in hierarchical models see reference [1].

[3]:
# In this unreparameterized model some of the parameters of the distributions
# explicitly depend on other parameters (in particular beta depends on lambdas and tau).
# This kind of coordinate system can be a challenge for HMC.
def _unrep_hs_model(X, Y):
    lambdas = numpyro.sample("lambdas", dist.HalfCauchy(jnp.ones(X.shape[1])))
    tau = numpyro.sample("tau", dist.HalfCauchy(jnp.ones(1)))
    betas = numpyro.sample("betas", dist.Normal(scale=tau * lambdas))
    mean_function = jnp.dot(X, betas)
    numpyro.sample("Y", dist.Normal(mean_function, 0.05), obs=Y)

To deal with the bad geometry that results form this coordinate system we change coordinates using the following re-write logic. Instead of

\[\beta \sim {\rm Normal}(0, \lambda \tau)\]

we write

\[\beta^\prime \sim {\rm Normal}(0, 1)\]

and

\[\beta \equiv \lambda \tau \beta^\prime\]

where \(\beta\) is now defined deterministically in terms of \(\lambda\), \(\tau\), and \(\beta^\prime\). In effect we’ve changed to a coordinate system where the different latent variables are less correlated with one another. In this new coordinate system we can expect HMC with a diagonal mass matrix to behave much better than it would in the original coordinate system.

There are basically two ways to implement this kind of reparameterization in NumPyro:

  • manually (i.e. by hand)

  • using numpyro.infer.reparam, which automates a few common reparameterization strategies

To begin with let’s do the reparameterization by hand.

[4]:
# In this reparameterized model none of the parameters of the distributions
# explicitly depend on other parameters. This model is exactly equivalent
# to _unrep_hs_model but is expressed in a different coordinate system.
def _rep_hs_model1(X, Y):
    lambdas = numpyro.sample("lambdas", dist.HalfCauchy(jnp.ones(X.shape[1])))
    tau = numpyro.sample("tau", dist.HalfCauchy(jnp.ones(1)))
    unscaled_betas = numpyro.sample(
        "unscaled_betas", dist.Normal(scale=jnp.ones(X.shape[1]))
    )
    scaled_betas = numpyro.deterministic("betas", tau * lambdas * unscaled_betas)
    mean_function = jnp.dot(X, scaled_betas)
    numpyro.sample("Y", dist.Normal(mean_function, 0.05), obs=Y)

Next we do the reparameterization using numpyro.infer.reparam. There are at least two ways to do this. First let’s use LocScaleReparam.

[5]:
from numpyro.infer.reparam import LocScaleReparam

# LocScaleReparam with centered=0 fully "decenters" the prior over betas.
config = {"betas": LocScaleReparam(centered=0)}
# The coordinate system of this model is equivalent to that in _rep_hs_model1 above.
_rep_hs_model2 = numpyro.handlers.reparam(_unrep_hs_model, config=config)

To show the versatility of the numpyro.infer.reparam library let’s do the reparameterization using TransformReparam instead.

[6]:
from numpyro.distributions.transforms import AffineTransform
from numpyro.infer.reparam import TransformReparam

# In this reparameterized model none of the parameters of the distributions
# explicitly depend on other parameters. This model is exactly equivalent
# to _unrep_hs_model but is expressed in a different coordinate system.
def _rep_hs_model3(X, Y):
    lambdas = numpyro.sample("lambdas", dist.HalfCauchy(jnp.ones(X.shape[1])))
    tau = numpyro.sample("tau", dist.HalfCauchy(jnp.ones(1)))

    # instruct NumPyro to do the reparameterization automatically.
    reparam_config = {"betas": TransformReparam()}
    with numpyro.handlers.reparam(config=reparam_config):
        betas_root_variance = tau * lambdas
        # in order to use TransformReparam we have to express the prior
        # over betas as a TransformedDistribution
        betas = numpyro.sample(
            "betas",
            dist.TransformedDistribution(
                dist.Normal(0.0, jnp.ones(X.shape[1])),
                AffineTransform(0.0, betas_root_variance),
            ),
        )

    mean_function = jnp.dot(X, betas)
    numpyro.sample("Y", dist.Normal(mean_function, 0.05), obs=Y)

Finally we verify that _rep_hs_model1, _rep_hs_model2, and _rep_hs_model3 do indeed achieve better r_hats than _unrep_hs_model.

[8]:
# create fake dataset
X = np.random.RandomState(0).randn(100, 500)
Y = X[:, 0]

print("unreparameterized model (very bad r_hats)")
run_inference(partial(_unrep_hs_model, X, Y))

print("\nreparameterized model with manual reparameterization (good r_hats)")
run_inference(partial(_rep_hs_model1, X, Y))

print("\nreparameterized model with LocScaleReparam (good r_hats)")
run_inference(partial(_rep_hs_model2, X, Y))

print("\nreparameterized model with TransformReparam (good r_hats)")
run_inference(partial(_rep_hs_model3, X, Y))
unreparameterized model (very bad r_hats)
[betas]                  max r_hat: 1.0775
[lambdas]                max r_hat: 3.2450
[tau]                    max r_hat: 2.1926

reparameterized model with manual reparameterization (good r_hats)
[betas]                  max r_hat: 1.0074
[lambdas]                max r_hat: 1.0146
[tau]                    max r_hat: 1.0036
[unscaled_betas]         max r_hat: 1.0059

reparameterized model with LocScaleReparam (good r_hats)
[betas]                  max r_hat: 1.0103
[betas_decentered]       max r_hat: 1.0060
[lambdas]                max r_hat: 1.0124
[tau]                    max r_hat: 0.9998

reparameterized model with TransformReparam (good r_hats)
[betas]                  max r_hat: 1.0087
[betas_base]             max r_hat: 1.0080
[lambdas]                max r_hat: 1.0114
[tau]                    max r_hat: 1.0029

Aside: numpyro.deterministic

In _rep_hs_model1 above we used numpyro.deterministic to define scaled_betas. We note that using this primitive is not strictly necessary; however, it has the consequence that scaled_betas will appear in the trace and will thus appear in the summary reported by mcmc.print_summary(). In other words we could also have written:

scaled_betas = tau * lambdas * unscaled_betas

without invoking the deterministic primitive.

Mass matrices

By default HMC/NUTS use diagonal mass matrices. For models with complex geometries it can pay to use a richer set of mass matrices.

Example #2

In this first simple example we show that using a full-rank (i.e. dense) mass matrix leads to a better r_hat.

[9]:
# Because rho is very close to 1.0 the posterior geometry
# is extremely skewed and using the "diagonal" coordinate system
# implied by dense_mass=False leads to bad results
rho = 0.9999
cov = jnp.array([[10.0, rho], [rho, 0.1]])


def mvn_model():
    numpyro.sample("x", dist.MultivariateNormal(jnp.zeros(2), covariance_matrix=cov))


print("dense_mass = False (bad r_hat)")
run_inference(mvn_model, dense_mass=False, max_tree_depth=3)

print("dense_mass = True (good r_hat)")
run_inference(mvn_model, dense_mass=True, max_tree_depth=3)
dense_mass = False (bad r_hat)
[x]                      max r_hat: 1.3810
dense_mass = True (good r_hat)
[x]                      max r_hat: 0.9992

Example #3

Using dense_mass=True can be very expensive when the dimension of the latent space D is very large. In addition it can be difficult to estimate a full-rank mass matrix with D^2 parameters using a moderate number of samples if D is large. In these cases dense_mass=True can be a poor choice. Luckily, the argument dense_mass can also be used to specify structured mass matrices that are richer than a diagonal mass matrix but more constrained (i.e. have fewer parameters) than a full-rank mass matrix (see the docs). In this second example we show how we can use dense_mass to specify such a structured mass matrix.

[10]:
rho = 0.9
cov = jnp.array([[10.0, rho], [rho, 0.1]])

# In this model x1 and x2 are highly correlated with one another
# but not correlated with y at all.
def partially_correlated_model():
    x1 = numpyro.sample(
        "x1", dist.MultivariateNormal(jnp.zeros(2), covariance_matrix=cov)
    )
    x2 = numpyro.sample(
        "x2", dist.MultivariateNormal(jnp.zeros(2), covariance_matrix=cov)
    )
    y = numpyro.sample("y", dist.Normal(jnp.zeros(100), 1.0))
    numpyro.sample("obs", dist.Normal(x1 - x2, 0.1), jnp.ones(2))

Now let’s compare two choices of dense_mass.

[11]:
print("dense_mass = False (very bad r_hats)")
run_inference(partially_correlated_model, dense_mass=False, max_tree_depth=3)

print("\ndense_mass = True (bad r_hats)")
run_inference(partially_correlated_model, dense_mass=True, max_tree_depth=3)

# We use dense_mass=[("x1", "x2")] to specify
# a structured mass matrix in which the y-part of the mass matrix is diagonal
# and the (x1, x2) block of the mass matrix is full-rank.

# Graphically:
#
#       x1 x2 y
#   x1 | * * 0 |
#   x2 | * * 0 |
#   y  | 0 0 * |

print("\nstructured mass matrix (good r_hats)")
run_inference(partially_correlated_model, dense_mass=[("x1", "x2")], max_tree_depth=3)
dense_mass = False (very bad r_hats)
[x1]                     max r_hat: 1.5882
[x2]                     max r_hat: 1.5410
[y]                      max r_hat: 2.0179

dense_mass = True (bad r_hats)
[x1]                     max r_hat: 1.0697
[x2]                     max r_hat: 1.0738
[y]                      max r_hat: 1.2746

structured mass matrix (good r_hats)
[x1]                     max r_hat: 1.0023
[x2]                     max r_hat: 1.0024
[y]                      max r_hat: 1.0030

max_tree_depth

The hyperparameter max_tree_depth can play an important role in determining the quality of posterior samples generated by NUTS. The default value in NumPyro is max_tree_depth=10. In some models, in particular those with especially difficult geometries, it may be necessary to increase max_tree_depth above 10. In other cases where computing the gradient of the model log density is particularly expensive, it may be necessary to decrease max_tree_depth below 10 to reduce compute. As an example where large max_tree_depth is essential, we return to a variant of example #2. (We note that in this particular case another way to improve performance would be to use dense_mass=True).

Example #4

[12]:
# Because rho is very close to 1.0 the posterior geometry is extremely
# skewed and using small max_tree_depth leads to bad results.
rho = 0.999
dim = 200
cov = rho * jnp.ones((dim, dim)) + (1 - rho) * jnp.eye(dim)


def mvn_model():
    x = numpyro.sample(
        "x", dist.MultivariateNormal(jnp.zeros(dim), covariance_matrix=cov)
    )


print("max_tree_depth = 5 (bad r_hat)")
run_inference(mvn_model, max_tree_depth=5)

print("max_tree_depth = 10 (good r_hat)")
run_inference(mvn_model, max_tree_depth=10)
max_tree_depth = 5 (bad r_hat)
[x]                      max r_hat: 1.1159
max_tree_depth = 10 (good r_hat)
[x]                      max r_hat: 1.0166

Other strategies

  • In some cases it can make sense to use variational inference to learn a new coordinate system. For details see examples/neutra.py and reference [2].

References

[1] “Hamiltonian Monte Carlo for Hierarchical Models,” M. J. Betancourt, Mark Girolami.

[2] “NeuTra-lizing Bad Geometry in Hamiltonian Monte Carlo Using Neural Transport,” Matthew Hoffman, Pavel Sountsov, Joshua V. Dillon, Ian Langmore, Dustin Tran, Srinivas Vasudevan.

[3] “Reparameterization” in the Stan user’s manual. https://mc-stan.org/docs/2_27/stan-users-guide/reparameterization-section.html