Source code for numpyro.distributions.flows
from jax import lax
import jax.numpy as np
from numpyro.distributions.constraints import real_vector
from numpyro.distributions.transforms import Transform
from numpyro.util import fori_loop
def _clamp_preserve_gradients(x, min, max):
return x + lax.stop_gradient(np.clip(x, a_min=min, a_max=max) - x)
# adapted from https://github.com/pyro-ppl/pyro/blob/dev/pyro/distributions/transforms/iaf.py
[docs]class InverseAutoregressiveTransform(Transform):
"""
An implementation of Inverse Autoregressive Flow, using Eq (10) from Kingma et al., 2016,
:math:`\\mathbf{y} = \\mu_t + \\sigma_t\\odot\\mathbf{x}`
where :math:`\\mathbf{x}` are the inputs, :math:`\\mathbf{y}` are the outputs, :math:`\\mu_t,\\sigma_t`
are calculated from an autoregressive network on :math:`\\mathbf{x}`, and :math:`\\sigma_t>0`.
**References**
1. *Improving Variational Inference with Inverse Autoregressive Flow* [arXiv:1606.04934],
Diederik P. Kingma, Tim Salimans, Rafal Jozefowicz, Xi Chen, Ilya Sutskever, Max Welling
"""
domain = real_vector
codomain = real_vector
event_dim = 1
def __init__(self, autoregressive_nn, log_scale_min_clip=-5., log_scale_max_clip=3.):
"""
:param autoregressive_nn: an autoregressive neural network whose forward call returns a real-valued
mean and log scale as a tuple
"""
self.arn = autoregressive_nn
self.log_scale_min_clip = log_scale_min_clip
self.log_scale_max_clip = log_scale_max_clip
def __call__(self, x):
"""
:param numpy.ndarray x: the input into the transform
"""
return self.call_with_intermediates(x)[0]
[docs] def call_with_intermediates(self, x):
mean, log_scale = self.arn(x)
log_scale = _clamp_preserve_gradients(log_scale, self.log_scale_min_clip, self.log_scale_max_clip)
scale = np.exp(log_scale)
return scale * x + mean, log_scale
[docs] def inv(self, y):
"""
:param numpy.ndarray y: the output of the transform to be inverted
"""
# NOTE: Inversion is an expensive operation that scales in the dimension of the input
def _update_x(i, x):
mean, log_scale = self.arn(x)
inverse_scale = np.exp(-_clamp_preserve_gradients(
log_scale, min=self.log_scale_min_clip, max=self.log_scale_max_clip))
x = (y - mean) * inverse_scale
return x
x = fori_loop(0, y.shape[-1], _update_x, np.zeros(y.shape))
return x
[docs] def log_abs_det_jacobian(self, x, y, intermediates=None):
"""
Calculates the elementwise determinant of the log jacobian.
:param numpy.ndarray x: the input to the transform
:param numpy.ndarray y: the output of the transform
"""
if intermediates is None:
log_scale = self.arn(x)[1]
log_scale = _clamp_preserve_gradients(log_scale, self.log_scale_min_clip, self.log_scale_max_clip)
return log_scale.sum(-1)
else:
log_scale = intermediates
return log_scale.sum(-1)