import jax
from jax import random, value_and_grad, vmap
from jax.flatten_util import ravel_pytree
import jax.numpy as np
from jax.tree_util import tree_flatten
import numpyro
import numpyro.distributions as dist
from numpyro.distributions.constraints import ComposeTransform, biject_to, real
from numpyro.handlers import block, seed, substitute, trace, condition
from numpyro.util import while_loop
[docs]def log_density(model, model_args, model_kwargs, params, skip_dist_transforms=False):
"""
Computes log of joint density for the model given latent values ``params``.
:param model: Python callable containing NumPyro primitives.
:param tuple model_args: args provided to the model.
:param dict model_kwargs`: kwargs provided to the model.
:param dict params: dictionary of current parameter values keyed by site
name.
:param bool skip_dist_transforms: whether to compute log probability of a site
(if its prior is a transformed distribution) in its base distribution
domain.
:return: log of joint density and a corresponding model trace
"""
# We skip transforms in
# + autoguide's model
# + hmc's model
# We apply transforms in
# + autoguide's guide
# + svi's model + guide
if skip_dist_transforms:
model = substitute(model, base_param_map=params)
else:
model = substitute(model, param_map=params)
model_trace = trace(model).get_trace(*model_args, **model_kwargs)
log_joint = 0.
for site in model_trace.values():
if site['type'] == 'sample':
value = site['value']
intermediates = site['intermediates']
if intermediates:
if skip_dist_transforms:
log_prob = site['fn'].base_dist.log_prob(intermediates[0][0])
else:
log_prob = site['fn'].log_prob(value, intermediates)
else:
log_prob = site['fn'].log_prob(value)
log_prob = np.sum(log_prob)
if 'scale' in site:
log_prob = site['scale'] * log_prob
log_joint = log_joint + log_prob
return log_joint, model_trace
[docs]def constrain_fn(model, model_args, model_kwargs, transforms, params):
"""
Gets value at each latent site in `model` given unconstrained parameters `params`.
The `transforms` is used to transform these unconstrained parameters to base values
of the corresponding priors in `model`. If a prior is a transformed distribution,
the corresponding base value lies in the support of base distribution. Otherwise,
the base value lies in the support of the distribution.
:param model: a callable containing NumPyro primitives.
:param tuple model_args: args provided to the model.
:param dict model_kwargs`: kwargs provided to the model.
:param dict transforms: dictionary of transforms keyed by names. Names in
`transforms` and `params` should align.
:param dict params: dictionary of unconstrained values keyed by site
names.
:return: `dict` of transformed params.
"""
params_constrained = transform_fn(transforms, params)
substituted_model = substitute(model, base_param_map=params_constrained)
model_trace = trace(substituted_model).get_trace(*model_args, **model_kwargs)
return {k: model_trace[k]['value'] for k, v in params.items() if k in model_trace}
[docs]def potential_energy(model, model_args, model_kwargs, inv_transforms, params):
"""
Makes a callable which computes potential energy of a model given unconstrained params.
The `inv_transforms` is used to transform these unconstrained parameters to base values
of the corresponding priors in `model`. If a prior is a transformed distribution,
the corresponding base value lies in the support of base distribution. Otherwise,
the base value lies in the support of the distribution.
:param model: a callable containing NumPyro primitives.
:param tuple model_args: args provided to the model.
:param dict model_kwargs`: kwargs provided to the model.
:param dict inv_transforms: dictionary of transforms keyed by names.
:return: a callable that computes potential energy given unconstrained parameters.
"""
params_constrained = transform_fn(inv_transforms, params)
log_joint, model_trace = log_density(model, model_args, model_kwargs, params_constrained,
skip_dist_transforms=True)
for name, t in inv_transforms.items():
t_log_det = np.sum(t.log_abs_det_jacobian(params[name], params_constrained[name]))
if 'scale' in model_trace[name]:
t_log_det = model_trace[name]['scale'] * t_log_det
log_joint = log_joint + t_log_det
return - log_joint
[docs]def init_to_prior(site, skip_param=False):
"""
Initialize to a prior sample.
"""
return init_to_median(site, num_samples=1, skip_param=skip_param)
[docs]def init_to_feasible(site, skip_param=False):
"""
Initialize to an arbitrary feasible point, ignoring distribution
parameters.
"""
return init_to_uniform(site, radius=0, skip_param=skip_param)
[docs]def find_valid_initial_params(rng, model, *model_args, init_strategy=init_to_uniform,
param_as_improper=False, prototype_params=None, **model_kwargs):
"""
Given a model with Pyro primitives, returns an initial valid unconstrained
parameters. This function also returns an `is_valid` flag to say whether the
initial parameters are valid.
:param jax.random.PRNGKey rng: random number generator seed to
sample from the prior. The returned `init_params` will have the
batch shape ``rng.shape[:-1]``.
:param model: Python callable containing Pyro primitives.
:param `*model_args`: args provided to the model.
:param callable init_strategy: a per-site initialization function.
:param bool param_as_improper: a flag to decide whether to consider sites with
`param` statement as sites with improper priors.
:param `**model_kwargs`: kwargs provided to the model.
:return: tuple of (`init_params`, `is_valid`).
"""
init_strategy = jax.partial(init_strategy, skip_param=not param_as_improper)
def cond_fn(state):
i, _, _, is_valid = state
return (i < 100) & (~is_valid)
def body_fn(state):
i, key, _, _ = state
key, subkey = random.split(key)
# Wrap model in a `substitute` handler to initialize from `init_loc_fn`.
# Use `block` to not record sample primitives in `init_loc_fn`.
seeded_model = substitute(model, substitute_fn=block(seed(init_strategy, subkey)))
model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs)
constrained_values, inv_transforms = {}, {}
for k, v in model_trace.items():
if v['type'] == 'sample' and not v['is_observed']:
if v['intermediates']:
constrained_values[k] = v['intermediates'][0][0]
inv_transforms[k] = biject_to(v['fn'].base_dist.support)
else:
constrained_values[k] = v['value']
inv_transforms[k] = biject_to(v['fn'].support)
elif v['type'] == 'param' and param_as_improper:
constraint = v['kwargs'].pop('constraint', real)
transform = biject_to(constraint)
if isinstance(transform, ComposeTransform):
base_transform = transform.parts[0]
inv_transforms[k] = base_transform
constrained_values[k] = base_transform(transform.inv(v['value']))
else:
inv_transforms[k] = transform
constrained_values[k] = v['value']
params = transform_fn(inv_transforms,
{k: v for k, v in constrained_values.items()},
invert=True)
potential_fn = jax.partial(potential_energy, model, model_args, model_kwargs, inv_transforms)
pe, param_grads = value_and_grad(potential_fn)(params)
z_grad = ravel_pytree(param_grads)[0]
is_valid = np.isfinite(pe) & np.all(np.isfinite(z_grad))
return i + 1, key, params, is_valid
if prototype_params is not None:
init_state = (0, rng, prototype_params, False)
else:
init_state = body_fn((0, rng, None, None))
_, _, init_params, is_valid = while_loop(cond_fn, body_fn, init_state)
return init_params, is_valid
[docs]def predictive(rng, model, posterior_samples, return_sites=None, *args, **kwargs):
"""
Run model by sampling latent parameters from `posterior_samples`, and return
values at sample sites from the forward run. By default, only sites not contained in
`posterior_samples` are returned. This can be modified by changing the `return_sites`
keyword argument.
.. warning::
The interface for the `predictive` function is experimental, and
might change in the future.
:param jax.random.PRNGKey rng: seed to draw samples
:param model: Python callable containing Pyro primitives.
:param dict posterior_samples: dictionary of samples from the posterior.
:param list return_sites: sites to return; by default only sample sites not present
in `posterior_samples` are returned.
:param args: model arguments.
:param kwargs: model kwargs.
:return: dict of samples from the predictive distribution.
"""
# TODO: consider to support `num_samples`, `return_traces`, `parallel` kwargs
def single_prediction(rng, samples):
model_trace = trace(seed(condition(model, samples), rng)).get_trace(*args, **kwargs)
sites = model_trace.keys() - samples.keys() if return_sites is None else return_sites
return {name: site['value'] for name, site in model_trace.items() if name in sites}
num_samples = tree_flatten(posterior_samples)[0][0].shape[0]
rngs = random.split(rng, num_samples)
return vmap(single_prediction)(rngs, posterior_samples)