# The implementation follows the design in PyTorch: torch.distributions.distribution.py
#
# Copyright (c) 2016- Facebook, Inc (Adam Paszke)
# Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
# Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
# Copyright (c) 2011-2013 NYU (Clement Farabet)
# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
# Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
import jax.numpy as np
from numpyro.distributions.constraints import Transform, is_dependent
from numpyro.distributions.util import sum_rightmost
[docs]class Distribution(object):
"""
Base class for probability distributions in NumPyro. The design largely
follows from :mod:`torch.distributions`.
:param batch_shape: The batch shape for the distribution. This designates
independent (possibly non-identical) dimensions of a sample from the
distribution. This is fixed for a distribution instance and is inferred
from the shape of the distribution parameters.
:param event_shape: The event shape for the distribution. This designates
the dependent dimensions of a sample from the distribution. These are
collapsed when we evaluate the log probability density of a batch of
samples using `.log_prob`.
:param validate_args: Whether to enable validation of distribution
parameters and arguments to `.log_prob` method.
As an example:
.. testsetup::
import jax.numpy as np
import numpyro.distributions as dist
.. doctest::
>>> d = dist.Dirichlet(np.ones((2, 3, 4)))
>>> d.batch_shape
(2, 3)
>>> d.event_shape
(4,)
"""
arg_constraints = {}
support = None
reparametrized_params = []
_validate_args = False
def __init__(self, batch_shape=(), event_shape=(), validate_args=None):
self._batch_shape = batch_shape
self._event_shape = event_shape
if validate_args is not None:
self._validate_args = validate_args
if self._validate_args:
for param, constraint in self.arg_constraints.items():
if is_dependent(constraint):
continue # skip constraints that cannot be checked
if not np.all(constraint(getattr(self, param))):
raise ValueError("The parameter {} has invalid values".format(param))
super(Distribution, self).__init__()
@property
def batch_shape(self):
"""
Returns the shape over which the distribution parameters are batched.
:return: batch shape of the distribution.
:rtype: tuple
"""
return self._batch_shape
@property
def event_shape(self):
"""
Returns the shape of a single sample from the distribution without
batching.
:return: event shape of the distribution.
:rtype: tuple
"""
return self._event_shape
[docs] def sample(self, key, sample_shape=()):
"""
Returns a sample from the distribution having shape given by
`sample_shape + batch_shape + event_shape`. Note that when `sample_shape` is non-empty,
leading dimensions (of size `sample_shape`) of the returned sample will
be filled with iid draws from the distribution instance.
:param jax.random.PRNGKey key: the rng key to be used for the distribution.
:param size: the sample shape for the distribution.
:return: a `numpy.ndarray` of shape `sample_shape + batch_shape + event_shape`
"""
raise NotImplementedError
[docs] def log_prob(self, value):
"""
Evaluates the log probability density for a batch of samples given by
`value`.
:param value: A batch of samples from the distribution.
:return: a `numpy.ndarray` with shape `value.shape[:-self.event_shape]`
"""
raise NotImplementedError
@property
def mean(self):
"""
Mean of the distribution.
"""
raise NotImplementedError
@property
def variance(self):
"""
Variance of the distribution.
"""
raise NotImplementedError
def _validate_sample(self, value):
if not np.all(self.support(value)):
raise ValueError('Invalid values provided to log prob method. '
'The value argument must be within the support.')
def __call__(self, *args, **kwargs):
key = kwargs.pop('random_state')
return self.sample(key, *args, **kwargs)