# Example: Holt-Winters Exponential Smoothing¶

In this example we show how to implement Exponential Smoothing. This is intended to be a simple counter-part to the Time Series Forecasting notebook.

The idea is that we have some times series

$y_1, ..., y_T, y_{T+1}, ..., y_{T+H}$

where we train on $$y_1, ..., y_T$$ and predict for $$y_{T+1}, ..., y_{T+H}$$, where $$T$$ is the maximum training timestamp and $$H$$ is the maximum number of future timesteps for which we want to forecast.

We will be using the update equations from the excellent book Forecasting Principles and Practice:

\begin{align}\begin{aligned}\hat{y}_{t+h|t} = l_t + hb_t + s_{t+h-m(k+1)}\\l_t = \alpha(y_t - s_{t-m}) + (1-\alpha)(l_{t-1} + b_{t-1})\\b_t = \beta^*(l_t-l_{t-1}) + (1-\beta^*)b_{t-1}\\s_t = \gamma(y_t-l_{t-1}-b_{t-1})+(1-\gamma)s_{t-m}\end{aligned}\end{align}

where

• $$\hat{y}_t$$ is the forecast at time $$t$$;

• $$h$$ is the number of time steps into the future which we want to predict for;

• $$l_t$$ is the level, $$b_t$$ is the trend, and $$s_t$$ is the seasonality,

• $$\alpha$$ is the level smoothing, $$\beta^*$$ is the trend smoothing, and $$\gamma$$ is the seasonality smoothing.

• $$k$$ is the integer part of $$(h-1)/m$$ (this looks more complicated than it is, it just takes the latest seasonality estimate for this time point). import argparse
import os
import time

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import jax
from jax import random
import jax.numpy as jnp

import numpyro
from numpyro.contrib.control_flow import scan
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive

matplotlib.use("Agg")

N_POINTS_PER_UNIT = 10  # number of points to plot for each unit interval

def holt_winters(y, n_seasons, future=0):
T = y.shape
level_smoothing = numpyro.sample("level_smoothing", dist.Beta(1, 1))
trend_smoothing = numpyro.sample("trend_smoothing", dist.Beta(1, 1))
seasonality_smoothing = numpyro.sample("seasonality_smoothing", dist.Beta(1, 1))
adj_seasonality_smoothing = seasonality_smoothing * (1 - level_smoothing)
noise = numpyro.sample("noise", dist.HalfNormal(1))
level_init = numpyro.sample("level_init", dist.Normal(0, 1))
trend_init = numpyro.sample("trend_init", dist.Normal(0, 1))
with numpyro.plate("n_seasons", n_seasons):
seasonality_init = numpyro.sample("seasonality_init", dist.Normal(0, 1))

def transition_fn(carry, t):
previous_level, previous_trend, previous_seasonality = carry
level = jnp.where(
t < T,
level_smoothing * (y[t] - previous_seasonality)
+ (1 - level_smoothing) * (previous_level + previous_trend),
previous_level,
)
trend = jnp.where(
t < T,
trend_smoothing * (level - previous_level)
+ (1 - trend_smoothing) * previous_trend,
previous_trend,
)
new_season = jnp.where(
t < T,
adj_seasonality_smoothing * (y[t] - (previous_level + previous_trend))
+ (1 - adj_seasonality_smoothing) * previous_seasonality,
previous_seasonality,
)
step = jnp.where(t < T, 1, t - T + 1)
mu = previous_level + step * previous_trend + previous_seasonality
pred = numpyro.sample("pred", dist.Normal(mu, noise))

seasonality = jnp.concatenate(
[previous_seasonality[1:], new_season[None]], axis=0
)
return (level, trend, seasonality), pred

with numpyro.handlers.condition(data={"pred": y}):
_, preds = scan(
transition_fn,
(level_init, trend_init, seasonality_init),
jnp.arange(T + future),
)

if future > 0:
numpyro.deterministic("y_forecast", preds[-future:])

def run_inference(model, args, rng_key, y, n_seasons):
start = time.time()
sampler = NUTS(model)
mcmc = MCMC(
sampler,
num_warmup=args.num_warmup,
num_samples=args.num_samples,
num_chains=args.num_chains,
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)
mcmc.run(rng_key, y=y, n_seasons=n_seasons)
mcmc.print_summary()
print("\nMCMC elapsed time:", time.time() - start)
return mcmc.get_samples()

def predict(model, args, samples, rng_key, y, n_seasons):
predictive = Predictive(model, samples, return_sites=["y_forecast"])
return predictive(
rng_key, y=y, n_seasons=n_seasons, future=args.future * N_POINTS_PER_UNIT
)["y_forecast"]

def main(args):
# generate artifical dataset
rng_key, _ = random.split(random.PRNGKey(0))
T = args.T
t = jnp.linspace(0, T + args.future, (T + args.future) * N_POINTS_PER_UNIT)
y = jnp.sin(2 * np.pi * t) + 0.3 * t + jax.random.normal(rng_key, t.shape) * 0.1
n_seasons = N_POINTS_PER_UNIT
y_train = y[: -args.future * N_POINTS_PER_UNIT]
t_test = t[-args.future * N_POINTS_PER_UNIT :]

# do inference
rng_key, _ = random.split(random.PRNGKey(1))
samples = run_inference(holt_winters, args, rng_key, y_train, n_seasons)

# do prediction
rng_key, _ = random.split(random.PRNGKey(2))
preds = predict(holt_winters, args, samples, rng_key, y_train, n_seasons)
mean_preds = preds.mean(axis=0)
hpdi_preds = hpdi(preds)

# make plots
fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)

# plot true data and predictions
ax.plot(t, y, color="blue", label="True values")
ax.plot(t_test, mean_preds, color="orange", label="Mean predictions")
ax.fill_between(t_test, *hpdi_preds, color="orange", alpha=0.2, label="90% CI")
ax.set(xlabel="time", ylabel="y", title="Holt-Winters Exponential Smoothing")
ax.legend()

plt.savefig("holt_winters_plot.pdf")

if __name__ == "__main__":
assert numpyro.__version__.startswith("0.10.0")
parser = argparse.ArgumentParser(description="Holt-Winters")