Effect Handlers in NumPyro: A Practical Guide
NumPyro’s effect handler system is the backbone of its inference engine. Every time your model calls numpyro.sample(...), the call passes through a stack of handlers that can intercept, modify, or record the operation, without changing the model code itself. This composability is what makes NumPyro (and Pyro) uniquely powerful among probabilistic programming languages.
In this tutorial we will:
Understand what effect handlers are and how they work.
Walk through each of NumPyro’s built-in handlers with concrete examples.
Learn how to compose handlers for practical tasks like posterior prediction, causal inference, and reparameterization.
Prerequisites: Basic familiarity with NumPyro models (e.g. numpyro.sample, numpyro.plate, MCMC). No prior knowledge of effect handlers is assumed.
Tutorial Outline:
What Are Effect Handlers?
Specifying the Model
Inspection:
seedandtraceConditioning:
conditionvssubstituteModel Surgery:
block,uncondition,liftComposition:
scopeandreplayScaling and Masking:
scaleandmaskCausal Inference:
doReparameterization:
reparamComposing Handlers: Nesting and Order
Practical Recipe: Posterior Predictive with Handlers
Summary and References
Prepare Notebook
[1]:
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import jax
import jax.numpy as jnp
import numpyro
from numpyro import handlers
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
from numpyro.infer.reparam import LocScaleReparam
numpyro.set_host_device_count(4)
plt.style.use("bmh")
plt.rcParams["figure.figsize"] = [10, 6]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"
rng_key = jax.random.key(42)
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
What Are Effect Handlers?
In NumPyro, a probabilistic model is just a Python function that calls primitives like numpyro.sample("x", dist.Normal(0, 1)). Normally, this would just draw a random sample. But what if we want to:
Record every sample site for debugging? → use
traceFix a variable to an observed value? → use
conditionIntervene on a variable for causal inference? → use
doBlock certain variables from being seen by inference? → use
block
Effect handlers let us do all of this without modifying the model function. They work by intercepting the messages that primitives send through a global handler stack.
The Messenger Pattern
The conceptual foundation comes from Pyro’s Mini Pyro tutorial, which is an excellent resource for understanding the design from scratch. Here is the key idea:
Every handler is a
Messengersubclass that gets pushed onto a global stack (_PYRO_STACK) when used as a context manager.When a primitive like
sample()is called, it creates a message dictionary and passes it throughapply_stack().apply_stackhas three phases:Phase 1 — ``process_message`` (bottom → top of stack): each handler can inspect and modify the message.
Phase 2 — default execution: if no handler set
msg["value"], the distribution is sampled.Phase 3 — ``postprocess_message`` (top → bottom): handlers can record or further modify the final result.
Message Anatomy
A sample-site message is a dictionary with fields like:
{
"type": "sample", # primitive type
"name": "slope", # site name
"fn": Normal(0, 5), # distribution
"value": None, # sampled/observed value (filled by handler or default)
"is_observed": False, # True if this is observed data
"scale": None, # log-prob scaling factor
"cond_indep_stack": [], # enclosing plates
"infer": {}, # metadata for inference algorithms
}
Each handler reads and writes specific fields in this dictionary. Understanding which fields each handler touches is the key to mastering their composition.
Two Usage Modes
Every handler can be used as a context manager or as a function wrapper:
# Context manager
with handlers.condition(data={"x": 1.0}):
model()
# Function wrapper (equivalent)
handlers.condition(model, data={"x": 1.0})()
Throughout this tutorial we will use both styles depending on which is clearer.
Specifying the Model
We will use a simple Bayesian linear regression model throughout this tutorial. This gives us a concrete, familiar context for exploring each handler.
First, we generate and visualize some synthetic data:
[2]:
n_obs = 50
rng_key, rng_subkey = jax.random.split(rng_key)
true_intercept, true_slope, true_sigma = 1.0, 2.0, 0.5
x_data = jnp.linspace(-2, 2, n_obs)
y_data = (
true_intercept
+ true_slope * x_data
+ true_sigma * jax.random.normal(rng_subkey, shape=(n_obs,))
)
fig, ax = plt.subplots()
ax.scatter(x_data, y_data, c="black")
ax.set(
title="Synthetic Data",
xlabel="x",
ylabel="y",
);
Now we define the linear model:
[3]:
def linear_model(x):
intercept = numpyro.sample("intercept", dist.Normal(0.0, 10.0))
slope = numpyro.sample("slope", dist.Normal(0.0, 5.0))
sigma = numpyro.sample("sigma", dist.HalfNormal(1.0))
mu = numpyro.deterministic("mu", intercept + slope * x)
with numpyro.plate("data", len(x)):
numpyro.sample("obs", dist.Normal(mu, sigma))
# Visualize the model structure
numpyro.render_model(linear_model, model_args=(x_data,), render_distributions=True)
[3]:
Inspection: seed and trace
These two handlers are the foundation for everything else. seed provides the random number generation keys that JAX requires, and trace records the execution of every primitive site.
seed: Providing Randomness
JAX has no global random state. Every random operation needs an explicit PRNG key. The seed handler automatically splits keys for each sample site so you don’t have to thread keys manually.
[4]:
# Using seed as a context manager
with handlers.seed(rng_seed=0):
val = numpyro.sample("x", dist.Normal(0.0, 1.0))
print(f"Sampled value: {val:.4f}")
Sampled value: -2.4425
[5]:
# Using seed as a function wrapper (equivalent)
val2 = handlers.seed(lambda: numpyro.sample("x", dist.Normal(0.0, 1.0)), rng_seed=0)()
print(f"Sampled value (wrapper): {val2:.4f}")
assert jnp.allclose(val, val2)
Sampled value (wrapper): -2.4425
trace: Recording Execution
The trace handler records every message into an OrderedDict keyed by site name. This is the primary introspection tool for understanding what a model does.
[6]:
# Trace the model and inspect the result
exec_trace = handlers.trace(handlers.seed(linear_model, rng_seed=0)).get_trace(x_data)
print(f"{'Site':>12s} | {'Type':>15s} | {'is_observed':>12s} | {'Value shape'}")
print("-" * 70)
for name, site in exec_trace.items():
print(
f"{name:>12s} | {site['type']:>15s} | "
f"{str(site.get('is_observed', 'N/A')):>12s} | "
f"{jnp.shape(site['value'])}"
)
Site | Type | is_observed | Value shape
----------------------------------------------------------------------
intercept | sample | False | ()
slope | sample | False | ()
sigma | sample | False | ()
mu | deterministic | N/A | (50,)
data | plate | N/A | (50,)
obs | sample | False | (50,)
Let’s deep-dive into the message dictionary for the slope site to see all the fields a handler can read or modify:
[7]:
slope_site = exec_trace["slope"]
# Show key fields (excluding internal JAX arrays for readability)
slope_site
[7]:
{'type': 'sample',
'name': 'slope',
'fn': <numpyro.distributions.continuous.Normal object at 0x1352e6030 with batch shape () and event shape ()>,
'args': (),
'kwargs': {'rng_key': Array((), dtype=key<fry>) overlaying:
[1353695780 2116000888],
'sample_shape': ()},
'value': Array(-6.2873883, dtype=float32),
'scale': None,
'is_observed': False,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}}
Conditioning: condition vs substitute
These two handlers both set the value of a site, but they differ in a critical way:
Handler |
Sets |
Sets |
Use case |
|---|---|---|---|
|
Yes |
Yes (→ |
Observed data, clamping for inference |
|
Yes |
No (stays |
Plugging in posterior samples, testing |
The is_observed flag matters because it affects how inference algorithms treat the site: observed sites contribute to the likelihood but are not sampled.
[8]:
# condition: sets value AND marks as observed
conditioned_trace = handlers.trace(
handlers.seed(handlers.condition(linear_model, data={"slope": 3.0}), rng_seed=0)
).get_trace(x_data)
print("condition:")
print(f" slope value: {conditioned_trace['slope']['value']}")
print(f" slope is_observed: {conditioned_trace['slope']['is_observed']}")
condition:
slope value: 3.0
slope is_observed: True
[9]:
# substitute: sets value but does NOT mark as observed
substituted_trace = handlers.trace(
handlers.seed(handlers.substitute(linear_model, data={"slope": 3.0}), rng_seed=0)
).get_trace(x_data)
print("substitute:")
print(f" slope value: {substituted_trace['slope']['value']}")
print(f" slope is_observed: {substituted_trace['slope']['is_observed']}")
substitute:
slope value: 3.0
slope is_observed: False
When to use which:
Use
conditionwhen you want to observe a variable, i.e., fix it to a known value and have it contribute to the log-likelihood (e.g., conditioning on data).Use
substitutewhen you want to plug in a value without changing the observation status (e.g., substituting posterior samples for prediction).
Model Surgery: block, uncondition, lift
These handlers let you modify which sites are visible, which are observed, and whether parameters become random variables.
block: Hiding Sites
The block handler hides sites from outer handlers by setting msg["stop"] = True. Hidden sites still execute, but they won’t appear in an outer trace.
[10]:
# Hide the "sigma" site
blocked_trace = handlers.trace(
handlers.seed(handlers.block(linear_model, hide=["sigma"]), rng_seed=0)
).get_trace(x_data)
print("Sites in blocked trace:", list(blocked_trace.keys()))
assert "sigma" not in blocked_trace
Sites in blocked trace: ['intercept', 'slope', 'mu', 'data', 'obs']
[11]:
# Inverse: expose only specific sites (hide everything else)
exposed_trace = handlers.trace(
handlers.seed(
handlers.block(linear_model, expose=["slope", "intercept"]), rng_seed=0
)
).get_trace(x_data)
print("Sites in exposed trace:", list(exposed_trace.keys()))
Sites in exposed trace: ['intercept', 'slope']
[12]:
# Custom logic with hide_fn
hidden_samples_trace = handlers.trace(
handlers.seed(
handlers.block(
linear_model,
hide_fn=lambda site: site["type"] == "sample" and not site["is_observed"],
),
rng_seed=0,
)
).get_trace(x_data)
print("Sites (hiding unobserved samples):", list(hidden_samples_trace.keys()))
Sites (hiding unobserved samples): ['mu', 'data']
uncondition: Forgetting Observations
The uncondition handler forces observed sites to sample from their prior distribution instead of using the observed value. This is useful for prior predictive checking when your model has hard-coded observations.
[13]:
# Condition the model on the observed data.
conditioned_linear_model = handlers.condition(linear_model, data={"obs": y_data})
# Normal execution: obs is observed
obs_trace = handlers.trace(
handlers.seed(conditioned_linear_model, rng_seed=0)
).get_trace(x_data)
print(f"obs is_observed: {obs_trace['obs']['is_observed']}")
print(f"obs value[:3]: {obs_trace['obs']['value'][:3]}")
print(f"y_data[:3]: {y_data[:3]}")
obs is_observed: True
obs value[:3]: [-2.697118 -2.4372127 -3.127933 ]
y_data[:3]: [-2.697118 -2.4372127 -3.127933 ]
[14]:
# Unconditioned: samples from prior even though y was passed
uncond_trace = handlers.trace(
handlers.uncondition(handlers.seed(conditioned_linear_model, rng_seed=0))
).get_trace(x_data)
print(f"obs is_observed: {uncond_trace['obs']['is_observed']}")
print(f"obs was_observed: {uncond_trace['obs']['infer']['was_observed']}")
print(f"obs value[:3]: {uncond_trace['obs']['value'][:3]} (sampled, not y_data!)")
obs is_observed: False
obs was_observed: True
obs value[:3]: [-15.044792 -12.289793 -10.538498] (sampled, not y_data!)
lift: Parameters to Random Variables
The lift handler converts numpyro.param sites into numpyro.sample sites by providing a prior distribution. This bridges the gap between optimization-based inference (SVI with params) and full Bayesian inference (MCMC with samples).
[15]:
def model_with_param(x):
"""A model using numpyro.param (typically for SVI/optimization)."""
s = numpyro.param("s", 1.0, constraint=dist.constraints.positive)
loc = numpyro.sample("loc", dist.Normal(0.0, s))
numpyro.sample("obs", dist.Normal(loc, 1.0), obs=x)
# Without lift: 's' is a param
param_trace = handlers.trace(handlers.seed(model_with_param, rng_seed=0)).get_trace(
jnp.ones(5)
)
print(f"Without lift: 's' type = {param_trace['s']['type']}")
# With lift: 's' becomes a sample from Exponential(1)
lifted_model = handlers.lift(model_with_param, prior={"s": dist.Exponential(1.0)})
lifted_trace = handlers.trace(handlers.seed(lifted_model, rng_seed=0)).get_trace(
jnp.ones(5)
)
print(f"With lift: 's' type = {lifted_trace['s']['type']}")
print(f" 's' value = {lifted_trace['s']['value']:.4f}")
Without lift: 's' type = param
With lift: 's' type = sample
's' value = 0.0073
Composition: scope and replay
These handlers support modular model building and deterministic replay of execution traces.
scope: Namespacing Sites
The scope handler prepends a prefix to all site names, enabling you to reuse the same model component multiple times without name collisions.
[16]:
def component():
z = numpyro.sample("z", dist.Normal(0.0, 1.0))
return z
def composed_model():
with handlers.scope(prefix="first", divider="::"):
z1 = component()
with handlers.scope(prefix="second", divider="/"):
z2 = component()
numpyro.deterministic("diff", z1 - z2)
trace_composed = handlers.trace(handlers.seed(composed_model, rng_seed=0)).get_trace()
print("Sites:", list(trace_composed.keys()))
Sites: ['first::z', 'second/z', 'diff']
The divider parameter controls the separator (default is "/"):
with handlers.scope(prefix="group", divider="."):
... # sites become "group.x", "group.y", etc.
replay: Deterministic Replay
The replay handler substitutes sample values from a previously recorded trace. Unlike substitute (which takes a simple dict), replay takes a full trace dict with metadata.
[17]:
# Record a trace
original_trace = handlers.trace(handlers.seed(linear_model, rng_seed=0)).get_trace(
x_data
)
print(f"Original slope: {original_trace['slope']['value']:.4f}")
# Replay with a DIFFERENT seed — replay overrides the sampled values
replayed_trace = handlers.trace(
handlers.replay(handlers.seed(linear_model, rng_seed=99), trace=original_trace)
).get_trace(x_data)
print(f"Replayed slope: {replayed_trace['slope']['value']:.4f}")
assert jnp.allclose(original_trace["slope"]["value"], replayed_trace["slope"]["value"])
print("Values match despite different seed!")
Original slope: -6.2874
Replayed slope: -6.2874
Values match despite different seed!
Scaling and Masking: scale and mask
These handlers modify the log-probability contribution of sample sites.
scale: Rescaling Log-Probabilities
The scale handler multiplies msg["scale"] by a positive factor. This is used internally for data subsampling (upweighting a mini-batch to represent the full dataset).
[18]:
def simple_obs_model():
z = numpyro.sample("z", dist.Normal(0.0, 1.0))
numpyro.sample("obs", dist.Normal(z, 1.0), obs=0.5)
# Without scale
normal_trace = handlers.trace(handlers.seed(simple_obs_model, rng_seed=0)).get_trace()
print(f"obs scale (no handler): {normal_trace['obs']['scale']}")
# With scale=10.0
with handlers.trace() as scaled_trace:
with handlers.scale(scale=10.0):
handlers.seed(simple_obs_model, rng_seed=0)()
print(f"obs scale (with handler): {scaled_trace['obs']['scale']}")
obs scale (no handler): None
obs scale (with handler): 10.0
mask: Element-wise Masking
The mask handler applies a boolean mask to the log-probability computation. Elements where the mask is False are excluded from the likelihood. This is useful for handling missing data or implementing train/test splits within a model.
[19]:
def model_for_mask():
z = numpyro.sample("z", dist.Normal(0.0, 1.0))
numpyro.sample("obs", dist.Normal(z, 1.0), obs=jnp.array([1.0, 2.0, 3.0]))
# Mask out the second observation
mask_array = jnp.array([True, False, True])
masked_trace = handlers.trace(
handlers.seed(handlers.mask(model_for_mask, mask=mask_array), rng_seed=0)
).get_trace()
print(f"obs fn type: {type(masked_trace['obs']['fn']).__name__}")
print("(The distribution is now wrapped in a MaskedDistribution)")
obs fn type: MaskedDistribution
(The distribution is now wrapped in a MaskedDistribution)
You can also query the current mask inside your model using numpyro.get_mask(). This is useful for conditionally skipping expensive computations:
def model():
if numpyro.get_mask() is not False:
numpyro.factor("expensive_term", expensive_computation())
Causal Inference: do
The do handler implements Pearl’s do-calculus intervention, following the Single World Intervention Graph (SWIG) framework by Richardson & Robins. This is one of NumPyro’s most unique features, few other PPLs offer first-class support for causal interventions (PyMC does offer pm.do).
How do Works
When you call handlers.do(model, data={"z": val}), the handler:
Creates a fresh sample at the original site name
"z"(sampled freely from the prior, this represents the counterfactual “what would have happened”). This fresh sample appears in the trace.Internally renames the original message to
"z__CF"and sets it to the intervention value, but this is hidden from outer handlers (stop=True).The intervened value
valis what downstream code actually uses.
The net effect: the model behaves as if z was hard-coded to val, while the trace still records a free sample at "z" for bookkeeping.
Let’s see this with a simple causal chain:
[20]:
def causal_chain(x):
"""x → z → y"""
z = numpyro.sample("z", dist.Normal(x, 1.0))
y = numpyro.sample("y", dist.Normal(z, 0.5))
return y
# Natural (observational) execution
natural_trace = handlers.trace(handlers.seed(causal_chain, rng_seed=0)).get_trace(1.0)
print(
f"Natural: z = {natural_trace['z']['value']:.3f}, "
f"y = {natural_trace['y']['value']:.3f}"
)
Natural: z = -1.442, y = -2.071
[21]:
# Intervene: do(z = 5.0)
intervened_model = handlers.do(causal_chain, data={"z": 5.0})
with handlers.trace() as int_trace:
y_result = handlers.seed(intervened_model, rng_seed=0)(1.0)
print("After do(z=5.0):")
print(f" Trace sites: {list(int_trace.keys())}")
print(
f" 'z' in trace (fresh sample, unused downstream): "
f"value={int_trace['z']['value']:.3f}, "
f"is_observed={int_trace['z']['is_observed']}"
)
print(f" Returned y: {y_result:.3f} (downstream used the intervened z=5.0)")
# The key insight: z in the trace is NOT the intervention value,
# but downstream code (y) used the intervention z=5.0
assert int_trace["z"]["value"] != 5.0, (
"z in trace is the free sample, not the intervention"
)
After do(z=5.0):
Trace sites: ['z', 'y']
'z' in trace (fresh sample, unused downstream): value=-1.442, is_observed=False
Returned y: 4.371 (downstream used the intervened z=5.0)
Applied Example: Collections Email Campaign
We now apply the do handler to a real-world causal inference problem from Causal Inference for the Brave and True, Ch.7 (see also this NumPyro causal inference notebook).
A fintech company ran a randomized experiment sending debt negotiation emails to customers with late payments. The question: does the email cause customers to pay more?
Variables:
email(treatment): whether the customer received the emailpayments(outcome): amount paidcredit_limit,risk_score(good controls): pre-treatment predictors of payments that reduce residual variance
Since email was randomly assigned, there is no confounding. However, including credit_limit and risk_score as good controls improves statistical power by explaining outcome variance.
[22]:
# Load data
data_url = (
"https://raw.githubusercontent.com/matheusfacure/"
"python-causality-handbook/master/causal-inference-for-the-brave-and-true"
"/data/collections_email.csv"
)
df = pd.read_csv(data_url)
df.head()
[22]:
| payments | opened | agreement | credit_limit | risk_score | ||
|---|---|---|---|---|---|---|
| 0 | 740 | 1 | 1.0 | 0.0 | 2348.495260 | 0.666752 |
| 1 | 580 | 1 | 1.0 | 1.0 | 334.111969 | 0.207395 |
| 2 | 600 | 1 | 1.0 | 1.0 | 1360.660722 | 0.550479 |
| 3 | 770 | 0 | 0.0 | 0.0 | 1531.828576 | 0.560488 |
| 4 | 660 | 0 | 0.0 | 0.0 | 979.855647 | 0.455140 |
[23]:
# Data statistics
df.describe()
[23]:
| payments | opened | agreement | credit_limit | risk_score | ||
|---|---|---|---|---|---|---|
| count | 5000.000000 | 5000.000000 | 5000.000000 | 5000.000000 | 5000.000000 | 5000.000000 |
| mean | 669.672000 | 0.490800 | 0.273400 | 0.160800 | 1194.845188 | 0.480812 |
| std | 103.970065 | 0.499965 | 0.445749 | 0.367383 | 480.978996 | 0.100376 |
| min | 330.000000 | 0.000000 | 0.000000 | 0.000000 | 193.695573 | 0.131784 |
| 25% | 600.000000 | 0.000000 | 0.000000 | 0.000000 | 843.049867 | 0.414027 |
| 50% | 670.000000 | 0.000000 | 0.000000 | 0.000000 | 1127.640297 | 0.486389 |
| 75% | 730.000000 | 1.000000 | 1.000000 | 0.000000 | 1469.096523 | 0.552727 |
| max | 1140.000000 | 1.000000 | 1.000000 | 1.000000 | 3882.178408 | 0.773459 |
[24]:
# Prepare data for the model
email = jnp.array(df["email"].values, dtype=jnp.float32)
payments = jnp.array(df["payments"].values, dtype=jnp.float32)
credit_limit = jnp.array(df["credit_limit"].values, dtype=jnp.float32)
risk_score = jnp.array(df["risk_score"].values, dtype=jnp.float32)
We define a Bayesian structural causal model. The key design: email is a sample site so the do handler can intervene on it.
[25]:
def email_causal_model(credit_limit, risk_score, email=None, payments=None):
n_obs = credit_limit.shape[0]
# Treatment model: email assignment (in this RCT it's random,
# but we model it to enable do-calculus)
email_prob = numpyro.sample("email_prob", dist.Beta(1.0, 1.0))
email = numpyro.sample(
"email", dist.Bernoulli(email_prob).expand([n_obs]).to_event(1), obs=email
)
# Outcome model: payments ~ email + credit_limit + risk_score
b_intercept = numpyro.sample("b_intercept", dist.Normal(500.0, 100.0))
b_email = numpyro.sample("b_email", dist.Normal(0.0, 10.0))
b_credit = numpyro.sample("b_credit", dist.Normal(0.0, 10.0))
b_risk = numpyro.sample("b_risk", dist.Normal(0.0, 10.0))
sigma = numpyro.sample("sigma", dist.HalfNormal(10.0))
mu = b_intercept + b_email * email + b_credit * credit_limit + b_risk * risk_score
with numpyro.plate("data", n_obs):
numpyro.sample("payments", dist.Normal(mu, sigma), obs=payments)
numpyro.render_model(
email_causal_model,
model_args=(credit_limit, risk_score),
render_distributions=True,
)
[25]:
[26]:
# Fit the model with MCMC
kernel = NUTS(email_causal_model)
mcmc = MCMC(kernel, num_warmup=1_000, num_samples=1_000, num_chains=4)
mcmc.run(jax.random.key(0), credit_limit, risk_score, email=email, payments=payments)
mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat
b_credit 0.15 0.00 0.15 0.14 0.15 3909.11 1.00
b_email 4.28 2.06 4.32 0.77 7.49 3562.58 1.00
b_intercept 488.97 3.80 488.94 482.90 495.46 2869.60 1.00
b_risk -0.05 9.54 0.06 -16.63 14.78 2733.98 1.00
email_prob 0.49 0.01 0.49 0.48 0.50 3784.22 1.00
sigma 74.77 0.76 74.77 73.55 76.04 4074.74 1.00
Number of divergences: 0
[27]:
idata = az.from_numpyro(posterior=mcmc)
axes = az.plot_trace(
data=idata,
var_names=["b_intercept", "b_email", "b_credit", "b_risk", "sigma"],
compact=True,
backend_kwargs={"figsize": (10, 7), "layout": "constrained"},
)
plt.gcf().suptitle("Posterior Trace Plots", fontsize=16, y=1.02)
plt.show()
The b_email coefficient tells us the observational association.
[28]:
point_estimate_blog = 4.4304
ci_blog = (0.255, 8.606)
fig, ax = plt.subplots()
az.plot_posterior(
idata,
var_names=["b_email"],
hdi_prob=0.95,
ref_val=point_estimate_blog,
ax=ax,
)
ax.axvspan(
ci_blog[0],
ci_blog[1],
alpha=0.15,
color="C1",
label="95% Confidence Interval (Blog)",
)
ax.legend()
ax.set(xlabel="ATE")
ax.set_title(
"Posterior Distribution of b_email (Average Treatment Effect)",
fontsize=18,
fontweight="bold",
);
Now let’s use the do handler to compute the interventional average treatment effect (ATE).
[29]:
posterior_samples = mcmc.get_samples()
# do(email=1): everyone receives email
intervened_model_1 = handlers.do(
email_causal_model, data={"email": jnp.ones_like(email)}
)
predictive_do_1 = Predictive(intervened_model_1, posterior_samples=posterior_samples)
rng_key, rng_subkey = jax.random.split(rng_key)
predictions_do_1 = az.from_dict(
{
k: v[jnp.newaxis, :]
for k, v in predictive_do_1(rng_subkey, credit_limit, risk_score).items()
},
coords={"obs_idx": np.arange(len(email))},
dims={"email": ["obs_idx"], "payments": ["obs_idx"]},
)
# do(email=0): no one receives email
intervened_model_0 = handlers.do(
email_causal_model, data={"email": jnp.zeros_like(email)}
)
predictive_do_0 = Predictive(intervened_model_0, posterior_samples=posterior_samples)
predictions_do_0 = az.from_dict(
{
k: v[jnp.newaxis, :]
for k, v in predictive_do_0(rng_subkey, credit_limit, risk_score).items()
},
coords={"obs_idx": np.arange(len(email))},
dims={"email": ["obs_idx"], "payments": ["obs_idx"]},
)
# Compute ATE (average over individuals, keep posterior samples)
# Transform back to original scale
y_do_1 = predictions_do_1["posterior"]["payments"].mean(dim="obs_idx")
y_do_0 = predictions_do_0["posterior"]["payments"].mean(dim="obs_idx")
ate_posterior = (y_do_1 - y_do_0).rename("ate")
Let’s compare the interventional ATE with the regression coefficient:
[30]:
ax, *_ = az.plot_forest(
[idata["posterior"]["b_email"].rename("ate"), ate_posterior],
model_names=["Regression", "Do-calculus"],
combined=True,
figsize=(6, 4),
)
ax.set_title("ATE Posterior Distribution", fontsize=18, fontweight="bold");
Since email was randomly assigned in this experiment, the regression coefficient and the interventional ATE from do-calculus should be very similar: both estimating around \(4\) units of additional payment. The do handler provides the machinery to do this in a principled way, which becomes essential when the treatment or the model are more complex.
Reparameterization: reparam
The reparam handler transforms sample sites to improve posterior geometry. This is critical for hierarchical models that suffer from the “funnel” problem, where the posterior has regions of high curvature that make sampling difficult.
We demonstrate this with the classic Eight Schools example from Gelman et al., Bayesian Data Analysis (Sec. 5.5, 2003).
[31]:
# Eight Schools data
J = 8
y_schools = jnp.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma_schools = jnp.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
def eight_schools(J, sigma, y=None):
mu = numpyro.sample("mu", dist.Normal(0, 5))
tau = numpyro.sample("tau", dist.HalfCauchy(5))
with numpyro.plate("J", J):
theta = numpyro.sample("theta", dist.Normal(mu, tau))
numpyro.sample("obs", dist.Normal(theta, sigma), obs=y)
Centered Parameterization (with divergences)
The centered model theta ~ Normal(mu, tau) creates a “funnel”: when tau is small, theta values must be tightly clustered around mu, creating a narrow region that NUTS struggles to navigate.
[32]:
kernel_centered = NUTS(
eight_schools,
target_accept_prob=0.9,
)
mcmc_centered = MCMC(
kernel_centered,
num_warmup=1_000,
num_samples=1_000,
num_chains=4,
)
rng_key, rng_subkey = jax.random.split(rng_key)
mcmc_centered.run(rng_subkey, J, sigma_schools, y=y_schools)
mcmc_centered.print_summary(exclude_deterministic=False)
n_divergences_centered = mcmc_centered.get_extra_fields()["diverging"].sum()
print(f"\nNumber of divergences (centered): {n_divergences_centered}")
mean std median 5.0% 95.0% n_eff r_hat
mu 4.18 3.30 4.19 -0.84 9.74 411.66 1.01
tau 3.87 3.03 3.03 0.43 7.72 267.52 1.01
theta[0] 6.04 5.67 5.44 -3.22 14.14 617.84 1.01
theta[1] 4.81 4.82 4.64 -3.41 11.90 628.66 1.01
theta[2] 3.72 5.21 3.95 -3.93 11.81 944.77 1.01
theta[3] 4.52 4.99 4.53 -3.60 12.07 875.17 1.00
theta[4] 3.39 4.74 3.53 -3.91 11.02 758.90 1.01
theta[5] 3.76 4.86 3.94 -4.18 11.43 722.05 1.01
theta[6] 6.33 5.40 5.71 -1.72 15.22 538.00 1.01
theta[7] 4.65 5.51 4.59 -3.92 12.49 902.01 1.00
Number of divergences: 88
Number of divergences (centered): 88
Non-Centered Parameterization with reparam
The LocScaleReparam(centered=0) handler transforms theta ~ Normal(mu, tau) into theta_decentered ~ Normal(0, 1) and then computes theta = mu + tau * theta_decentered. This eliminates the funnel geometry.
[33]:
reparam_config = {"theta": LocScaleReparam(centered=0)}
reparam_eight_schools = handlers.reparam(eight_schools, config=reparam_config)
kernel_noncentered = NUTS(
reparam_eight_schools,
target_accept_prob=0.9,
)
mcmc_noncentered = MCMC(
kernel_noncentered, num_warmup=1_000, num_samples=1_000, num_chains=4
)
rng_key, rng_subkey = jax.random.split(rng_key)
mcmc_noncentered.run(rng_subkey, J, sigma_schools, y=y_schools)
mcmc_noncentered.print_summary(exclude_deterministic=False)
n_divergences_noncentered = mcmc_noncentered.get_extra_fields()["diverging"].sum()
print(f"\nNumber of divergences (non-centered): {n_divergences_noncentered}")
mean std median 5.0% 95.0% n_eff r_hat
mu 4.37 3.38 4.38 -0.69 10.26 3948.88 1.00
tau 3.66 3.30 2.79 0.00 8.01 2674.42 1.00
theta[0] 6.15 5.57 5.67 -2.82 14.48 3835.82 1.00
theta[1] 4.96 4.74 4.83 -2.22 13.17 4389.62 1.00
theta[2] 3.94 5.36 4.21 -5.02 11.95 4288.62 1.00
theta[3] 4.76 4.91 4.66 -3.35 12.31 4360.83 1.00
theta[4] 3.53 4.78 3.75 -3.95 11.16 4292.33 1.00
theta[5] 3.97 4.90 4.10 -3.72 11.88 4726.12 1.00
theta[6] 6.40 5.11 5.87 -1.98 14.32 4306.90 1.00
theta[7] 4.85 5.38 4.76 -3.11 13.36 4177.88 1.00
theta_decentered[0] 0.30 1.00 0.32 -1.34 1.88 4533.85 1.00
theta_decentered[1] 0.10 0.93 0.09 -1.48 1.58 4267.47 1.00
theta_decentered[2] -0.08 0.96 -0.08 -1.70 1.48 4957.02 1.00
theta_decentered[3] 0.08 0.96 0.08 -1.36 1.80 4671.31 1.00
theta_decentered[4] -0.16 0.92 -0.17 -1.65 1.34 4361.58 1.00
theta_decentered[5] -0.09 0.93 -0.10 -1.58 1.46 4997.15 1.00
theta_decentered[6] 0.36 0.97 0.39 -1.21 1.98 4546.13 1.00
theta_decentered[7] 0.07 0.98 0.09 -1.48 1.67 4644.83 1.00
Number of divergences: 0
Number of divergences (non-centered): 0
Let’s inspect the trace to see the reparameterized sites:
[34]:
reparam_trace = handlers.trace(
handlers.seed(reparam_eight_schools, rng_seed=0)
).get_trace(J, sigma_schools)
print(f"{'Site':>20s} | {'Type':>15s} | {'Shape'}")
print("-" * 55)
for name, site in reparam_trace.items():
print(f"{name:>20s} | {site['type']:>15s} | {jnp.shape(site['value'])}")
Site | Type | Shape
-------------------------------------------------------
mu | sample | ()
tau | sample | ()
J | plate | (8,)
theta_decentered | sample | (8,)
theta | deterministic | (8,)
obs | sample | (8,)
Notice that theta_decentered is a new sample site (sampling from Normal(0, 1)), while theta has become a deterministic site (computed as mu + tau * theta_decentered). This is the key insight: reparam changes the parameterization without changing the mathematical model.
Composing Handlers: Nesting and Order
The real power of effect handlers emerges when you compose them. Handlers are stacked via nesting, and order matters because process_message runs from the bottom of the stack (innermost handler) to the top (outermost handler). The outermost handler’s process_message runs last, so it can overwrite values set by inner handlers.
with handlers.trace(): # TOP — postprocess first, process last
with handlers.condition(): # MIDDLE
with handlers.seed(): # BOTTOM — process first, postprocess last
model()
[35]:
# condition then substitute on same site
trace1 = handlers.trace(
handlers.condition(
handlers.substitute(
handlers.seed(linear_model, rng_seed=0), data={"slope": 10.0}
),
data={"slope": 3.0},
)
).get_trace(x_data)
print(
f"condition(substitute(...)): slope={trace1['slope']['value']}, "
f"is_observed={trace1['slope']['is_observed']}"
)
# substitute then condition on same site
trace2 = handlers.trace(
handlers.substitute(
handlers.condition(
handlers.seed(linear_model, rng_seed=0), data={"slope": 3.0}
),
data={"slope": 10.0},
)
).get_trace(x_data)
print(
f"substitute(condition(...)): slope={trace2['slope']['value']}, "
f"is_observed={trace2['slope']['is_observed']}"
)
condition(substitute(...)): slope=3.0, is_observed=True
substitute(condition(...)): slope=10.0, is_observed=True
In the first case, condition (outer) runs its process_message after substitute (inner), so it overwrites the value to 3.0 and sets is_observed=True.
In the second case, substitute (outer) runs after condition (inner), so it overwrites the value to 10.0, but is_observed was already set to True by condition and substitute does not change it.
[36]:
# Equivalent using context managers
with handlers.trace() as tr:
with handlers.condition(data={"slope": 2.5}):
handlers.seed(linear_model, rng_seed=0)(x_data)
print(f"Context manager style: slope = {tr['slope']['value']}")
Context manager style: slope = 2.5
Practical Recipe: Posterior Predictive with Handlers
Let’s tie everything together by using handlers to implement the full posterior predictive workflow, the same operations that Predictive does internally.
Prior Predictive (using seed + trace)
To get a prior predictive sample, we simply trace the model without conditioning on any data:
[37]:
rng_key, rng_subkey = jax.random.split(rng_key)
prior_trace = handlers.trace(
handlers.seed(linear_model, rng_seed=rng_subkey)
).get_trace(x_data)
y_prior = prior_trace["obs"]["value"]
print(f"Prior predictive sample shape: {y_prior.shape}")
Prior predictive sample shape: (50,)
[38]:
fig, ax = plt.subplots()
ax.scatter(x_data, y_data, c="black", label="observed data")
ax.scatter(x_data, y_prior, c="C0", label="prior predictive sample")
ax.legend()
ax.set(
title="Synthetic Data",
xlabel="x",
ylabel="y",
);
Posterior Predictive (using substitute + seed + trace)
To get a posterior predictive sample, we substitute one set of posterior samples and trace the model:
[39]:
# First, fit the linear model with MCMC
kernel = NUTS(conditioned_linear_model)
mcmc_linear = MCMC(
kernel, num_warmup=1_000, num_samples=1_000, num_chains=4, progress_bar=False
)
rng_key, rng_subkey = jax.random.split(rng_key)
mcmc_linear.run(rng_subkey, x_data)
mcmc_linear.print_summary()
posterior_samples_linear = mcmc_linear.get_samples()
mean std median 5.0% 95.0% n_eff r_hat
intercept 0.91 0.06 0.91 0.81 1.01 3982.28 1.00
sigma 0.45 0.05 0.45 0.38 0.52 3744.90 1.00
slope 2.02 0.05 2.02 1.93 2.10 3469.39 1.00
Number of divergences: 0
[40]:
# Pick one posterior sample
one_sample = {k: v[0] for k, v in posterior_samples_linear.items()}
post_pred_trace = handlers.trace(
handlers.seed(handlers.substitute(linear_model, data=one_sample), rng_seed=2)
).get_trace(x_data)
y_post = post_pred_trace["obs"]["value"]
print(f"Posterior predictive sample shape: {y_post.shape}")
Posterior predictive sample shape: (50,)
[41]:
fig, ax = plt.subplots()
ax.scatter(x_data, y_data, c="black", label="observed data")
ax.scatter(x_data, y_post, c="C1", label="posterior predictive sample")
ax.legend()
ax.set(
title="Synthetic Data",
xlabel="x",
ylabel="y",
);
Log-Likelihood (using substitute + seed + trace)
We can compose handlers to compute the pointwise log-likelihood:
[42]:
def log_likelihood(rng_key, params, model, *args, **kwargs):
"""Compute pointwise log-likelihood using handler composition."""
model_sub = handlers.substitute(handlers.seed(model, rng_key), params)
model_trace = handlers.trace(model_sub).get_trace(*args, **kwargs)
obs_site = model_trace["obs"]
return obs_site["fn"].log_prob(obs_site["value"])
rng_key, rng_subkey = jax.random.split(rng_key)
ll = log_likelihood(rng_subkey, one_sample, conditioned_linear_model, x_data)
print(f"Log-likelihood shape: {ll.shape}, sum: {ll.sum():.2f}")
Log-likelihood shape: (50,), sum: -29.19
Comparison with Predictive
The Predictive utility does exactly this under the hood, composing substitute, seed, and trace handlers, but wraps it in a convenient vmap over all posterior samples:
[43]:
predictive = Predictive(linear_model, posterior_samples=posterior_samples_linear)
pred_samples = predictive(jax.random.key(4), x_data)
posterior_predictive = az.from_dict(
posterior_predictive={k: v[jnp.newaxis, :] for k, v in pred_samples.items()},
coords={"obs_idx": np.arange(len(x_data))},
dims={"obs_idx": ["obs_idx"], "obs": ["obs_idx"]},
)
[44]:
fig, ax = plt.subplots()
az.plot_hdi(
x_data,
posterior_predictive["posterior_predictive"]["obs"],
hdi_prob=0.5,
fill_kwargs={"alpha": 0.5, "label": "50% HDI"},
ax=ax,
)
az.plot_hdi(
x_data,
posterior_predictive["posterior_predictive"]["obs"],
hdi_prob=0.94,
fill_kwargs={"alpha": 0.3, "label": "94% HDI"},
ax=ax,
)
ax.scatter(x_data, y_data, c="black", label="observed data")
ax.legend()
ax.set(
title="Synthetic Data",
xlabel="x",
ylabel="y",
);
Summary and References
NumPyro provides 13+ composable effect handlers that let you manipulate probabilistic programs without modifying model code. Here is a quick reference:
Handler |
What it does |
Key |
|---|---|---|
|
Provides PRNG key splitting |
|
|
Records all sites into an |
(reads all) |
|
Fixes value and marks as observed |
|
|
Fixes value, keeps observation status |
|
|
Causal intervention (SWIG) |
|
|
Hides sites from outer handlers |
|
|
Forces observed sites to sample from prior |
|
|
Converts |
|
|
Replays values from a previous trace |
|
|
Prepends prefix to site names |
|
|
Multiplies log-prob scaling factor |
|
|
Element-wise boolean masking of log-prob |
|
|
Reparameterizes sample sites |
|
References
Mini Pyro — the foundational tutorial on effect handlers in the Pyro ecosystem
Pyro Effect Handlers — Pyro’s perspective on the handler pattern
NumPyro Handlers API — official API documentation
Causal Inference for the Brave and True, Ch.7 — the collections email example
Single World Intervention Graphs — Richardson & Robins, the theory behind the
dohandlerGelman et al., Bayesian Data Analysis (Sec. 5.5, 2003) — the Eight Schools example