Example: Hidden Markov Model¶

In this example, we will follow [1] to construct a semi-supervised Hidden Markov Model for a generative model with observations are words and latent variables are categories. Instead of automatically marginalizing all discrete latent variables (as in [2]), we will use the “forward algorithm” (which exploits the conditional independent of a Markov model - see [3]) to iteratively do this marginalization.

The semi-supervised problem is chosen instead of an unsupervised one because it is hard to make the inference works for an unsupervised model (see the discussion [4]). On the other hand, this example also illustrates the usage of JAX’s lax.scan primitive. The primitive will greatly improve compiling for the model.

References:

import argparse
import os
import time

import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import gaussian_kde

from jax import lax, random
import jax.numpy as jnp
from jax.scipy.special import logsumexp

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

def simulate_data(
rng_key, num_categories, num_words, num_supervised_data, num_unsupervised_data
):
rng_key, rng_key_transition, rng_key_emission = random.split(rng_key, 3)

transition_prior = jnp.ones(num_categories)
emission_prior = jnp.repeat(0.1, num_words)

transition_prob = dist.Dirichlet(transition_prior).sample(
key=rng_key_transition, sample_shape=(num_categories,)
)
emission_prob = dist.Dirichlet(emission_prior).sample(
key=rng_key_emission, sample_shape=(num_categories,)
)

start_prob = jnp.repeat(1.0 / num_categories, num_categories)
categories, words = [], []
for t in range(num_supervised_data + num_unsupervised_data):
rng_key, rng_key_transition, rng_key_emission = random.split(rng_key, 3)
if t == 0 or t == num_supervised_data:
category = dist.Categorical(start_prob).sample(key=rng_key_transition)
else:
category = dist.Categorical(transition_prob[category]).sample(
key=rng_key_transition
)
word = dist.Categorical(emission_prob[category]).sample(key=rng_key_emission)
categories.append(category)
words.append(word)

# split into supervised data and unsupervised data
categories, words = jnp.stack(categories), jnp.stack(words)
supervised_categories = categories[:num_supervised_data]
supervised_words = words[:num_supervised_data]
unsupervised_words = words[num_supervised_data:]
return (
transition_prior,
emission_prior,
transition_prob,
emission_prob,
supervised_categories,
supervised_words,
unsupervised_words,
)

def forward_one_step(prev_log_prob, curr_word, transition_log_prob, emission_log_prob):
log_prob_tmp = jnp.expand_dims(prev_log_prob, axis=1) + transition_log_prob
log_prob = log_prob_tmp + emission_log_prob[:, curr_word]
return logsumexp(log_prob, axis=0)

def forward_log_prob(
init_log_prob, words, transition_log_prob, emission_log_prob, unroll_loop=False
):
# Note: The following naive implementation will make it very slow to compile
# and do inference. So we use lax.scan instead.
#
# >>> log_prob = init_log_prob
# >>> for word in words:
# ...     log_prob = forward_one_step(log_prob, word, transition_log_prob, emission_log_prob)
def scan_fn(log_prob, word):
return (
forward_one_step(log_prob, word, transition_log_prob, emission_log_prob),
None,  # we don't need to collect during scan
)

if unroll_loop:
log_prob = init_log_prob
for word in words:
log_prob = forward_one_step(
log_prob, word, transition_log_prob, emission_log_prob
)
else:
log_prob, _ = lax.scan(scan_fn, init_log_prob, words)
return log_prob

def semi_supervised_hmm(
transition_prior,
emission_prior,
supervised_categories,
supervised_words,
unsupervised_words,
unroll_loop=False,
):
num_categories, num_words = transition_prior.shape[0], emission_prior.shape[0]
transition_prob = numpyro.sample(
"transition_prob",
dist.Dirichlet(
),
)
emission_prob = numpyro.sample(
"emission_prob",
)

# models supervised data;
# here we don't make any assumption about the first supervised category, in other words,
# we place a flat/uniform prior on it.
numpyro.sample(
"supervised_categories",
dist.Categorical(transition_prob[supervised_categories[:-1]]),
obs=supervised_categories[1:],
)
numpyro.sample(
"supervised_words",
dist.Categorical(emission_prob[supervised_categories]),
obs=supervised_words,
)

# computes log prob of unsupervised data
transition_log_prob = jnp.log(transition_prob)
emission_log_prob = jnp.log(emission_prob)
init_log_prob = emission_log_prob[:, unsupervised_words[0]]
log_prob = forward_log_prob(
init_log_prob,
unsupervised_words[1:],
transition_log_prob,
emission_log_prob,
unroll_loop,
)
log_prob = logsumexp(log_prob, axis=0, keepdims=True)
# inject log_prob to potential function
numpyro.factor("forward_log_prob", log_prob)

def print_results(posterior, transition_prob, emission_prob):
header = semi_supervised_hmm.__name__ + " - TRAIN"
columns = ["", "ActualProb", "Pred(p25)", "Pred(p50)", "Pred(p75)"]
header_format = "{:>20} {:>10} {:>10} {:>10} {:>10}"
row_format = "{:>20} {:>10.2f} {:>10.2f} {:>10.2f} {:>10.2f}"
print("\n", "=" * 20 + header + "=" * 20, "\n")

quantiles = np.quantile(posterior["transition_prob"], [0.25, 0.5, 0.75], axis=0)
for i in range(transition_prob.shape[0]):
for j in range(transition_prob.shape[1]):
idx = "transition[{},{}]".format(i, j)
print(
row_format.format(idx, transition_prob[i, j], *quantiles[:, i, j]), "\n"
)

quantiles = np.quantile(posterior["emission_prob"], [0.25, 0.5, 0.75], axis=0)
for i in range(emission_prob.shape[0]):
for j in range(emission_prob.shape[1]):
idx = "emission[{},{}]".format(i, j)
print(
row_format.format(idx, emission_prob[i, j], *quantiles[:, i, j]), "\n"
)

def main(args):
print("Simulating data...")
(
transition_prior,
emission_prior,
transition_prob,
emission_prob,
supervised_categories,
supervised_words,
unsupervised_words,
) = simulate_data(
random.PRNGKey(1),
num_categories=args.num_categories,
num_words=args.num_words,
num_supervised_data=args.num_supervised,
num_unsupervised_data=args.num_unsupervised,
)
print("Starting inference...")
rng_key = random.PRNGKey(2)
start = time.time()
kernel = NUTS(semi_supervised_hmm)
mcmc = MCMC(
kernel,
num_warmup=args.num_warmup,
num_samples=args.num_samples,
num_chains=args.num_chains,
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)
mcmc.run(
rng_key,
transition_prior,
emission_prior,
supervised_categories,
supervised_words,
unsupervised_words,
args.unroll_loop,
)
samples = mcmc.get_samples()
print_results(samples, transition_prob, emission_prob)
print("\nMCMC elapsed time:", time.time() - start)

# make plots
fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)

x = np.linspace(0, 1, 101)
for i in range(transition_prob.shape[0]):
for j in range(transition_prob.shape[1]):
ax.plot(
x,
gaussian_kde(samples["transition_prob"][:, i, j])(x),
label="trans_prob[{}, {}], true value = {:.2f}".format(
i, j, transition_prob[i, j]
),
)
ax.set(
xlabel="Probability",
ylabel="Frequency",
title="Transition probability posterior",
)
ax.legend()

plt.savefig("hmm_plot.pdf")

if __name__ == "__main__":
assert numpyro.__version__.startswith("0.7.0")
parser = argparse.ArgumentParser(description="Semi-supervised Hidden Markov Model")