Interactive online version: Open In Colab

Hierarchical Forecasting

In this notebook we port the Pyro hierarchical multivariate forecasting example from the Forecasting III: hierarchical models tutorial to NumPyro.

We use the BART train ridership dataset, which contains the number of trips between all stations in the BART system by hour. The idea is to predict the number of trips for each station pair in the future. We do not want to do this separately for each station pair, but rather use the station-to-station structure and other features (seasonality) to generate forecasts.

This model port was originally presented in the blog post From Pyro to NumPyro: Forecasting Hierarchical Models - Part II.

Prepare Notebook

[1]:
import arviz as az
import matplotlib.pyplot as plt
import numpy as np

import jax
from jax import Array, random
import jax.numpy as jnp
import optax

import numpyro
import numpyro.distributions as dist
from numpyro.examples.datasets import load_bart_od
from numpyro.infer import SVI, Predictive, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal
from numpyro.infer.reparam import LocScaleReparam

plt.style.use("bmh")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"

numpyro.set_host_device_count(n=4)

rng_key = random.PRNGKey(seed=42)

assert numpyro.__version__.startswith("0.18.0")

%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
[2]:
!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro arviz matplotlib optax

Read Data

Let’s load the data.

[3]:
dataset = load_bart_od()
print(dataset.keys())
print(dataset["counts"].shape)
print(" ".join(dataset["stations"]))
dict_keys(['stations', 'start_date', 'counts'])
(78888, 50, 50)
12TH 16TH 19TH 24TH ANTC ASHB BALB BAYF BERY CAST CIVC COLM COLS CONC DALY DBRK DELN DUBL EMBR FRMT FTVL GLEN HAYW LAFY LAKE MCAR MLBR MLPT MONT NBRK NCON OAKL ORIN PCTR PHIL PITT PLZA POWL RICH ROCK SANL SBRN SFIA SHAY SSAN UCTY WARM WCRK WDUB WOAK

Recall that our objective is to model all the rides from all stations to all other stations. For a simpler forecasting task, you can take a look into the introductory example

[4]:
data = jnp.log1p(np.permute_dims(dataset["counts"], (1, 2, 0)))
T = data.shape[-2]
print(data.shape)
(50, 50, 78888)

Train - Test Split

Similarly as in the Pyro example, we do a train test split.

[5]:
T2 = data.shape[-1]  # end
T1 = T2 - 24 * 7 * 2  # train/test split
T0 = T1 - 24 * 90  # beginning: train on 90 days of data
[6]:
y = data[..., T0:T2]
y_train = data[..., T0:T1]
y_test = data[..., T1:T2]

print(f"y: {y.shape}")
print(f"y_train: {y_train.shape}")
print(f"y_test: {y_test.shape}")
y: (50, 50, 2496)
y_train: (50, 50, 2160)
y_test: (50, 50, 336)
[7]:
n_stations = y_train.shape[-2]

time = jnp.array(range(T0, T2))
time_train = jnp.array(range(T0, T1))
t_max_train = time_train.size

time_test = jnp.array(range(T1, T2))
t_max_test = time_test.size

covariates = jnp.zeros_like(y)
covariates_train = jnp.zeros_like(y_train)
covariates_test = jnp.zeros_like(y_test)

assert time_train.size + time_test.size == time.size
assert y_train.shape == (n_stations, n_stations, t_max_train)
assert y_test.shape == (n_stations, n_stations, t_max_test)
assert covariates.shape == y.shape
assert covariates_train.shape == y_train.shape
assert covariates_test.shape == y_test.shape

Repeating Seasonal Features

In order to model the weekly seasonality, Pyro provides a very handy helper function `periodic_repeat <https://docs.pyro.ai/en/stable/ops.html#pyro.ops.tensor_utils.periodic_repeat>`__ to repeat the seasonal features. Here we provide a JAX version of the function.

[8]:
def periodic_repeat_jax(tensor: Array, size: int, dim: int) -> Array:
    """
    Repeat a period-sized tensor up to given size using JAX.

    Parameters
    ----------
    tensor : Array
        A JAX array to be repeated.
    size : int
        Desired size of the result along dimension `dim`.
    dim : int
        The tensor dimension along which to repeat.

    Returns
    -------
    Array
        The repeated tensor.

    References
    ----------
    https://docs.pyro.ai/en/stable/ops.html#pyro.ops.tensor_utils.periodic_repeat
    """
    assert isinstance(size, int) and size >= 0
    assert isinstance(dim, int)
    if dim >= 0:
        dim -= tensor.ndim

    period = tensor.shape[dim]
    repeats = [1] * tensor.ndim
    repeats[dim] = (size + period - 1) // period
    result = jnp.tile(tensor, repeats)

    slices = [slice(None)] * tensor.ndim
    slices[dim] = slice(None, size)

    return result[tuple(slices)]

Model Specification

Here is are the main components of the forecasting model:

  • The local level dynamic is driven by the destination station. we use a hierarchical structure to model all the destination-level drift of the local level component. As these hierarchical models can have weird geometries to sample from, we also learn the re-parameterization parameters from the data.

  • The seasonal components and the noise scales come as a sum of the origin and destination stations.

To make everything more tangible, lets look into the model structure in code.

[9]:
def model(covariates: Array, y: Array | None = None) -> None:
    # Get the time and feature dimensions
    n_series, n_series, t_max = covariates.shape

    # Define the plates to be able to use them below
    origin_plate = numpyro.plate("origin", n_series, dim=-3)
    destin_plate = numpyro.plate("destin", n_series, dim=-2)
    hour_of_week_plate = numpyro.plate("hour_of_week", 24 * 7, dim=-1)

    # Global scale for the drift
    drift_scale = numpyro.sample("drift_scale", dist.LogNormal(loc=-20, scale=5))

    # Sample the centered parameter for the LocScaleReparam
    destin_centered = numpyro.sample("destin_centered", dist.Uniform(low=0, high=1))

    with origin_plate, hour_of_week_plate:
        origin_seasonal = numpyro.sample("origin_seasonal", dist.Normal(loc=0, scale=5))

    with destin_plate:
        with (
            numpyro.plate("time", t_max),
            numpyro.handlers.reparam(
                config={"drift": LocScaleReparam(centered=destin_centered)}
            ),
        ):
            # Sample the drift parameters
            # We have one drift parameter per time series (station) and time point
            drift = numpyro.sample("drift", dist.Normal(loc=0, scale=drift_scale))

        with hour_of_week_plate:
            # Sample the seasonal parameters
            # We have one seasonal parameter per hour of the week and per station
            destin_seasonal = numpyro.sample(
                "destin_seasonal", dist.Normal(loc=0, scale=5)
            )

    # We model a static pairwise station->station affinity, which e.g.
    # can compensate for the fact that people tend not to travel from
    # a station to itself.
    with origin_plate, destin_plate:
        pairwise = numpyro.sample("pairwise", dist.Normal(0, 1))

    # We model the origin and destination scales separately
    # and then add them together to get the final scale.
    with origin_plate:
        origin_scale = numpyro.sample("origin_scale", dist.LogNormal(-5, 5))
    with destin_plate:
        destin_scale = numpyro.sample("destin_scale", dist.LogNormal(-5, 5))
        scale = origin_scale + destin_scale

    # Repeat the seasonal parameters to match the length of the time series
    seasonal = origin_seasonal + destin_seasonal
    seasonal_repeat = periodic_repeat_jax(seasonal, t_max, dim=-1)

    # Define the local level transition function
    def transition_fn(carry, t):
        "Local level transition function"
        previous_level = carry
        current_level = previous_level + drift[..., t]
        return current_level, current_level

    # Compute the latent levels using scan
    _, pred_levels = jax.lax.scan(
        transition_fn, init=jnp.zeros((n_series,)), xs=jnp.arange(t_max)
    )

    # We need to transpose the prediction levels to match the shape of the data
    pred_levels = pred_levels.transpose(1, 0)

    # Compute the mean of the model
    mu = pred_levels + seasonal_repeat + pairwise

    # Sample the observations
    with numpyro.handlers.condition(data={"obs": y}):
        numpyro.sample("obs", dist.Normal(loc=mu, scale=scale))

Let’s visualize the model structure.

[10]:
numpyro.render_model(
    model=model,
    model_kwargs={"covariates": covariates_train, "y": y_train},
    render_distributions=True,
    render_params=True,
)
[10]:
../_images/tutorials_hierarchical_forecasting_17_0.svg

Prior Predictive Checks

As usual (highly recommended!), we should perform prior predictive checks.

[11]:
prior_predictive = Predictive(model=model, num_samples=500, return_sites=["obs"])

rng_key, rng_subkey = random.split(rng_key)

prior_samples = prior_predictive(rng_subkey, covariates_train)

idata_prior = az.from_dict(
    prior_predictive={k: v[None, ...] for k, v in prior_samples.items()},
    coords={
        "time_train": time_train,
        "origin": dataset["stations"],
        "destin": dataset["stations"],
    },
    dims={"obs": ["origin", "destin", "time_train"]},
)

Let’s plot the prior predictive distribution for the first \(8\) stations for the destination station ANTC.

[12]:
station = "ANTC"
idx = np.nonzero(dataset["stations"] == station)[0].item()

fig, axes = plt.subplots(
    nrows=8, ncols=1, figsize=(12, 15), sharex=True, sharey=True, layout="constrained"
)
for i, ax in enumerate(axes):
    for j, hdi_prob in enumerate([0.94, 0.5]):
        az.plot_hdi(
            time_train[time_train >= T1 - 3 * (24 * 7)],
            idata_prior["prior_predictive"]["obs"]
            .sel(destin=station)
            .isel(origin=i)[:, :, np.array(time_train) >= T1 - 3 * (24 * 7)]
            .clip(min=0),
            hdi_prob=hdi_prob,
            color="C0",
            fill_kwargs={
                "alpha": 0.3 + 0.2 * j,
                "label": f"{hdi_prob * 100:.0f}% HDI (train)",
            },
            smooth=False,
            ax=ax,
        )

    ax.plot(
        time_train[time_train >= T1 - 3 * (24 * 7)],
        data[i, idx, T1 - 3 * (24 * 7) : T1],
        "black",
        lw=1,
        label="Truth",
    )

    ax.legend(
        bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0.0, fontsize=12
    )

fig.suptitle("Prior predictive checks", fontsize=18, fontweight="bold");
../_images/tutorials_hierarchical_forecasting_21_0.png

Overall, the prior ranges look very reasonable.

Inference with SVI

We now fit the model to the data using stochastic variational inference.

[13]:
%%time

# See https://optax.readthedocs.io/en/latest/getting_started.html#custom-optimizers
scheduler = optax.linear_onecycle_schedule(
    transition_steps=8_000,
    peak_value=0.1,
    pct_start=0.1,
    pct_final=0.7,
    div_factor=2,
    final_div_factor=4,
)

optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.scale_by_adam(),
    optax.scale_by_schedule(scheduler),
    optax.scale(-1.0),
)


guide = AutoNormal(model)
optimizer = optimizer
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
num_steps = 8_000

rng_key, rng_subkey = random.split(key=rng_key)

svi_result = svi.run(rng_subkey, num_steps, covariates_train, y_train)

fig, ax = plt.subplots(figsize=(9, 6))
ax.plot(svi_result.losses)
ax.set_yscale("log")
ax.set_title("ELBO loss", fontsize=18, fontweight="bold");
100%|██████████| 8000/8000 [02:54<00:00, 45.80it/s, init loss: 8788430848.0000, avg. loss [7601-8000]: 3667545.0000]
CPU times: user 8min 51s, sys: 1min 17s, total: 10min 8s
Wall time: 3min 2s
../_images/tutorials_hierarchical_forecasting_24_2.png

The resulting ELBO loss good!

Posterior Predictive Check

Next, we generate posterior predictive samples for the forecast for each of the stations pairs.

[14]:
posterior = Predictive(
    model=model,
    guide=guide,
    params=svi_result.params,
    num_samples=200,
    return_sites=["obs"],
)
[15]:
rng_key, rng_subkey = random.split(rng_key)

idata_train = az.from_dict(
    posterior_predictive={
        k: v[None, ...] for k, v in posterior(rng_subkey, covariates_train).items()
    },
    coords={
        "time_train": time_train,
        "origin": dataset["stations"],
        "destin": dataset["stations"],
    },
    dims={"obs": ["origin", "destin", "time_train"]},
)

idata_test = az.from_dict(
    posterior_predictive={
        k: v[None, ...] for k, v in posterior(rng_subkey, covariates).items()
    },
    coords={
        "time": time,
        "origin": dataset["stations"],
        "destin": dataset["stations"],
    },
    dims={"obs": ["origin", "destin", "time"]},
)

Model Evaluation

To evaluate the model performance,we compute the CRPS for the training and test data. We also port the Pyro code to compute the (empirical) CRPS (see here) to JAX.

For comparison purposes, we clip the data to ensure the predictions are non-negative.

[16]:
def crps(truth: Array, pred: Array) -> float:
    """Compute the CRPS for a given truth and prediction.

    Parameters
    ----------
    truth : Array
        The truth values.
    pred : Array
        A set of sample predictions batched on rightmost dim.
        This should have shape ``(num_samples,) + truth.shape``

    Returns
    -------
    crps : float
        The average CRPS score.

    References
    ----------
    https://docs.pyro.ai/en/stable/_modules/pyro/ops/stats.html
    """
    if pred.shape[1:] != (1,) * (pred.ndim - truth.ndim - 1) + truth.shape:
        raise ValueError(
            f"""Expected pred to have one extra sample dim on left.
            Actual shapes: {pred.shape} versus {truth.shape}"""
        )

    absolute_error = jnp.mean(jnp.abs(pred - truth), axis=0)

    num_samples = pred.shape[0]
    if num_samples == 1:
        return jnp.average(absolute_error)

    pred = jnp.sort(pred, axis=0)
    diff = pred[1:] - pred[:-1]
    weight = jnp.arange(1, num_samples) * jnp.arange(num_samples - 1, 0, -1)
    weight = weight.reshape(weight.shape + (1,) * (diff.ndim - 1))

    per_obs_crps = absolute_error - jnp.sum(diff * weight, axis=0) / num_samples**2
    return jnp.average(per_obs_crps)


crps_train = crps(
    y_train,
    jnp.array(idata_train["posterior_predictive"]["obs"].sel(chain=0).clip(min=0)),
)

crps_test = crps(
    y_test,
    jnp.array(
        idata_test["posterior_predictive"]["obs"]
        .sel(chain=0)
        .sel(time=slice(T1, T2))
        .clip(min=0)
    ),
)

Finally, we reproduce the model fit and plot from the Pyro example.

[17]:
station = "ANTC"
idx = np.nonzero(dataset["stations"] == station)[0].item()

fig, axes = plt.subplots(
    nrows=8, ncols=1, figsize=(12, 15), sharex=True, sharey=True, layout="constrained"
)
for i, ax in enumerate(axes):
    for j, hdi_prob in enumerate([0.94, 0.5]):
        az.plot_hdi(
            time_train[time_train >= T1 - 24 * 7],
            idata_train["posterior_predictive"]["obs"]
            .sel(destin=station)
            .isel(origin=i)[:, :, np.array(time_train) >= T1 - 24 * 7]
            .clip(min=0),
            hdi_prob=hdi_prob,
            color="C0",
            fill_kwargs={
                "alpha": 0.3 + 0.2 * j,
                "label": f"{hdi_prob * 100:.0f}% HDI (train)",
            },
            smooth=False,
            ax=ax,
        )

        az.plot_hdi(
            time[time >= T1],
            idata_test["posterior_predictive"]["obs"]
            .sel(destin=station)
            .isel(origin=i)[:, :, np.array(time) >= T1]
            .clip(min=0),
            hdi_prob=hdi_prob,
            color="C1",
            fill_kwargs={
                "alpha": 0.2 + 0.2 * j,
                "label": f"{hdi_prob * 100:.0f}% HDI (test)",
            },
            smooth=False,
            ax=ax,
        )

    christmas_index = 78736
    ax.axvline(christmas_index, color="C2", lw=20, alpha=0.2, label="Christmas")

    ax.plot(
        time[time >= T1 - 24 * 7],
        data[i, idx, T1 - 24 * 7 : T2],
        "black",
        lw=1,
        label="Truth",
    )

    ax.axvline(T1, color="C3", linestyle="--", label="Train/test split")

    ax.legend(
        bbox_to_anchor=(1.05, 1),
        loc="upper left",
        borderaxespad=0.0,
        fontsize=9,
        labelspacing=0.6,
    )

fig.suptitle(
    f"""Posterior predictive checks

    Train CRPS: {crps_train:.4f} | Test CRPS: {crps_test:.4f}
    """,
    fontsize=18,
    fontweight="bold",
);
../_images/tutorials_hierarchical_forecasting_33_0.png

Overall, the predictions seem quite good!