from jax import lax
import jax.numpy as np

from numpyro.distributions.constraints import real_vector
from numpyro.distributions.transforms import Transform
from numpyro.util import fori_loop

def _clamp_preserve_gradients(x, min, max):
    return x + lax.stop_gradient(np.clip(x, a_min=min, a_max=max) - x)

[docs]class InverseAutoregressiveTransform(Transform): """ An implementation of Inverse Autoregressive Flow, using Eq (10) from Kingma et al., 2016, :math:`\\mathbf{y} = \\mu_t + \\sigma_t\\odot\\mathbf{x}` where :math:`\\mathbf{x}` are the inputs, :math:`\\mathbf{y}` are the outputs, :math:`\\mu_t,\\sigma_t` are calculated from an autoregressive network on :math:`\\mathbf{x}`, and :math:`\\sigma_t>0`. **References** 1. *Improving Variational Inference with Inverse Autoregressive Flow* [arXiv:1606.04934], Diederik P. Kingma, Tim Salimans, Rafal Jozefowicz, Xi Chen, Ilya Sutskever, Max Welling """ domain = real_vector codomain = real_vector event_dim = 1 def __init__(self, autoregressive_nn, log_scale_min_clip=-5., log_scale_max_clip=3.): """ :param autoregressive_nn: an autoregressive neural network whose forward call returns a real-valued mean and log scale as a tuple """ self.arn = autoregressive_nn self.log_scale_min_clip = log_scale_min_clip self.log_scale_max_clip = log_scale_max_clip def __call__(self, x): """ :param numpy.ndarray x: the input into the transform """ return self.call_with_intermediates(x)[0]
[docs] def call_with_intermediates(self, x): mean, log_scale = self.arn(x) log_scale = _clamp_preserve_gradients(log_scale, self.log_scale_min_clip, self.log_scale_max_clip) scale = np.exp(log_scale) return scale * x + mean, log_scale
[docs] def inv(self, y): """ :param numpy.ndarray y: the output of the transform to be inverted """ # NOTE: Inversion is an expensive operation that scales in the dimension of the input def _update_x(i, x): mean, log_scale = self.arn(x) inverse_scale = np.exp(-_clamp_preserve_gradients( log_scale, min=self.log_scale_min_clip, max=self.log_scale_max_clip)) x = (y - mean) * inverse_scale return x x = fori_loop(0, y.shape[-1], _update_x, np.zeros(y.shape)) return x
[docs] def log_abs_det_jacobian(self, x, y, intermediates=None): """ Calculates the elementwise determinant of the log jacobian. :param numpy.ndarray x: the input to the transform :param numpy.ndarray y: the output of the transform """ if intermediates is None: log_scale = self.arn(x)[1] log_scale = _clamp_preserve_gradients(log_scale, self.log_scale_min_clip, self.log_scale_max_clip) return log_scale.sum(-1) else: log_scale = intermediates return log_scale.sum(-1)