# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from collections import namedtuple
from copy import deepcopy
import functools
from functools import partial
from itertools import chain
import operator
from jax import grad, numpy as jnp, random, tree, vmap
from jax.flatten_util import ravel_pytree
from numpyro import handlers
from numpyro.contrib.einstein.stein_loss import SteinLoss
from numpyro.contrib.einstein.stein_util import (
batch_ravel_pytree,
get_parameter_transform,
)
from numpyro.distributions import Distribution
from numpyro.infer.autoguide import AutoDelta, AutoGuide
from numpyro.infer.util import transform_fn
from numpyro.util import fori_collect
SteinVIState = namedtuple("SteinVIState", ["optim_state", "rng_key"])
SteinVIRunResult = namedtuple("SteinRunResult", ["params", "state", "losses"])
def _numel(shape):
return functools.reduce(operator.mul, shape, 1)
[docs]
class SteinVI:
"""Variational inference with Stein mixtures inference.
**Example:**
.. doctest::
>>> from jax import random, numpy as jnp
>>> from numpyro import sample, param, plate
>>> from numpyro.distributions import Beta, Bernoulli
>>> from numpyro.distributions.constraints import positive
>>> from numpyro.optim import Adagrad
>>> from numpyro.contrib.einstein import MixtureGuidePredictive, SteinVI, RBFKernel
>>> def model(data):
... f = sample("fairness", Beta(10, 10))
... n = data.shape[0] if data is not None else 1
... with plate("N", n):
... sample("obs", Bernoulli(f), obs=data)
>>> def guide(data):
... # Initialize all particles in the same point.
... alpha_q = param("alpha_q", 15., constraint=positive)
... # Initialize particles by sampling an Exponential distribution.
... beta_q = param("beta_q",
... lambda rng_key: random.exponential(rng_key),
... constraint=positive)
... sample("fairness", Beta(alpha_q, beta_q))
>>> data = jnp.concatenate([jnp.ones(6), jnp.zeros(4)])
>>> opt = Adagrad(step_size=0.05)
>>> k = RBFKernel()
>>> stein = SteinVI(model, guide, opt, k, num_stein_particles=2)
>>> stein_result = stein.run(random.PRNGKey(0), 200, data)
>>> params = stein_result.params
>>> # Use guide to make predictions.
>>> predictive = MixtureGuidePredictive(model, guide, params, num_samples=10, guide_sites=stein.guide_sites)
>>> samples = predictive(random.PRNGKey(1), data=None)
:param Callable model: Python callable with NumPyro primitives for the model.
:param Callable guide: Python callable with NumPyro primitives for the guide.
:param _NumPyroOptim optim: An instance of :class:`~numpyro.optim._NumpyroOptim`.
Adagrad should be preferred over Adam [1].
:param SteinKernel kernel_fn: Function that computes the reproducing kernel to use with Stein mixture
inference. We currently recommend :class:`~numpyro.contrib.einstein.RBFKernel`.
This may change as criteria for kernel selection are not well understood yet.
:param num_stein_particles: Number of particles (i.e., mixture components) in the mixture approximation.
Default is `10`.
:param num_elbo_particles: Number of Monte Carlo draws used to approximate the attractive force gradient.
More particles give better gradient approximations. Default is `10`.
:param Float loss_temperature: Scaling factor of the attractive force. Default is `1`.
:param Float repulsion_temperature: Scaling factor of the repulsive force [2].
We recommend not scaling the repulsion. Default is `1`.
:param Callable non_mixture_guide_param_fn: Predicate on names of parameters in the guide which should be optimized
using one particle. This could be parameters for large normal networks or other transformation.
Default excludes all parameters from this option.
:param static_kwargs: Static keyword arguments for the model and guide. These arguments cannot change
during inference.
**References:** (MLA style)
1. Liu, Chang, et al. "Understanding and Accelerating Particle-Based Variational Inference."
International Conference on Machine Learning. PMLR, 2019.
2. Wang, Dilin, and Qiang Liu. "Nonlinear Stein Variational Gradient Descent for Learning Diversified Mixture Models."
International Conference on Machine Learning. PMLR, 2019.
""" # noqa: E501
def __init__(
self,
model,
guide,
optim,
kernel_fn,
num_stein_particles=10,
num_elbo_particles=10,
loss_temperature=1.0,
repulsion_temperature=1.0,
non_mixture_guide_params_fn=lambda name: False,
**static_kwargs,
):
if isinstance(guide, AutoGuide):
not_comptaible_guides = [
"AutoIAFNormal",
"AutoBNAFNormal",
"AutoDAIS",
"AutoSemiDAIS",
"AutoSurrogateLikelihoodDAIS",
]
guide_name = guide.__class__.__name__
assert guide_name not in not_comptaible_guides, (
f"SteinVI currently not compatible with {guide_name}. "
f"If you have a use case, feel free to open an issue."
)
init_loc_error_message = (
"SteinVI is not compatible with init_to_feasible, init_to_value, "
"and init_to_uniform with radius=0. If you have a use case, "
"feel free to open an issue."
)
if isinstance(guide.init_loc_fn, partial):
init_fn_name = guide.init_loc_fn.func.__name__
if init_fn_name == "init_to_uniform":
assert (
guide.init_loc_fn.keywords.get("radius", None) != 0.0
), init_loc_error_message
else:
init_fn_name = guide.init_loc_fn.__name__
assert init_fn_name not in [
"init_to_feasible",
"init_to_value",
], init_loc_error_message
self._inference_model = model
self.model = model
self.guide = guide
self._init_guide = deepcopy(guide)
self.optim = optim
self.stein_loss = SteinLoss( # TODO: @OlaRonning handle enum
elbo_num_particles=num_elbo_particles,
stein_num_particles=num_stein_particles,
)
self.kernel_fn = kernel_fn
self.static_kwargs = static_kwargs
self.num_stein_particles = num_stein_particles
self.loss_temperature = loss_temperature
self.repulsion_temperature = repulsion_temperature
self.non_mixture_params_fn = non_mixture_guide_params_fn
self.guide_sites = None
self.constrain_fn = None
self.uconstrain_fn = None
self.particle_transform_fn = None
self.particle_transforms = None
def _apply_kernel(self, kernel, x, y, v):
if self.kernel_fn.mode == "norm" or self.kernel_fn.mode == "vector":
return kernel(x, y) * v
else:
return kernel(x, y) @ v
def _kernel_grad(self, kernel, x, y):
if self.kernel_fn.mode == "norm":
return grad(lambda x: kernel(x, y))(x)
elif self.kernel_fn.mode == "vector":
return vmap(lambda i: grad(lambda x: kernel(x, y)[i])(x)[i])(
jnp.arange(x.shape[0])
)
else:
return vmap(
lambda a: jnp.sum(
vmap(lambda b: grad(lambda x: kernel(x, y)[a, b])(x)[b])(
jnp.arange(x.shape[0])
)
)
)(jnp.arange(x.shape[0]))
def _param_size(self, param):
if isinstance(param, tuple) or isinstance(param, list):
return sum(map(self._param_size, param))
return param.size
def _calc_particle_info(self, uparams, num_particles, start_index=0):
uparam_keys = list(uparams.keys())
uparam_keys.sort()
res = {}
end_index = start_index
for k in uparam_keys:
if isinstance(uparams[k], dict):
res_sub, end_index = self._calc_particle_info(
uparams[k], num_particles, start_index
)
res[k] = res_sub
else:
end_index = start_index + self._param_size(uparams[k]) // num_particles
res[k] = (start_index, end_index)
start_index = end_index
return res, end_index
def _find_init_params(self, particle_seed, inner_guide, model_args, model_kwargs):
def local_trace(key):
guide = deepcopy(inner_guide)
with handlers.seed(rng_seed=key), handlers.trace() as mixture_trace:
guide(*model_args, **model_kwargs)
init_params = {
name: site["value"]
for name, site in mixture_trace.items()
if site.get("type") == "param"
}
return init_params
return vmap(local_trace)(random.split(particle_seed, self.num_stein_particles))
def _svgd_loss_and_grads(self, rng_key, unconstr_params, *args, **kwargs):
# 0. Separate model and guide parameters, since only guide parameters are updated using Stein
non_mixture_uparams = { # Includes any marked guide parameters and all model parameters
p: v
for p, v in unconstr_params.items()
if p not in self.guide_sites or self.non_mixture_params_fn(p)
}
stein_uparams = {
p: v for p, v in unconstr_params.items() if p not in non_mixture_uparams
}
# 1. Collect each guide parameter into monolithic particles that capture correlations
# between parameter values across each individual particle
stein_particles, unravel_pytree, unravel_pytree_batched = batch_ravel_pytree(
stein_uparams, nbatch_dims=1
)
particle_info, _ = self._calc_particle_info(
stein_uparams, stein_particles.shape[0]
)
attractive_key, classic_key = random.split(rng_key)
def particle_transform_fn(particle):
params = unravel_pytree(particle)
ctparams = self.constrain_fn(self.particle_transform_fn(params))
ctparticle, _ = ravel_pytree(ctparams)
return ctparticle
# 2. Calculate gradients for each particle
def kernel_particles_loss_fn(rng_key, particles):
particle_keys = random.split(rng_key, self.stein_loss.stein_num_particles)
grads = vmap(
lambda i: grad(
lambda particle: self.stein_loss.particle_loss(
rng_key=particle_keys[i],
model=handlers.scale(
self._inference_model, self.loss_temperature
),
guide=self.guide,
selected_particle=self.constrain_fn(unravel_pytree(particle)),
unravel_pytree=unravel_pytree,
flat_particles=vmap(particle_transform_fn)(particles),
select_index=i,
model_args=args,
model_kwargs=kwargs,
param_map=self.constrain_fn(non_mixture_uparams),
)
)(particles[i])
)(jnp.arange(self.stein_loss.stein_num_particles))
return grads
# 2.1 Compute particle gradients (for attractive force)
particle_ljp_grads = kernel_particles_loss_fn(attractive_key, stein_particles)
# 2.3 Lift particles to constraint space
ctstein_particles = vmap(particle_transform_fn)(stein_particles)
# 2.4 Compute non-mixture parameter gradients
non_mixture_param_grads = grad(
lambda cps: -self.stein_loss.loss(
classic_key,
self.constrain_fn(cps),
handlers.scale(self._inference_model, self.loss_temperature),
self.guide,
unravel_pytree_batched(ctstein_particles),
*args,
**kwargs,
)
)(non_mixture_uparams)
# 3. Calculate kernel of particles
def loss_fn(particle, i):
return self.stein_loss.particle_loss(
rng_key=rng_key,
model=handlers.scale(self._inference_model, self.loss_temperature),
guide=self.guide,
selected_particle=self.constrain_fn(unravel_pytree(particle)),
unravel_pytree=unravel_pytree,
flat_particles=ctstein_particles,
select_index=i,
model_args=args,
model_kwargs=kwargs,
param_map=self.constrain_fn(non_mixture_uparams),
)
kernel = self.kernel_fn.compute(
rng_key, stein_particles, particle_info, loss_fn
)
# 4. Calculate the attractive force and repulsive force on the particles
attractive_force = vmap(
lambda y: jnp.sum(
vmap(
lambda x, x_ljp_grad: self._apply_kernel(kernel, x, y, x_ljp_grad)
)(stein_particles, particle_ljp_grads),
axis=0,
)
)(stein_particles)
repulsive_force = vmap(
lambda y: jnp.mean(
vmap(
lambda x: self.repulsion_temperature
* self._kernel_grad(kernel, x, y)
)(stein_particles),
axis=0,
)
)(stein_particles)
# 6. Compute the stein force
particle_grads = attractive_force + repulsive_force
# 7. Decompose the monolithic particle forces back to concrete parameter values
stein_param_grads = unravel_pytree_batched(particle_grads)
# 8. Return loss and gradients (based on parameter forces)
res_grads = tree.map(
lambda x: -x, {**non_mixture_param_grads, **stein_param_grads}
)
return jnp.linalg.norm(particle_grads), res_grads
def init(self, rng_key, *args, **kwargs):
"""Register random variable transformations, constraints and determine initialize positions of the particles.
:param jax.random.PRNGKey rng_key: Random number generator seed.
:param args: Positional arguments to the model and guide.
:param kwargs: Keyword arguments to the model and guide.
:return: Initial :data:`SteinVIState`.
"""
rng_key, kernel_seed, model_seed, guide_seed, particle_seed = random.split(
rng_key, 5
)
model_init = handlers.seed(self.model, model_seed)
model_trace = handlers.trace(model_init).get_trace(
*args, **kwargs, **self.static_kwargs
)
guide_init_params = self._find_init_params(
particle_seed, self._init_guide, args, kwargs
)
guide_init = handlers.seed(self.guide, guide_seed)
guide_trace = handlers.trace(guide_init).get_trace(
*args, **kwargs, **self.static_kwargs
)
params = {}
transforms = {}
inv_transforms = {}
particle_transforms = {}
guide_param_names = set()
for site in model_trace.values():
if (
"fn" in site
and site["type"] == "sample"
and not site["is_observed"]
and isinstance(site["fn"], Distribution)
and site["fn"].is_discrete
):
if site["fn"].has_enumerate_support:
raise Exception(
"Cannot enumerate model with discrete variables without enumerate support"
)
# NB: params in model_trace will be overwritten by params in guide_trace
for site in chain(model_trace.values(), guide_trace.values()):
if site["type"] == "param":
transform = get_parameter_transform(site)
inv_transforms[site["name"]] = transform
transforms[site["name"]] = transform.inv
particle_transforms[site["name"]] = transform
if site["name"] in guide_init_params:
pval = guide_init_params[site["name"]]
if self.non_mixture_params_fn(site["name"]):
pval = tree.map(lambda x: x[0], pval)
else:
pval = site["value"]
params[site["name"]] = transform.inv(pval)
if site["name"] in guide_trace:
guide_param_names.add(site["name"])
self.guide_sites = guide_param_names
self.constrain_fn = partial(transform_fn, inv_transforms)
self.uconstrain_fn = partial(transform_fn, transforms)
self.particle_transforms = particle_transforms
self.particle_transform_fn = partial(transform_fn, particle_transforms)
stein_particles, _, _ = batch_ravel_pytree(
{
k: params[k]
for k, site in guide_trace.items()
if site["type"] == "param" and site["name"] in guide_init_params
},
nbatch_dims=1,
)
self.kernel_fn.init(kernel_seed, stein_particles.shape)
return SteinVIState(self.optim.init(params), rng_key)
def get_params(self, state: SteinVIState):
"""Gets values at `param` sites of the `model` and `guide`.
:param SteinVIState state: Current state of optimization.
:return: Constraint parameters (i.e., particles).
"""
params = self.constrain_fn(self.optim.get_params(state.optim_state))
return params
def update(self, state: SteinVIState, *args, **kwargs) -> SteinVIState:
"""Take a single step of SteinVI using the optimizer. We recommend using
the run method instead of update.
:param SteinVIState state: Current state of inference.
:param args: Position arguments to the model and guide.
:param kwargs: Keyword arguments to the model and guide.
:return: next :data:`SteinVIState`
"""
rng_key, rng_key_mcmc, rng_key_step = random.split(state.rng_key, num=3)
params = self.optim.get_params(state.optim_state)
optim_state = state.optim_state
loss_val, grads = self._svgd_loss_and_grads(
rng_key_step, params, *args, **kwargs, **self.static_kwargs
)
optim_state = self.optim.update(grads, optim_state)
return SteinVIState(optim_state, rng_key), loss_val
def setup_run(self, rng_key, num_steps, args, init_state, kwargs):
if init_state is None:
state = self.init(rng_key, *args, **kwargs)
else:
state = init_state
loss = self.evaluate(state, *args, **kwargs)
info_init = (state, loss)
def step(info):
state, loss = info
return self.update(state, *args, **kwargs) # uses closure!
def collect(info):
_, loss = info
return loss
def extract(info):
state, _ = info
return state
def diagnostic(info):
_, loss = info
return f"Stein force {loss:.2f}."
return step, diagnostic, collect, extract, info_init
def run(
self,
rng_key,
num_steps,
*args,
progress_bar=True,
init_state=None,
**kwargs,
):
"""Run SteinVI inference.
:param jax.random.PRNGKey rng_key: Random number generator seed.
:param int num_steps: Number of steps to optimize.
:param *args: Positional arguments to the model and guide.
:param bool progress_bar: Use a progress bar. Default is `True`.
Inference is faster with `False`.
:param SteinVIState init_state: Initial state of inference.
Default is ``None``, which will initialize using init before running inference.
:param **kwargs: Keyword arguments to the model and guide.
"""
step, diagnostic, collect, extract, init_info = self.setup_run(
rng_key, num_steps, args, init_state, kwargs
)
auxiliaries, last_res = fori_collect(
0,
num_steps,
step,
init_info,
progbar=progress_bar,
transform=collect,
return_last_val=True,
diagnostics_fn=diagnostic if progress_bar else None,
)
state = extract(last_res)
return SteinVIRunResult(self.get_params(state), state, auxiliaries)
def evaluate(self, state: SteinVIState, *args, **kwargs):
"""Take a single step of Stein (possibly on a batch / minibatch of data).
:param SteinVIState state: Current state of inference.
:param args: Positional arguments to the model and guide.
:param kwargs: Keyword arguments to the model and guide.
:return: Normed Stein force given by :data:`SteinVIState`.
"""
# we split to have the same seed as `update_fn` given a state
_, _, rng_key_eval = random.split(state.rng_key, num=3)
params = self.optim.get_params(state.optim_state)
normed_stein_force, _ = self._svgd_loss_and_grads(
rng_key_eval, params, *args, **kwargs, **self.static_kwargs
)
return normed_stein_force
[docs]
class SVGD(SteinVI):
"""Stein variational gradient descent [1].
**Example:**
.. doctest::
>>> from jax import random, numpy as jnp
>>> from numpyro import sample, param, plate
>>> from numpyro.distributions import Beta, Bernoulli
>>> from numpyro.distributions.constraints import positive
>>> from numpyro.optim import Adagrad
>>> from numpyro.contrib.einstein import SVGD, RBFKernel
>>> from numpyro.infer import Predictive
>>> def model(data):
... f = sample("fairness", Beta(10, 10))
... n = data.shape[0] if data is not None else 1
... with plate("N", n):
... sample("obs", Bernoulli(f), obs=data)
>>> data = jnp.concatenate([jnp.ones(6), jnp.zeros(4)])
>>> opt = Adagrad(step_size=0.05)
>>> k = RBFKernel()
>>> svgd = SVGD(model, opt, k, num_stein_particles=2)
>>> svgd_result = svgd.run(random.PRNGKey(0), 200, data)
>>> params = svgd_result.params
>>> predictive = Predictive(model, guide=svgd.guide, params=params, num_samples=10, batch_ndims=1)
>>> samples = predictive(random.PRNGKey(1), data=None)
:param Callable model: Python callable with NumPyro primitives for the model.
:param Callable guide: Python callable with NumPyro primitives for the guide.
:param _NumPyroOptim optim: An instance of :class:`~numpyro.optim._NumpyroOptim`.
Adagrad should be preferred over Adam [1].
:param SteinKernel kernel_fn: Function that computes the reproducing kernel to use with SVGD.
We currently recommend :class:`~numpyro.contrib.einstein.RBFKernel`. This may change as criteria for
kernel selection are not well understood yet.
:param num_stein_particles: Number of particles (i.e., mixture components) in the mixture approximation.
Default is 10.
:param Dict guide_kwargs: Keyword arguments for :class:`~numpyro.infer.autoguide.AutoDelta`.
Default behaviour is the same as the default for :class:`~numpyro.infer.autoguide.AutoDelta`.
Usage::
opt = Adagrad(step_size=0.05)
k = RBFKernel()
svgd = SVGD(model, opt, k, guide_kwargs={'init_loc_fn': partial(init_to_uniform, radius=0.1)})
:param Dict static_kwargs: Static keyword arguments for the model and guide. These arguments cannot
change during inference.
**References:** (MLA style)
1. Liu, Qiang, and Dilin Wang. "Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm."
Advances in neural information processing systems 29 (2016).
"""
def __init__(
self,
model,
optim,
kernel_fn,
num_stein_particles=10,
guide_kwargs={},
**static_kwargs,
):
super().__init__(
model=model,
guide=AutoDelta(model, **guide_kwargs),
optim=optim,
kernel_fn=kernel_fn,
num_stein_particles=num_stein_particles,
# With a Delta guide we only need one draw
# per particle to get its contribution to the expectation.
num_elbo_particles=1,
loss_temperature=1.0 / float(num_stein_particles),
# For SVGD repulsion temperature != 1 changes the
# target posterior so we keep it fixed at 1.
repulsion_temperature=1.0,
non_mixture_guide_params_fn=lambda name: False,
**static_kwargs,
)
[docs]
class ASVGD(SVGD):
"""Annealing Stein variational gradient descent [1].
**Example:**
.. doctest::
>>> from jax import random, numpy as jnp
>>> from numpyro import sample, param, plate
>>> from numpyro.distributions import Beta, Bernoulli
>>> from numpyro.distributions.constraints import positive
>>> from numpyro.optim import Adagrad
>>> from numpyro.contrib.einstein import ASVGD, RBFKernel
>>> from numpyro.infer import Predictive
>>> def model(data):
... f = sample("fairness", Beta(10, 10))
... n = data.shape[0] if data is not None else 1
... with plate("N", n):
... sample("obs", Bernoulli(f), obs=data)
>>> data = jnp.concatenate([jnp.ones(6), jnp.zeros(4)])
>>> opt = Adagrad(step_size=0.05)
>>> k = RBFKernel()
>>> asvgd = ASVGD(model, opt, k, num_stein_particles=2)
>>> asvgd_result = asvgd.run(random.PRNGKey(0), 200, data)
>>> params = asvgd_result.params
>>> predictive = Predictive(model, guide=asvgd.guide, params=params, num_samples=10, batch_ndims=1)
>>> samples = predictive(random.PRNGKey(1), data=None)
:param Callable model: Python callable with NumPyro primitives for the model.
:param Callable guide: Python callable with NumPyro primitives for the guide.
:param _NumPyroOptim optim: An instance of :class:`~numpyro.optim._NumpyroOptim`.
Adagrad should be preferred over Adam [1].
:param SteinKernel kernel_fn: Function that computes the reproducing kernel to use with ASVGD.
We currently recommend :class:`~numpyro.contrib.einstein.RBFKernel`.
This may change as criteria for kernel selection are not well understood yet.
:param num_stein_particles: Number of particles (i.e., mixture components) in the mixture approximation.
Default is `10`.
:param num_cycles: The total number of cycles during inference. This corresponds to :math:`C` in eq. 4 of [1].
Default is `10`.
:param trans_speed: Speed of transition between two phases during inference. This corresponds to :math:`p` in eq. 4
of [1]. Default is `10`.
:param Dict guide_kwargs: Keyword arguments for :class:`~numpyro.infer.autoguide.AutoDelta`.
Default behaviour is the same as the default for :class:`~numpyro.infer.autoguide.AutoDelta`.
Usage::
opt = Adagrad(step_size=0.05)
k = RBFKernel()
asvgd = ASVGD(model, opt, k, guide_kwargs={'init_loc_fn': partial(init_to_uniform, radius=0.1)})
:param Dict static_kwargs: Static keyword arguments for the model and guide. These arguments cannot
change during inference.
**References:** (MLA style)
1. D'Angelo, Francesco, and Vincent Fortuin. "Annealed Stein Variational Gradient Descent."
Third Symposium on Advances in Approximate Bayesian Inference, 2021.
"""
def __init__(
self,
model,
optim,
kernel_fn,
num_stein_particles=10,
num_cycles=10,
trans_speed=10,
guide_kwargs={},
**static_kwargs,
):
self.num_cycles = num_cycles
self.trans_speed = trans_speed
super().__init__(
model,
optim,
kernel_fn,
num_stein_particles,
guide_kwargs,
**static_kwargs,
)
@staticmethod
def _cyclical_annealing(num_steps: int, num_cycles: int, trans_speed: int):
"""Cyclical annealing schedule as in eq. 4 of [1].
**References** (MLA)
1. D'Angelo, Francesco, and Vincent Fortuin. "Annealed Stein Variational Gradient Descent."
Third Symposium on Advances in Approximate Bayesian Inference, 2021.
:param num_steps: The total number of steps. Corresponds to $T$ in eq. 4 of [1].
:param num_cycles: The total number of cycles. Corresponds to $C$ in eq. 4 of [1].
:param trans_speed: Speed of transition between two phases. Corresponds to $p$ in eq. 4 of [1].
"""
norm = float(num_steps + 1) / float(num_cycles)
cycle_len = num_steps // num_cycles
last_start = (num_cycles - 1) * cycle_len
def cycle_fn(t):
last_cycle = t // last_start
return (1 - last_cycle) * (
((t % cycle_len) + 1) / norm
) ** trans_speed + last_cycle
return cycle_fn
def setup_run(self, rng_key, num_steps, args, init_state, kwargs):
cyc_fn = ASVGD._cyclical_annealing(num_steps, self.num_cycles, self.trans_speed)
(
istep,
idiag,
icol,
iext,
iinit,
) = super().setup_run(
rng_key,
num_steps,
args,
init_state,
kwargs,
)
def step(info):
t, iinfo = info[0], info[-1]
self.loss_temperature = cyc_fn(t) / float(self.num_stein_particles)
return (t + 1, istep(iinfo))
def diagnostic(info):
_, iinfo = info
return idiag(iinfo)
def collect(info):
_, iinfo = info
return icol(iinfo)
def extract_state(info):
_, iinfo = info
return iext(iinfo)
info_init = (0, iinit)
return step, diagnostic, collect, extract_state, info_init