Source code for numpyro.svi

import os

from jax import random, value_and_grad

from numpyro.handlers import replay, seed, substitute, trace
from numpyro.hmc_util import log_density


def _seed(model, guide, rng):
    model_seed, guide_seed = random.split(rng, 2)
    model_init = seed(model, model_seed)
    guide_init = seed(guide, guide_seed)
    return model_init, guide_init


[docs]def svi(model, guide, loss, optim_init, optim_update, get_params, **kwargs): """ Stochastic Variational Inference given an ELBo loss objective. :param model: Python callable with Pyro primitives for the model. :param guide: Python callable with Pyro primitives for the guide (recognition network). :param loss: ELBo loss, i.e. negative Evidence Lower Bound, to minimize. :param optim_init: initialization function returned by a JAX optimizer. see: :mod:`jax.experimental.optimizers`. :param optim_update: update function for the optimizer :param get_params: function to get current parameters values given the optimizer state. :param `**kwargs`: static arguments for the model / guide, i.e. arguments that remain constant during fitting. :return: tuple of `(init_fn, update_fn, evaluate)`. """ def init_fn(rng, model_args=(), guide_args=(), params=None): """ :param jax.random.PRNGKey rng: random number generator seed. :param tuple model_args: arguments to the model (these can possibly vary during the course of fitting). :param tuple guide_args: arguments to the guide (these can possibly vary during the course of fitting). :param dict params: initial parameter values to condition on. This can be useful forx :return: initial optimizer state. """ assert isinstance(model_args, tuple) assert isinstance(guide_args, tuple) model_init, guide_init = _seed(model, guide, rng) if params is None: params = {} else: model_init = substitute(model_init, params) guide_init = substitute(guide_init, params) guide_trace = trace(guide_init).get_trace(*guide_args, **kwargs) model_trace = trace(model_init).get_trace(*model_args, **kwargs) for site in list(guide_trace.values()) + list(model_trace.values()): if site['type'] == 'param': params[site['name']] = site['value'] return optim_init(params) def update_fn(i, opt_state, rng, model_args=(), guide_args=()): """ Take a single step of SVI (possibly on a batch / minibatch of data), using the optimizer. :param int i: represents the i'th iteration over the epoch, passed as an argument to the optimizer's update function. :param opt_state: current optimizer state. :param jax.random.PRNGKey rng: random number generator seed. :param tuple model_args: dynamic arguments to the model. :param tuple guide_args: dynamic arguments to the guide. :return: tuple of `(loss_val, opt_state, rng)`. """ model_init, guide_init = _seed(model, guide, rng) params = get_params(opt_state) loss_val, grads = value_and_grad(loss)(params, model_init, guide_init, model_args, guide_args, kwargs) opt_state = optim_update(i, grads, opt_state) rng, = random.split(rng, 1) return loss_val, opt_state, rng def evaluate(opt_state, rng, model_args=(), guide_args=()): """ Take a single step of SVI (possibly on a batch / minibatch of data). :param opt_state: current optimizer state. :param jax.random.PRNGKey rng: random number generator seed. :param tuple model_args: arguments to the model (these can possibly vary during the course of fitting). :param tuple guide_args: arguments to the guide (these can possibly vary during the course of fitting). :return: evaluate ELBo loss given the current parameter values (held within `opt_state`). """ model_init, guide_init = _seed(model, guide, rng) params = get_params(opt_state) return loss(params, model_init, guide_init, model_args, guide_args, kwargs) # Make local functions visible from the global scope once # `svi` is called for sphinx doc generation. if 'SPHINX_BUILD' in os.environ: svi.init_fn = init_fn svi.update_fn = update_fn svi.evaluate = evaluate return init_fn, update_fn, evaluate
[docs]def elbo(param_map, model, guide, model_args, guide_args, kwargs): """ This is the most basic implementation of the Evidence Lower Bound, which is the fundamental objective in Variational Inference. This implementation has various limitations (for example it only supports random variablbes with reparameterized samplers) but can be used as a template to build more sophisticated loss objectives. For more details, refer to http://pyro.ai/examples/svi_part_i.html. :param dict param_map: dictionary of current parameter values keyed by site name. :param model: Python callable with Pyro primitives for the model. :param guide: Python callable with Pyro primitives for the guide (recognition network). :param tuple model_args: arguments to the model (these can possibly vary during the course of fitting). :param tuple guide_args: arguments to the guide (these can possibly vary during the course of fitting). :param dict kwargs: static keyword arguments to the model / guide. :return: negative of the Evidence Lower Bound (ELBo) to be minimized. """ guide_log_density, guide_trace = log_density(guide, guide_args, kwargs, param_map) model_log_density, _ = log_density(replay(model, guide_trace), model_args, kwargs, param_map) # log p(z) - log q(z) elbo = model_log_density - guide_log_density # Return (-elbo) since by convention we do gradient descent on a loss and # the ELBO is a lower bound that needs to be maximized. return -elbo