{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n# Example: Predator-Prey Model\n\nThis example replicates the great case study [1], which leverages the Lotka-Volterra\nequation [2] to describe the dynamics of Canada lynx (predator) and snowshoe hare\n(prey) populations. We will use the dataset obtained from [3] and run MCMC to get\ninferences about parameters of the differential equation governing the dynamics.\n\n**References:**\n\n 1. Bob Carpenter (2018), `\"Predator-Prey Population Dynamics: the Lotka-Volterra model in Stan\"\n `_.\n 2. https://en.wikipedia.org/wiki/Lotka-Volterra_equations\n 3. http://people.whitman.edu/~hundledr/courses/M250F03/M250.html\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import argparse\nimport os\n\nimport matplotlib\nimport matplotlib.pyplot as plt\n\nfrom jax.experimental.ode import odeint\nimport jax.numpy as jnp\nfrom jax.random import PRNGKey\n\nimport numpyro\nimport numpyro.distributions as dist\nfrom numpyro.examples.datasets import LYNXHARE, load_dataset\nfrom numpyro.infer import MCMC, NUTS, Predictive\n\nmatplotlib.use(\"Agg\") # noqa: E402\n\n\ndef dz_dt(z, t, theta):\n \"\"\"\n Lotka\u2013Volterra equations. Real positive parameters `alpha`, `beta`, `gamma`, `delta`\n describes the interaction of two species.\n \"\"\"\n u = z[0]\n v = z[1]\n alpha, beta, gamma, delta = (\n theta[..., 0],\n theta[..., 1],\n theta[..., 2],\n theta[..., 3],\n )\n du_dt = (alpha - beta * v) * u\n dv_dt = (-gamma + delta * u) * v\n return jnp.stack([du_dt, dv_dt])\n\n\ndef model(N, y=None):\n \"\"\"\n :param int N: number of measurement times\n :param numpy.ndarray y: measured populations with shape (N, 2)\n \"\"\"\n # initial population\n z_init = numpyro.sample(\"z_init\", dist.LogNormal(jnp.log(10), 1).expand([2]))\n # measurement times\n ts = jnp.arange(float(N))\n # parameters alpha, beta, gamma, delta of dz_dt\n theta = numpyro.sample(\n \"theta\",\n dist.TruncatedNormal(\n low=0.0,\n loc=jnp.array([1.0, 0.05, 1.0, 0.05]),\n scale=jnp.array([0.5, 0.05, 0.5, 0.05]),\n ),\n )\n # integrate dz/dt, the result will have shape N x 2\n z = odeint(dz_dt, z_init, ts, theta, rtol=1e-6, atol=1e-5, mxstep=1000)\n # measurement errors\n sigma = numpyro.sample(\"sigma\", dist.LogNormal(-1, 1).expand([2]))\n # measured populations\n numpyro.sample(\"y\", dist.LogNormal(jnp.log(z), sigma), obs=y)\n\n\ndef main(args):\n _, fetch = load_dataset(LYNXHARE, shuffle=False)\n year, data = fetch() # data is in hare -> lynx order\n\n # use dense_mass for better mixing rate\n mcmc = MCMC(\n NUTS(model, dense_mass=True),\n num_warmup=args.num_warmup,\n num_samples=args.num_samples,\n num_chains=args.num_chains,\n progress_bar=False if \"NUMPYRO_SPHINXBUILD\" in os.environ else True,\n )\n mcmc.run(PRNGKey(1), N=data.shape[0], y=data)\n mcmc.print_summary()\n\n # predict populations\n pop_pred = Predictive(model, mcmc.get_samples())(PRNGKey(2), data.shape[0])[\"y\"]\n mu = jnp.mean(pop_pred, 0)\n pi = jnp.percentile(pop_pred, jnp.array([10, 90]), 0)\n plt.figure(figsize=(8, 6), constrained_layout=True)\n plt.plot(year, data[:, 0], \"ko\", mfc=\"none\", ms=4, label=\"true hare\", alpha=0.67)\n plt.plot(year, data[:, 1], \"bx\", label=\"true lynx\")\n plt.plot(year, mu[:, 0], \"k-.\", label=\"pred hare\", lw=1, alpha=0.67)\n plt.plot(year, mu[:, 1], \"b--\", label=\"pred lynx\")\n plt.fill_between(year, pi[0, :, 0], pi[1, :, 0], color=\"k\", alpha=0.2)\n plt.fill_between(year, pi[0, :, 1], pi[1, :, 1], color=\"b\", alpha=0.3)\n plt.gca().set(ylim=(0, 160), xlabel=\"year\", ylabel=\"population (in thousands)\")\n plt.title(\"Posterior predictive (80% CI) with predator-prey pattern.\")\n plt.legend()\n\n plt.savefig(\"ode_plot.pdf\")\n\n\nif __name__ == \"__main__\":\n assert numpyro.__version__.startswith(\"0.11.0\")\n parser = argparse.ArgumentParser(description=\"Predator-Prey Model\")\n parser.add_argument(\"-n\", \"--num-samples\", nargs=\"?\", default=1000, type=int)\n parser.add_argument(\"--num-warmup\", nargs=\"?\", default=1000, type=int)\n parser.add_argument(\"--num-chains\", nargs=\"?\", default=1, type=int)\n parser.add_argument(\"--device\", default=\"cpu\", type=str, help='use \"cpu\" or \"gpu\".')\n args = parser.parse_args()\n\n numpyro.set_platform(args.device)\n numpyro.set_host_device_count(args.num_chains)\n\n main(args)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.15"
}
},
"nbformat": 4,
"nbformat_minor": 0
}