# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from abc import ABCMeta
from collections import namedtuple
import inspect
import jax
from jax import random, vmap
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
import tensorflow_probability.substrates.jax as tfp
from numpyro.infer import init_to_uniform
from numpyro.infer.mcmc import MCMCKernel
from numpyro.infer.util import initialize_model
from numpyro.util import identity, is_prng_key
TFPKernelState = namedtuple("TFPKernelState", ["z", "kernel_results", "rng_key"])
def _extract_kernel_functions(kernel):
def init_fn(z, rng_key):
z_flat, _ = ravel_pytree(z)
results = kernel.bootstrap_results(z_flat)
return TFPKernelState(z, results, rng_key)
def sample_fn(state, model_args=(), model_kwargs=None):
rng_key, rng_key_transition = random.split(state.rng_key)
z_flat, unravel_fn = ravel_pytree(state.z)
z_new_flat, results = kernel.one_step(
z_flat, state.kernel_results, seed=rng_key_transition
)
return TFPKernelState(unravel_fn(z_new_flat), results, rng_key)
return init_fn, sample_fn
def _make_log_prob_fn(potential_fn, unravel_fn):
def log_prob_fn(x):
# we deal with batched x in case the kernel is ReplicaExchangeMC
batch_shape = jnp.shape(x)[:-1]
if batch_shape:
flatten_result = vmap(lambda a: -potential_fn(unravel_fn(a)))(
jnp.reshape(x, (-1,) + jnp.shape(x)[-1:])
)
return jax.tree.map(
lambda a: jnp.reshape(a, batch_shape + jnp.shape(a)[1:]), flatten_result
)
else:
return -potential_fn(unravel_fn(x))
return log_prob_fn
class _TFPKernelMeta(ABCMeta):
def __getitem__(cls, kernel_class):
assert issubclass(kernel_class, tfp.mcmc.TransitionKernel)
assert (
"target_log_prob_fn" in inspect.getfullargspec(kernel_class).args
), f"the first argument of {kernel_class} must be `target_log_prob_fn`"
_PyroKernel = type(kernel_class.__name__, (TFPKernel,), {})
_PyroKernel.kernel_class = kernel_class
return _PyroKernel
[docs]
class TFPKernel(MCMCKernel, metaclass=_TFPKernelMeta):
"""
A thin wrapper for TensorFlow Probability (TFP) MCMC transition kernels.
The argument `target_log_prob_fn` in TFP is replaced by either `model`
or `potential_fn` (which is the negative of `target_log_prob_fn`).
This class can be used to convert a TFP kernel to a NumPyro-compatible one
as follows::
from numpyro.contrib.tfp.mcmc import TFPKernel
kernel = TFPKernel[tfp.mcmc.NoUTurnSampler](model, step_size=1.)
.. note:: By default, uncalibrated kernels will be inner kernels of the
:class:`~tensorflow_probability.substrates.jax.mcmc.MetropolisHastings` kernel.
.. note:: For :class:`~numpyro.contrib.tfp.mcmc.ReplicaExchangeMC`, TFP requires
that the shape of `step_size` of the inner kernel must be
`[len(inverse_temperatures), 1]` or `[len(inverse_temperatures), latent_size]`.
:param model: Python callable containing Pyro :mod:`~numpyro.primitives`.
If model is provided, `potential_fn` will be inferred using the model.
:param potential_fn: Python callable that computes the target potential energy
given input parameters. The input parameters to `potential_fn`
can be any python collection type, provided that `init_params` argument to
:meth:`init` has the same type.
:param callable init_strategy: a per-site initialization function.
See :ref:`init_strategy` section for available functions.
:param kernel_kwargs: other arguments to be passed to TFP kernel constructor.
"""
kernel_class = None
def __init__(
self,
model=None,
potential_fn=None,
init_strategy=init_to_uniform,
**kernel_kwargs,
):
if not (model is None) ^ (potential_fn is None):
raise ValueError("Only one of `model` or `potential_fn` must be specified.")
self._model = model
self._potential_fn = potential_fn
self._kernel_kwargs = kernel_kwargs
self._init_strategy = init_strategy
# Set on first call to init
self._init_fn = None
self._postprocess_fn = None
self._sample_fn = None
def _init_state(self, rng_key, model_args, model_kwargs, init_params):
if self._model is not None:
init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
rng_key,
self._model,
init_strategy=self._init_strategy,
dynamic_args=True,
model_args=model_args,
model_kwargs=model_kwargs,
)
init_params = init_params.z
if self._init_fn is None:
_, unravel_fn = ravel_pytree(init_params)
kernel = self.kernel_class(
_make_log_prob_fn(
potential_fn(*model_args, **model_kwargs), unravel_fn
),
**self._kernel_kwargs,
)
# Uncalibrated... kernels have to used inside MetropolisHastings, see
# https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/mcmc/UncalibratedLangevin
if self.kernel_class.__name__.startswith("Uncalibrated"):
kernel = tfp.mcmc.MetropolisHastings(kernel)
self._init_fn, self._sample_fn = _extract_kernel_functions(kernel)
self._postprocess_fn = postprocess_fn
elif self._init_fn is None:
_, unravel_fn = ravel_pytree(init_params)
kernel = self.kernel_class(
_make_log_prob_fn(self._potential_fn, unravel_fn), **self._kernel_kwargs
)
if self.kernel_class.__name__.startswith("Uncalibrated"):
kernel = tfp.mcmc.MetropolisHastings(kernel)
self._init_fn, self._sample_fn = _extract_kernel_functions(kernel)
return init_params
@property
def model(self):
return self._model
@property
def sample_field(self):
return "z"
@property
def default_fields(self):
return ("z",)
def get_diagnostics_str(self, state):
"""
Given the current `state`, returns the diagnostics string to
be added to progress bar for diagnostics purpose.
"""
return ""
def init(
self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={}
):
# non-vectorized
if is_prng_key(rng_key):
rng_key, rng_key_init_model = random.split(rng_key)
# vectorized
else:
rng_key, rng_key_init_model = jnp.swapaxes(
vmap(random.split)(rng_key), 0, 1
)
init_params = self._init_state(
rng_key_init_model, model_args, model_kwargs, init_params
)
if self._potential_fn and init_params is None:
raise ValueError(
"Valid value of `init_params` must be provided with"
" `target_log_prob_fn`."
)
if is_prng_key(rng_key):
init_state = self._init_fn(init_params, rng_key)
else:
# XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
# nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth,
# wa_steps because those variables do not depend on traced args: init_params, rng_key.
init_state = vmap(self._init_fn)(init_params, rng_key)
sample_fn = vmap(self._sample_fn, in_axes=(0, None, None))
self._sample_fn = sample_fn
return init_state
def postprocess_fn(self, args, kwargs):
if self._postprocess_fn is None:
return identity
return self._postprocess_fn(*args, **kwargs)
def sample(self, state, model_args, model_kwargs):
"""
Run the kernel from the given :data:`~numpyro.contrib.tfp.mcmc.TFPKernelState`
and return the resulting :data:`~numpyro.contrib.tfp.mcmc.TFPKernelState`.
:param TFPKernelState state: Represents the current state.
:param model_args: Arguments provided to the model.
:param model_kwargs: Keyword arguments provided to the model.
:return: Next `state` after running the kernel.
"""
return self._sample_fn(state, model_args, model_kwargs)
__all__ = ["TFPKernel"]
for _name, _Kernel in tfp.mcmc.__dict__.items():
if not isinstance(_Kernel, type):
continue
if not issubclass(_Kernel, tfp.mcmc.TransitionKernel):
continue
if "target_log_prob_fn" not in inspect.getfullargspec(_Kernel).args:
continue
_PyroKernel = TFPKernel[_Kernel]
_PyroKernel.__module__ = __name__
locals()[_name] = _PyroKernel
_PyroKernel.__doc__ = """
Wraps `{}.{} <https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/mcmc/{}>`_
with :class:`~numpyro.contrib.tfp.mcmc.TFPKernel`. The first argument `target_log_prob_fn`
in TFP kernel construction is replaced by either `model` or `potential_fn`.
""".format(_Kernel.__module__, _Kernel.__name__, _Kernel.__name__)
__all__.append(_name)
# Create sphinx documentation.
__doc__ = "\n\n".join(
[
"""
{0}
----------------------------------------------------------------
.. autoclass:: numpyro.contrib.tfp.mcmc.{0}
""".format(_name)
for _name in __all__[:1] + sorted(__all__[1:])
]
)