# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from collections import namedtuple
import functools
from functools import partial
from itertools import chain
import operator
from typing import Callable
import jax
import jax.numpy as jnp
import jax.random
from jax.tree_util import tree_map
from numpyro import handlers
from numpyro.contrib.einstein.kernels import SteinKernel
from numpyro.contrib.einstein.util import batch_ravel_pytree, get_parameter_transform
from numpyro.contrib.funsor import config_enumerate, enum
from numpyro.distributions import Distribution, Normal
from numpyro.distributions.constraints import real
from numpyro.distributions.transforms import IdentityTransform
from numpyro.infer.autoguide import AutoGuide
from numpyro.infer.util import _guess_max_plate_nesting, transform_fn
from numpyro.util import fori_collect, ravel_pytree
SteinVIState = namedtuple("SteinVIState", ["optim_state", "rng_key"])
SteinVIRunResult = namedtuple("SteinRunResult", ["params", "state", "losses"])
def _numel(shape):
return functools.reduce(operator.mul, shape, 1)
[docs]class SteinVI:
"""Stein variational inference for stein mixtures.
:param model: Python callable with Pyro primitives for the model.
:param guide: Python callable with Pyro primitives for the guide
(recognition network).
:param optim: an instance of :class:`~numpyro.optim._NumpyroOptim`.
:param loss: ELBO loss, i.e. negative Evidence Lower Bound, to minimize.
:param kernel_fn: Function that produces a logarithm of the statistical kernel to use with Stein inference
:param num_particles: number of particles for Stein inference.
(More particles capture more of the posterior distribution)
:param loss_temperature: scaling of loss factor
:param repulsion_temperature: scaling of repulsive forces (Non-linear Stein)
:param enum: whether to apply automatic marginalization of discrete variables
:param classic_guide_param_fn: predicate on names of parameters in guide which should be optimized classically
without Stein (E.g. parameters for large normal networks or other transformation)
:param static_kwargs: Static keyword arguments for the model / guide, i.e. arguments
that remain constant during fitting.
"""
def __init__(
self,
model,
guide,
optim,
loss,
kernel_fn: SteinKernel,
num_particles: int = 10,
loss_temperature: float = 1.0,
repulsion_temperature: float = 1.0,
classic_guide_params_fn: Callable[[str], bool] = lambda name: False,
enum=True,
**static_kwargs,
):
self._inference_model = model
self.model = model
self.guide = guide
self.optim = optim
self.loss = loss
self.kernel_fn = kernel_fn
self.static_kwargs = static_kwargs
self.num_particles = num_particles
self.loss_temperature = loss_temperature
self.repulsion_temperature = repulsion_temperature
self.enum = enum
self.classic_guide_params_fn = classic_guide_params_fn
self.guide_param_names = None
self.constrain_fn = None
self.uconstrain_fn = None
self.particle_transform_fn = None
self.particle_transforms = None
def _apply_kernel(self, kernel, x, y, v):
if self.kernel_fn.mode == "norm" or self.kernel_fn.mode == "vector":
return kernel(x, y) * v
else:
return kernel(x, y) @ v
def _kernel_grad(self, kernel, x, y):
if self.kernel_fn.mode == "norm":
return jax.grad(lambda x: kernel(x, y))(x)
elif self.kernel_fn.mode == "vector":
return jax.vmap(lambda i: jax.grad(lambda x: kernel(x, y)[i])(x)[i])(
jnp.arange(x.shape[0])
)
else:
return jax.vmap(
lambda l: jnp.sum(
jax.vmap(lambda m: jax.grad(lambda x: kernel(x, y)[l, m])(x)[m])(
jnp.arange(x.shape[0])
)
)
)(jnp.arange(x.shape[0]))
def _param_size(self, param):
if isinstance(param, tuple) or isinstance(param, list):
return sum(map(self._param_size, param))
return param.size
def _calc_particle_info(self, uparams, num_particles, start_index=0):
uparam_keys = list(uparams.keys())
uparam_keys.sort()
res = {}
end_index = start_index
for k in uparam_keys:
if isinstance(uparams[k], dict):
res_sub, end_index = self._calc_particle_info(
uparams[k], num_particles, start_index
)
res[k] = res_sub
else:
end_index = start_index + self._param_size(uparams[k]) // num_particles
res[k] = (start_index, end_index)
start_index = end_index
return res, end_index
def _find_init_params(self, particle_seed, inner_guide, inner_guide_trace):
def extract_info(site):
nonlocal particle_seed
name = site["name"]
value = site["value"]
constraint = site["kwargs"].get("constraint", real)
transform = get_parameter_transform(site)
if (
isinstance(inner_guide, AutoGuide)
and "_".join((inner_guide.prefix, "loc")) in name
):
site_key, particle_seed = jax.random.split(particle_seed)
unconstrained_shape = transform.inverse_shape(value.shape)
init_value = jnp.expand_dims(
transform.inv(value), 0
) + Normal( # Add gaussian noise
scale=0.1
).sample(
particle_seed, (self.num_particles, *unconstrained_shape)
)
init_value = transform(init_value)
else:
site_fn = site["fn"]
site_args = site["args"]
site_key, particle_seed = jax.random.split(particle_seed)
def _reinit(seed):
with handlers.seed(rng_seed=seed):
return site_fn(*site_args)
init_value = jax.vmap(_reinit)(
jax.random.split(particle_seed, self.num_particles)
)
return init_value, constraint
init_params = {
name: extract_info(site)
for name, site in inner_guide_trace.items()
if site.get("type") == "param"
}
return init_params
def _svgd_loss_and_grads(self, rng_key, unconstr_params, *args, **kwargs):
# 0. Separate model and guide parameters, since only guide parameters are updated using Stein
classic_uparams = {
p: v
for p, v in unconstr_params.items()
if p not in self.guide_param_names or self.classic_guide_params_fn(p)
}
stein_uparams = {
p: v for p, v in unconstr_params.items() if p not in classic_uparams
}
# 1. Collect each guide parameter into monolithic particles that capture correlations
# between parameter values across each individual particle
stein_particles, unravel_pytree, unravel_pytree_batched = batch_ravel_pytree(
stein_uparams, nbatch_dims=1
)
particle_info, _ = self._calc_particle_info(
stein_uparams, stein_particles.shape[0]
)
# 2. Calculate loss and gradients for each parameter
def scaled_loss(rng_key, classic_params, stein_params):
params = {**classic_params, **stein_params}
loss_val = self.loss.loss(
rng_key,
params,
handlers.scale(self._inference_model, self.loss_temperature),
self.guide,
*args,
**kwargs,
)
return -loss_val
def kernel_particle_loss_fn(ps):
return scaled_loss(
rng_key,
self.constrain_fn(classic_uparams),
self.constrain_fn(unravel_pytree(ps)),
)
def particle_transform_fn(particle):
params = unravel_pytree(particle)
tparams = self.particle_transform_fn(params)
tparticle, _ = ravel_pytree(tparams)
return tparticle
tstein_particles = jax.vmap(particle_transform_fn)(stein_particles)
loss, particle_ljp_grads = jax.vmap(
jax.value_and_grad(kernel_particle_loss_fn)
)(tstein_particles)
classic_param_grads = jax.vmap(
lambda ps: jax.grad(
lambda cps: scaled_loss(
rng_key,
self.constrain_fn(cps),
self.constrain_fn(unravel_pytree(ps)),
)
)(classic_uparams)
)(stein_particles)
classic_param_grads = tree_map(partial(jnp.mean, axis=0), classic_param_grads)
# 3. Calculate kernel on monolithic particle
kernel = self.kernel_fn.compute(
stein_particles, particle_info, kernel_particle_loss_fn
)
# 4. Calculate the attractive force and repulsive force on the monolithic particles
attractive_force = jax.vmap(
lambda y: jnp.sum(
jax.vmap(
lambda x, x_ljp_grad: self._apply_kernel(kernel, x, y, x_ljp_grad)
)(tstein_particles, particle_ljp_grads),
axis=0,
)
)(tstein_particles)
repulsive_force = jax.vmap(
lambda y: jnp.sum(
jax.vmap(
lambda x: self.repulsion_temperature
* self._kernel_grad(kernel, x, y)
)(tstein_particles),
axis=0,
)
)(tstein_particles)
def single_particle_grad(particle, attr_forces, rep_forces):
def _nontrivial_jac(var_name, var):
if isinstance(self.particle_transforms[var_name], IdentityTransform):
return None
return jax.jacfwd(self.particle_transforms[var_name].inv)(var)
def _update_force(attr_force, rep_force, jac):
force = attr_force.reshape(-1) + rep_force.reshape(-1)
if jac is not None:
force = force @ jac.reshape(
(_numel(jac.shape[: len(jac.shape) // 2]), -1)
)
return force.reshape(attr_force.shape)
reparam_jac = {
name: jax.tree_map(lambda var: _nontrivial_jac(name, var), variables)
for name, variables in unravel_pytree(particle).items()
}
jac_params = jax.tree_multimap(
_update_force,
unravel_pytree(attr_forces),
unravel_pytree(rep_forces),
reparam_jac,
)
jac_particle, _ = ravel_pytree(jac_params)
return jac_particle
particle_grads = (
jax.vmap(single_particle_grad)(
stein_particles, attractive_force, repulsive_force
)
/ self.num_particles
)
# 5. Decompose the monolithic particle forces back to concrete parameter values
stein_param_grads = unravel_pytree_batched(particle_grads)
# 6. Return loss and gradients (based on parameter forces)
res_grads = tree_map(lambda x: -x, {**classic_param_grads, **stein_param_grads})
return -jnp.mean(loss), res_grads
def init(self, rng_key, *args, **kwargs):
"""
:param jax.random.PRNGKey rng_key: random number generator seed.
:param args: arguments to the model / guide (these can possibly vary during
the course of fitting).
:param kwargs: keyword arguments to the model / guide (these can possibly vary
during the course of fitting).
:return: initial :data:`SteinVIState`
"""
rng_key, kernel_seed, model_seed, guide_seed = jax.random.split(rng_key, 4)
model_init = handlers.seed(self.model, model_seed)
guide_init = handlers.seed(self.guide, guide_seed)
guide_trace = handlers.trace(guide_init).get_trace(
*args, **kwargs, **self.static_kwargs
)
model_trace = handlers.trace(model_init).get_trace(
*args, **kwargs, **self.static_kwargs
)
rng_key, particle_seed = jax.random.split(rng_key)
guide_init_params = self._find_init_params(
particle_seed, self.guide, guide_trace
)
params = {}
transforms = {}
inv_transforms = {}
particle_transforms = {}
guide_param_names = set()
should_enum = False
for site in model_trace.values():
if (
"fn" in site
and site["type"] == "sample"
and not site["is_observed"]
and isinstance(site["fn"], Distribution)
and site["fn"].is_discrete
):
if site["fn"].has_enumerate_support and self.enum:
should_enum = True
else:
raise Exception(
"Cannot enumerate model with discrete variables without enumerate support"
)
# NB: params in model_trace will be overwritten by params in guide_trace
for site in chain(model_trace.values(), guide_trace.values()):
if site["type"] == "param":
transform = get_parameter_transform(site)
inv_transforms[site["name"]] = transform
transforms[site["name"]] = transform.inv
particle_transforms[site["name"]] = site.get(
"particle_transform", IdentityTransform()
)
if site["name"] in guide_init_params:
pval, _ = guide_init_params[site["name"]]
if self.classic_guide_params_fn(site["name"]):
pval = tree_map(lambda x: x[0], pval)
else:
pval = site["value"]
params[site["name"]] = transform.inv(pval)
if site["name"] in guide_trace:
guide_param_names.add(site["name"])
if should_enum:
mpn = _guess_max_plate_nesting(model_trace)
self._inference_model = enum(config_enumerate(self.model), -mpn - 1)
self.guide_param_names = guide_param_names
self.constrain_fn = partial(transform_fn, inv_transforms)
self.uconstrain_fn = partial(transform_fn, transforms)
self.particle_transforms = particle_transforms
self.particle_transform_fn = partial(transform_fn, particle_transforms)
stein_particles, _, _ = batch_ravel_pytree(
{
k: params[k]
for k, site in guide_trace.items()
if site["type"] == "param" and site["name"] in guide_init_params
},
nbatch_dims=1,
)
self.kernel_fn.init(kernel_seed, stein_particles.shape)
return SteinVIState(self.optim.init(params), rng_key)
def get_params(self, state: SteinVIState):
"""
Gets values at `param` sites of the `model` and `guide`.
:param state: current state of the optimizer.
"""
params = self.constrain_fn(self.optim.get_params(state.optim_state))
return params
def update(self, state: SteinVIState, *args, **kwargs):
"""
Take a single step of Stein (possibly on a batch / minibatch of data),
using the optimizer.
:param state: current state of Stein.
:param args: arguments to the model / guide (these can possibly vary during
the course of fitting).
:param kwargs: keyword arguments to the model / guide (these can possibly vary
during the course of fitting).
:return: tuple of `(state, loss)`.
"""
rng_key, rng_key_mcmc, rng_key_step = jax.random.split(state.rng_key, num=3)
params = self.optim.get_params(state.optim_state)
optim_state = state.optim_state
loss_val, grads = self._svgd_loss_and_grads(
rng_key_step, params, *args, **kwargs, **self.static_kwargs
)
optim_state = self.optim.update(grads, optim_state)
return SteinVIState(optim_state, rng_key), loss_val
def run(
self,
rng_key,
num_steps,
*args,
progress_bar=True,
init_state=None,
collect_fn=lambda val: val[1], # TODO: refactor
**kwargs,
):
def bodyfn(_i, info):
body_state = info[0]
return (*self.update(body_state, *info[2:], **kwargs), *info[2:])
if init_state is None:
state = self.init(rng_key, *args, **kwargs)
else:
state = init_state
loss = self.evaluate(state, *args, **kwargs)
auxiliaries, last_res = fori_collect(
0,
num_steps,
lambda info: bodyfn(0, info),
(state, loss, *args),
progbar=progress_bar,
transform=collect_fn,
return_last_val=True,
)
state = last_res[0]
return SteinVIRunResult(self.get_params(state), state, auxiliaries)
def evaluate(self, state, *args, **kwargs):
"""
Take a single step of Stein (possibly on a batch / minibatch of data).
:param state: current state of Stein.
:param args: arguments to the model / guide (these can possibly vary during
the course of fitting).
:param kwargs: keyword arguments to the model / guide.
:return: evaluate loss given the current parameter values (held within `state.optim_state`).
"""
# we split to have the same seed as `update_fn` given a state
_, _, rng_key_eval = jax.random.split(state.rng_key, num=3)
params = self.optim.get_params(state.optim_state)
loss_val, _ = self._svgd_loss_and_grads(
rng_key_eval, params, *args, **kwargs, **self.static_kwargs
)
return loss_val