Interactive online version: Open In Colab

Solving differential equations (ODEs) for multiple initial conditions.

Ordinary differential equations (ODEs) find applications in various fields, including epidemiology, physics, chemistry, banking, and more. Oftentimes, an ODE system requires integration for multiple initial conditions while keeping parameters constant. Additionally, typical datasets often contain missing values, exhibit different durations, and have irregularly spaced data points. This tutorial expands upon the previous Predator-Prey Model tutorial to address these challenges. We will:

  1. Define ODEs and the probabilistic model.

  2. Generate synthetic datasets with imperfections.

  3. Perform parameter estimation using the MCMC algorithm.

[ ]:
#!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro
[2]:
import functools

import matplotlib.pyplot as plt

import jax
from jax.experimental.ode import odeint
import jax.numpy as jnp
from jax.random import PRNGKey

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive, init_to_sample

# Numerical instabilities may arise during ODE solving,
# so one has sometimes to play around with solver settings,
# change solver, or change numeric precision as we do here.
numpyro.enable_x64(True)

Model

Let’s start by defining our differential equations, dz_dt, and the probabilistic model, model. The differential equations remain the same as in the Lotka-Volterra tutorial. However, notable changes are introduced in the model to accommodate multiple initial conditions simultaneously. We begin by sampling initial conditions, z_init, and parameters, theta. Subsequently, the ODE system is solved in a vectorized form. Vectorization is achieved using jax.vmap, with the use of functools.partial for passing kwargs. Next, we sample sigma to represent measurement error. Finally, we sample the measured populations. Given that missing values may exist in the observed y, we mask non-finite values.

[3]:
def dz_dt(z, t, theta):
    """
    Lotka–Volterra equations. Real positive parameters `alpha`, `beta`, `gamma`, `delta`
    describes the interaction of two species.
    """
    u, v = z
    alpha, beta, gamma, delta = theta

    du_dt = (alpha - beta * v) * u
    dv_dt = (-gamma + delta * u) * v
    return jnp.stack([du_dt, dv_dt])


def model(ts, y_init, y=None):
    """
    :param numpy.ndarray ts: measurement times
    :param numpy.ndarray y_init: measured inital conditions
    :param numpy.ndarray y: measured populations
    """
    # initial population
    z_init = numpyro.sample(
        "z_init", dist.LogNormal(jnp.log(y_init), jnp.ones_like(y_init))
    )

    # parameters alpha, beta, gamma, delta of dz_dt
    theta = numpyro.sample(
        "theta",
        dist.TruncatedNormal(
            low=0.0,
            loc=jnp.array([1.0, 0.05, 1.0, 0.05]),
            scale=jnp.array([0.2, 0.01, 0.2, 0.01]),
        ),
    )

    # helpers to solve ODEs in a vectorized form
    odeint_with_kwargs = functools.partial(odeint, rtol=1e-6, atol=1e-5, mxstep=1000)
    vect_solve_ode = jax.vmap(
        odeint_with_kwargs,
        in_axes=(None, 0, 0, None),
    )

    # integrate dz/dt
    zs = vect_solve_ode(dz_dt, z_init, ts, theta)
    # measurement errors
    sigma = numpyro.sample("sigma", dist.LogNormal(-1, 1).expand([2]))
    # measured populations
    if y is not None:
        # mask missing observations in the observed y
        mask = jnp.isfinite(jnp.log(y))
        numpyro.sample("y", dist.LogNormal(jnp.log(zs), sigma).mask(mask), obs=y)
    else:
        numpyro.sample("y", dist.LogNormal(jnp.log(zs), sigma))

Dataset

For the purpose of this tutorial, we will utilize synthetic datasets generated by sampling from the previously defined model. To emulate the non-ideal properties of real-life datasets, we will introduce missing values, varying durations, and irregular spacing between timepoints. It’s important to note that JAX works with vectorized and compiled calculations, requiring datasets to have the same length. In our case, although we have different spacing, we maintain the same number of points. If it’s not the case one can use jnp.pad to extend all datasets to the same length with dummy fill values, which can later be masked.

First, let’s establish simulation settings. The datasets will exhibit varying timespans between t_min and t_max, with the number of points constrained between n_points_min and n_points_max. Additionally, we will introduce missing values with a probability of p_missing.

[4]:
n_datasets = 3  # int n_datasets: number of datasets to generate
t_min = 100  # int t_min: minimal allowed length of the generated time array
t_max = 200  # int t_min: maximal allowed length of the generated time array
n_points_min = 80  # int n_points_min: minimal allowed number of points in a data set
n_points_max = 120  # int n_points_max: maximal allowed number of points in a data set
y0_min = 2.0  # float y0_min: minimal allowed value for initial conditions
y0_max = 10.0  # float y0_max: maximal allowed value for initial conditions
p_missing = 0.1  # float p_missing: probability of having missing values

Generate an array with initial conditons

[5]:
# generate an array with initial conditons
z_inits = jnp.array(
    [jnp.linspace(y0_min, y0_max, n_datasets), jnp.linspace(y0_max, y0_min, n_datasets)]
).T

print(f"Initial conditons are: \n {z_inits}")
Initial conditons are:
 [[ 2. 10.]
 [ 6.  6.]
 [10.  2.]]

Next, let’s create a time matrix ts to store the time points for each individual dataset. We will generate random integers in rand_duration between t_min and t_max to represent varying durations. Similarly, rand_n_points will correspond to different spacings in each dataset. Since JAX requires a matrix with a constant shape, we will use jnp.pad to pad individual observations to the common length of the longest array.

[6]:
# generate array with random integers between t_min and t_max, representing tiem duration in the data set
rand_duration = jax.random.randint(
    PRNGKey(1), shape=(n_datasets,), minval=t_min, maxval=t_max
)

# generate array with random integers between n_points_min and n_points_max,
# representing number of time points per dataset
rand_n_points = jax.random.randint(
    PRNGKey(1), shape=(n_datasets,), minval=n_points_min, maxval=n_points_max
)

# Note that arrays have different length and are stored in a list
time_arrays = [
    jnp.linspace(0, j, num=rand_n_points[i]).astype(float)
    for i, j in enumerate(rand_duration)
]
longest = jnp.max(jnp.array([len(i) for i in time_arrays]))

# Make a time matrix
ts = jnp.array(
    [
        jnp.pad(arr, pad_width=(0, longest - len(arr)), constant_values=jnp.nan)
        for arr in time_arrays
    ]
)

print(f"The shape of the time matrix is {ts.shape}")
print(f"First values are \n {ts[:, :10]}")
print(f"Last values are \n {ts[:, -10:]}")
The shape of the time matrix is (3, 108)
First values are
 [[ 0.          1.00934579  2.01869159  3.02803738  4.03738318  5.04672897
   6.05607477  7.06542056  8.07476636  9.08411215]
 [ 0.          1.23863636  2.47727273  3.71590909  4.95454545  6.19318182
   7.43181818  8.67045455  9.90909091 11.14772727]
 [ 0.          1.21212121  2.42424242  3.63636364  4.84848485  6.06060606
   7.27272727  8.48484848  9.6969697  10.90909091]]
Last values are
 [[ 98.91588785  99.92523364 100.93457944 101.94392523 102.95327103
  103.96261682 104.97196262 105.98130841 106.99065421 108.        ]
 [         nan          nan          nan          nan          nan
           nan          nan          nan          nan          nan]
 [118.78787879 120.                  nan          nan          nan
           nan          nan          nan          nan          nan]]

We’ll utilize the Predictive mode from NumPyro to draw a single sample, representing our synthetic dataset. Subsequently, we’ll apply a mask with NaNs to the data to simulate missing values. For simplicity, we’ll ensure that initial values are non-missing. In real datasets where this may not hold true, then various imputation methods can be applied.

[7]:
# take a single sample that will be our synthetic data
sample = Predictive(model, num_samples=1)(PRNGKey(100), ts, z_inits)
data = sample["y"][0]

# create a mask that will add missing values to the data
missing_obs_mask = jax.random.choice(
    PRNGKey(1),
    jnp.array([True, False]),
    shape=data.shape,
    p=jnp.array([p_missing, 1 - p_missing]),
)
# make sure that initial values are not missing
missing_obs_mask = missing_obs_mask.at[:, 0, :].set(False)

# data with missing values
data = data.at[missing_obs_mask].set(jnp.nan)

Finally, for compatibility with NUTS later on, we need to fill NaN values in the time matrix ts with dummy variables. The odeint function from JAX requires these values to be in increasing order. We fill them with values greater than t_max from the time matrix. Importantly, these values do not affect the MCMC estimation, as the corresponding values in the data are missing and thereby ignored during the posterior estimation.

[8]:
# fill_nans
def fill_nans(ts):
    n_nan = jnp.sum(jnp.isnan(ts))
    if n_nan > 0:
        loc_first_nan = jnp.where(jnp.isnan(ts))[0][0]
        ts_filled_nans = ts.at[loc_first_nan:].set(
            jnp.linspace(t_max, t_max + 20, n_nan)
        )
        return ts_filled_nans
    else:
        return ts


ts_filled_nans = jnp.array([fill_nans(t) for t in ts])

Let’s briefly summarize our synthetic dataset:

[9]:
print(f"The dataset has the shape {data.shape}, (n_datasets, n_points, n_observables)")
print(f"The time matrix has the shape {ts.shape}, (n_datasets, n_timepoints)")
print(f"The time matrix has different spacing between timepoints: \n {ts[:,:5]}")
print(f"The final timepoints are: {jnp.nanmax(ts,1)} years.")
print(
    f"The dataset has {jnp.sum(jnp.isnan(data))/jnp.size(data):.0%} missing observations"
)
print(f"True params mean: {sample['theta'][0]}")
The dataset has the shape (3, 108, 2), (n_datasets, n_points, n_observables)
The time matrix has the shape (3, 108), (n_datasets, n_timepoints)
The time matrix has different spacing between timepoints:
 [[0.         1.00934579 2.01869159 3.02803738 4.03738318]
 [0.         1.23863636 2.47727273 3.71590909 4.95454545]
 [0.         1.21212121 2.42424242 3.63636364 4.84848485]]
The final timepoints are: [108. 109. 120.] years.
The dataset has 19% missing observations
True params mean: [0.78770691 0.05049109 0.89073622 0.05296055]

Let’s visualize the dataset, with solid lines helping to guide the eye. You’ll notice line breaks where NaN values occur.

[10]:
# Plotting
fig, axs = plt.subplots(2, n_datasets, figsize=(15, 4))

for i in range(n_datasets):
    loc = jnp.where(jnp.isfinite(data[i, :, 0]))[0][-1]

    axs[0, i].plot(
        ts[i, :], data[i, :, 0], "ko", mfc="none", ms=4, label="true hare", alpha=0.67
    )
    axs[0, i].plot(ts[i, :], data[i, :, 0], label="true hare", alpha=0.67)
    axs[0, i].set_xlabel("Time, year")
    axs[0, i].set_ylabel("Population")
    axs[0, i].set_xlim([-5, jnp.nanmax(ts)])

    axs[1, i].plot(ts[i, :], data[i, :, 1], "bx", label="true lynx")
    axs[1, i].plot(ts[i, :], data[i, :, 1], label="true lynx")
    axs[1, i].set_xlabel("Time, year")
    axs[1, i].set_ylabel("Population")
    axs[1, i].set_xlim([-5, jnp.nanmax(ts)])

fig.tight_layout()
../_images/tutorials_lotka_volterra_multiple_21_0.png

Perform MCMC.

To achieve a balance between accuracy and speed, one has to adjust the parameters of both the MCMC solver and the ODE solver to suit the specific problem.

[11]:
y_init = data[:, 0, :]

mcmc = MCMC(
    NUTS(
        model,
        dense_mass=True,
        init_strategy=init_to_sample(),
        max_tree_depth=10,
    ),
    num_warmup=1000,
    num_samples=1000,
    num_chains=1,
    progress_bar=True,
)

mcmc.run(PRNGKey(1031410), ts=ts_filled_nans, y_init=y_init, y=data)
mcmc.print_summary()

print(f"True params mean: {sample['theta'][0]}")
print(f"Estimated params mean: {jnp.mean(mcmc.get_samples()['theta'], axis = 0)}")
sample: 100%|██████████| 2000/2000 [09:09<00:00,  3.64it/s, 31 steps of size 1.23e-01. acc. prob=0.94]

                 mean       std    median      5.0%     95.0%     n_eff     r_hat
   sigma[0]      0.29      0.01      0.29      0.27      0.31   1064.77      1.00
   sigma[1]      0.51      0.02      0.51      0.47      0.55   1593.43      1.00
   theta[0]      0.77      0.02      0.77      0.74      0.79    760.41      1.00
   theta[1]      0.05      0.00      0.05      0.05      0.05    888.74      1.00
   theta[2]      0.91      0.02      0.91      0.87      0.94    842.09      1.00
   theta[3]      0.06      0.00      0.06      0.05      0.06    858.84      1.00
z_init[0,0]      1.51      0.05      1.51      1.43      1.60    782.07      1.00
z_init[0,1]      9.11      0.55      9.09      8.06      9.88   1072.88      1.00
z_init[1,0]      3.83      0.14      3.83      3.63      4.07    986.01      1.00
z_init[1,1]      8.54      0.57      8.54      7.66      9.52    945.91      1.00
z_init[2,0]      3.87      0.15      3.86      3.64      4.11   1210.24      1.00
z_init[2,1]      3.70      0.19      3.69      3.39      4.02   1342.93      1.00

Number of divergences: 0
True params mean: [0.78770691 0.05049109 0.89073622 0.05296055]
Estimated params mean: [0.7684689  0.05000161 0.90749349 0.05559383]

Run predictions.

[12]:
# predict
ts_pred = jnp.tile(jnp.linspace(0, 200, 1000), (n_datasets, 1))
pop_pred = Predictive(model, mcmc.get_samples())(PRNGKey(1041140), ts_pred, y_init)["y"]
mu = jnp.mean(pop_pred, 0)
pi = jnp.percentile(pop_pred, jnp.array([10, 90]), 0)


print(f"True params mean: {sample['theta'][0]}")
print(f"Estimated params mean: {jnp.mean(mcmc.get_samples()['theta'], axis = 0)}")
True params mean: [0.78770691 0.05049109 0.89073622 0.05296055]
Estimated params mean: [0.7684689  0.05000161 0.90749349 0.05559383]

Plot the observed points and predicted mean with prediction intervals.

[13]:
# Plotting
fig, axs = plt.subplots(2, n_datasets, figsize=(15, 4))

for i in range(n_datasets):
    loc = jnp.where(jnp.isfinite(data[i, :, 0]))[0][-1]

    axs[0, i].plot(
        ts_pred[i, :], mu[i, :, 0], "k-.", label="pred hare", lw=1, alpha=0.67
    )
    axs[0, i].plot(
        ts[i, :], data[i, :, 0], "ko", mfc="none", ms=4, label="true hare", alpha=0.67
    )
    axs[0, i].fill_between(
        ts_pred[i, :], pi[0, i, :, 0], pi[1, i, :, 0], color="k", alpha=0.2
    )
    axs[0, i].set_xlabel("Time, year")
    axs[0, i].set_ylabel("Population")
    axs[0, i].set_xlim([-5, jnp.nanmax(ts)])

    axs[1, i].plot(ts_pred[i, :], mu[i, :, 1], "b--", label="pred lynx")
    axs[1, i].plot(ts[i, :], data[i, :, 1], "bx", label="true lynx")
    axs[1, i].fill_between(
        ts_pred[i, :], pi[0, i, :, 1], pi[1, i, :, 1], color="b", alpha=0.2
    )
    axs[1, i].set_xlabel("Time, year")
    axs[1, i].set_ylabel("Population")
    axs[1, i].set_xlim([-5, jnp.nanmax(ts)])


fig.tight_layout()
../_images/tutorials_lotka_volterra_multiple_28_0.png