Interactive online version: Open In Colab

NNX and NumPyro Integration

In this example notebook we illustrate how to incorporate neural network components from the NNX library into NumPyro models. In a similar way, you can use the Flax Linen API.

This notebook is based on the blog post Flax and NumPyro Toy Example.

Prepare Notebook

[1]:
!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro arviz flax matplotlib
[2]:
import arviz as az
import matplotlib.pyplot as plt
import numpy as np

from flax import nnx
from jax import random
import jax.numpy as jnp

import numpyro
from numpyro.contrib.module import nnx_module, random_nnx_module
import numpyro.distributions as dist
from numpyro.handlers import condition
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal
from numpyro.infer.util import Predictive

plt.style.use("bmh")
plt.rcParams["figure.figsize"] = [10, 6]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"

numpyro.set_host_device_count(n=4)

rng_key = random.PRNGKey(seed=42)

%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"

Generate Data

We generate a synthetic dataset with non-linear mean and standard deviation.

[3]:
n = 32 * 10
rng_key, rng_subkey = random.split(rng_key)
x = jnp.linspace(1, jnp.pi, n)
mu_true = jnp.sqrt(x + 0.5) * jnp.sin(9 * x)
sigma_true = 0.15 * x**2
rng_key, rng_subkey = random.split(rng_key)
y = mu_true + sigma_true * random.normal(rng_key, shape=(n,))

Let’s visualize the resulting data set:

[4]:
fig, ax = plt.subplots()
ax.plot(x, mu_true, color="C0", label=r"$\mu$", linewidth=3)
ax.fill_between(
    x,
    (mu_true - 2 * sigma_true),
    (mu_true + 2 * sigma_true),
    color="C0",
    alpha=0.2,
    label=r"$\mu \pm 2 \sigma$",
)
ax.scatter(x, y, color="black", label="data")
ax.legend(loc="upper left")
ax.set_title(label="Simulated Data", fontsize=18, fontweight="bold")
ax.set(xlabel="x", ylabel="y");
../_images/tutorials_nnx_example_7_0.png

We clearly see that the data is non-linear and has heteroskedastic noise. We would like to model the mean and standard deviation of the data as a function of the input \(x\) using neural networks to model the non-linearities.

Model Specification

First, we prepare the data for training.

[5]:
x_train = x[..., None]
y_train = y

Next, we use NNX to define two MLP components, one for the mean and one for the standard deviation. You can look into the `NNX basics <https://flax.readthedocs.io/en/v0.8.3/experimental/nnx/nnx_basics.html>`__ for more details.

[6]:
class LocMLP(nnx.Module):
    """3-layer Multi-layer perceptron for the mean."""

    def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
        self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
        self.linear2 = nnx.Linear(dmid, dmid, rngs=rngs)
        self.linear3 = nnx.Linear(dmid, dout, rngs=rngs)

    def __call__(self, x, rngs=None):
        x = self.linear1(x)
        x = nnx.sigmoid(x)
        x = self.linear2(x)
        x = nnx.sigmoid(x)
        x = self.linear3(x)
        return x


class ScaleMLP(nnx.Module):
    """Single-layer MLP for the standard deviation."""

    def __init__(self, *, rngs: nnx.Rngs) -> None:
        self.linear = nnx.Linear(1, 1, rngs=rngs)

    def __call__(self, x, rngs=None):
        x = self.linear(x)
        return nnx.softplus(x)

We now define the the neural networks components in an “eager” way.

[7]:
mu_nn_module = LocMLP(din=1, dmid=8, dout=1, rngs=nnx.Rngs(0))

sigma_nn_module = ScaleMLP(rngs=nnx.Rngs(1))

Finally, we can add the neural networks components to the NumPyro model where we use a Normal distribution for the likelihood and allow the parameters vary over the input \(x\).

[8]:
def model(x):
    # Neural network component for the mean. Here we consider the parameters of the
    # neural network as learnable.
    mu_nn = nnx_module("mu_nn", mu_nn_module)
    # Here we consider the parameters of the neural network as random variables.
    # Hence we can set priors for them.
    sigma_nn = random_nnx_module(
        "sigma_nn",
        sigma_nn_module,
        prior={
            # From the data we know the variance is increasing over x.
            # Hence we use a HalfNormal distribution to model the kernel term.
            "linear.kernel": dist.HalfNormal(scale=1),
            # We use a Normal distribution for the bias.
            "linear.bias": dist.Normal(loc=0, scale=1),
        },
    )

    mu = numpyro.deterministic("mu", mu_nn(x).squeeze())

    sigma = numpyro.deterministic("sigma", sigma_nn(x).squeeze())

    with numpyro.plate("data", len(x)):
        numpyro.sample("likelihood", dist.Normal(loc=mu, scale=sigma))


numpyro.render_model(
    model=model,
    model_args=(x_train,),
    render_distributions=True,
    render_params=True,
)
[8]:
../_images/tutorials_nnx_example_16_0.svg

Prior Predictive Checks

Before we perform inference, we can check the prior predictive distribution. To ensure the priors are meaningful.

[9]:
prior_predictive = Predictive(model=model, num_samples=100)
rng_key, rng_subkey = random.split(key=rng_key)
prior_predictive_samples = prior_predictive(rng_subkey, x_train)

obs_train = jnp.arange(x_train.size)

idata = az.from_dict(
    prior_predictive={
        k: np.expand_dims(a=np.asarray(v), axis=0)
        for k, v in prior_predictive_samples.items()
    },
    coords={"obs": obs_train},
    dims={"mu": ["obs"], "sigma": ["obs"], "likelihood": ["obs"]},
)

First, we visualize the prior predictive distribution. We clearly wee that the variance is increasing over \(x\) as we expected (as defined by the prior). We also see that the range of the samples are withing a reasonable range from the data.

[10]:
fig, ax = plt.subplots()

for i in range(10):
    ax.plot(x, idata["prior_predictive"]["likelihood"].sel(chain=0, draw=i), alpha=0.25)
ax.scatter(x, y, color="black", label="data")
ax.set(xlabel="x", ylabel="y", title="Prior Predictive Distribution Samples");
../_images/tutorials_nnx_example_20_0.png

We can also verify the prior for the sigma kernel is working as expected:

[11]:
fig, ax = plt.subplots()

az.plot_dist(
    idata["prior_predictive"]["sigma_nn/linear.kernel"].squeeze(axis=(-1, -2)),
    color="C2",
    fill_kwargs={"alpha": 0.3},
)

ax.set(
    xlabel="sigma_nn/linear.kernel",
    title="Prior Predictive Distribution - Kernel of the Standard Deviation MLP",
);
../_images/tutorials_nnx_example_22_0.png

We do see the prior is positive!

Model Inference

We now perform inference on the model using SVI.

[12]:
# We condition the model on the training data
conditioned_model = condition(model, data={"likelihood": y_train})

guide = AutoNormal(model=conditioned_model)
optimizer = numpyro.optim.Adam(step_size=0.005)
svi = SVI(conditioned_model, guide, optimizer, loss=Trace_ELBO())
n_samples = 8_000
rng_key, rng_subkey = random.split(key=rng_key)
svi_result = svi.run(rng_subkey, n_samples, x_train)

fig, ax = plt.subplots()
ax.plot(svi_result.losses)
ax.set_title("ELBO loss", fontsize=18, fontweight="bold");
100%|██████████| 8000/8000 [00:01<00:00, 5976.75it/s, init loss: 1032.5969, avg. loss [7601-8000]: 287.5595]
../_images/tutorials_nnx_example_25_1.png

We now generate the posterior predictive distribution.

[13]:
params = svi_result.params
posterior_predictive = Predictive(
    model=model,
    guide=guide,
    params=params,
    num_samples=2_000,
    return_sites=["mu", "sigma", "likelihood"],
)
rng_key, rng_subkey = random.split(key=rng_key)
posterior_predictive_samples = posterior_predictive(rng_subkey, x_train)

We now collect the posterior predictive samples for visualization purposes.

[14]:
idata.extend(
    az.from_dict(
        posterior_predictive={
            k: np.expand_dims(a=np.asarray(v), axis=0)
            for k, v in posterior_predictive_samples.items()
        },
        coords={"obs": obs_train},
        dims={"mu": ["obs"], "sigma": ["obs"], "likelihood": ["obs"]},
    )
)

Finally, we visualize the posterior predictive distribution, the mean and the standard deviation components.

[15]:
fig, ax = plt.subplots(
    nrows=3,
    ncols=1,
    sharex=True,
    sharey=True,
    figsize=(10, 9),
    layout="constrained",
)

az.plot_hdi(
    x,
    idata["posterior_predictive"]["likelihood"],
    color="C1",
    fill_kwargs={"alpha": 0.3, "label": "94% HDI"},
    ax=ax[0],
)
ax[0].plot(
    x_train,
    idata["posterior_predictive"]["likelihood"].mean(dim=("chain", "draw")),
    color="C1",
    linewidth=3,
    label="SVI Posterior Mean",
)
ax[0].plot(x, mu_true, color="C0", label=r"$\mu$", linewidth=3)
ax[0].scatter(x, y, color="black", label="data")
ax[0].legend(loc="upper left")
ax[0].set(ylabel="y")
ax[0].set_title(label="Posterior Predictive Distribution")

ax[1].plot(
    x,
    idata["posterior_predictive"]["mu"].mean(dim=("chain", "draw")),
    linewidth=3,
    color="C2",
)
ax[1].set_title(label="Mean")

ax[2].plot(
    x,
    idata["posterior_predictive"]["sigma"].mean(dim=("chain", "draw")),
    linewidth=3,
    color="C3",
)
ax[2].set(xlabel="x")
ax[2].set_title(label="Standard Deviation");
../_images/tutorials_nnx_example_31_0.png

The results look great! The fit and the components are working as expected.