# Example: Deep Markov Model inferred using SteinVI¶

In this example we infer a deep Markov model (DMM) using SteinVI for generating music (chorales by Johan Sebastian Bach).

The model DMM based on reference [1][2] and the Pyro DMM example: https://pyro.ai/examples/dmm.html.

Reference:

1. Pathwise Derivatives for Multivariate Distributions Martin Jankowiak and Theofanis Karaletsos (2019)
2. Structured Inference Networks for Nonlinear State Space Models [arXiv:1609.09869]
Rahul G. Krishnan, Uri Shalit and David Sontag (2016)
import argparse

import numpy as np

import jax
from jax import nn, numpy as jnp, random
from optax import adam, chain

import numpyro
from numpyro.contrib.einstein import SteinVI
from numpyro.contrib.einstein.kernels import RBFKernel
import numpyro.distributions as dist
from numpyro.examples.datasets import JSB_CHORALES, load_dataset
from numpyro.infer import Predictive, Trace_ELBO
from numpyro.optim import optax_to_numpyro

def _reverse_single(p, length):
new = jnp.zeros_like(p)
reverse = jnp.roll(p[::-1], length, axis=0)
return new.at[:].set(reverse)

_, fetch = load_dataset(JSB_CHORALES, split=split)
lengths, seqs = fetch(0)
return (seqs, _reverse_padded(seqs, lengths), lengths)

def emitter(x, params):
"""Parameterizes the bernoulli observation likelihood p(x_t | z_t)"""
l1 = nn.relu(jnp.matmul(x, params["l1"]))
l2 = nn.relu(jnp.matmul(l1, params["l2"]))
return jnp.matmul(l2, params["l3"])

def transition(x, params):
"""Parameterizes the gaussian latent transition probability p(z_t | z_{t-1})
See section 5 in [1].

**Reference:**
1. Structured Inference Networks for Nonlinear State Space Models [arXiv:1609.09869]
Rahul G. Krishnan, Uri Shalit and David Sontag (2016)
"""

def _gate(x, params):
l1 = nn.relu(jnp.matmul(x, params["l1"]))
return nn.sigmoid(jnp.matmul(l1, params["l2"]))

def _shared(x, params):
l1 = nn.relu(jnp.matmul(x, params["l1"]))
return jnp.matmul(l1, params["l2"])

def _mean(x, params):
return jnp.matmul(x, params["l1"])

def _std(x, params):
l1 = jnp.matmul(nn.relu(x), params["l1"])
return nn.softplus(l1)

gt = _gate(x, params["gate"])
ht = _shared(x, params["shared"])
loc = (1 - gt) * _mean(x, params["mean"]) + gt * ht
std = _std(ht, params["std"])
return loc, std

def combiner(x, params):
mean = jnp.matmul(x, params["mean"])
std = nn.softplus(jnp.matmul(x, params["std"]))
return mean, std

def gru(xs, lengths, init_hidden, params):
"""RNN with GRU. Based on https://github.com/google/jax/pull/2298"""

def apply_fun_single(state, inputs):
i, x = inputs
inp_update = jnp.matmul(x, params["update_in"])
hidden_update = jnp.dot(state, params["update_weight"])
update_gate = nn.sigmoid(inp_update + hidden_update)
reset_gate = nn.sigmoid(
jnp.matmul(x, params["reset_in"]) + jnp.dot(state, params["reset_weight"])
)
output_gate = update_gate * state + (1 - update_gate) * jnp.tanh(
jnp.matmul(x, params["out_in"])
+ jnp.dot(reset_gate * state, params["out_weight"])
)
hidden = jnp.where((i < lengths)[:, None], output_gate, jnp.zeros_like(state))
return hidden, hidden

init_hidden = jnp.broadcast_to(init_hidden, (xs.shape[1], init_hidden.shape[1]))
return jax.lax.scan(apply_fun_single, init_hidden, (jnp.arange(xs.shape[0]), xs))

def _normal_init(*shape):
return lambda rng_key: dist.Normal(scale=0.1).sample(rng_key, shape)

def model(
seqs,
seqs_rev,
lengths,
*,
subsample_size=77,
latent_dim=32,
emission_dim=100,
transition_dim=200,
data_dim=88,
gru_dim=150,
annealing_factor=1.0,
predict=False,
):
max_seq_length = seqs.shape[1]

emitter_params = {
"l1": numpyro.param("emitter_l1", _normal_init(latent_dim, emission_dim)),
"l2": numpyro.param("emitter_l2", _normal_init(emission_dim, emission_dim)),
"l3": numpyro.param("emitter_l3", _normal_init(emission_dim, data_dim)),
}

trans_params = {
"gate": {
"l1": numpyro.param("gate_l1", _normal_init(latent_dim, transition_dim)),
"l2": numpyro.param("gate_l2", _normal_init(transition_dim, latent_dim)),
},
"shared": {
"l1": numpyro.param("shared_l1", _normal_init(latent_dim, transition_dim)),
"l2": numpyro.param("shared_l2", _normal_init(transition_dim, latent_dim)),
},
"mean": {"l1": numpyro.param("mean_l1", _normal_init(latent_dim, latent_dim))},
"std": {"l1": numpyro.param("std_l1", _normal_init(latent_dim, latent_dim))},
}

z0 = numpyro.param(
"z0", lambda rng_key: dist.Normal(0, 1.0).sample(rng_key, (latent_dim,))
)
z0 = jnp.broadcast_to(z0, (subsample_size, 1, latent_dim))
with numpyro.plate(
"data", seqs.shape[0], subsample_size=subsample_size, dim=-1
) as idx:
if subsample_size == seqs.shape[0]:
seqs_batch = seqs
lengths_batch = lengths
else:
seqs_batch = seqs[idx]
lengths_batch = lengths[idx]

jnp.expand_dims(jnp.arange(max_seq_length), axis=0), subsample_size, axis=0
) < jnp.expand_dims(lengths_batch, axis=-1)
# NB: Mask is to avoid scoring 'z' using distribution at this point
z = numpyro.sample(
"z",
dist.Normal(0.0, jnp.ones((max_seq_length, latent_dim)))
.to_event(2),
)

z_shift = jnp.concatenate([z0, z[:, :-1, :]], axis=-2)
z_loc, z_scale = transition(z_shift, params=trans_params)

with numpyro.handlers.scale(scale=annealing_factor):
# Actually score 'z'
numpyro.sample(
"z_aux",
dist.Normal(z_loc, z_scale)
.to_event(2),
obs=z,
)

emission_probs = emitter(z, params=emitter_params)
if predict:
tunes = None
else:
tunes = seqs_batch
numpyro.sample(
"tunes",
dist.Bernoulli(logits=emission_probs)
.to_event(2),
obs=tunes,
)

def guide(
seqs,
seqs_rev,
lengths,
*,
subsample_size=77,
latent_dim=32,
emission_dim=100,
transition_dim=200,
data_dim=88,
gru_dim=150,
annealing_factor=1.0,
predict=False,
):
max_seq_length = seqs.shape[1]
seqs_rev = jnp.transpose(seqs_rev, axes=(1, 0, 2))

combiner_params = {
"mean": numpyro.param("combiner_mean", _normal_init(gru_dim, latent_dim)),
"std": numpyro.param("combiner_std", _normal_init(gru_dim, latent_dim)),
}

gru_params = {
"update_in": numpyro.param("update_in", _normal_init(data_dim, gru_dim)),
"update_weight": numpyro.param("update_weight", _normal_init(gru_dim, gru_dim)),
"reset_in": numpyro.param("reset_in", _normal_init(data_dim, gru_dim)),
"reset_weight": numpyro.param("reset_weight", _normal_init(gru_dim, gru_dim)),
"out_in": numpyro.param("out_in", _normal_init(data_dim, gru_dim)),
"out_weight": numpyro.param("out_weight", _normal_init(gru_dim, gru_dim)),
}

with numpyro.plate(
"data", seqs.shape[0], subsample_size=subsample_size, dim=-1
) as idx:
if subsample_size == seqs.shape[0]:
seqs_rev_batch = seqs_rev
lengths_batch = lengths
else:
seqs_rev_batch = seqs_rev[:, idx, :]
lengths_batch = lengths[idx]

jnp.expand_dims(jnp.arange(max_seq_length), axis=0), subsample_size, axis=0
) < jnp.expand_dims(lengths_batch, axis=-1)

h0 = numpyro.param(
"h0",
lambda rng_key: dist.Normal(0.0, 1).sample(rng_key, (1, gru_dim)),
)
_, hs = gru(seqs_rev_batch, lengths_batch, h0, gru_params)
hs = _reverse_padded(jnp.transpose(hs, axes=(1, 0, 2)), lengths_batch)
with numpyro.handlers.scale(scale=annealing_factor):
numpyro.sample(
"z",
dist.Normal(*combiner(hs, combiner_params))
.to_event(2),
)

def vis_tune(i, tunes, lengths, name="stein_dmm.pdf"):
tune = tunes[i, : lengths[i]]
try:
from music21.chord import Chord
from music21.pitch import Pitch
from music21.stream import Stream

stream = Stream()
for chord in tune:
stream.append(
Chord(list(Pitch(pitch) for pitch in (np.arange(88) + 21)[chord > 0]))
)
plot = stream.plot(doneAction=None)
plot.write(name)
except ModuleNotFoundError:
import matplotlib.pyplot as plt

plt.imshow(tune.T, cmap="Greys")
plt.ylabel("Pitch")
plt.xlabel("Offset")
plt.savefig(name)

def main(args):
inf_key, pred_key = random.split(random.PRNGKey(seed=args.rng_seed), 2)

vi = SteinVI(
model,
guide,
Trace_ELBO(),
RBFKernel(),
num_particles=args.num_particles,
)

seqs, rev_seqs, lengths = load_data()
results = vi.run(
inf_key,
args.max_iter,
seqs,
rev_seqs,
lengths,
gru_dim=args.gru_dim,
subsample_size=args.subsample_size,
)
pred = Predictive(
model,
params=results.params,
num_samples=1,
batch_ndims=1,
)
seqs, rev_seqs, lengths = load_data("valid")
pred_notes = pred(
pred_key, seqs, rev_seqs, lengths, subsample_size=seqs.shape[0], predict=True
)["tunes"]

vis_tune(0, pred_notes[0], lengths)

if __name__ == "__main__":
parser = argparse.ArgumentParser()