from collections import namedtuple
from contextlib import contextmanager
import random
import numpy as onp
import tqdm
from jax import jit, lax, ops, vmap
from jax.lib.xla_bridge import canonicalize_dtype
import jax.numpy as np
from jax.tree_util import tree_flatten, tree_map, tree_unflatten
_DATA_TYPES = {}
_DISABLE_CONTROL_FLOW_PRIM = False
def set_rng_seed(rng_seed):
random.seed(rng_seed)
onp.random.seed(rng_seed)
@contextmanager
def optional(condition, context_manager):
"""
Optionally wrap inside `context_manager` if condition is `True`.
"""
if condition:
with context_manager:
yield
else:
yield
@contextmanager
def control_flow_prims_disabled():
global _DISABLE_CONTROL_FLOW_PRIM
stored_flag = _DISABLE_CONTROL_FLOW_PRIM
try:
_DISABLE_CONTROL_FLOW_PRIM = True
yield
finally:
_DISABLE_CONTROL_FLOW_PRIM = stored_flag
def cond(pred, true_operand, true_fun, false_operand, false_fun):
if _DISABLE_CONTROL_FLOW_PRIM:
if pred:
return true_fun(true_operand)
else:
return false_fun(false_operand)
else:
return lax.cond(pred, true_operand, true_fun, false_operand, false_fun)
def while_loop(cond_fun, body_fun, init_val):
if _DISABLE_CONTROL_FLOW_PRIM:
val = init_val
while cond_fun(val):
val = body_fun(val)
return val
else:
return lax.while_loop(cond_fun, body_fun, init_val)
def fori_loop(lower, upper, body_fun, init_val):
if _DISABLE_CONTROL_FLOW_PRIM:
val = init_val
for i in range(int(lower), int(upper)):
val = body_fun(i, val)
return val
else:
return lax.fori_loop(lower, upper, body_fun, init_val)
def identity(x):
return x
[docs]def fori_collect(lower, upper, body_fun, init_val, transform=identity, progbar=True, **progbar_opts):
"""
This looping construct works like :func:`~jax.lax.fori_loop` but with the additional
effect of collecting values from the loop body. In addition, this allows for
post-processing of these samples via `transform`, and progress bar updates.
Note that, `progbar=False` will be faster, especially when collecting a
lot of samples. Refer to example usage in :func:`~numpyro.mcmc.hmc`.
:param int lower: the index to start the collective work. In other words,
we will skip collecting the first `lower` values.
:param int upper: number of times to run the loop body.
:param body_fun: a callable that takes a collection of
`np.ndarray` and returns a collection with the same shape and
`dtype`.
:param init_val: initial value to pass as argument to `body_fun`. Can
be any Python collection type containing `np.ndarray` objects.
:param transform: a callable to post-process the values returned by `body_fn`.
:param progbar: whether to post progress bar updates.
:param `**progbar_opts`: optional additional progress bar arguments. A
`diagnostics_fn` can be supplied which when passed the current value
from `body_fun` returns a string that is used to update the progress
bar postfix. Also a `progbar_desc` keyword argument can be supplied
which is used to label the progress bar.
:return: collection with the same type as `init_val` with values
collected along the leading axis of `np.ndarray` objects.
"""
assert lower < upper
init_val_flat, unravel_fn = ravel_pytree(transform(init_val))
ravel_fn = lambda x: ravel_pytree(transform(x))[0] # noqa: E731
if not progbar:
collection = np.zeros((upper - lower,) + init_val_flat.shape)
def _body_fn(i, vals):
val, collection = vals
val = body_fun(val)
i = np.where(i >= lower, i - lower, 0)
collection = ops.index_update(collection, i, ravel_fn(val))
return val, collection
_, collection = fori_loop(0, upper, _body_fn, (init_val, collection))
else:
diagnostics_fn = progbar_opts.pop('diagnostics_fn', None)
progbar_desc = progbar_opts.pop('progbar_desc', lambda x: '')
collection = []
val = init_val
with tqdm.trange(upper) as t:
for i in t:
val = body_fun(val)
if i >= lower:
collection.append(jit(ravel_fn)(val))
t.set_description(progbar_desc(val), refresh=False)
if diagnostics_fn:
t.set_postfix_str(diagnostics_fn(val), refresh=False)
collection = np.stack(collection)
return vmap(unravel_fn)(collection)
def copy_docs_from(source_class, full_text=False):
"""
Decorator to copy class and method docs from source to destin class.
"""
def decorator(destin_class):
# This works only in python 3.3+:
# if not destin_class.__doc__:
# destin_class.__doc__ = source_class.__doc__
for name in dir(destin_class):
if name.startswith('_'):
continue
destin_attr = getattr(destin_class, name)
destin_attr = getattr(destin_attr, '__func__', destin_attr)
source_attr = getattr(source_class, name, None)
source_doc = getattr(source_attr, '__doc__', None)
if source_doc and not getattr(destin_attr, '__doc__', None):
if full_text or source_doc.startswith('See '):
destin_doc = source_doc
else:
destin_doc = 'See :meth:`{}.{}.{}`'.format(
source_class.__module__, source_class.__name__, name)
if isinstance(destin_attr, property):
# Set docs for object properties.
# Since __doc__ is read-only, we need to reset the property
# with the updated doc.
updated_property = property(destin_attr.fget,
destin_attr.fset,
destin_attr.fdel,
destin_doc)
setattr(destin_class, name, updated_property)
else:
destin_attr.__doc__ = destin_doc
return destin_class
return decorator
pytree_metadata = namedtuple('pytree_metadata', ['flat', 'shape', 'size', 'dtype'])
def _ravel_list(*leaves):
leaves_metadata = tree_map(lambda l: pytree_metadata(
np.ravel(l), np.shape(l), np.size(l), canonicalize_dtype(lax.dtype(l))), leaves)
leaves_idx = np.cumsum(np.array((0,) + tuple(d.size for d in leaves_metadata)))
def unravel_list(arr):
return [np.reshape(lax.dynamic_slice_in_dim(arr, leaves_idx[i], m.size),
m.shape).astype(m.dtype)
for i, m in enumerate(leaves_metadata)]
return np.concatenate([m.flat for m in leaves_metadata]), unravel_list
def ravel_pytree(pytree):
leaves, treedef = tree_flatten(pytree)
flat, unravel_list = _ravel_list(*leaves)
def unravel_pytree(arr):
return tree_unflatten(treedef, unravel_list(arr))
return flat, unravel_pytree