Source code for numpyro.infer.util

from functools import partial
import warnings

import jax
from jax import device_get, lax, random, value_and_grad, vmap
from jax.flatten_util import ravel_pytree
import jax.numpy as np

import numpyro
import numpyro.distributions as dist
from numpyro.distributions.constraints import real
from numpyro.distributions.transforms import ComposeTransform, biject_to
from numpyro.handlers import block, seed, substitute, trace
from numpyro.util import not_jax_tracer, while_loop

__all__ = [
    'find_valid_initial_params',
    'get_potential_fn',
    'log_density',
    'log_likelihood',
    'init_to_feasible',
    'init_to_median',
    'init_to_prior',
    'init_to_uniform',
    'init_to_value',
    'potential_energy',
    'initialize_model',
    'Predictive',
    'transformed_potential_energy',
]


[docs]def log_density(model, model_args, model_kwargs, params, skip_dist_transforms=False): """ (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given latent values ``params``. :param model: Python callable containing NumPyro primitives. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :param dict params: dictionary of current parameter values keyed by site name. :param bool skip_dist_transforms: whether to compute log probability of a site (if its prior is a transformed distribution) in its base distribution domain. :return: log of joint density and a corresponding model trace """ # We skip transforms in # + autoguide's model # + hmc's model # We apply transforms in # + autoguide's guide # + svi's model + guide if skip_dist_transforms: model = substitute(model, base_param_map=params) else: model = substitute(model, param_map=params) model_trace = trace(model).get_trace(*model_args, **model_kwargs) log_joint = 0. for site in model_trace.values(): if site['type'] == 'sample': value = site['value'] intermediates = site['intermediates'] if intermediates: if skip_dist_transforms: log_prob = site['fn'].base_dist.log_prob(intermediates[0][0]) else: log_prob = site['fn'].log_prob(value, intermediates) else: log_prob = site['fn'].log_prob(value) log_prob = np.sum(log_prob) if 'scale' in site: log_prob = site['scale'] * log_prob log_joint = log_joint + log_prob return log_joint, model_trace
[docs]def transform_fn(transforms, params, invert=False): """ (EXPERIMENTAL INTERFACE) Callable that applies a transformation from the `transforms` dict to values in the `params` dict and returns the transformed values keyed on the same names. :param transforms: Dictionary of transforms keyed by names. Names in `transforms` and `params` should align. :param params: Dictionary of arrays keyed by names. :param invert: Whether to apply the inverse of the transforms. :return: `dict` of transformed params. """ if invert: transforms = {k: v.inv for k, v in transforms.items()} return {k: transforms[k](v) if k in transforms else v for k, v in params.items()}
[docs]def constrain_fn(model, transforms, model_args, model_kwargs, params): """ (EXPERIMENTAL INTERFACE) Gets value at each latent site in `model` given unconstrained parameters `params`. The `transforms` is used to transform these unconstrained parameters to base values of the corresponding priors in `model`. If a prior is a transformed distribution, the corresponding base value lies in the support of base distribution. Otherwise, the base value lies in the support of the distribution. :param model: a callable containing NumPyro primitives. :param dict transforms: dictionary of transforms keyed by names. Names in `transforms` and `params` should align. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :param dict params: dictionary of unconstrained values keyed by site names. :return: `dict` of transformed params. """ params_constrained = transform_fn(transforms, params) substituted_model = substitute(model, base_param_map=params_constrained) model_trace = trace(substituted_model).get_trace(*model_args, **model_kwargs) return {k: model_trace[k]['value'] for k, v in params.items() if k in model_trace}
[docs]def potential_energy(model, inv_transforms, model_args, model_kwargs, params): """ (EXPERIMENTAL INTERFACE) Computes potential energy of a model given unconstrained params. The `inv_transforms` is used to transform these unconstrained parameters to base values of the corresponding priors in `model`. If a prior is a transformed distribution, the corresponding base value lies in the support of base distribution. Otherwise, the base value lies in the support of the distribution. :param model: a callable containing NumPyro primitives. :param dict inv_transforms: dictionary of transforms keyed by names. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :param dict params: unconstrained parameters of `model`. :return: potential energy given unconstrained parameters. """ params_constrained = transform_fn(inv_transforms, params) log_joint, model_trace = log_density(model, model_args, model_kwargs, params_constrained, skip_dist_transforms=True) for name, t in inv_transforms.items(): t_log_det = np.sum(t.log_abs_det_jacobian(params[name], params_constrained[name])) if 'scale' in model_trace[name]: t_log_det = model_trace[name]['scale'] * t_log_det log_joint = log_joint + t_log_det return - log_joint
def transformed_potential_energy(potential_energy, inv_transform, z): """ Given a potential energy `p(x)`, compute potential energy of `p(z)` with `z = transform(x)` (i.e. `x = inv_transform(z)`). :param potential_energy: a callable to compute potential energy of original variable `x`. :param ~numpyro.distributions.constraints.Transform inv_transform: a transform from the new variable `z` to `x`. :param z: new variable to compute potential energy :return: potential energy of `z`. """ x, intermediates = inv_transform.call_with_intermediates(z) logdet = inv_transform.log_abs_det_jacobian(z, x, intermediates=intermediates) return potential_energy(x) - logdet def _init_to_median(site, num_samples=15, skip_param=False): if site['type'] == 'sample' and not site['is_observed']: if isinstance(site['fn'], dist.TransformedDistribution): fn = site['fn'].base_dist else: fn = site['fn'] samples = numpyro.sample('_init', fn, sample_shape=(num_samples,) + site['kwargs']['sample_shape']) return np.median(samples, axis=0) if site['type'] == 'param' and not skip_param: # return base value of param site constraint = site['kwargs'].pop('constraint', real) transform = biject_to(constraint) value = site['args'][0] if isinstance(transform, ComposeTransform): base_transform = transform.parts[0] value = base_transform(transform.inv(value)) return value
[docs]def init_to_median(num_samples=15): """ Initialize to the prior median. :param int num_samples: number of prior points to calculate median. """ return partial(_init_to_median, num_samples=num_samples)
[docs]def init_to_prior(): """ Initialize to a prior sample. """ return init_to_median(num_samples=1)
def _init_to_uniform(site, radius=2, skip_param=False): if site['type'] == 'sample' and not site['is_observed']: if isinstance(site['fn'], dist.TransformedDistribution): fn = site['fn'].base_dist else: fn = site['fn'] value = numpyro.sample('_init', fn, sample_shape=site['kwargs']['sample_shape']) base_transform = biject_to(fn.support) unconstrained_value = numpyro.sample('_unconstrained_init', dist.Uniform(-radius, radius), sample_shape=np.shape(base_transform.inv(value))) return base_transform(unconstrained_value) if site['type'] == 'param' and not skip_param: # return base value of param site constraint = site['kwargs'].pop('constraint', real) transform = biject_to(constraint) value = site['args'][0] unconstrained_value = numpyro.sample('_unconstrained_init', dist.Uniform(-radius, radius), sample_shape=np.shape(transform.inv(value))) if isinstance(transform, ComposeTransform): base_transform = transform.parts[0] else: base_transform = transform return base_transform(unconstrained_value)
[docs]def init_to_uniform(radius=2): """ Initialize to a random point in the area `(-radius, radius)` of unconstrained domain. :param float radius: specifies the range to draw an initial point in the unconstrained domain. """ return partial(_init_to_uniform, radius=radius)
[docs]def init_to_feasible(): """ Initialize to an arbitrary feasible point, ignoring distribution parameters. """ return init_to_uniform(radius=0)
def _init_to_value(site, values={}, skip_param=False): if site['type'] == 'sample' and not site['is_observed']: if site['name'] not in values: return _init_to_uniform(site, skip_param=skip_param) value = values[site['name']] if isinstance(site['fn'], dist.TransformedDistribution): value = ComposeTransform(site['fn'].transforms).inv(value) return value if site['type'] == 'param' and not skip_param: # return base value of param site constraint = site['kwargs'].pop('constraint', real) transform = biject_to(constraint) value = site['args'][0] if isinstance(transform, ComposeTransform): base_transform = transform.parts[0] value = base_transform(transform.inv(value)) return value
[docs]def init_to_value(values): """ Initialize to the value specified in `values`. We defer to :func:`init_to_uniform` strategy for sites which do not appear in `values`. :param dict values: dictionary of initial values keyed by site name. """ return partial(_init_to_value, values=values)
[docs]def find_valid_initial_params(rng_key, model, init_strategy=init_to_uniform(), param_as_improper=False, model_args=(), model_kwargs=None): """ (EXPERIMENTAL INTERFACE) Given a model with Pyro primitives, returns an initial valid unconstrained value for all the parameters. This function also returns an `is_valid` flag to say whether the initial parameters are valid. Parameter values are considered valid if the values and the gradients for the log density have finite values. :param jax.random.PRNGKey rng_key: random number generator seed to sample from the prior. The returned `init_params` will have the batch shape ``rng_key.shape[:-1]``. :param model: Python callable containing Pyro primitives. :param callable init_strategy: a per-site initialization function. :param bool param_as_improper: a flag to decide whether to consider sites with `param` statement as sites with improper priors. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :return: tuple of (`init_params`, `is_valid`). """ init_strategy = jax.partial(init_strategy, skip_param=not param_as_improper) def cond_fn(state): i, _, _, is_valid = state return (i < 100) & (~is_valid) def body_fn(state): i, key, _, _ = state key, subkey = random.split(key) # Wrap model in a `substitute` handler to initialize from `init_loc_fn`. # Use `block` to not record sample primitives in `init_loc_fn`. seeded_model = substitute(model, substitute_fn=block(seed(init_strategy, subkey))) model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs) constrained_values, inv_transforms = {}, {} for k, v in model_trace.items(): if v['type'] == 'sample' and not v['is_observed']: if v['intermediates']: constrained_values[k] = v['intermediates'][0][0] inv_transforms[k] = biject_to(v['fn'].base_dist.support) else: constrained_values[k] = v['value'] inv_transforms[k] = biject_to(v['fn'].support) elif v['type'] == 'param' and param_as_improper: constraint = v['kwargs'].pop('constraint', real) transform = biject_to(constraint) if isinstance(transform, ComposeTransform): base_transform = transform.parts[0] inv_transforms[k] = base_transform constrained_values[k] = base_transform(transform.inv(v['value'])) else: inv_transforms[k] = transform constrained_values[k] = v['value'] params = transform_fn(inv_transforms, {k: v for k, v in constrained_values.items()}, invert=True) potential_fn = jax.partial(potential_energy, model, inv_transforms, model_args, model_kwargs) pe, param_grads = value_and_grad(potential_fn)(params) z_grad = ravel_pytree(param_grads)[0] is_valid = np.isfinite(pe) & np.all(np.isfinite(z_grad)) return i + 1, key, params, is_valid def _find_valid_params(rng_key_): _, _, prototype_params, is_valid = init_state = body_fn((0, rng_key_, None, None)) # Early return if valid params found. if not_jax_tracer(is_valid): if device_get(is_valid): return prototype_params, is_valid _, _, init_params, is_valid = while_loop(cond_fn, body_fn, init_state) return init_params, is_valid # Handle possible vectorization if rng_key.ndim == 1: init_params, is_valid = _find_valid_params(rng_key) else: init_params, is_valid = lax.map(_find_valid_params, rng_key) return init_params, is_valid
def get_model_transforms(rng_key, model, model_args=(), model_kwargs=None): model_kwargs = {} if model_kwargs is None else model_kwargs seeded_model = seed(model, rng_key if rng_key.ndim == 1 else rng_key[0]) model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs) inv_transforms = {} has_transformed_dist = False for k, v in model_trace.items(): if v['type'] == 'sample' and not v['is_observed']: if v['intermediates']: inv_transforms[k] = biject_to(v['fn'].base_dist.support) has_transformed_dist = True else: inv_transforms[k] = biject_to(v['fn'].support) elif v['type'] == 'param': constraint = v['kwargs'].pop('constraint', real) transform = biject_to(constraint) if isinstance(transform, ComposeTransform): inv_transforms[k] = transform.parts[0] has_transformed_dist = True else: inv_transforms[k] = transform return inv_transforms, has_transformed_dist def get_potential_fn(rng_key, model, dynamic_args=False, model_args=(), model_kwargs=None): """ (EXPERIMENTAL INTERFACE) Given a model with Pyro primitives, returns a function which, given unconstrained parameters, evaluates the potential energy (negative log joint density). In addition, this returns a function to transform unconstrained values at sample sites to constrained values within their respective support. :param jax.random.PRNGKey rng_key: random number generator seed to sample from the prior. The returned `init_params` will have the batch shape ``rng_key.shape[:-1]``. :param model: Python callable containing Pyro primitives. :param bool dynamic_args: if `True`, the `potential_fn` and `constraints_fn` are themselves dependent on model arguments. When provided a `*model_args, **model_kwargs`, they return `potential_fn` and `constraints_fn` callables, respectively. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :return: tuple of (`potential_fn`, `constrain_fn`). The latter is used to constrain unconstrained samples (e.g. those returned by HMC) to values that lie within the site's support. """ if dynamic_args: def potential_fn(*args, **kwargs): inv_transforms, has_transformed_dist = get_model_transforms(rng_key, model, args, kwargs) return jax.partial(potential_energy, model, inv_transforms, args, kwargs) def constrain_fun(*args, **kwargs): inv_transforms, has_transformed_dist = get_model_transforms(rng_key, model, args, kwargs) if has_transformed_dist: return jax.partial(constrain_fn, model, inv_transforms, args, kwargs) else: return jax.partial(transform_fn, inv_transforms) else: inv_transforms, has_transformed_dist = get_model_transforms(rng_key, model, model_args, model_kwargs) potential_fn = jax.partial(potential_energy, model, inv_transforms, model_args, model_kwargs) if has_transformed_dist: constrain_fun = jax.partial(constrain_fn, model, inv_transforms, model_args, model_kwargs) else: constrain_fun = jax.partial(transform_fn, inv_transforms) return potential_fn, constrain_fun
[docs]def initialize_model(rng_key, model, init_strategy=init_to_uniform(), dynamic_args=False, model_args=(), model_kwargs=None): """ (EXPERIMENTAL INTERFACE) Helper function that calls :func:`~numpyro.infer.util.get_potential_fn` and :func:`~numpyro.infer.util.find_valid_initial_params` under the hood to return a tuple of (`init_params`, `potential_fn`, `constrain_fn`). :param jax.random.PRNGKey rng_key: random number generator seed to sample from the prior. The returned `init_params` will have the batch shape ``rng_key.shape[:-1]``. :param model: Python callable containing Pyro primitives. :param callable init_strategy: a per-site initialization function. See :ref:`init_strategy` section for available functions. :param bool dynamic_args: if `True`, the `potential_fn` and `constraints_fn` are themselves dependent on model arguments. When provided a `*model_args, **model_kwargs`, they return `potential_fn` and `constraints_fn` callables, respectively. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :return: tuple of (`init_params`, `potential_fn`, `constrain_fn`), `init_params` are values from the prior used to initiate MCMC, `constrain_fn` is a callable that uses inverse transforms to convert unconstrained HMC samples to constrained values that lie within the site's support. """ if model_kwargs is None: model_kwargs = {} potential_fun, constrain_fun = get_potential_fn(rng_key if rng_key.ndim == 1 else rng_key[0], model, dynamic_args=dynamic_args, model_args=model_args, model_kwargs=model_kwargs) init_params, is_valid = find_valid_initial_params(rng_key, model, init_strategy=init_strategy, param_as_improper=True, model_args=model_args, model_kwargs=model_kwargs) if not_jax_tracer(is_valid): if device_get(~np.all(is_valid)): raise RuntimeError("Cannot find valid initial parameters. Please check your model again.") return init_params, potential_fun, constrain_fun
def _predictive(rng_key, model, posterior_samples, num_samples, return_sites=None, parallel=True, model_args=(), model_kwargs={}): rng_keys = random.split(rng_key, num_samples) def single_prediction(val): rng_key, samples = val model_trace = trace(seed(substitute(model, samples), rng_key)).get_trace( *model_args, **model_kwargs) if return_sites is not None: if return_sites == '': sites = {k for k, site in model_trace.items() if site['type'] != 'plate'} else: sites = return_sites else: sites = {k for k, site in model_trace.items() if site['type'] != 'plate' and k not in samples} return {name: site['value'] for name, site in model_trace.items() if name in sites} if parallel: return vmap(single_prediction)((rng_keys, posterior_samples)) else: return lax.map(single_prediction, (rng_keys, posterior_samples))
[docs]class Predictive(object): """ This class is used to construct predictive distribution. The predictive distribution is obtained by running model conditioned on latent samples from `posterior_samples`. .. warning:: The interface for the `Predictive` class is experimental, and might change in the future. :param model: Python callable containing Pyro primitives. :param dict posterior_samples: dictionary of samples from the posterior. :param callable guide: optional guide to get posterior samples of sites not present in `posterior_samples`. :param dict params: dictionary of values for param sites of model/guide. :param int num_samples: number of samples :param list return_sites: sites to return; by default only sample sites not present in `posterior_samples` are returned. :param bool parallel: whether to predict in parallel using JAX vectorized map :func:`jax.vmap`. Defaults to False. :return: dict of samples from the predictive distribution. """ def __init__(self, model, posterior_samples=None, guide=None, params=None, num_samples=None, return_sites=None, parallel=False): if posterior_samples is None and num_samples is None: raise ValueError("Either posterior_samples or num_samples must be specified.") posterior_samples = {} if posterior_samples is None else posterior_samples for name, sample in posterior_samples.items(): batch_size = sample.shape[0] if (num_samples is not None) and (num_samples != batch_size): warnings.warn("Sample's leading dimension size {} is different from the " "provided {} num_samples argument. Defaulting to {}." .format(batch_size, num_samples, batch_size), UserWarning) num_samples = batch_size if num_samples is None: raise ValueError("No sample sites in posterior samples to infer `num_samples`.") if return_sites is not None: assert isinstance(return_sites, (list, tuple, set)) self.model = model self.posterior_samples = {} if posterior_samples is None else posterior_samples self.num_samples = num_samples self.guide = guide self.params = {} if params is None else params self.return_sites = return_sites self.parallel = parallel
[docs] def get_samples(self, rng_key, *args, **kwargs): """ Returns dict of samples from the predictive distribution. By default, only sample sites not contained in `posterior_samples` are returned. This can be modified by changing the `return_sites` keyword argument of this :class:`Predictive` instance. :param jax.random.PRNGKey rng_key: random key to draw samples. :param args: model arguments. :param kwargs: model kwargs. """ posterior_samples = self.posterior_samples if self.guide is not None: rng_key, guide_rng_key = random.split(rng_key) # use return_sites='' as a special signal to return all sites guide = substitute(self.guide, self.params) posterior_samples = _predictive(guide_rng_key, guide, posterior_samples, self.num_samples, return_sites='', parallel=self.parallel, model_args=args, model_kwargs=kwargs) model = substitute(self.model, self.params) return _predictive(rng_key, model, posterior_samples, self.num_samples, return_sites=self.return_sites, parallel=self.parallel, model_args=args, model_kwargs=kwargs)
[docs]def log_likelihood(model, posterior_samples, *args, **kwargs): """ (EXPERIMENTAL INTERFACE) Returns log likelihood at observation nodes of model, given samples of all latent variables. :param model: Python callable containing Pyro primitives. :param dict posterior_samples: dictionary of samples from the posterior. :param args: model arguments. :param kwargs: model kwargs. :return: dict of log likelihoods at observation sites. """ def single_loglik(samples): model_trace = trace(substitute(model, samples)).get_trace(*args, **kwargs) return {name: site['fn'].log_prob(site['value']) for name, site in model_trace.items() if site['type'] == 'sample' and site['is_observed']} return vmap(single_loglik)(posterior_samples)