Embarrassingly Parallel MCMC via Consensus Monte Carlo
This example demonstrates how to scale Bayesian inference to large datasets using embarrassingly parallel MCMC methods. We show how to split a dataset into shards, run independent MCMC on each shard, and merge the subposteriors using two algorithms implemented in NumPyro:
Consensus Monte Carlo (Scott et al., 2013)
Parametric Monte Carlo (Neiswanger et al., 2014)
References:
Scott, S. L., Blocker, A. W., Bonassi, F. V., Chipman, H. A., George, E. I., & McCulloch, R. E. (2016). Bayes and big data: The consensus Monte Carlo algorithm. International Journal of Management Science and Engineering Management, 11(2), 78-88.
Neiswanger, W., Wang, C., & Xing, E. (2014). Asymptotically exact, embarrassingly parallel MCMC. Proceedings of the Thirtieth Conference on Uncertainty in Artificial Intelligence (UAI).
Motivation: Why Embarrassingly Parallel MCMC?
Standard MCMC methods like NUTS (the No-U-Turn Sampler) evaluate the full log-likelihood at every leapfrog step. For a dataset of size \(N\), this means each gradient evaluation costs \(O(N)\), making posterior inference slow when \(N\) is large (e.g., millions of observations).
There are several strategies to deal with this:
Stochastic Variational Inference (SVI): Optimizes a parametric approximation to the posterior using mini-batches. Very fast, but the approximate posterior may be misspecified (e.g., a mean-field Gaussian cannot capture multimodality or complex correlations).
Data subsampling MCMC (e.g., HMCECS): Uses mini-batches within a single MCMC chain, with control variates to correct the bias. Requires good reference parameters and is not trivially parallelizable across data.
Embarrassingly parallel MCMC: Splits the data across machines or cores, runs completely independent MCMC chains on each shard, and merges the results in a single post-processing step. There is no communication between chains during sampling, making this approach trivially parallelizable and easy to implement on distributed systems.
This tutorial focuses on the third approach.
How Consensus Monte Carlo Works
Setup
We want to draw samples from the posterior distribution:
where \(p(\theta)\) is the prior, \(D = \{(x_i, y_i)\}_{i=1}^{N}\) is the full dataset, and \(\theta\) are the model parameters.
Step 1: Partition the data
Split the dataset \(D\) into \(K\) disjoint shards \(D_1, D_2, \ldots, D_K\).
Step 2: Define subposteriors
For each shard \(k\), define a subposterior:
The key insight is that the product of all \(K\) subposteriors recovers the full posterior:
Each shard uses \(p(\theta)^{1/K}\) (the prior raised to the power \(1/K\)) so that the prior is not over-counted when the subposteriors are combined. In practice, this is implemented by scaling the log-prior by \(1/K\) using numpyro.handlers.scale.
Step 3: Sample independently
Run MCMC (e.g., NUTS) on each shard independently. This is “embarrassingly parallel” since there is no inter-chain communication during sampling. Each chain produces samples \(\{\theta_k^{(s)}\}_{s=1}^{S}\) from its subposterior \(p_k(\theta \mid D_k)\).
Step 4: Merge the subposteriors
Combine the \(K\) sets of subposterior samples into a single set of samples approximating the full posterior. NumPyro provides two merging strategies:
Consensus MC (Scott et al., 2013)
Computes a weighted average of samples across shards. For each draw \(s\), the merged sample is:
where the weights \(W_k\) are based on each shard’s inverse covariance(precision matrix):
and \(\Sigma_k\) is the sample covariance of shard \(k\)’s subposterior.
With diagonal=True, the weights are computed per-parameter using inverse variances (cheaper, ignores cross-parameter correlations). With diagonal=False, full covariance matrices are used (captures correlations but is more expensive and can be numerically unstable in high dimensions).
This method preserves non-Gaussian features of the subposteriors to some extent, since it operates directly on the samples rather than fitting a parametric distribution.
Parametric MC (Neiswanger et al., 2014)
Fits a Gaussian to each subposterior by estimating its mean \(\mu_k\) and covariance \(\Sigma_k\), then analytically combines the Gaussians. The merged posterior is itself Gaussian with:
Fresh samples are then drawn from \(\text{MultivariateNormal}(\mu, \Sigma)\). This is principled when the posterior is approximately Gaussian (which is often reasonable for large \(N\) by the Bernstein-von Mises theorem), but will lose information about non-Gaussian features such as skewness or multimodality.
When to Use Which Approach
Method |
Best for |
Limitations |
|---|---|---|
NUTS |
Gold standard; data fits in memory; runtime acceptable |
Cost per step \(\propto N\); slow for very large \(N\) |
SVI |
Fast approximate inference; very large data; deep models |
Approximate posterior; may miss multimodality |
HMCECS |
Large data; want exact MCMC with subsampling |
Needs reference params (e.g. from SVI); single chain |
Consensus MC |
Data too large for one machine; multiple GPUs available |
Quality depends on Gaussian assumption; each shard needs enough data |
Parametric MC |
Same as consensus; want clean Gaussian approximation |
Explicitly Gaussian; loses non-Gaussian features |
Rules of thumb:
If your data fits in memory and NUTS runs in reasonable time, use NUTS.
If you need speed and an approximate posterior is acceptable, use SVI.
If data is too large for NUTS but the posterior is likely near-Gaussian (e.g., many observations per parameter, generalized linear models), consensus MC or parametric MC are good choices.
Consensus MC works best when each shard has enough data that its subposterior is well-identified and approximately Gaussian (i.e., in the large-sample regime where the central limit theorem applies).
Example: Bayesian Exponential Decay Regression
We illustrate the method on a non-linear regression problem using an exponential decay model. Unlike linear models (where the posterior is approximately Gaussian and all merging methods agree), this model has parameters that enter non-linearly, producing a posterior with curved correlations that clearly reveals the differences between merging strategies.
Model
where:
\(A\) is the amplitude (initial value of the decay),
\(r\) is the decay rate (how fast the signal decays),
\(c\) is an offset (asymptotic value as \(x \to \infty\)),
\(\sigma\) is the observation noise standard deviation.
Why This Model Shows Differences
The amplitude \(A\) and rate \(r\) interact non-linearly: a larger \(A\) combined with a larger \(r\) can produce a similar curve to a smaller \(A\) with a smaller \(r\). This creates a banana-shaped (curved) joint posterior for \((A, r)\) that is clearly non-Gaussian. Since consensus and parametric merging both rely on Gaussian approximations (either implicitly via covariance weighting or explicitly via Gaussian fitting), this curvature makes the differences between methods visible:
Diagonal methods ignore the \(A\)–\(r\) correlation entirely.
Parametric methods fit an elliptical Gaussian, losing the banana curvature.
Consensus (full cov) better preserves the shape through direct sample weighting.
[ ]:
# !pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro
[ ]:
import arviz as az
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.handlers import scale
from numpyro.infer import MCMC, NUTS
from numpyro.infer.hmc_util import consensus, parametric_draws
from numpyro.infer.util import Predictive
plt.style.use("bmh")
plt.rcParams["figure.figsize"] = [10, 6]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"
numpyro.set_host_device_count(n=4)
rng_key = random.key(seed=42)
assert numpyro.__version__.startswith("0.21.0")
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
True Function and Data Generation
We generate synthetic data from a known exponential decay function corrupted by Gaussian noise.
[2]:
TRUE_AMPLITUDE = 3.0
TRUE_RATE = 2.5
TRUE_OFFSET = 1.0
TRUE_SIGMA = 0.5
def true_function(x):
"""The true underlying function: exponential decay with offset."""
return TRUE_AMPLITUDE * jnp.exp(-TRUE_RATE * x) + TRUE_OFFSET
def generate_data(rng_key, num_obs=200, x_range=(0.0, 5.0)):
k1, k2 = random.split(rng_key)
x = random.uniform(k1, (num_obs,), minval=x_range[0], maxval=x_range[1])
y = true_function(x) + TRUE_SIGMA * random.normal(k2, (num_obs,))
return x, y
[3]:
# Generate synthetic data.
X_RANGE = (0.0, 5.0)
NUM_OBS = 10_000
rng_key, data_key = random.split(rng_key)
x, y = generate_data(data_key, num_obs=NUM_OBS, x_range=X_RANGE)
print(f"Dataset: {NUM_OBS} observations")
print("Model parameters: amplitude, rate, offset, sigma (4 total)")
Dataset: 10000 observations
Model parameters: amplitude, rate, offset, sigma (4 total)
[4]:
fig, ax = plt.subplots()
x_grid = jnp.linspace(X_RANGE[0], X_RANGE[1], 200)
ax.plot(x_grid, true_function(x_grid), c="red", lw=4, label="true function")
ax.scatter(x, y, c="black", alpha=0.2, label="observed")
ax.legend()
ax.set(xlabel="x", ylabel="y")
ax.set_title("Synthetic data", fontsize=18, fontweight="bold");
Model Definition
The model places weakly informative priors on all parameters. The prior_scale argument scales the log-prior by \(1/K\) for each of the \(K\) shards, implemented via numpyro.handlers.scale.
Note that only the prior is scaled, not the likelihood. Each shard sees the full likelihood on its own data partition.
[5]:
def model(x, y=None, prior_scale=1.0):
with scale(scale=prior_scale):
amplitude = numpyro.sample("amplitude", dist.Normal(5.0, 3.0))
rate = numpyro.sample("rate", dist.HalfNormal(3.0))
offset = numpyro.sample("offset", dist.Normal(0.0, 2.0))
sigma = numpyro.sample("sigma", dist.HalfNormal(1.0))
mu = amplitude * jnp.exp(-rate * x) + offset
numpyro.sample("obs", dist.Normal(mu, sigma), obs=y)
MCMC Helper
A utility function that runs NUTS on a given dataset shard with the specified prior scaling.
[6]:
def run_mcmc(
rng_key,
x,
y,
prior_scale=1.0,
num_warmup=500,
num_samples=1000,
num_chains=1,
):
kernel = NUTS(model)
mcmc = MCMC(
kernel,
num_warmup=num_warmup,
num_samples=num_samples,
num_chains=num_chains,
progress_bar=True,
)
mcmc.run(rng_key, x, y, prior_scale)
return mcmc
Full-Data MCMC (Baseline)
We run NUTS on the entire dataset with the full prior (prior_scale=1.0). This is the gold-standard baseline.
[7]:
def run_full_mcmc(rng_key, x, y, num_warmup, num_samples, num_chains):
return run_mcmc(
rng_key,
x,
y,
prior_scale=1.0,
num_warmup=num_warmup,
num_samples=num_samples,
num_chains=num_chains,
)
[8]:
%%time
# Run full-data NUTS (baseline).
num_warmup = 1_000
num_samples = 1_000
num_chains = 1
rng_key, mcmc_key = random.split(rng_key)
full_mcmc = run_full_mcmc(
mcmc_key,
x,
y,
num_warmup=num_warmup,
num_samples=num_samples,
num_chains=num_chains,
)
full_posterior_predictive = Predictive(model, full_mcmc.get_samples())
rng_key, predictive_key = random.split(rng_key)
idata_full_mcmc = az.from_numpyro(
posterior=full_mcmc,
posterior_predictive=full_posterior_predictive(predictive_key, x),
)
sample: 100%|██████████| 2000/2000 [00:01<00:00, 1899.95it/s, 7 steps of size 4.25e-01. acc. prob=0.94]
CPU times: user 4.1 s, sys: 1.3 s, total: 5.41 s
Wall time: 2.61 s
Sharded MCMC
We split the data into num_shards disjoint partitions and run independent NUTS on each shard. Each shard uses prior_scale = 1/K so that the prior contributions sum to the full prior when combined. Note that we shuffle the data before splitting to ensure each shard has a representative mix of input locations.
[9]:
def run_sharded_mcmc(rng_key, x, y, num_shards, num_warmup, num_samples, num_chains):
N = x.shape[0]
shard_size = N // num_shards
subposteriors = []
prior_scale = 1.0 / num_shards
# Shuffle data before splitting into shards
rng_key, shuffle_key = random.split(rng_key)
perm = random.permutation(shuffle_key, N)
x_shuffled = x[perm]
y_shuffled = y[perm]
for k in tqdm(range(num_shards)):
rng_key, sub_key = random.split(rng_key)
start = k * shard_size
end = start + shard_size if k < num_shards - 1 else N
x_k = x_shuffled[start:end]
y_k = y_shuffled[start:end]
samples_k = run_mcmc(
sub_key,
x_k,
y_k,
prior_scale=prior_scale,
num_warmup=num_warmup,
num_samples=num_samples,
num_chains=num_chains,
).get_samples()
subposteriors.append(samples_k)
return subposteriors
[10]:
%%time
# Run sharded NUTS (embarrassingly parallel).
num_shards = 20
rng_key, shard_key = random.split(rng_key)
subposteriors = run_sharded_mcmc(
shard_key,
x,
y,
num_shards=num_shards,
num_warmup=num_warmup,
num_samples=num_samples,
num_chains=num_chains,
)
sample: 100%|██████████| 2000/2000 [00:00<00:00, 2799.16it/s, 7 steps of size 5.34e-01. acc. prob=0.91]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 2687.25it/s, 3 steps of size 4.89e-01. acc. prob=0.92]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 2737.43it/s, 3 steps of size 4.65e-01. acc. prob=0.92]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 2527.72it/s, 7 steps of size 4.32e-01. acc. prob=0.93]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 2590.70it/s, 15 steps of size 4.46e-01. acc. prob=0.94]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 2587.18it/s, 7 steps of size 5.13e-01. acc. prob=0.90]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 2787.84it/s, 7 steps of size 4.77e-01. acc. prob=0.92]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 3013.46it/s, 3 steps of size 5.71e-01. acc. prob=0.88]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 2915.64it/s, 7 steps of size 4.24e-01. acc. prob=0.94]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 2995.39it/s, 7 steps of size 5.43e-01. acc. prob=0.90]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 2932.38it/s, 7 steps of size 5.36e-01. acc. prob=0.91]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 2965.55it/s, 7 steps of size 5.48e-01. acc. prob=0.90]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 2792.06it/s, 15 steps of size 4.34e-01. acc. prob=0.92]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 2548.10it/s, 15 steps of size 4.43e-01. acc. prob=0.92]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 2909.51it/s, 7 steps of size 4.42e-01. acc. prob=0.94]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 2932.55it/s, 7 steps of size 5.64e-01. acc. prob=0.90]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 2966.09it/s, 7 steps of size 5.40e-01. acc. prob=0.92]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 2990.48it/s, 3 steps of size 5.36e-01. acc. prob=0.91]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 2861.06it/s, 7 steps of size 4.75e-01. acc. prob=0.93]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 2849.30it/s, 23 steps of size 4.43e-01. acc. prob=0.93]
CPU times: user 35.3 s, sys: 4.45 s, total: 39.8 s
Wall time: 15.2 s
Merging Subposteriors
We apply both merging strategies with both diagonal=True and diagonal=False:
consensus(diagonal=True): per-parameter inverse-variance weightingconsensus(diagonal=False): full inverse-covariance weightingparametric_draws(diagonal=True): Gaussian merge with diagonal covarianceparametric_draws(diagonal=False): Gaussian merge with full covariance
[11]:
def merge_subposteriors(subposteriors, num_draws, rng_key):
k1, k2, k3, k4 = random.split(rng_key, 4)
results = {
"Consensus (diagonal)": consensus(
subposteriors, num_draws=num_draws, diagonal=True, rng_key=k1
),
"Consensus (full cov)": consensus(
subposteriors, num_draws=num_draws, diagonal=False, rng_key=k2
),
"Parametric (diagonal)": parametric_draws(
subposteriors, num_draws=num_draws, diagonal=True, rng_key=k3
),
"Parametric (full cov)": parametric_draws(
subposteriors, num_draws=num_draws, diagonal=False, rng_key=k4
),
}
return results
[12]:
# Merge the subposteriors.
rng_key, merge_key = random.split(rng_key)
merged = merge_subposteriors(
subposteriors,
num_draws=num_samples,
rng_key=merge_key,
)
We store the results as an inferencedata object for each method.
[13]:
idatas = {}
idatas["Full NUTS"] = idata_full_mcmc
for method, samples in merged.items():
idata = az.from_dict(
{k: v[jnp.newaxis, :] for k, v in samples.items()},
)
rng_key, predictive_key = random.split(rng_key)
pred = Predictive(model, samples)
posterior_predictive = az.from_numpyro(
posterior_predictive=pred(predictive_key, x),
)
idata.extend(posterior_predictive)
idatas[method] = idata
Plotting
We now analyze some plots to understand the results.
Posterior
We plot the posterior distributions of all model parameters.
[14]:
axes = az.plot_forest(
[p["posterior"] for p in idatas.values()],
model_names=idatas.keys(),
combined=True,
figsize=(6, 6),
)
ax = axes.flatten()[0]
legend = ax.get_legend()
legend.set_bbox_to_anchor((1.0, 0.60))
ax.set_title("Posterior of Model Parameters (94% HDI)", fontsize=18, fontweight="bold");
Posterior Predictive
For each method, we compute the posterior predictive mean and a \(94\%\) and \(50\%\) credible intervals.
[15]:
def plot_posterior_predictive(x, y, idatas):
fig, ax = plt.subplots(
nrows=len(idatas),
ncols=1,
figsize=(8, 12),
sharex=True,
sharey=True,
layout="constrained",
)
for j, (method, idata) in enumerate(idatas.items(), start=0):
ax[j].scatter(x, y, c="black", alpha=0.2, label="observed")
for i, hdi in enumerate([0.94, 0.5]):
az.plot_hdi(
x,
idata["posterior_predictive"]["obs"],
hdi_prob=hdi,
color=f"C{j}",
fill_kwargs={"alpha": 0.3 + 0.3 * i, "label": f"{hdi: .0%} HDI"},
ax=ax[j],
)
ax[j].legend(loc="upper right")
ax[j].set_title(method)
fig.suptitle(
"Fitted Curves: Full NUTS vs Embarrassingly Parallel Methods",
fontsize=18,
fontweight="bold",
)
plot_posterior_predictive(x, y, idatas)
2D Posterior Correlations
We plot the joint posterior of amplitude and rate. These two parameters interact non-linearly in the exponential decay model, producing a characteristic banana-shaped correlation. This is where the differences between merging methods are most visible.
[16]:
def plot_2d_comparison(idatas, var0="amplitude", var1="rate"):
fig, ax = plt.subplots(
nrows=2,
ncols=3,
figsize=(10, 7),
sharex=True,
sharey=True,
layout="constrained",
)
ax = ax.flatten()
for i, (method, idata) in enumerate(idatas.items()):
az.plot_kde(
idata.posterior[var0].sel(chain=0),
idata.posterior[var1].sel(chain=0),
ax=ax[i],
)
ax[i].set(xlabel=var0, ylabel=var1)
ax[i].set_title(method)
fig.suptitle(
f"2D Posterior: {var0} vs {var1}",
fontsize=18,
fontweight="bold",
)
plot_2d_comparison(idatas, var0="amplitude", var1="rate")
Discussion: Why Do the Methods Differ Here?
The 2D plots above reveal a key difference: the full NUTS posterior for \((A, r)\) has a rotated elliptical correlation structure, not an circular one. This rotation arises because \(A\) and \(r\) enter the model as \(A \exp(-r x)\), a product of parameters inside a non-linear function.
Each merging method handles this curvature differently:
Full NUTS (baseline): Samples the true posterior directly, capturing the full banana shape.
Consensus (full cov): Computes a weighted average of subposterior samples using the full inverse covariance. Since it operates on the actual samples, it can preserve some of the non-Gaussian shape, though the linear weighting still tends to “straighten” the banana.
Consensus (diagonal): Uses only per-parameter inverse variances, ignoring the \(A\)–\(r\) correlation entirely. This produces a rounder, less structured joint distribution.
Parametric (full cov): Fits a multivariate Gaussian to each subposterior and combines them analytically. The result is an elliptical distribution that captures the orientation of the correlation but loses the curvature.
Parametric (diagonal): Fits independent Gaussians per parameter, losing both the correlation and the curvature.
The marginal posteriors (forest plot) do not differ that much because we have many observations per shard.
Summary
This example illustrates the trade-offs of embarrassingly parallel MCMC methods:
When the posterior is approximately Gaussian (e.g., linear models with large \(N\) per shard), all methods: consensus and parametric, diagonal and full covariance, produce similar results. In this regime, the Gaussian assumption is justified and the methods work well.
When the posterior is non-Gaussian (e.g., non-linear models, small \(N\) per shard, parameters near boundaries), the methods diverge:
Parametric methods force a Gaussian shape, which can distort both the marginals (losing skewness) and the joint distribution (losing curvature).
Consensus methods operate directly on samples, better preserving non-Gaussian features, especially when using the full covariance.
Diagonal variants of both methods ignore cross-parameter correlations, which matters when parameters are strongly correlated (as \(A\) and \(r\) are here).
Practical guidance: Consensus MC with full covariance (
diagonal=False) is generally the safest default for embarrassingly parallel MCMC, as it best preserves the posterior structure. Parametric methods are cheaper and work well when the Gaussian assumption holds (large \(N\), linear or mildly non-linear models). When in doubt, compare the parallel results against a full NUTS run on a subset of the data as a sanity check.