Example: VAR(2) process

In this example, we demonstrate how to implement and perform Bayesian inference for a Vector Autoregressive process of order 2 (VAR(2)). VAR models are widely used in time series analysis, especially for capturing the dynamics between multiple variables.

A VAR(2) process for a multivariate time series \(y_t\) with \(K\) variables is defined as:

\[y_t = c + \Phi_1 y_{t-1} + \Phi_2 y_{t-2} + \epsilon_t\]

Here, \(c\) is a constant vector, \(\Phi_1\) and \(\Phi_2\) are coefficient matrices for lag 1 and lag 2, respectively, and \(\epsilon_t\) is a Gaussian noise term with zero mean and a covariance matrix \(\Sigma\).

This example uses NumPyro’s scan utility to efficiently model the temporal dependencies without explicit Python loops.

For more general time series forecasting techniques and examples, refer to the Time Series Forecasting tutorial: https://num.pyro.ai/en/stable/tutorials/time_series_forecasting.html#Forecasting

Reference

For more information on Vector Autoregressive models, see: https://otexts.com/fpp2/VAR.html

../_images/var2.png
import argparse
import os
import time

import matplotlib.pyplot as plt
import numpy as np

from jax import random
import jax.numpy as jnp

import numpyro
from numpyro.contrib.control_flow import scan
import numpyro.distributions as dist


def var2_scan(y):
    T, K = y.shape  # Number of time steps and number of variables

    # Priors for constants and coefficients
    c = numpyro.sample("c", dist.Normal(0, 1).expand([K]))  # Constants vector of size K
    Phi1 = numpyro.sample(
        "Phi1", dist.Normal(0, 1).expand([K, K]).to_event(2)
    )  # Coefficients for lag 1
    Phi2 = numpyro.sample(
        "Phi2", dist.Normal(0, 1).expand([K, K]).to_event(2)
    )  # Coefficients for lag 2

    # Priors for error terms
    sigma = numpyro.sample("sigma", dist.HalfNormal(1.0).expand([K]).to_event(1))
    L_omega = numpyro.sample(
        "L_omega", dist.LKJCholesky(dimension=K, concentration=1.0)
    )
    L_Sigma = (
        sigma[..., None] * L_omega
    )  # Alternative: jnp.einsum("...i,...ij->...ij", sigma, L_omega)

    def transition(carry, t):
        y_prev1, y_prev2, y_obs = carry  # Previous two observations and observed data
        m_t = c + jnp.dot(Phi1, y_prev1) + jnp.dot(Phi2, y_prev2)  # Mean prediction
        # Conditioned on observed y
        y_t = numpyro.sample(
            f"y_{t}",
            dist.MultivariateNormal(loc=m_t, scale_tril=L_Sigma),
            obs=y_obs[t],
        )
        new_carry = (y_t, y_prev1, y_obs)
        return new_carry, m_t

    # Initial carry: observations at time steps 1 and 0
    init_carry = (y[1], y[0], y[2:])

    # Time indices starting from time step 2
    time_indices = jnp.arange(T - 2)

    # Run the scan
    _, mu = scan(transition, init_carry, time_indices)

    # Store the mean trajectory as a deterministic variable
    numpyro.deterministic("mu", mu)


def generate_var2_data(T, K, c, Phi1, Phi2, sigma):
    """
    Generate time series data from a VAR(2) process.
    Args:
        T (int): Number of time steps.
        K (int): Number of variables in the time series.
        c (array): Constants (shape: (K,)).
        Phi1 (array): Coefficients for lag 1 (shape: (K, K)).
        Phi2 (array): Coefficients for lag 2 (shape: (K, K)).
        sigma (array): Covariance matrix for the noise (shape: (K, K)).
    Returns:
        np.ndarray: Generated time series data (shape: (T, K)).
    """
    # Initialize time series with random values
    y = np.zeros((T, K))
    y[:2] = np.random.multivariate_normal(mean=np.zeros(K), cov=sigma, size=2)

    # Generate the time series
    for t in range(2, T):
        y[t] = (
            c
            + Phi1 @ y[t - 1]
            + Phi2 @ y[t - 2]
            + np.random.multivariate_normal(mean=np.zeros(K), cov=sigma)
        )

    return y


def run_inference(model, args, rng_key, y):
    """
    Run MCMC inference for the given model.
    Args:
        model: The probabilistic model to infer.
        args: Command-line arguments.
        rng_key: PRNG key for randomness.
        y: Observed time series data.
    """
    start = time.time()
    sampler = numpyro.infer.NUTS(model)
    mcmc = numpyro.infer.MCMC(
        sampler,
        num_warmup=args.num_warmup,
        num_samples=args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
    )
    mcmc.run(rng_key, y=y)
    mcmc.print_summary()
    print("\nMCMC elapsed time:", time.time() - start)
    return mcmc.get_samples()


def main(args):
    # Generate artificial dataset
    T = args.num_data  # Number of time steps
    K = 2  # Number of variables
    c_true = jnp.array([0.5, -0.3])  # Constants
    Phi1_true = jnp.array([[0.7, 0.1], [0.2, 0.5]])  # Coefficients for lag 1
    Phi2_true = jnp.array([[0.2, -0.1], [-0.1, 0.2]])  # Coefficients for lag 2
    sigma_true = jnp.array([[0.1, 0.02], [0.02, 0.1]])  # Covariance matrix

    rng_key = random.PRNGKey(0)
    y = generate_var2_data(T, K, c_true, Phi1_true, Phi2_true, sigma_true)

    # Perform inference
    samples = run_inference(var2_scan, args, rng_key, y)

    # Prediction
    mean_prediction = samples["mu"].mean(axis=0)
    lower_bound = jnp.percentile(samples["mu"], 2.5, axis=0)  # 2.5th percentile
    upper_bound = jnp.percentile(samples["mu"], 97.5, axis=0)  # 97.5th percentile

    # Plot results
    fig, axes = plt.subplots(K, 1, figsize=(10, 6), sharex=True)
    time_steps = jnp.arange(T)

    for i in range(K):
        # True values
        axes[i].plot(time_steps, y[:, i], label=f"True Variable {i + 1}", color="blue")
        # Posterior mean prediction
        axes[i].plot(
            time_steps[2:],
            mean_prediction[:, i],
            label=f"Predicted Mean Variable {i + 1}",
            color="orange",
        )
        # 95% confidence interval
        axes[i].fill_between(
            time_steps[2:],
            lower_bound[:, i],
            upper_bound[:, i],
            color="orange",
            alpha=0.2,
            label="95% CI",
        )
        axes[i].set_title(f"Variable {i + 1}")
        axes[i].legend()
        axes[i].grid(True)

    plt.xlabel("Time Steps")
    plt.tight_layout()
    plt.savefig("var2.png")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="VAR(2) example")
    parser.add_argument("--num-data", nargs="?", default=100, type=int)
    parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
    parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
    parser.add_argument("--num-chains", nargs="?", default=1, type=int)
    parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".')
    args = parser.parse_args()

    numpyro.set_platform(args.device)
    numpyro.set_host_device_count(args.num_chains)

    main(args)

Gallery generated by Sphinx-Gallery