from collections import namedtuple
import functools
import jax
from jax import lax
import numpyro
from numpyro.distributions.discrete import PRNGIdentity
_PYRO_STACK = []
CondIndepStackFrame = namedtuple('CondIndepStackFrame', ['name', 'dim', 'size'])
def apply_stack(msg):
pointer = 0
for pointer, handler in enumerate(reversed(_PYRO_STACK)):
handler.process_message(msg)
# When a Messenger sets the "stop" field of a message,
# it prevents any Messengers above it on the stack from being applied.
if msg.get("stop"):
break
if msg['value'] is None:
if msg['type'] == 'sample':
msg['value'], msg['intermediates'] = msg['fn'](*msg['args'],
sample_intermediates=True,
**msg['kwargs'])
else:
msg['value'] = msg['fn'](*msg['args'], **msg['kwargs'])
# A Messenger that sets msg["stop"] == True also prevents application
# of postprocess_message by Messengers above it on the stack
# via the pointer variable from the process_message loop
for handler in _PYRO_STACK[-pointer-1:]:
handler.postprocess_message(msg)
return msg
class Messenger(object):
def __init__(self, fn=None):
self.fn = fn
functools.update_wrapper(self, fn, updated=[])
def __enter__(self):
_PYRO_STACK.append(self)
def __exit__(self, *args, **kwargs):
assert _PYRO_STACK[-1] is self
_PYRO_STACK.pop()
def process_message(self, msg):
pass
def postprocess_message(self, msg):
pass
def __call__(self, *args, **kwargs):
with self:
return self.fn(*args, **kwargs)
[docs]def sample(name, fn, obs=None, rng_key=None, sample_shape=()):
"""
Returns a random sample from the stochastic function `fn`. This can have
additional side effects when wrapped inside effect handlers like
:class:`~numpyro.handlers.substitute`.
.. note::
By design, `sample` primitive is meant to be used inside a NumPyro model.
Then :class:`~numpyro.handlers.seed` handler is used to inject a random
state to `fn`. In those situations, `rng_key` keyword will take no
effect.
:param str name: name of the sample site
:param fn: Python callable
:param numpy.ndarray obs: observed value
:param jax.random.PRNGKey rng_key: an optional random key for `fn`.
:param sample_shape: Shape of samples to be drawn.
:return: sample from the stochastic `fn`.
"""
# if there are no active Messengers, we just draw a sample and return it as expected:
if not _PYRO_STACK:
return fn(rng_key=rng_key, sample_shape=sample_shape)
# Otherwise, we initialize a message...
initial_msg = {
'type': 'sample',
'name': name,
'fn': fn,
'args': (),
'kwargs': {'rng_key': rng_key, 'sample_shape': sample_shape},
'value': obs,
'scale': 1.0,
'is_observed': obs is not None,
'intermediates': [],
'cond_indep_stack': [],
}
# ...and use apply_stack to send it to the Messengers
msg = apply_stack(initial_msg)
return msg['value']
def identity(x, *args, **kwargs):
return x
[docs]def param(name, init_value=None, **kwargs):
"""
Annotate the given site as an optimizable parameter for use with
:mod:`jax.experimental.optimizers`. For an example of how `param` statements
can be used in inference algorithms, refer to :func:`~numpyro.svi.svi`.
:param str name: name of site.
:param numpy.ndarray init_value: initial value specified by the user. Note that
the onus of using this to initialize the optimizer is on the user /
inference algorithm, since there is no global parameter store in
NumPyro.
:return: value for the parameter. Unless wrapped inside a
handler like :class:`~numpyro.handlers.substitute`, this will simply
return the initial value.
"""
# if there are no active Messengers, we just draw a sample and return it as expected:
if not _PYRO_STACK:
return init_value
# Otherwise, we initialize a message...
initial_msg = {
'type': 'param',
'name': name,
'fn': identity,
'args': (init_value,),
'kwargs': kwargs,
'value': None,
'scale': 1.0,
'cond_indep_stack': [],
}
# ...and use apply_stack to send it to the Messengers
msg = apply_stack(initial_msg)
return msg['value']
[docs]def module(name, nn, input_shape=None):
"""
Declare a :mod:`~jax.experimental.stax` style neural network inside a
model so that its parameters are registered for optimization via
:func:`~numpyro.primitives.param` statements.
:param str name: name of the module to be registered.
:param tuple nn: a tuple of `(init_fn, apply_fn)` obtained by a :mod:`~jax.experimental.stax`
constructor function.
:param tuple input_shape: shape of the input taken by the
neural network.
:return: a `apply_fn` with bound parameters that takes an array
as an input and returns the neural network transformed output
array.
"""
module_key = name + '$params'
nn_init, nn_apply = nn
nn_params = param(module_key)
if nn_params is None:
if input_shape is None:
raise ValueError('Valid value for `input_size` needed to initialize.')
rng_key = numpyro.sample(name + '$rng_key', PRNGIdentity())
_, nn_params = nn_init(rng_key, input_shape)
param(module_key, nn_params)
return jax.partial(nn_apply, nn_params)
[docs]class plate(Messenger):
"""
Construct for annotating conditionally independent variables. Within a
`plate` context manager, `sample` sites will be automatically broadcasted to
the size of the plate. Additionally, a scale factor might be applied by
certain inference algorithms if `subsample_size` is specified.
:param str name: Name of the plate.
:param int size: Size of the plate.
:param int subsample_size: Optional argument denoting the size of the mini-batch.
This can be used to apply a scaling factor by inference algorithms. e.g.
when computing ELBO using a mini-batch.
:param int dim: Optional argument to specify which dimension in the tensor
is used as the plate dim. If `None` (default), the leftmost available dim
is allocated.
"""
def __init__(self, name, size, subsample_size=None, dim=None):
self.name = name
self.size = size
self.subsample_size = size if subsample_size is None else subsample_size
if dim is not None and dim >= 0:
raise ValueError('dim arg must be negative.')
self.dim = dim
self._validate_and_set_dim()
super(plate, self).__init__()
def _validate_and_set_dim(self):
msg = {
'type': 'plate',
'fn': identity,
'name': self.name,
'args': (None,),
'kwargs': {},
'value': None,
'scale': 1.0,
'cond_indep_stack': [],
}
apply_stack(msg)
cond_indep_stack = msg['cond_indep_stack']
occupied_dims = {f.dim for f in cond_indep_stack}
dim = -1
while True:
if dim not in occupied_dims:
break
dim -= 1
if self.dim is None:
self.dim = dim
else:
assert self.dim not in occupied_dims
@staticmethod
def _get_batch_shape(cond_indep_stack):
n_dims = max(-f.dim for f in cond_indep_stack)
batch_shape = [1] * n_dims
for f in cond_indep_stack:
batch_shape[f.dim] = f.size
return tuple(batch_shape)
def process_message(self, msg):
cond_indep_stack = msg['cond_indep_stack']
frame = CondIndepStackFrame(self.name, self.dim, self.subsample_size)
cond_indep_stack.append(frame)
expected_shape = self._get_batch_shape(cond_indep_stack)
dist_batch_shape = msg['fn'].batch_shape if msg['type'] == 'sample' else ()
overlap_idx = len(expected_shape) - len(dist_batch_shape)
if overlap_idx < 0:
raise ValueError('Expected dimensions within plate = {}, which is less than the '
'distribution\'s batch shape = {}.'.format(len(expected_shape), len(dist_batch_shape)))
trailing_shape = expected_shape[overlap_idx:]
# e.g. distribution with batch shape (1, 5) cannot be broadcast to (5, 5)
broadcast_shape = lax.broadcast_shapes(trailing_shape, dist_batch_shape)
if broadcast_shape != dist_batch_shape:
raise ValueError('Distribution batch shape = {} cannot be broadcast up to {}. '
'Consider using unbatched distributions.'
.format(dist_batch_shape, broadcast_shape))
batch_shape = expected_shape[:overlap_idx]
if 'sample_shape' in msg['kwargs']:
batch_shape = lax.broadcast_shapes(msg['kwargs']['sample_shape'], batch_shape)
msg['kwargs']['sample_shape'] = batch_shape
msg['scale'] = msg['scale'] * self.size / self.subsample_size
[docs]def factor(name, log_factor):
"""
Factor statement to add arbitrary log probability factor to a
probabilistic model.
:param str name: Name of the trivial sample.
:param numpy.ndarray log_factor: A possibly batched log probability factor.
"""
unit_dist = numpyro.distributions.distribution.Unit(log_factor)
unit_value = unit_dist.sample(None)
sample(name, unit_dist, obs=unit_value)