# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from collections import namedtuple
import warnings
import jax
from jax import random, vmap
import jax.numpy as jnp
from jax.scipy.stats import gaussian_kde
import numpyro.distributions as dist
from numpyro.infer.ensemble_util import batch_ravel_pytree, get_nondiagonal_indices
from numpyro.infer.initialization import init_to_uniform
from numpyro.infer.mcmc import MCMCKernel
from numpyro.infer.util import initialize_model
from numpyro.util import identity, is_prng_key
EnsembleSamplerState = namedtuple(
"EnsembleSamplerState", ["z", "inner_state", "rng_key"]
)
"""
A :func:`~collections.namedtuple` consisting of the following fields:
- **z** - Python collection representing values (unconstrained samples from
the posterior) at latent sites.
- **inner_state** - A namedtuple containing information needed to update half the ensemble.
- **rng_key** - random number generator seed used for generating proposals, etc.
"""
AIESState = namedtuple("AIESState", ["i", "accept_prob", "mean_accept_prob", "rng_key"])
"""
A :func:`~collections.namedtuple` consisting of the following fields.
- **i** - iteration.
- **accept_prob** - Acceptance probability of the proposal. Note that ``z``
does not correspond to the proposal if it is rejected.
- **mean_accept_prob** - Mean acceptance probability until current iteration
during warmup adaptation or sampling (for diagnostics).
- **rng_key** - random number generator seed used for generating proposals, etc.
"""
ESSState = namedtuple("ESSState", ["i",
"n_expansions",
"n_contractions",
"mu",
"rng_key"
]
)
"""
A :func:`~collections.namedtuple` used as an inner state for Ensemble Sampler.
This consists of the following fields:
- **i** - iteration.
- **n_expansions** - number of expansions in the current batch. Used for tuning mu.
- **n_contractions** - number of contractions in the current batch. Used for tuning mu.
- **mu** - Scale factor. This is tuned if tune_mu=True.
- **rng_key** - random number generator seed used for generating proposals, etc.
"""
[docs]
class EnsembleSampler(MCMCKernel, ABC):
"""
Abstract class for ensemble samplers. Each MCMC sample is divided into two sub-iterations
in which half of the ensemble is updated.
:param model: Python callable containing Pyro :mod:`~numpyro.primitives`.
If model is provided, `potential_fn` will be inferred using the model.
:param potential_fn: Python callable that computes the potential energy
given input parameters. The input parameters to `potential_fn` can be
any python collection type, provided that `init_params` argument to
:meth:`init` has the same type.
:param bool randomize_split: whether or not to permute the chain order at each iteration.
:param callable init_strategy: a per-site initialization function.
See :ref:`init_strategy` section for available functions.
"""
def __init__(self, model=None, potential_fn=None, *, randomize_split, init_strategy):
if not (model is None) ^ (potential_fn is None):
raise ValueError("Only one of `model` or `potential_fn` must be specified.")
self._model = model
self._potential_fn = potential_fn
self._batch_log_density = None
# unravel an (n_chains, n_params) Array into a pytree and
# evaluate the log density at each chain
# --- other hyperparams go here
self._num_chains = None # must be an even number >= 2
self._randomize_split = randomize_split
# ---
self._init_strategy = init_strategy
self._postprocess_fn = None
@property
def model(self):
return self._model
@property
def sample_field(self):
return "z"
@property
def is_ensemble_kernel(self):
return True
[docs]
@abstractmethod
def init_inner_state(self, rng_key):
"""return inner_state"""
raise NotImplementedError
[docs]
@abstractmethod
def update_active_chains(self, active, inactive, inner_state):
"""return (updated active set of chains, updated inner state)"""
raise NotImplementedError
def _init_state(self, rng_key, model_args, model_kwargs, init_params):
if self._model is not None:
new_params_info, potential_fn_gen, self._postprocess_fn, _ = initialize_model(
rng_key,
self._model,
dynamic_args=True,
init_strategy=self._init_strategy,
model_args=model_args,
model_kwargs=model_kwargs,
validate_grad=False,
)
new_init_params = new_params_info[0]
self._potential_fn = potential_fn_gen(*model_args, **model_kwargs)
if init_params is None:
init_params = new_init_params
flat_params, unravel_fn = batch_ravel_pytree(init_params)
self._batch_log_density = lambda z: -vmap(self._potential_fn)(unravel_fn(z))
if self._num_chains < 2 * flat_params.shape[1]:
warnings.warn("Setting n_chains to at least 2*n_params is strongly recommended.\n"
f"n_chains: {self._num_chains}, n_params: {flat_params.shape[1]}")
return init_params
[docs]
def init(
self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={}
):
assert not is_prng_key(
rng_key
), ("EnsembleSampler only supports chain_method='vectorized' with num_chains > 1.\n"
"If you want to run chains in parallel, please raise a github issue.")
assert rng_key.shape[0] % 2 == 0, "Number of chains must be even."
self._num_chains = rng_key.shape[0]
if self._potential_fn and init_params is None:
raise ValueError(
"Valid value of `init_params` must be provided with `potential_fn`."
)
if init_params is not None:
assert all([param.shape[0] == self._num_chains
for param in jax.tree_util.tree_leaves(init_params)]), (
"The batch dimension of each param must match n_chains")
rng_key, rng_key_inner_state, rng_key_init_model = random.split(rng_key[0], 3)
rng_key_init_model = random.split(rng_key_init_model, self._num_chains)
init_params = self._init_state(
rng_key_init_model, model_args, model_kwargs, init_params
)
self._num_warmup = num_warmup
return EnsembleSamplerState(
init_params, self.init_inner_state(rng_key_inner_state), rng_key
)
[docs]
def postprocess_fn(self, args, kwargs):
if self._postprocess_fn is None:
return identity
return self._postprocess_fn(*args, **kwargs)
[docs]
def sample(self, state, model_args, model_kwargs):
z, inner_state, rng_key = state
rng_key, _ = random.split(rng_key)
z_flat, unravel_fn = batch_ravel_pytree(z)
if self._randomize_split:
z_flat = random.permutation(rng_key, z_flat, axis=0)
split_ind = self._num_chains // 2
def body_fn(i, z_flat_inner_state):
z_flat, inner_state = z_flat_inner_state
active, inactive = jax.lax.cond(i == 0,
lambda x: (x[:split_ind], x[split_ind:]),
lambda x: (x[split_ind:], x[split_ind:]),
z_flat)
z_updates, inner_state = self.update_active_chains(active, inactive, inner_state)
z_flat = jax.lax.cond(i == 0,
lambda x: x.at[:split_ind].set(z_updates),
lambda x: x.at[split_ind:].set(z_updates),
z_flat)
return (z_flat, inner_state)
z_flat, inner_state = jax.lax.fori_loop(0, 2, body_fn, (z_flat, inner_state))
return EnsembleSamplerState(unravel_fn(z_flat), inner_state, rng_key)
[docs]
class AIES(EnsembleSampler):
"""
Affine-Invariant Ensemble Sampling: a gradient free method that informs Metropolis-Hastings
proposals by sharing information between chains. Suitable for low to moderate dimensional models.
Generally, `num_chains` should be at least twice the dimensionality of the model.
.. note:: This kernel must be used with `num_chains` > 1 and `chain_method="vectorized`
in :class:`MCMC`. The number of chains must be divisible by 2.
**References:**
1. *emcee: The MCMC Hammer* (https://iopscience.iop.org/article/10.1086/670067),
Daniel Foreman-Mackey, David W. Hogg, Dustin Lang, and Jonathan Goodman.
:param model: Python callable containing Pyro :mod:`~numpyro.primitives`.
If model is provided, `potential_fn` will be inferred using the model.
:param potential_fn: Python callable that computes the potential energy
given input parameters. The input parameters to `potential_fn` can be
any python collection type, provided that `init_params` argument to
:meth:`init` has the same type.
:param bool randomize_split: whether or not to permute the chain order at each iteration.
Defaults to False.
:param moves: a dictionary mapping moves to their respective probabilities of being selected.
Valid keys are `AIES.DEMove()` and `AIES.StretchMove()`. Both tend to work well in practice.
If the sum of probabilites exceeds 1, the probabilities will be normalized. Defaults to `{AIES.DEMove(): 1.0}`.
:param callable init_strategy: a per-site initialization function.
See :ref:`init_strategy` section for available functions.
**Example**
.. doctest::
>>> import jax
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.infer import MCMC, AIES
>>> def model():
... x = numpyro.sample("x", dist.Normal().expand([10]))
... numpyro.sample("obs", dist.Normal(x, 1.0), obs=jnp.ones(10))
>>>
>>> kernel = AIES(model, moves={AIES.DEMove() : 0.5,
... AIES.StretchMove() : 0.5})
>>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=20, chain_method='vectorized')
>>> mcmc.run(jax.random.PRNGKey(0))
"""
def __init__(self, model=None, potential_fn=None, randomize_split=False, moves=None, init_strategy=init_to_uniform):
if not moves:
self._moves = [AIES.DEMove()]
self._weights = jnp.array([1.0])
else:
self._moves = list(moves.keys())
self._weights = jnp.array([weight for weight in moves.values()]) / len(moves)
assert all([hasattr(move, '__call__') for move in self._moves]), (
"Each move must be a callable (one of AIES.DEMove(), or AIES.StretchMove()).")
assert jnp.all(self._weights >= 0), "Each specified move must have probability >= 0"
super().__init__(model,
potential_fn,
randomize_split=randomize_split,
init_strategy=init_strategy)
[docs]
def get_diagnostics_str(self, state):
return "acc. prob={:.2f}".format(state.inner_state.mean_accept_prob)
[docs]
def init_inner_state(self, rng_key):
# XXX hack -- we don't know num_chains until we init the inner state
self._moves = [move(self._num_chains) if move.__name__ == 'make_de_move'
else move for move in self._moves]
return AIESState(jnp.array(0.0), jnp.array(0.0), jnp.array(0.0), rng_key)
[docs]
def update_active_chains(self, active, inactive, inner_state):
i, _, mean_accept_prob, rng_key = inner_state
rng_key, move_key, proposal_key, accept_key = random.split(rng_key, 4)
move_i = random.choice(move_key, len(self._moves), p=self._weights)
proposal, factors = jax.lax.switch(
move_i, self._moves, proposal_key, active, inactive
)
# --- evaluate the proposal ---
log_accept_prob = (
factors
+ self._batch_log_density(proposal)
- self._batch_log_density(active)
)
accepted = dist.Uniform().sample(accept_key, (active.shape[0],)) < jnp.exp(
log_accept_prob
)
updated_active_chains = jnp.where(accepted[:, jnp.newaxis], proposal, active)
accept_prob = jnp.count_nonzero(accepted) / accepted.shape[0]
itr = i + 0.5
n = jnp.where(i < self._num_warmup, itr, itr - self._num_warmup)
mean_accept_prob = mean_accept_prob + (accept_prob - mean_accept_prob) / n
return updated_active_chains, AIESState(
itr, accept_prob, mean_accept_prob, rng_key
)
[docs]
@staticmethod
def DEMove(sigma=1.0e-5, g0=None):
"""
A proposal using differential evolution.
This `Differential evolution proposal
<http://www.stat.columbia.edu/~gelman/stuff_for_blog/cajo.pdf>`_ is
implemented following `Nelson et al. (2013)
<https://doi.org/10.1088/0067-0049/210/1/11>`_.
:param sigma: (optional)
The standard deviation of the Gaussian used to stretch the proposal vector.
Defaults to `1.0.e-5`.
:param g0 (optional):
The mean stretch factor for the proposal vector. By default,
it is `2.38 / sqrt(2*ndim)` as recommended by the two references.
"""
def make_de_move(n_chains):
PAIRS = get_nondiagonal_indices(n_chains // 2)
def de_move(rng_key, active, inactive):
pairs_key, gamma_key = random.split(rng_key)
n_active_chains, n_params = inactive.shape
# XXX: if we pass in n_params to parent scope we don't need to
# recompute this each time
g = 2.38 / jnp.sqrt(2.0 * n_params) if not g0 else g0
selected_pairs = random.choice(pairs_key, PAIRS, shape=(n_active_chains,))
# Compute diff vectors
diffs = jnp.diff(inactive[selected_pairs], axis=1).squeeze(axis=1)
# Sample a gamma value for each walker following Nelson et al. (2013)
gamma = dist.Normal(g, g * sigma).sample(
gamma_key, sample_shape=(n_active_chains, 1)
)
# In this way, sigma is the standard deviation of the distribution of gamma,
# instead of the standard deviation of the distribution of the proposal as proposed by Ter Braak (2006).
# Otherwise, sigma should be tuned for each dimension, which confronts the idea of affine-invariance.
proposal = active + gamma * diffs
return proposal, jnp.zeros(n_active_chains)
return de_move
return make_de_move
[docs]
@staticmethod
def StretchMove(a=2.0):
"""
A `Goodman & Weare (2010)
<https://msp.org/camcos/2010/5-1/p04.xhtml>`_ "stretch move" with
parallelization as described in `Foreman-Mackey et al. (2013)
<https://arxiv.org/abs/1202.3665>`_.
:param a: (optional)
The stretch scale parameter. (default: ``2.0``)
"""
def stretch_move(rng_key, active, inactive):
n_active_chains, n_params = active.shape
unif_key, idx_key = random.split(rng_key)
zz = (
(a - 1.0) * random.uniform(unif_key, shape=(n_active_chains,)) + 1
) ** 2.0 / a
factors = (n_params - 1.0) * jnp.log(zz)
r_idxs = random.randint(
idx_key, shape=(n_active_chains,), minval=0, maxval=n_active_chains
)
proposal = inactive[r_idxs] - (inactive[r_idxs] - active) * zz[:, jnp.newaxis]
return proposal, factors
return stretch_move
[docs]
class ESS(EnsembleSampler):
"""
Ensemble Slice Sampling: a gradient free method that finds better slice sampling directions
by sharing information between chains. Suitable for low to moderate dimensional models.
Generally, `num_chains` should be at least twice the dimensionality of the model.
.. note:: This kernel must be used with `num_chains` > 1 and `chain_method="vectorized`
in :class:`MCMC`. The number of chains must be divisible by 2.
**References:**
1. *zeus: a PYTHON implementation of ensemble slice sampling for efficient Bayesian parameter inference* (https://academic.oup.com/mnras/article/508/3/3589/6381726),
Minas Karamanis, Florian Beutler, and John A. Peacock.
2. *Ensemble slice sampling* (https://link.springer.com/article/10.1007/s11222-021-10038-2),
Minas Karamanis, Florian Beutler.
:param model: Python callable containing Pyro :mod:`~numpyro.primitives`.
If model is provided, `potential_fn` will be inferred using the model.
:param potential_fn: Python callable that computes the potential energy
given input parameters. The input parameters to `potential_fn` can be
any python collection type, provided that `init_params` argument to
:meth:`init` has the same type.
:param bool randomize_split: whether or not to permute the chain order at each iteration.
Defaults to True.
:param moves: a dictionary mapping moves to their respective probabilities of being selected.
If the sum of probabilites exceeds 1, the probabilities will be normalized. Valid keys include:
`ESS.DifferentialMove()` -> default proposal, works well along a wide range of target distributions,
`ESS.GaussianMove()` -> for approximately normally distributed targets,
`ESS.KDEMove()` -> for multimodal posteriors - requires large `num_chains`, and they must be well initialized
`ESS.RandomMove()` -> no chain interaction, useful for debugging.
Defaults to `{ESS.DifferentialMove(): 1.0}`.
:param int max_steps: number of maximum stepping-out steps per sample. Defaults to 10,000.
:param int max_iter: number of maximum expansions/contractions per sample. Defaults to 10,000.
:param float init_mu: initial scale factor. Defaults to 1.0.
:param bool tune_mu: whether or not to tune the initial scale factor. Defaults to True.
:param callable init_strategy: a per-site initialization function.
See :ref:`init_strategy` section for available functions.
**Example**
.. doctest::
>>> import jax
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.infer import MCMC, ESS
>>> def model():
... x = numpyro.sample("x", dist.Normal().expand([10]))
... numpyro.sample("obs", dist.Normal(x, 1.0), obs=jnp.ones(10))
>>>
>>> kernel = ESS(model, moves={ESS.DifferentialMove() : 0.8,
... ESS.RandomMove() : 0.2})
>>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=20, chain_method='vectorized')
>>> mcmc.run(jax.random.PRNGKey(0))
"""
def __init__(
self,
model=None,
potential_fn=None,
randomize_split=True,
moves=None,
max_steps=10_000,
max_iter=10_000,
init_mu=1.0,
tune_mu=True,
init_strategy=init_to_uniform,
):
if not moves:
self._moves = [ESS.DifferentialMove()]
self._weights = jnp.array([1.0])
else:
self._moves = list(moves.keys())
self._weights = jnp.array([weight for weight in moves.values()]) / len(moves)
assert all([hasattr(move, '__call__') for move in self._moves]), (
"Each move must be a callable (one of `ESS.DifferentialMove()`, "
"`ESS.GaussianMove()`, `ESS.KDEMove()`, `ESS.RandomMove()`)")
assert jnp.all(self._weights >= 0), "Each specified move must have probability >= 0"
assert init_mu > 0, "Scale factor should be strictly positive"
self._max_steps = max_steps # max number of stepping out steps
self._max_iter = max_iter # max number of expansions/contractions
self._init_mu = init_mu
self._tune_mu = tune_mu
super().__init__(model,
potential_fn,
randomize_split=randomize_split,
init_strategy=init_strategy)
[docs]
def init_inner_state(self, rng_key):
self.batch_log_density = lambda x: self._batch_log_density(x)[:, jnp.newaxis]
# XXX hack -- we don't know num_chains until we init the inner state
self._moves = [move(self._num_chains) if move.__name__ == 'make_differential_move'
else move for move in self._moves]
return ESSState(jnp.array(0.0), jnp.array(0), jnp.array(0), self._init_mu, rng_key)
[docs]
def update_active_chains(self, active, inactive, inner_state):
i, n_expansions, n_contractions, mu, rng_key = inner_state
(rng_key,
move_key,
dir_key,
height_key,
step_out_key,
shrink_key) = random.split(rng_key, 6)
n_active_chains, n_params = active.shape
move_i = random.choice(move_key, len(self._moves), p=self._weights)
directions = jax.lax.switch(move_i, self._moves, dir_key, inactive, mu)
log_slice_height = self.batch_log_density(active) - dist.Exponential().sample(
height_key, sample_shape=(n_active_chains, 1)
)
curr_n_expansions, L, R = self._step_out(
step_out_key, log_slice_height, active, directions
)
proposal, curr_n_contractions = self._shrink(
shrink_key, log_slice_height, L, R, active, directions
)
n_expansions += curr_n_expansions
n_contractions += curr_n_contractions
itr = i + 0.5
if self._tune_mu:
safe_n_expansions = jnp.max(jnp.array([1, n_expansions]))
# only update tuning scale if a full iteration has passed
mu, n_expansions, n_contractions = jax.lax.cond(jnp.all(itr % 1 == 0),
lambda n_exp, n_con: (2.0 * n_exp / (n_exp + n_con),
jnp.array(0),
jnp.array(0)
),
lambda _, __: (mu,
n_expansions,
n_contractions
),
safe_n_expansions, n_contractions)
return proposal, ESSState(itr, n_expansions, n_contractions, mu, rng_key)
[docs]
@staticmethod
def RandomMove():
"""
The `Karamanis & Beutler (2020) <https://arxiv.org/abs/2002.06212>`_ "Random Move" with parallelization.
When this move is used the walkers move along random directions. There is no communication between the
walkers and this Move corresponds to the vanilla Slice Sampling method. This Move should be used for
debugging purposes only.
"""
def random_move(rng_key, inactive, mu):
directions = dist.Normal(loc=0, scale=1).sample(
rng_key, sample_shape=inactive.shape
)
directions /= jnp.linalg.norm(directions, axis=0)
return 2.0 * mu * directions
return random_move
[docs]
@staticmethod
def KDEMove(bw_method=None):
"""
The `Karamanis & Beutler (2020) <https://arxiv.org/abs/2002.06212>`_ "KDE Move" with parallelization.
When this Move is used the distribution of the walkers of the complementary ensemble is traced using
a Gaussian Kernel Density Estimation methods. The walkers then move along random direction vectos
sampled from this distribution.
"""
def kde_move(rng_key, inactive, mu):
n_active_chains, n_params = inactive.shape
kde = gaussian_kde(inactive.T, bw_method=bw_method)
vectors = kde.resample(rng_key, (2 * n_active_chains,)).T
directions = vectors[:n_active_chains] - vectors[n_active_chains:]
return 2.0 * mu * directions
return kde_move
[docs]
@staticmethod
def GaussianMove():
"""
The `Karamanis & Beutler (2020) <https://arxiv.org/abs/2002.06212>`_ "Gaussian Move" with parallelization.
When this Move is used the walkers move along directions defined by random vectors sampled from the Gaussian
approximation of the walkers of the complementary ensemble.
"""
# In high dimensional regimes with sufficiently small n_active_chains,
# it is more efficient to sample without computing the Cholesky
# decomposition of the covariance matrix:
# eps = dist.Normal(0, 1).sample(rng_key, (n_active_chains, n_active_chains))
# return 2.0 * mu * (eps @ (inactive - jnp.mean(inactive, axis=0)) / jnp.sqrt(n_active_chains))
def gaussian_move(rng_key, inactive, mu):
n_active_chains, n_params = inactive.shape
cov = jnp.cov(inactive, rowvar=False)
return (
2.0
* mu
* dist.MultivariateNormal(0, cov).sample(
rng_key, sample_shape=(n_active_chains,)
)
)
return gaussian_move
[docs]
@staticmethod
def DifferentialMove():
"""
The `Karamanis & Beutler (2020) <https://arxiv.org/abs/2002.06212>`_ "Differential Move" with parallelization.
When this Move is used the walkers move along directions defined by random pairs of walkers sampled (with no
replacement) from the complementary ensemble. This is the default choice and performs well along a wide range
of target distributions.
"""
def make_differential_move(n_chains):
PAIRS = get_nondiagonal_indices(n_chains // 2)
def differential_move(rng_key, inactive, mu):
n_active_chains, n_params = inactive.shape
selected_pairs = random.choice(rng_key, PAIRS, shape=(n_active_chains,))
diffs = jnp.diff(inactive[selected_pairs], axis=1).squeeze(
axis=1
) # get the pairwise difference of each vector
return 2.0 * mu * diffs
return differential_move
return make_differential_move
def _step_out(self, rng_key, log_slice_height, active, directions):
init_L_key, init_J_key = random.split(rng_key)
n_active_chains, n_params = active.shape
iteration = 0
n_expansions = 0
# set initial interval boundaries
L = -dist.Uniform().sample(init_L_key, sample_shape=(n_active_chains, 1))
R = L + 1.0
# stepping out
J = jnp.floor(
dist.Uniform(low=0, high=self._max_steps).sample(
init_J_key, sample_shape=(n_active_chains, 1)
)
)
K = (self._max_steps - 1) - J
# left stepping-out initialisation
mask_J = jnp.full((n_active_chains, 1), True)
# right stepping-out initialisation
mask_K = jnp.full((n_active_chains, 1), True)
init_values = (n_expansions, L, R, J, K, mask_J, mask_K, iteration)
def cond_fn(args):
n_expansions, L, R, J, K, mask_J, mask_K, iteration = args
return (jnp.count_nonzero(mask_J) + jnp.count_nonzero(mask_K) > 0) & (
iteration < self._max_iter
)
def body_fn(args):
n_expansions, L, R, J, K, mask_J, mask_K, iteration = args
log_prob_L = self.batch_log_density(directions * L + active)
log_prob_R = self.batch_log_density(directions * R + active)
can_expand_L = log_prob_L > log_slice_height
L = jnp.where(can_expand_L, L - 1, L)
J = jnp.where(can_expand_L, J - 1, J)
mask_J = jnp.where(can_expand_L, mask_J, False)
can_expand_R = log_prob_R > log_slice_height
R = jnp.where(can_expand_R, R + 1, R)
K = jnp.where(can_expand_R, K - 1, K)
mask_K = jnp.where(can_expand_R, mask_K, False)
iteration += 1
n_expansions += jnp.count_nonzero(can_expand_L) + jnp.count_nonzero(
can_expand_R
)
return (n_expansions, L, R, J, K, mask_J, mask_K, iteration)
n_expansions, L, R, J, K, mask_J, mask_K, iteration = jax.lax.while_loop(
cond_fn, body_fn, init_values
)
return n_expansions, L, R
def _shrink(self, rng_key, log_slice_height, L, R, active, directions):
n_active_chains, n_params = active.shape
iteration = 0
n_contractions = 0
widths = jnp.zeros((n_active_chains, 1))
proposed = jnp.zeros((n_active_chains, n_params))
can_shrink = jnp.full((n_active_chains, 1), True)
init_values = (
rng_key,
proposed,
n_contractions,
L,
R,
widths,
can_shrink,
iteration,
)
def cond_fn(args):
(
rng_key,
proposed,
n_contractions,
L,
R,
widths,
can_shrink,
iteration,
) = args
return (jnp.count_nonzero(can_shrink) > 0) & (iteration < self._max_iter)
def body_fn(args):
(
rng_key,
proposed,
n_contractions,
L,
R,
widths,
can_shrink,
iteration,
) = args
rng_key, _ = random.split(rng_key)
widths = jnp.where(
can_shrink, dist.Uniform(low=L, high=R).sample(rng_key), widths
)
# compute new positions
proposed = jnp.where(can_shrink, directions * widths + active, proposed)
proposed_log_prob = self.batch_log_density(proposed)
# shrink slices
can_shrink = proposed_log_prob < log_slice_height
L_cond = can_shrink & (widths < 0.0)
L = jnp.where(L_cond, widths, L)
R_cond = can_shrink & (widths > 0.0)
R = jnp.where(R_cond, widths, R)
iteration += 1
n_contractions += jnp.count_nonzero(L_cond) + jnp.count_nonzero(R_cond)
return (
rng_key,
proposed,
n_contractions,
L,
R,
widths,
can_shrink,
iteration,
)
(
rng_key,
proposed,
n_contractions,
L,
R,
widths,
can_shrink,
iteration,
) = jax.lax.while_loop(cond_fn, body_fn, init_values)
return proposed, n_contractions