Source code for numpyro.contrib.einstein.steinvi

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import namedtuple
import functools
from functools import partial
from itertools import chain
import operator
from typing import Callable

import jax
import jax.numpy as jnp
import jax.random
from jax.tree_util import tree_map

from numpyro import handlers
from numpyro.contrib.einstein.kernels import SteinKernel
from numpyro.contrib.einstein.util import batch_ravel_pytree, get_parameter_transform
from numpyro.contrib.funsor import config_enumerate, enum
from numpyro.distributions import Distribution, Normal
from numpyro.distributions.constraints import real
from numpyro.distributions.transforms import IdentityTransform
from numpyro.infer.autoguide import AutoGuide
from numpyro.infer.util import _guess_max_plate_nesting, transform_fn
from numpyro.util import fori_collect, ravel_pytree

SteinVIState = namedtuple("SteinVIState", ["optim_state", "rng_key"])
SteinVIRunResult = namedtuple("SteinRunResult", ["params", "state", "losses"])

def _numel(shape):
    return functools.reduce(operator.mul, shape, 1)

[docs]class SteinVI: """Stein variational inference for stein mixtures. :param model: Python callable with Pyro primitives for the model. :param guide: Python callable with Pyro primitives for the guide (recognition network). :param optim: an instance of :class:`~numpyro.optim._NumpyroOptim`. :param loss: ELBO loss, i.e. negative Evidence Lower Bound, to minimize. :param kernel_fn: Function that produces a logarithm of the statistical kernel to use with Stein inference :param num_particles: number of particles for Stein inference. (More particles capture more of the posterior distribution) :param loss_temperature: scaling of loss factor :param repulsion_temperature: scaling of repulsive forces (Non-linear Stein) :param enum: whether to apply automatic marginalization of discrete variables :param classic_guide_param_fn: predicate on names of parameters in guide which should be optimized classically without Stein (E.g. parameters for large normal networks or other transformation) :param static_kwargs: Static keyword arguments for the model / guide, i.e. arguments that remain constant during fitting. """ def __init__( self, model, guide, optim, loss, kernel_fn: SteinKernel, num_particles: int = 10, loss_temperature: float = 1.0, repulsion_temperature: float = 1.0, classic_guide_params_fn: Callable[[str], bool] = lambda name: False, enum=True, **static_kwargs, ): self._inference_model = model self.model = model = guide self.optim = optim self.loss = loss self.kernel_fn = kernel_fn self.static_kwargs = static_kwargs self.num_particles = num_particles self.loss_temperature = loss_temperature self.repulsion_temperature = repulsion_temperature self.enum = enum self.classic_guide_params_fn = classic_guide_params_fn self.guide_param_names = None self.constrain_fn = None self.uconstrain_fn = None self.particle_transform_fn = None self.particle_transforms = None def _apply_kernel(self, kernel, x, y, v): if self.kernel_fn.mode == "norm" or self.kernel_fn.mode == "vector": return kernel(x, y) * v else: return kernel(x, y) @ v def _kernel_grad(self, kernel, x, y): if self.kernel_fn.mode == "norm": return jax.grad(lambda x: kernel(x, y))(x) elif self.kernel_fn.mode == "vector": return jax.vmap(lambda i: jax.grad(lambda x: kernel(x, y)[i])(x)[i])( jnp.arange(x.shape[0]) ) else: return jax.vmap( lambda l: jnp.sum( jax.vmap(lambda m: jax.grad(lambda x: kernel(x, y)[l, m])(x)[m])( jnp.arange(x.shape[0]) ) ) )(jnp.arange(x.shape[0])) def _param_size(self, param): if isinstance(param, tuple) or isinstance(param, list): return sum(map(self._param_size, param)) return param.size def _calc_particle_info(self, uparams, num_particles, start_index=0): uparam_keys = list(uparams.keys()) uparam_keys.sort() res = {} end_index = start_index for k in uparam_keys: if isinstance(uparams[k], dict): res_sub, end_index = self._calc_particle_info( uparams[k], num_particles, start_index ) res[k] = res_sub else: end_index = start_index + self._param_size(uparams[k]) // num_particles res[k] = (start_index, end_index) start_index = end_index return res, end_index def _find_init_params(self, particle_seed, inner_guide, inner_guide_trace): def extract_info(site): nonlocal particle_seed name = site["name"] value = site["value"] constraint = site["kwargs"].get("constraint", real) transform = get_parameter_transform(site) if ( isinstance(inner_guide, AutoGuide) and "_".join((inner_guide.prefix, "loc")) in name ): site_key, particle_seed = jax.random.split(particle_seed) unconstrained_shape = transform.inverse_shape(value.shape) init_value = jnp.expand_dims( transform.inv(value), 0 ) + Normal( # Add gaussian noise scale=0.1 ).sample( particle_seed, (self.num_particles, *unconstrained_shape) ) init_value = transform(init_value) else: site_fn = site["fn"] site_args = site["args"] site_key, particle_seed = jax.random.split(particle_seed) def _reinit(seed): with handlers.seed(rng_seed=seed): return site_fn(*site_args) init_value = jax.vmap(_reinit)( jax.random.split(particle_seed, self.num_particles) ) return init_value, constraint init_params = { name: extract_info(site) for name, site in inner_guide_trace.items() if site.get("type") == "param" } return init_params def _svgd_loss_and_grads(self, rng_key, unconstr_params, *args, **kwargs): # 0. Separate model and guide parameters, since only guide parameters are updated using Stein classic_uparams = { p: v for p, v in unconstr_params.items() if p not in self.guide_param_names or self.classic_guide_params_fn(p) } stein_uparams = { p: v for p, v in unconstr_params.items() if p not in classic_uparams } # 1. Collect each guide parameter into monolithic particles that capture correlations # between parameter values across each individual particle stein_particles, unravel_pytree, unravel_pytree_batched = batch_ravel_pytree( stein_uparams, nbatch_dims=1 ) particle_info, _ = self._calc_particle_info( stein_uparams, stein_particles.shape[0] ) # 2. Calculate loss and gradients for each parameter def scaled_loss(rng_key, classic_params, stein_params): params = {**classic_params, **stein_params} loss_val = self.loss.loss( rng_key, params, handlers.scale(self._inference_model, self.loss_temperature),, *args, **kwargs, ) return -loss_val def kernel_particle_loss_fn(ps): return scaled_loss( rng_key, self.constrain_fn(classic_uparams), self.constrain_fn(unravel_pytree(ps)), ) def particle_transform_fn(particle): params = unravel_pytree(particle) tparams = self.particle_transform_fn(params) tparticle, _ = ravel_pytree(tparams) return tparticle tstein_particles = jax.vmap(particle_transform_fn)(stein_particles) loss, particle_ljp_grads = jax.vmap( jax.value_and_grad(kernel_particle_loss_fn) )(tstein_particles) classic_param_grads = jax.vmap( lambda ps: jax.grad( lambda cps: scaled_loss( rng_key, self.constrain_fn(cps), self.constrain_fn(unravel_pytree(ps)), ) )(classic_uparams) )(stein_particles) classic_param_grads = tree_map(partial(jnp.mean, axis=0), classic_param_grads) # 3. Calculate kernel on monolithic particle kernel = self.kernel_fn.compute( stein_particles, particle_info, kernel_particle_loss_fn ) # 4. Calculate the attractive force and repulsive force on the monolithic particles attractive_force = jax.vmap( lambda y: jnp.sum( jax.vmap( lambda x, x_ljp_grad: self._apply_kernel(kernel, x, y, x_ljp_grad) )(tstein_particles, particle_ljp_grads), axis=0, ) )(tstein_particles) repulsive_force = jax.vmap( lambda y: jnp.sum( jax.vmap( lambda x: self.repulsion_temperature * self._kernel_grad(kernel, x, y) )(tstein_particles), axis=0, ) )(tstein_particles) def single_particle_grad(particle, attr_forces, rep_forces): def _nontrivial_jac(var_name, var): if isinstance(self.particle_transforms[var_name], IdentityTransform): return None return jax.jacfwd(self.particle_transforms[var_name].inv)(var) def _update_force(attr_force, rep_force, jac): force = attr_force.reshape(-1) + rep_force.reshape(-1) if jac is not None: force = force @ jac.reshape( (_numel(jac.shape[: len(jac.shape) // 2]), -1) ) return force.reshape(attr_force.shape) reparam_jac = { name: tree_map(lambda var: _nontrivial_jac(name, var), variables) for name, variables in unravel_pytree(particle).items() } jac_params = tree_map( _update_force, unravel_pytree(attr_forces), unravel_pytree(rep_forces), reparam_jac, ) jac_particle, _ = ravel_pytree(jac_params) return jac_particle particle_grads = ( jax.vmap(single_particle_grad)( stein_particles, attractive_force, repulsive_force ) / self.num_particles ) # 5. Decompose the monolithic particle forces back to concrete parameter values stein_param_grads = unravel_pytree_batched(particle_grads) # 6. Return loss and gradients (based on parameter forces) res_grads = tree_map(lambda x: -x, {**classic_param_grads, **stein_param_grads}) return -jnp.mean(loss), res_grads def init(self, rng_key, *args, **kwargs): """ :param jax.random.PRNGKey rng_key: random number generator seed. :param args: arguments to the model / guide (these can possibly vary during the course of fitting). :param kwargs: keyword arguments to the model / guide (these can possibly vary during the course of fitting). :return: initial :data:`SteinVIState` """ rng_key, kernel_seed, model_seed, guide_seed = jax.random.split(rng_key, 4) model_init = handlers.seed(self.model, model_seed) guide_init = handlers.seed(, guide_seed) guide_trace = handlers.trace(guide_init).get_trace( *args, **kwargs, **self.static_kwargs ) model_trace = handlers.trace(model_init).get_trace( *args, **kwargs, **self.static_kwargs ) rng_key, particle_seed = jax.random.split(rng_key) guide_init_params = self._find_init_params( particle_seed,, guide_trace ) params = {} transforms = {} inv_transforms = {} particle_transforms = {} guide_param_names = set() should_enum = False for site in model_trace.values(): if ( "fn" in site and site["type"] == "sample" and not site["is_observed"] and isinstance(site["fn"], Distribution) and site["fn"].is_discrete ): if site["fn"].has_enumerate_support and self.enum: should_enum = True else: raise Exception( "Cannot enumerate model with discrete variables without enumerate support" ) # NB: params in model_trace will be overwritten by params in guide_trace for site in chain(model_trace.values(), guide_trace.values()): if site["type"] == "param": transform = get_parameter_transform(site) inv_transforms[site["name"]] = transform transforms[site["name"]] = transform.inv particle_transforms[site["name"]] = site.get( "particle_transform", IdentityTransform() ) if site["name"] in guide_init_params: pval, _ = guide_init_params[site["name"]] if self.classic_guide_params_fn(site["name"]): pval = tree_map(lambda x: x[0], pval) else: pval = site["value"] params[site["name"]] = transform.inv(pval) if site["name"] in guide_trace: guide_param_names.add(site["name"]) if should_enum: mpn = _guess_max_plate_nesting(model_trace) self._inference_model = enum(config_enumerate(self.model), -mpn - 1) self.guide_param_names = guide_param_names self.constrain_fn = partial(transform_fn, inv_transforms) self.uconstrain_fn = partial(transform_fn, transforms) self.particle_transforms = particle_transforms self.particle_transform_fn = partial(transform_fn, particle_transforms) stein_particles, _, _ = batch_ravel_pytree( { k: params[k] for k, site in guide_trace.items() if site["type"] == "param" and site["name"] in guide_init_params }, nbatch_dims=1, ) self.kernel_fn.init(kernel_seed, stein_particles.shape) return SteinVIState(self.optim.init(params), rng_key) def get_params(self, state: SteinVIState): """ Gets values at `param` sites of the `model` and `guide`. :param state: current state of the optimizer. """ params = self.constrain_fn(self.optim.get_params(state.optim_state)) return params def update(self, state: SteinVIState, *args, **kwargs): """ Take a single step of Stein (possibly on a batch / minibatch of data), using the optimizer. :param state: current state of Stein. :param args: arguments to the model / guide (these can possibly vary during the course of fitting). :param kwargs: keyword arguments to the model / guide (these can possibly vary during the course of fitting). :return: tuple of `(state, loss)`. """ rng_key, rng_key_mcmc, rng_key_step = jax.random.split(state.rng_key, num=3) params = self.optim.get_params(state.optim_state) optim_state = state.optim_state loss_val, grads = self._svgd_loss_and_grads( rng_key_step, params, *args, **kwargs, **self.static_kwargs ) optim_state = self.optim.update(grads, optim_state) return SteinVIState(optim_state, rng_key), loss_val def run( self, rng_key, num_steps, *args, progress_bar=True, init_state=None, collect_fn=lambda val: val[1], # TODO: refactor **kwargs, ): def bodyfn(_i, info): body_state = info[0] return (*self.update(body_state, *info[2:], **kwargs), *info[2:]) if init_state is None: state = self.init(rng_key, *args, **kwargs) else: state = init_state loss = self.evaluate(state, *args, **kwargs) auxiliaries, last_res = fori_collect( 0, num_steps, lambda info: bodyfn(0, info), (state, loss, *args), progbar=progress_bar, transform=collect_fn, return_last_val=True, ) state = last_res[0] return SteinVIRunResult(self.get_params(state), state, auxiliaries) def evaluate(self, state, *args, **kwargs): """ Take a single step of Stein (possibly on a batch / minibatch of data). :param state: current state of Stein. :param args: arguments to the model / guide (these can possibly vary during the course of fitting). :param kwargs: keyword arguments to the model / guide. :return: evaluate loss given the current parameter values (held within `state.optim_state`). """ # we split to have the same seed as `update_fn` given a state _, _, rng_key_eval = jax.random.split(state.rng_key, num=3) params = self.optim.get_params(state.optim_state) loss_val, _ = self._svgd_loss_and_grads( rng_key_eval, params, *args, **kwargs, **self.static_kwargs ) return loss_val