Source code for numpyro.distributions.constraints

# The implementation follows the design in PyTorch: torch.distributions.constraints.py
#
# Copyright (c) 2016-     Facebook, Inc            (Adam Paszke)
# Copyright (c) 2014-     Facebook, Inc            (Soumith Chintala)
# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
# Copyright (c) 2012-2014 Deepmind Technologies    (Koray Kavukcuoglu)
# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
# Copyright (c) 2011-2013 NYU                      (Clement Farabet)
# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
# Copyright (c) 2006      Idiap Research Institute (Samy Bengio)
# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.


__all__ = [
    'boolean',
    'corr_cholesky',
    'corr_matrix',
    'dependent',
    'greater_than',
    'integer_interval',
    'integer_greater_than',
    'interval',
    'is_dependent',
    'lower_cholesky',
    'multinomial',
    'nonnegative_integer',
    'positive',
    'positive_definite',
    'positive_integer',
    'real',
    'real_vector',
    'simplex',
    'unit_interval',
    'Constraint',
]


import jax.numpy as np


class Constraint(object):
    def __call__(self, x):
        raise NotImplementedError


class _Boolean(Constraint):
    def __call__(self, x):
        return (x == 0) | (x == 1)


class _CorrCholesky(Constraint):
    def __call__(self, x):
        tril = np.tril(x)
        lower_triangular = np.all(np.reshape(tril == x, x.shape[:-2] + (-1,)), axis=-1)
        positive_diagonal = np.all(np.diagonal(x, axis1=-2, axis2=-1) > 0, axis=-1)
        x_norm = np.linalg.norm(x, axis=-1)
        unit_norm_row = np.all((x_norm <= 1) & (x_norm > 1 - 1e-6), axis=-1)
        return lower_triangular & positive_diagonal & unit_norm_row


class _CorrMatrix(Constraint):
    def __call__(self, x):
        # check for symmetric
        symmetric = np.all(np.all(x == np.swapaxes(x, -2, -1), axis=-1), axis=-1)
        # check for the smallest eigenvalue is positive
        positive = np.linalg.eigh(x)[0][..., 0] > 0
        # check for diagonal equal to 1
        unit_variance = np.all(np.abs(np.diagonal(x, axis1=-2, axis2=-1) - 1) < 1e-6, axis=-1)
        return symmetric & positive & unit_variance


class _Dependent(Constraint):
    def __call__(self, x):
        raise ValueError('Cannot determine validity of dependent constraint')


def is_dependent(constraint):
    return isinstance(constraint, _Dependent)


class _GreaterThan(Constraint):
    def __init__(self, lower_bound):
        self.lower_bound = lower_bound

    def __call__(self, x):
        return x > self.lower_bound


class _IntegerInterval(Constraint):
    def __init__(self, lower_bound, upper_bound):
        self.lower_bound = lower_bound
        self.upper_bound = upper_bound

    def __call__(self, x):
        return (x >= self.lower_bound) & (x <= self.upper_bound) & (x == np.floor(x))


class _IntegerGreaterThan(Constraint):
    def __init__(self, lower_bound):
        self.lower_bound = lower_bound

    def __call__(self, x):
        return (x % 1 == 0) & (x >= self.lower_bound)


class _Interval(Constraint):
    def __init__(self, lower_bound, upper_bound):
        self.lower_bound = lower_bound
        self.upper_bound = upper_bound

    def __call__(self, x):
        return (x > self.lower_bound) & (x < self.upper_bound)


class _LowerCholesky(Constraint):
    def __call__(self, x):
        tril = np.tril(x)
        lower_triangular = np.all(np.reshape(tril == x, x.shape[:-2] + (-1,)), axis=-1)
        positive_diagonal = np.all(np.diagonal(x, axis1=-2, axis2=-1) > 0, axis=-1)
        return lower_triangular & positive_diagonal


class _Multinomial(Constraint):
    def __init__(self, upper_bound):
        self.upper_bound = upper_bound

    def __call__(self, x):
        return np.all(x >= 0, axis=-1) & (np.sum(x, -1) == self.upper_bound)


class _OrderedVector(Constraint):
    def __call__(self, x):
        return np.all(x[..., 1:] > x[..., :-1], axis=-1)


class _PositiveDefinite(Constraint):
    def __call__(self, x):
        # check for symmetric
        symmetric = np.all(np.all(x == np.swapaxes(x, -2, -1), axis=-1), axis=-1)
        # check for the smallest eigenvalue is positive
        positive = np.linalg.eigh(x)[0][..., 0] > 0
        return symmetric & positive


class _Real(Constraint):
    def __call__(self, x):
        return np.isfinite(x)


class _RealVector(Constraint):
    def __call__(self, x):
        return np.all(np.isfinite(x), axis=-1)


class _Simplex(Constraint):
    def __call__(self, x):
        x_sum = np.sum(x, axis=-1)
        return np.all(x > 0, axis=-1) & (x_sum <= 1) & (x_sum > 1 - 1e-6)


# TODO: Make types consistent

boolean = _Boolean()
corr_cholesky = _CorrCholesky()
corr_matrix = _CorrMatrix()
dependent = _Dependent()
greater_than = _GreaterThan
integer_interval = _IntegerInterval
integer_greater_than = _IntegerGreaterThan
interval = _Interval
lower_cholesky = _LowerCholesky()
multinomial = _Multinomial
nonnegative_integer = _IntegerGreaterThan(0)
ordered_vector = _OrderedVector()
positive = _GreaterThan(0.)
positive_definite = _PositiveDefinite()
positive_integer = _IntegerGreaterThan(1)
real = _Real()
real_vector = _RealVector()
simplex = _Simplex()
unit_interval = _Interval(0., 1.)