# Example: MCMC Methods for Tall Data¶

This example illustrates the usages of various MCMC methods which are suitable for tall data:

• algo=”SA” uses the sample adaptive MCMC method in [1]
• algo=”HMCECS” uses the energy conserving subsampling method in [2]
• algo=”FlowHMCECS” utilizes a normalizing flow to neutralize the posterior geometry into a Gaussian-like one. Then HMCECS is used to draw the posterior samples. Currently, this method gives the best mixing rate among those methods.

References:

1. Sample Adaptive MCMC, Michael Zhu (2019)
2. Hamiltonian Monte Carlo with energy conserving subsampling, Dang, K. D., Quiroz, M., Kohn, R., Minh-Ngoc, T., & Villani, M. (2019)
3. NeuTra-lizing Bad Geometry in Hamiltonian Monte Carlo Using Neural Transport, Hoffman, M. et al. (2019)
import argparse
import time

import matplotlib.pyplot as plt

from jax import random
import jax.numpy as jnp

import numpyro
import numpyro.distributions as dist
from numpyro.examples.datasets import COVTYPE, load_dataset
from numpyro.infer import HMC, HMCECS, MCMC, NUTS, SA, SVI, Trace_ELBO, init_to_value
from numpyro.infer.autoguide import AutoBNAFNormal
from numpyro.infer.reparam import NeuTraReparam

_, fetch = load_dataset(COVTYPE, shuffle=False)
features, labels = fetch()

# normalize features and add intercept
features = (features - features.mean(0)) / features.std(0)
features = jnp.hstack([features, jnp.ones((features.shape[0], 1))])

# make binary feature
_, counts = jnp.unique(labels, return_counts=True)
specific_category = jnp.argmax(counts)
labels = labels == specific_category

N, dim = features.shape
print("Data shape:", features.shape)
print(
"Label distribution: {} has label 1, {} has label 0".format(
labels.sum(), N - labels.sum()
)
)
return features, labels

def model(data, labels, subsample_size=None):
dim = data.shape[1]
coefs = numpyro.sample("coefs", dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
with numpyro.plate("N", data.shape[0], subsample_size=subsample_size) as idx:
logits = jnp.dot(data[idx], coefs)
return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels[idx])

def benchmark_hmc(args, features, labels):
rng_key = random.PRNGKey(1)
start = time.time()
# a MAP estimate at the following source
ref_params = {
"coefs": jnp.array(
[
+2.03420663e00,
-3.53567265e-02,
-1.49223924e-01,
-3.07049364e-01,
-1.00028366e-01,
-1.46827862e-01,
-1.64167881e-01,
-4.20344204e-01,
+9.47479829e-02,
-1.12681836e-02,
+2.64442056e-01,
-1.22087866e-01,
-6.00568838e-02,
-3.79419506e-01,
-1.06668741e-01,
-2.97053963e-01,
-2.05253899e-01,
-4.69537191e-02,
-2.78072730e-02,
-1.43250525e-01,
-6.77954629e-02,
-4.34899796e-03,
+5.90927452e-02,
+7.23133609e-02,
+1.38526391e-02,
-1.24497898e-01,
-1.50733739e-02,
-2.68872194e-02,
-1.80925727e-02,
+3.47936489e-02,
+4.03552800e-02,
-9.98773426e-03,
+6.20188080e-02,
+1.15002751e-01,
+1.32145107e-01,
+2.69109547e-01,
+2.45785132e-01,
+1.19035013e-01,
-2.59744357e-02,
+9.94279515e-04,
+3.39266285e-02,
-1.44057125e-02,
-6.95222765e-02,
-7.52013028e-02,
+1.21171586e-01,
+2.29205526e-02,
+1.47308692e-01,
-8.34354162e-02,
-9.34122875e-02,
-2.97472421e-02,
-3.03937674e-01,
-1.70958012e-01,
-1.59496680e-01,
-1.88516974e-01,
-1.20889175e00,
]
)
}
if args.algo == "HMC":
step_size = jnp.sqrt(0.5 / features.shape[0])
trajectory_length = step_size * args.num_steps
kernel = HMC(
model,
step_size=step_size,
trajectory_length=trajectory_length,
dense_mass=args.dense_mass,
)
subsample_size = None
elif args.algo == "NUTS":
kernel = NUTS(model, dense_mass=args.dense_mass)
subsample_size = None
elif args.algo == "HMCECS":
subsample_size = 1000
inner_kernel = NUTS(
model,
init_strategy=init_to_value(values=ref_params),
dense_mass=args.dense_mass,
)
# note: if num_blocks=100, we'll update 10 index at each MCMC step
# so it took 50000 MCMC steps to iterative the whole dataset
kernel = HMCECS(
inner_kernel, num_blocks=100, proxy=HMCECS.taylor_proxy(ref_params)
)
elif args.algo == "SA":
# NB: this kernel requires large num_warmup and num_samples
# and running on GPU is much faster than on CPU
kernel = SA(
)
subsample_size = None
elif args.algo == "FlowHMCECS":
subsample_size = 1000
guide = AutoBNAFNormal(model, num_flows=1, hidden_factors=[8])
svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO())
svi_result = svi.run(random.PRNGKey(2), 2000, features, labels)
params, losses = svi_result.params, svi_result.losses
plt.plot(losses)
plt.show()

neutra = NeuTraReparam(guide, params)
neutra_model = neutra.reparam(model)
neutra_ref_params = {"auto_shared_latent": jnp.zeros(55)}
# no need to adapt mass matrix if the flow does a good job
inner_kernel = NUTS(
neutra_model,
init_strategy=init_to_value(values=neutra_ref_params),
)
kernel = HMCECS(
inner_kernel, num_blocks=100, proxy=HMCECS.taylor_proxy(neutra_ref_params)
)
else:
raise ValueError("Invalid algorithm, either 'HMC', 'NUTS', or 'HMCECS'.")
mcmc = MCMC(kernel, num_warmup=args.num_warmup, num_samples=args.num_samples)
mcmc.run(rng_key, features, labels, subsample_size, extra_fields=("accept_prob",))
print("Mean accept prob:", jnp.mean(mcmc.get_extra_fields()["accept_prob"]))
mcmc.print_summary(exclude_deterministic=False)
print("\nMCMC elapsed time:", time.time() - start)

def main(args):
features, labels = _load_dataset()
benchmark_hmc(args, features, labels)

if __name__ == "__main__":
assert numpyro.__version__.startswith("0.7.0")
parser = argparse.ArgumentParser(description="parse args")
"-n", "--num-samples", default=1000, type=int, help="number of samples"
)
"--num-warmup", default=1000, type=int, help="number of warmup steps"
)
"--num-steps", default=10, type=int, help='number of steps (for "HMC")'
)
parser.add_argument("--num-chains", nargs="?", default=1, type=int)
"--algo",
default="HMCECS",
type=str,
help='whether to run "HMC", "NUTS", "HMCECS", "SA" or "FlowHMCECS"',
)