NumPyro Integration with Other Libraries¶
In this notebook we describe how to integrate NumPyro with other libraries to take advantage of alternative inference algorithms. We focus on two libraries:
-
We consider the Pathfinder variational inference algorithm.
-
We look into the normalizing-flow enhanced Markov chain Monte Carlo.
The main idea behind the integration is to use the function numpyro.infer.util.initialize_model
to compute the log-density and the necessary transformations to go from the unconstrained space to the constrained space. Let’s see how to do it.
This example is based on the original example notebook NumPyro with Pathfinder.
Prepare Notebook¶
[1]:
!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro arviz blackjax flowMC
[2]:
import arviz as az
import blackjax
from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline
from flowMC.proposal.MALA import MALA
from flowMC.Sampler import Sampler
import matplotlib.pyplot as plt
import numpy as np
import jax
from jax import random
import numpyro
import numpyro.distributions as dist
from numpyro.infer.util import Predictive, initialize_model
plt.style.use("bmh")
plt.rcParams["figure.figsize"] = [10, 6]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"
jax.config.update("jax_enable_x64", True)
numpyro.set_host_device_count(n=4)
rng_key = random.PRNGKey(seed=42)
assert numpyro.__version__.startswith("0.15.3")
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
Generate Synthetic Data¶
We generate some data from a simple linear regression model.
[3]:
def generate_data(rng_key, a, b, sigma, n):
x = random.normal(rng_key, (n,))
rng_key, rng_subkey = random.split(rng_key)
epsilon = sigma * random.normal(rng_subkey, (n,))
y = a + b * x + epsilon
return x, y
# true parameters
a = 1.0
b = 2.0
sigma = 0.5
n = 100
# generate data
rng_key, rng_subkey = random.split(rng_key)
x, y = generate_data(rng_key, a, b, sigma, n)
# plot data
fig, ax = plt.subplots(figsize=(8, 7))
ax.plot(x, y, "o", c="C0", label="data")
ax.axline((0, a), slope=b, color="C1", label="true mean")
ax.legend(loc="upper left")
ax.set(xlabel="x", ylabel="y", title="Raw Data");
Model Specification¶
We define a simple linear regression model in NumPyro.
[4]:
def model(x, y=None):
a = numpyro.sample("a", dist.Normal(loc=0.0, scale=2.0))
b = numpyro.sample("b", dist.HalfNormal(scale=2.0))
sigma = numpyro.sample("sigma", dist.Exponential(rate=1.0))
mean = numpyro.deterministic("mu", a + b * x)
with numpyro.plate("data", len(x)):
numpyro.sample("likelihood", dist.Normal(loc=mean, scale=sigma), obs=y)
numpyro.render_model(
model=model,
model_args=(x, y),
render_distributions=True,
render_params=True,
)
[4]:
Extract Model Ingredients¶
As mentioned in the introduction, we need to use the function numpyro.infer.util.initialize_model
to extract the log-density and the necessary transformations to go from the unconstrained space to the constrained space needed by Blackjax and FlowMC. The input to this function is the model, the data, and a random key.
[5]:
rng_key, rng_subkey = random.split(rng_key)
param_info, potential_fn, postprocess_fn, *_ = initialize_model(
rng_subkey,
model,
model_args=(x, y),
dynamic_args=True, # <- this is important!
)
param_info
is a namedtupleParamInfo
containing values from the prior used to initiate MCMC.potential_fn
is a callable that returns the potential energy of the model given the data and the parameters.postprocess_fn
is a callable that uses inverse transforms to convert unconstrained HMC samples to constrained values that lie within the site’s support, in addition to returning values atdeterministic
sites in the model.
Let’s extract an initial position from parameters.
[6]:
# get initial position
initial_position = param_info.z
initial_position
[6]:
{'a': Array(-1.5517484, dtype=float64),
'b': Array(1.12366214, dtype=float64),
'sigma': Array(-0.52973833, dtype=float64)}
Remark Observe that the initial position of sigma
is negative. The reason is that the prior distribution for sigma
is dist.Exponential(rate=1.0)
, which is a positive distribution. Hence, we need to transform it to an unconstrained space through a bijective transformation. The function postprocess_fn
will transform this negative value to the positive space using the inverse transform.
Next, we transform the potential energy function to a log-density function.
[7]:
# get log-density from the potential function
def logdensity_fn(position):
func = potential_fn(x, y)
return -func(position)
Let’s verify we can evaluate the log-density function at the initial position.
[8]:
logdensity_fn(initial_position)
[8]:
Array(-1141.81434653, dtype=float64)
Now, we are ready to run our first sampler.
Pathfinder Sampler¶
From Blackjax documentation:
Pathfinder locates normal approximations to the target density along a quasi-Newton optimization path, with local covariance estimated using the inverse Hessian estimates produced by the L-BFGS optimizer. PathfinderState stores for an interation fo the L-BFGS optimizer the resulting ELBO and all factors needed to sample from the approximated target density.
For more information about Pathfinder, please refer to the paper:
Lu Zhang, Bob Carpenter, Andrew Gelman, and Aki Vehtari.Pathfinder: parallel quasi-newton variational inference. Journal of Machine Learning Research, 23(306):1–49, 2022.
Remark: From Blackjax’s sampling book documentation:
L-BFGS algorithm struggles with float32s and log-likelihood functions; it’s suggested to use double precision numbers.
Run Sampler¶
We can now use blackjax.vi.pathfinder.approximate to run the variational inference algorithm.
[9]:
%%time
# run pathfinder
rng_key, rng_subkey = random.split(rng_key)
pathfinder_state, _ = blackjax.vi.pathfinder.approximate(
rng_key=rng_subkey,
logdensity_fn=logdensity_fn,
initial_position=initial_position,
num_samples=15_000,
ftol=1e-4,
)
# sample from the posterior
rng_key, rng_subkey = random.split(rng_key)
posterior_samples_pathfinder, _ = blackjax.vi.pathfinder.sample(
rng_key=rng_subkey,
state=pathfinder_state,
num_samples=5_000,
)
# convert to arviz
idata_pathfinder = az.from_dict(
posterior={
k: np.expand_dims(a=np.asarray(v), axis=0)
for k, v in posterior_samples_pathfinder.items()
},
)
CPU times: user 2.59 s, sys: 278 ms, total: 2.87 s
Wall time: 2.55 s
Visualize Results¶
We can visualize the results after sampling.
[10]:
az.summary(data=idata_pathfinder, round_to=3)
arviz - WARNING - Shape validation failed: input_shape: (1, 5000), minimum_shape: (chains=2, draws=4)
[10]:
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
a | 0.973 | 0.052 | 0.878 | 1.070 | 0.001 | 0.001 | 4882.712 | 4860.828 | NaN |
b | 0.684 | 0.022 | 0.645 | 0.726 | 0.000 | 0.000 | 4797.817 | 4793.793 | NaN |
sigma | -0.632 | 0.063 | -0.753 | -0.515 | 0.001 | 0.001 | 4723.374 | 4790.730 | NaN |
[11]:
axes = az.plot_trace(
data=idata_pathfinder,
compact=True,
figsize=(10, 6),
backend_kwargs={"layout": "constrained"},
)
plt.gcf().suptitle(
t="Pathfinder Trace - Transformed Space", fontsize=18, fontweight="bold"
);
Note that the value for a
is close to the true value of 1.0
. However, the values for b
and sigma
do not match the true values of 2.0
and 0.5
respectively. Again, the reason is that we are working in the unconstrained space. We need to transform the samples to the original space to compare them with the true values.
Transform Samples¶
We can use the postprocess_fn
function returned by initialize_model
to transform the samples from the unconstrained space to the constrained space:
[12]:
# posterior samples
posterior_samples_pathfinder_transformed = jax.vmap(postprocess_fn(x, y))(
posterior_samples_pathfinder
)
# posterior predictive samples
rng_key, rng_subkey = random.split(rng_key)
posterior_predictive_samples_pathfinder_transformed = Predictive(
model=model, posterior_samples=posterior_samples_pathfinder_transformed
)(rng_subkey, x)
Let’s see the posterior distribution in the original space.
[13]:
idata_pathfinder_transformed = az.from_dict(
posterior={
k: np.expand_dims(a=np.asarray(v), axis=0)
for k, v in posterior_samples_pathfinder_transformed.items()
},
posterior_predictive={
k: np.expand_dims(a=np.asarray(v), axis=0)
for k, v in posterior_predictive_samples_pathfinder_transformed.items()
},
)
axes = az.plot_trace(
data=idata_pathfinder_transformed,
var_names=["~mu"],
compact=True,
figsize=(10, 6),
lines=[
("a", {}, a),
("b", {}, b),
("sigma", {}, sigma),
],
backend_kwargs={"layout": "constrained"},
)
plt.gcf().suptitle(
t="Pathfinder Trace - Original Space", fontsize=18, fontweight="bold"
);
Finally, we can visualize the posterior predictive distribution.
[14]:
fig, ax = plt.subplots(figsize=(7, 6))
ax.plot(x, y, "o", c="C0", label="data")
ax.axline((0, a), slope=b, color="C1", label="true mean")
az.plot_hdi(
x=x,
y=idata_pathfinder_transformed["posterior_predictive"]["mu"],
color="C2",
fill_kwargs={"alpha": 0.7, "label": "mu posterior ($94\\%$ HDI)"},
ax=ax,
)
az.plot_hdi(
x=x,
y=idata_pathfinder_transformed["posterior_predictive"]["likelihood"],
color="C2",
fill_kwargs={"alpha": 0.2, "label": "posterior predictive ($94\\%$ HDI)"},
ax=ax,
)
ax.legend(loc="upper left")
ax.set(xlabel="x", ylabel="y", title="Pathfinder Posterior Predictive");
The results look good!
FlowMC Normalizing Flow Sampler¶
We can run the FlowMC sampler in a similar way as above. We just need to adapt the log-density function to the FlowMC format.
Define Log-Density Function¶
[15]:
def logdensity_fn_flowmc(position, data):
"""FlowMC log-density function requires the position to be an array of shape
(n_chains, n_dim) and the data to be a dictionary."""
x = data["x"]
y = data["y"]
dict_position = dict(zip(param_info.z.keys(), position[..., None]))
func = potential_fn(x, y)
return -func(dict_position)
Let’s verify that the log-density function is working.
[16]:
n_dim = 3 # number of parameters
n_chains = 20 # number of chains
[17]:
data = {"x": x, "y": y}
rng_key, subkey = random.split(rng_key)
initial_position_array = jax.random.normal(subkey, shape=(n_chains, n_dim))
[18]:
logdensity_fn_flowmc(initial_position_array, data)
[18]:
Array(-868.2817303, dtype=float64)
Define FlowMC Sampler¶
We can now define the FlowMC sampler. For more details see this example from the documentation.
[19]:
# local sampler: Metropolis-adjusted Langevin algorithm sampler class builiding the mala_sampler method
mala_sampler = MALA(logpdf=logdensity_fn_flowmc, jit=True, step_size=0.1)
rng_key, subkey = random.split(rng_key)
# nortmalizing flow model: Rational quadratic spline normalizing flow model using distrax.
nf_model = MaskedCouplingRQSpline(
n_features=n_dim, n_layers=4, hidden_size=[32, 32], num_bins=8, key=subkey
)
[20]:
%%time
sampler_params = {
"n_loop_training": 7,
"n_loop_production": 7,
"n_local_steps": 150,
"n_global_steps": 100,
"learning_rate": 0.001,
"momentum": 0.9,
"num_epochs": 30,
"batch_size": 10_000,
"use_global": True,
}
rng_key, rng_subkey = random.split(rng_key)
nf_sampler = Sampler(
n_dim=n_dim,
rng_key=rng_subkey,
data=data,
local_sampler=mala_sampler,
nf_model=nf_model,
**sampler_params,
)
nf_sampler.sample(initial_position_array, data)
rng_key, subkey = jax.random.split(rng_key)
nf_samples = nf_sampler.sample_flow(subkey, 5_000)
['n_dim', 'n_chains', 'n_local_steps', 'n_global_steps', 'n_loop', 'output_thinning', 'verbose']
Global Tuning: 0%| | 0/7 [00:00<?, ?it/s]
Compiling MALA body
Global Tuning: 100%|██████████| 7/7 [00:45<00:00, 6.57s/it]
Global Sampling: 100%|██████████| 7/7 [00:00<00:00, 13.46it/s]
CPU times: user 2min 44s, sys: 5min 15s, total: 7min 59s
Wall time: 47.2 s
Visualize Results¶
We collect the posterior samples and visualize the results.
[21]:
posterior_samples_flowmc = dict(zip(param_info.z.keys(), nf_samples.T))
flowmc_idata = az.from_dict(posterior=posterior_samples_flowmc)
[22]:
axes = az.plot_trace(
data=flowmc_idata,
compact=True,
figsize=(10, 6),
backend_kwargs={"layout": "constrained"},
)
plt.gcf().suptitle(
t="FlowMC Trace - Transformed Space", fontsize=18, fontweight="bold"
);
Transform Samples¶
We transform the samples to the original space as we did for Pathfinder.
[23]:
# posterior samples
posterior_samples_flowmc_transformed = jax.vmap(postprocess_fn(x, y))(
posterior_samples_flowmc
)
# posterior predictive samples
rng_key, rng_subkey = random.split(rng_key)
posterior_predictive_samples_flowmc_transformed = Predictive(
model=model, posterior_samples=posterior_samples_flowmc_transformed
)(rng_subkey, x)
[24]:
idata_flowmc_transformed = az.from_dict(
posterior={
k: np.expand_dims(a=np.asarray(v), axis=0)
for k, v in posterior_samples_flowmc_transformed.items()
},
posterior_predictive={
k: np.expand_dims(a=np.asarray(v), axis=0)
for k, v in posterior_predictive_samples_flowmc_transformed.items()
},
)
axes = az.plot_trace(
data=idata_flowmc_transformed,
var_names=["~mu"],
compact=True,
figsize=(10, 6),
lines=[
("a", {}, a),
("b", {}, b),
("sigma", {}, sigma),
],
backend_kwargs={"layout": "constrained"},
)
plt.gcf().suptitle(t="FlowMC Trace - Original Space", fontsize=18, fontweight="bold");
[25]:
fig, ax = plt.subplots(figsize=(7, 6))
ax.plot(x, y, "o", c="C0", label="data")
ax.axline((0, a), slope=b, color="C1", label="true mean")
az.plot_hdi(
x=x,
y=idata_flowmc_transformed["posterior_predictive"]["mu"],
color="C2",
fill_kwargs={"alpha": 0.7, "label": "mu posterior ($94\\%$ HDI)"},
ax=ax,
)
az.plot_hdi(
x=x,
y=idata_flowmc_transformed["posterior_predictive"]["likelihood"],
color="C2",
fill_kwargs={"alpha": 0.2, "label": "posterior predictive ($94\\%$ HDI)"},
ax=ax,
)
ax.legend(loc="upper left")
ax.set(xlabel="x", ylabel="y", title="FlowMC Posterior Predictive");
Model Comparison¶
Finally, we compare the results of the two samplers.
[26]:
az.plot_forest(
data=[idata_pathfinder_transformed, idata_flowmc_transformed],
model_names=["Pathfinder", "FlowMC"],
var_names=["a", "b", "sigma"],
combined=True,
figsize=(8, 5),
backend_kwargs={"layout": "constrained"},
);
Both samplers perform well and the results are very similar.
Remark: We would like to mention a relevant project that helps fitting NumPyro models with other inference algorithms:
bayeuxlets you write a probabilistic model in JAX and immediately have access to state-of-the-art inference methods. The API aims to be simple, self descriptive, and helpful. Simply provide a log density function (which doesn’t even have to be normalized), along with a single point (specified as a pytree) where that log density is finite. Then let bayeux do the rest!
Check it out!