# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import namedtuple
from contextlib import contextmanager
from functools import partial
from typing import Callable, Dict, List, Optional
import warnings

import numpy as np

import jax
from jax import device_get, jacfwd, lax, random, tree_flatten, value_and_grad
from jax.flatten_util import ravel_pytree
from jax.lax import broadcast_shapes
import jax.numpy as jnp
from jax.tree_util import tree_map

import numpyro
from numpyro.distributions import constraints
from numpyro.distributions.transforms import biject_to
from numpyro.distributions.util import is_identically_one, sum_rightmost
from numpyro.handlers import condition, replay, seed, substitute, trace
from numpyro.infer.initialization import init_to_uniform, init_to_value
from numpyro.util import (

ModelInfo = namedtuple(
    "ModelInfo", ["param_info", "potential_fn", "postprocess_fn", "model_trace"]
ParamInfo = namedtuple("ParamInfo", ["z", "potential_energy", "z_grad"])

[docs]def log_density(model, model_args, model_kwargs, params): """ (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given latent values ``params``. :param model: Python callable containing NumPyro primitives. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :param dict params: dictionary of current parameter values keyed by site name. :return: log of joint density and a corresponding model trace """ model = substitute(model, data=params) model_trace = trace(model).get_trace(*model_args, **model_kwargs) log_joint = jnp.zeros(()) for site in model_trace.values(): if site["type"] == "sample": value = site["value"] intermediates = site["intermediates"] scale = site["scale"] if intermediates: log_prob = site["fn"].log_prob(value, intermediates) else: guide_shape = jnp.shape(value) model_shape = tuple( site["fn"].shape() ) # TensorShape from tfp needs casting to tuple try: broadcast_shapes(guide_shape, model_shape) except ValueError: raise ValueError( "Model and guide shapes disagree at site: '{}': {} vs {}".format( site["name"], model_shape, guide_shape ) ) log_prob = site["fn"].log_prob(value) if (scale is not None) and (not is_identically_one(scale)): log_prob = scale * log_prob log_prob = jnp.sum(log_prob) log_joint = log_joint + log_prob return log_joint, model_trace
class _without_rsample_stop_gradient(numpyro.primitives.Messenger): """ Stop gradient for samples at latent sample sites for which has_rsample=False. """ def postprocess_message(self, msg): if ( msg["type"] == "sample" and (not msg["is_observed"]) and (not msg["fn"].has_rsample) ): msg["value"] = lax.stop_gradient(msg["value"]) # TODO: reconsider this logic # here we clear all the cached value so that gradients of log_prob(value) w.r.t. # all parameters of the transformed distributions match the behavior of # TransformedDistribution(d, transform) in Pyro with transform.cache_size == 0 msg["intermediates"] = None def get_importance_trace(model, guide, args, kwargs, params): """ (EXPERIMENTAL) Returns traces from the guide and the model that is run against it. The returned traces also store the log probability at each site. .. note:: Gradients are blocked at latent sites which do not have reparametrized samplers. """ guide = substitute(guide, data=params) with _without_rsample_stop_gradient(): guide_trace = trace(guide).get_trace(*args, **kwargs) model = substitute(replay(model, guide_trace), data=params) model_trace = trace(model).get_trace(*args, **kwargs) for tr in (guide_trace, model_trace): for site in tr.values(): if site["type"] == "sample": if "log_prob" not in site: value = site["value"] intermediates = site["intermediates"] scale = site["scale"] if intermediates: log_prob = site["fn"].log_prob(value, intermediates) else: log_prob = site["fn"].log_prob(value) if (scale is not None) and (not is_identically_one(scale)): log_prob = scale * log_prob site["log_prob"] = log_prob return model_trace, guide_trace
[docs]def transform_fn(transforms, params, invert=False): """ (EXPERIMENTAL INTERFACE) Callable that applies a transformation from the `transforms` dict to values in the `params` dict and returns the transformed values keyed on the same names. :param transforms: Dictionary of transforms keyed by names. Names in `transforms` and `params` should align. :param params: Dictionary of arrays keyed by names. :param invert: Whether to apply the inverse of the transforms. :return: `dict` of transformed params. """ if invert: transforms = {k: v.inv for k, v in transforms.items()} return {k: transforms[k](v) if k in transforms else v for k, v in params.items()}
[docs]def constrain_fn(model, model_args, model_kwargs, params, return_deterministic=False): """ (EXPERIMENTAL INTERFACE) Gets value at each latent site in `model` given unconstrained parameters `params`. The `transforms` is used to transform these unconstrained parameters to base values of the corresponding priors in `model`. If a prior is a transformed distribution, the corresponding base value lies in the support of base distribution. Otherwise, the base value lies in the support of the distribution. :param model: a callable containing NumPyro primitives. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :param dict params: dictionary of unconstrained values keyed by site names. :param bool return_deterministic: whether to return the value of `deterministic` sites from the model. Defaults to `False`. :return: `dict` of transformed params. """ def substitute_fn(site): if site["name"] in params: if site["type"] == "sample": with helpful_support_errors(site): return biject_to(site["fn"].support)(params[site["name"]]) else: return params[site["name"]] substituted_model = substitute(model, substitute_fn=substitute_fn) model_trace = trace(substituted_model).get_trace(*model_args, **model_kwargs) return { k: v["value"] for k, v in model_trace.items() if (k in params) or (return_deterministic and (v["type"] == "deterministic")) }
def _unconstrain_reparam(params, site): name = site["name"] if name in params: p = params[name] support = site["fn"].support with helpful_support_errors(site): t = biject_to(support) # in scan, we might only want to substitute an item at index i, rather than the whole sequence i = site["infer"].get("_scan_current_index", None) if i is not None: event_dim_shift = t.codomain.event_dim - t.domain.event_dim expected_unconstrained_dim = len(site["fn"].shape()) - event_dim_shift # check if p has additional time dimension if jnp.ndim(p) > expected_unconstrained_dim: p = p[i] if support in [constraints.real, constraints.real_vector]: return p value = t(p) log_det = t.log_abs_det_jacobian(p, value) log_det = sum_rightmost( log_det, jnp.ndim(log_det) - jnp.ndim(value) + len(site["fn"].event_shape) ) if site["scale"] is not None: log_det = site["scale"] * log_det numpyro.factor("_{}_log_det".format(name), log_det) return value
[docs]def potential_energy(model, model_args, model_kwargs, params, enum=False): """ (EXPERIMENTAL INTERFACE) Computes potential energy of a model given unconstrained params. Under the hood, we will transform these unconstrained parameters to the values belong to the supports of the corresponding priors in `model`. :param model: a callable containing NumPyro primitives. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :param dict params: unconstrained parameters of `model`. :param bool enum: whether to enumerate over discrete latent sites. :return: potential energy given unconstrained parameters. """ if enum: from numpyro.contrib.funsor import log_density as log_density_ else: log_density_ = log_density substituted_model = substitute( model, substitute_fn=partial(_unconstrain_reparam, params) ) # no param is needed for log_density computation because we already substitute log_joint, model_trace = log_density_( substituted_model, model_args, model_kwargs, {} ) return -log_joint
def _init_to_unconstrained_value(site=None, values={}): if site is None: return partial(_init_to_unconstrained_value, values=values)
[docs]def find_valid_initial_params( rng_key, model, *, init_strategy=init_to_uniform, enum=False, model_args=(), model_kwargs=None, prototype_params=None, forward_mode_differentiation=False, validate_grad=True, ): """ (EXPERIMENTAL INTERFACE) Given a model with Pyro primitives, returns an initial valid unconstrained value for all the parameters. This function also returns the corresponding potential energy, the gradients, and an `is_valid` flag to say whether the initial parameters are valid. Parameter values are considered valid if the values and the gradients for the log density have finite values. :param jax.random.PRNGKey rng_key: random number generator seed to sample from the prior. The returned `init_params` will have the batch shape ``rng_key.shape[:-1]``. :param model: Python callable containing Pyro primitives. :param callable init_strategy: a per-site initialization function. :param bool enum: whether to enumerate over discrete latent sites. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :param dict prototype_params: an optional prototype parameters, which is used to define the shape for initial parameters. :param bool forward_mode_differentiation: whether to use forward-mode differentiation or reverse-mode differentiation. Defaults to False. :param bool validate_grad: whether to validate gradient of the initial params. Defaults to True. :return: tuple of `init_params_info` and `is_valid`, where `init_params_info` is the tuple containing the initial params, their potential energy, and their gradients. """ model_kwargs = {} if model_kwargs is None else model_kwargs init_strategy = ( init_strategy if isinstance(init_strategy, partial) else init_strategy() ) # handle those init strategies differently to save computation if init_strategy.func is init_to_uniform: radius = init_strategy.keywords.get("radius") init_values = {} elif init_strategy.func is _init_to_unconstrained_value: radius = 2 init_values = init_strategy.keywords.get("values") else: radius = None def cond_fn(state): i, _, _, is_valid = state return (i < 100) & (~is_valid) def body_fn(state): i, key, _, _ = state key, subkey = random.split(key) if radius is None or prototype_params is None: # XXX: we don't want to apply enum to draw latent samples model_ = model if enum: from numpyro.contrib.funsor import enum as enum_handler if isinstance(model, substitute) and isinstance(model.fn, enum_handler): model_ = substitute(model.fn.fn, elif isinstance(model, enum_handler): model_ = model.fn # Wrap model in a `substitute` handler to initialize from `init_loc_fn`. seeded_model = substitute(seed(model_, subkey), substitute_fn=init_strategy) model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs) constrained_values, inv_transforms = {}, {} for k, v in model_trace.items(): if ( v["type"] == "sample" and not v["is_observed"] and not v["fn"].support.is_discrete ): constrained_values[k] = v["value"] with helpful_support_errors(v): inv_transforms[k] = biject_to(v["fn"].support) params = transform_fn( inv_transforms, {k: v for k, v in constrained_values.items()}, invert=True, ) else: # this branch doesn't require tracing the model params = {} for k, v in prototype_params.items(): if k in init_values: params[k] = init_values[k] else: params[k] = random.uniform( subkey, jnp.shape(v), minval=-radius, maxval=radius ) key, subkey = random.split(key) potential_fn = partial( potential_energy, model, model_args, model_kwargs, enum=enum ) if validate_grad: if forward_mode_differentiation: pe = potential_fn(params) z_grad = jacfwd(potential_fn)(params) else: pe, z_grad = value_and_grad(potential_fn)(params) z_grad_flat = ravel_pytree(z_grad)[0] is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat)) else: pe = potential_fn(params) is_valid = jnp.isfinite(pe) z_grad = None return i + 1, key, (params, pe, z_grad), is_valid def _find_valid_params(rng_key, exit_early=False): init_state = (0, rng_key, (prototype_params, 0.0, prototype_params), False) if exit_early and not_jax_tracer(rng_key): # Early return if valid params found. This is only helpful for single chain, # where we can avoid compiling body_fn in while_loop. _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state) if not_jax_tracer(is_valid): if device_get(is_valid): return (init_params, pe, z_grad), is_valid # XXX: this requires compiling the model, so for multi-chain, we trace the model 2-times # even if the init_state is a valid result _, _, (init_params, pe, z_grad), is_valid = while_loop( cond_fn, body_fn, init_state ) return (init_params, pe, z_grad), is_valid # Handle possible vectorization if rng_key.ndim == 1: (init_params, pe, z_grad), is_valid = _find_valid_params( rng_key, exit_early=True ) else: (init_params, pe, z_grad), is_valid =, rng_key) return (init_params, pe, z_grad), is_valid
def _get_model_transforms(model, model_args=(), model_kwargs=None): model_kwargs = {} if model_kwargs is None else model_kwargs model_trace = trace(model).get_trace(*model_args, **model_kwargs) inv_transforms = {} # model code may need to be replayed in the presence of deterministic sites replay_model = False has_enumerate_support = False for k, v in model_trace.items(): if v["type"] == "sample" and not v["is_observed"]: if v["fn"].support.is_discrete: enum_type = v["infer"].get("enumerate") if enum_type is not None and (enum_type != "parallel"): raise RuntimeError( "This algorithm might only work for discrete sites with" f" enumerate marked 'parallel'. But the site {k} is marked" f" as '{enum_type}'." ) has_enumerate_support = True if not v["fn"].has_enumerate_support: dist_name = type(v["fn"]).__name__ raise RuntimeError( "This algorithm might only work for discrete sites with" f" enumerate support. But the {dist_name} distribution at" f" site {k} does not have enumerate support." ) if enum_type is None: warnings.warn( "Some algorithms will automatically enumerate the discrete" f" latent site {k} of your model. In the future," " enumerated sites need to be marked with" " `infer={'enumerate': 'parallel'}`.", FutureWarning, stacklevel=find_stack_level(), ) else: support = v["fn"].support with helpful_support_errors(v, raise_warnings=True): inv_transforms[k] = biject_to(support) # XXX: the following code filters out most situations with dynamic supports args = () if isinstance(support, constraints._GreaterThan): args = ("lower_bound",) elif isinstance(support, constraints._Interval): args = ("lower_bound", "upper_bound") for arg in args: if not isinstance(getattr(support, arg), (int, float)): replay_model = True elif v["type"] == "deterministic": replay_model = True return inv_transforms, replay_model, has_enumerate_support, model_trace def _partial_args_kwargs(fn, *args, **kwargs): """Returns a partial function of `fn` and args, kwargs.""" return partial(fn, args, kwargs) def _drop_args_kwargs(fn, *args, **kwargs): """Returns the input function `fn`, ignoring args and kwargs.""" return fn def get_potential_fn( model, inv_transforms, *, enum=False, replay_model=False, dynamic_args=False, model_args=(), model_kwargs=None, ): """ (EXPERIMENTAL INTERFACE) Given a model with Pyro primitives, returns a function which, given unconstrained parameters, evaluates the potential energy (negative log joint density). In addition, this returns a function to transform unconstrained values at sample sites to constrained values within their respective support. :param model: Python callable containing Pyro primitives. :param dict inv_transforms: dictionary of transforms keyed by names. :param bool enum: whether to enumerate over discrete latent sites. :param bool replay_model: whether we need to replay model in `postprocess_fn` to obtain `deterministic` sites. :param bool dynamic_args: if `True`, the `potential_fn` and `constraints_fn` are themselves dependent on model arguments. When provided a `*model_args, **model_kwargs`, they return `potential_fn` and `constraints_fn` callables, respectively. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :return: tuple of (`potential_fn`, `postprocess_fn`). The latter is used to constrain unconstrained samples (e.g. those returned by HMC) to values that lie within the site's support, and return values at `deterministic` sites in the model. """ if dynamic_args: potential_fn = partial( _partial_args_kwargs, partial(potential_energy, model, enum=enum) ) if replay_model: # XXX: we seed to sample discrete sites (but not collect them) model_ = seed(model.fn, 0) if enum else model postprocess_fn = partial( _partial_args_kwargs, partial(constrain_fn, model, return_deterministic=True), ) else: postprocess_fn = partial( _drop_args_kwargs, partial(transform_fn, inv_transforms) ) else: model_kwargs = {} if model_kwargs is None else model_kwargs potential_fn = partial( potential_energy, model, model_args, model_kwargs, enum=enum ) if replay_model: model_ = seed(model.fn, 0) if enum else model postprocess_fn = partial( constrain_fn, model_, model_args, model_kwargs, return_deterministic=True, ) else: postprocess_fn = partial(transform_fn, inv_transforms) return potential_fn, postprocess_fn def _guess_max_plate_nesting(model_trace): """ Guesses max_plate_nesting by using model trace. This optimistically assumes static model structure. """ sites = [site for site in model_trace.values() if site["type"] == "sample"] dims = [ frame.dim for site in sites for frame in site["cond_indep_stack"] if frame.dim is not None ] max_plate_nesting = -min(dims) if dims else 0 return max_plate_nesting
[docs]def initialize_model( rng_key, model, *, init_strategy=init_to_uniform, dynamic_args=False, model_args=(), model_kwargs=None, forward_mode_differentiation=False, validate_grad=True, ): """ (EXPERIMENTAL INTERFACE) Helper function that calls :func:`~numpyro.infer.util.get_potential_fn` and :func:`~numpyro.infer.util.find_valid_initial_params` under the hood to return a tuple of (`init_params_info`, `potential_fn`, `postprocess_fn`, `model_trace`). :param jax.random.PRNGKey rng_key: random number generator seed to sample from the prior. The returned `init_params` will have the batch shape ``rng_key.shape[:-1]``. :param model: Python callable containing Pyro primitives. :param callable init_strategy: a per-site initialization function. See :ref:`init_strategy` section for available functions. :param bool dynamic_args: if `True`, the `potential_fn` and `constraints_fn` are themselves dependent on model arguments. When provided a `*model_args, **model_kwargs`, they return `potential_fn` and `constraints_fn` callables, respectively. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :param bool forward_mode_differentiation: whether to use forward-mode differentiation or reverse-mode differentiation. By default, we use reverse mode but the forward mode can be useful in some cases to improve the performance. In addition, some control flow utility on JAX such as `jax.lax.while_loop` or `jax.lax.fori_loop` only supports forward-mode differentiation. See `JAX's The Autodiff Cookbook <>`_ for more information. :param bool validate_grad: whether to validate gradient of the initial params. Defaults to True. :return: a namedtupe `ModelInfo` which contains the fields (`param_info`, `potential_fn`, `postprocess_fn`, `model_trace`), where `param_info` is a namedtuple `ParamInfo` containing values from the prior used to initiate MCMC, their corresponding potential energy, and their gradients; `postprocess_fn` is a callable that uses inverse transforms to convert unconstrained HMC samples to constrained values that lie within the site's support, in addition to returning values at `deterministic` sites in the model. """ model_kwargs = {} if model_kwargs is None else model_kwargs substituted_model = substitute( seed(model, rng_key if jnp.ndim(rng_key) == 1 else rng_key[0]), substitute_fn=init_strategy, ) ( inv_transforms, replay_model, has_enumerate_support, model_trace, ) = _get_model_transforms(substituted_model, model_args, model_kwargs) # substitute param sites from model_trace to model so # we don't need to generate again parameters of `numpyro.module` model = substitute( model, data={ k: site["value"] for k, site in model_trace.items() if site["type"] in ["param"] }, ) constrained_values = { k: v["value"] for k, v in model_trace.items() if v["type"] == "sample" and not v["is_observed"] and not v["fn"].support.is_discrete } if has_enumerate_support: from numpyro.contrib.funsor import config_enumerate, enum if not isinstance(model, enum): max_plate_nesting = _guess_max_plate_nesting(model_trace) _validate_model(model_trace, plate_warning="error") model = enum(config_enumerate(model), -max_plate_nesting - 1) else: _validate_model(model_trace, plate_warning="loose") potential_fn, postprocess_fn = get_potential_fn( model, inv_transforms, replay_model=replay_model, enum=has_enumerate_support, dynamic_args=dynamic_args, model_args=model_args, model_kwargs=model_kwargs, ) init_strategy = ( init_strategy if isinstance(init_strategy, partial) else init_strategy() ) if (init_strategy.func is init_to_value) and not replay_model: init_values = init_strategy.keywords.get("values") unconstrained_values = transform_fn(inv_transforms, init_values, invert=True) init_strategy = _init_to_unconstrained_value(values=unconstrained_values) prototype_params = transform_fn(inv_transforms, constrained_values, invert=True) (init_params, pe, grad), is_valid = find_valid_initial_params( rng_key, substitute( model, data={ k: site["value"] for k, site in model_trace.items() if site["type"] in ["plate"] }, ), init_strategy=init_strategy, enum=has_enumerate_support, model_args=model_args, model_kwargs=model_kwargs, prototype_params=prototype_params, forward_mode_differentiation=forward_mode_differentiation, validate_grad=validate_grad, ) if not_jax_tracer(is_valid): if device_get(~jnp.all(is_valid)): with numpyro.validation_enabled(), trace() as tr: # validate parameters substituted_model(*model_args, **model_kwargs) # validate values for site in tr.values(): if site["type"] == "sample": with warnings.catch_warnings(record=True) as ws: site["fn"]._validate_sample(site["value"]) if len(ws) > 0: for w in ws: # at site information to the warning message w.message.args = ( "Site {}: {}".format( site["name"], w.message.args[0] ), ) + w.message.args[1:] warnings.showwarning( w.message, w.category, w.filename, w.lineno, file=w.file, line=w.line, ) raise RuntimeError( "Cannot find valid initial parameters. Please check your model again." ) return ModelInfo( ParamInfo(init_params, pe, grad), potential_fn, postprocess_fn, model_trace )
def _predictive( rng_key, model, posterior_samples, batch_shape, return_sites=None, infer_discrete=False, parallel=True, model_args=(), model_kwargs={}, ): masked_model = numpyro.handlers.mask(model, mask=False) if infer_discrete: # inspect the model to get some structure rng_key, subkey = random.split(rng_key) batch_ndim = len(batch_shape) prototype_sample = tree_map( lambda x: jnp.reshape(x, (-1,) + jnp.shape(x)[batch_ndim:])[0], posterior_samples, ) prototype_trace = trace( seed(substitute(masked_model, prototype_sample), subkey) ).get_trace(*model_args, **model_kwargs) first_available_dim = -_guess_max_plate_nesting(prototype_trace) - 1 def single_prediction(val): rng_key, samples = val if infer_discrete: from numpyro.contrib.funsor import config_enumerate from numpyro.contrib.funsor.discrete import _sample_posterior model_trace = prototype_trace temperature = 1 pred_samples = _sample_posterior( config_enumerate(condition(model, samples)), first_available_dim, temperature, rng_key, *model_args, **model_kwargs, ) else: model_trace = trace( seed(substitute(masked_model, samples), rng_key) ).get_trace(*model_args, **model_kwargs) pred_samples = {name: site["value"] for name, site in model_trace.items()} if return_sites is not None: if return_sites == "": sites = { k for k, site in model_trace.items() if site["type"] != "plate" } else: sites = return_sites else: sites = { k for k, site in model_trace.items() if (site["type"] == "sample" and k not in samples) or (site["type"] == "deterministic") } return {name: value for name, value in pred_samples.items() if name in sites} num_samples = int( if num_samples > 1: rng_key = random.split(rng_key, num_samples) rng_key = rng_key.reshape((*batch_shape, 2)) chunk_size = num_samples if parallel else 1 return soft_vmap( single_prediction, (rng_key, posterior_samples), len(batch_shape), chunk_size )
[docs]class Predictive(object): """ This class is used to construct predictive distribution. The predictive distribution is obtained by running model conditioned on latent samples from `posterior_samples`. .. warning:: The interface for the `Predictive` class is experimental, and might change in the future. :param model: Python callable containing Pyro primitives. :param dict posterior_samples: dictionary of samples from the posterior. :param callable guide: optional guide to get posterior samples of sites not present in `posterior_samples`. :param dict params: dictionary of values for param sites of model/guide. :param int num_samples: number of samples :param list return_sites: sites to return; by default only sample sites not present in `posterior_samples` are returned. :param bool infer_discrete: whether or not to sample discrete sites from the posterior, conditioned on observations and other latent values in ``posterior_samples``. Under the hood, those sites will be marked with ``site["infer"]["enumerate"] = "parallel"``. See how `infer_discrete` works at the `Pyro enumeration tutorial <>`_. Note that this requires ``funsor`` installation. :param bool parallel: whether to predict in parallel using JAX vectorized map :func:`jax.vmap`. Defaults to False. :param batch_ndims: the number of batch dimensions in posterior samples or parameters. If `None` defaults to 0 if guide is set (i.e. not `None`) and 1 otherwise. Usages for batched posterior samples: + set `batch_ndims=0` to get prediction for 1 single sample + set `batch_ndims=1` to get prediction for `posterior_samples` with shapes `(num_samples x ...)` (same as`batch_ndims=None` with `guide=None`) + set `batch_ndims=2` to get prediction for `posterior_samples` with shapes `(num_chains x N x ...)`. Note that if `num_samples` argument is not None, its value should be equal to `num_chains x N`. Usages for batched parameters: + set `batch_ndims=0` to get 1 sample from the guide and parameters (same as `batch_ndims=None` with guide) + set `batch_ndims=1` to get predictions from a one dimensional batch of the guide and parameters with shapes `(num_samples x batch_size x ...)` :return: dict of samples from the predictive distribution. **Example:** Given a model:: def model(X, y=None): ... return numpyro.sample("obs", likelihood, obs=y) you can sample from the prior predictive:: predictive = Predictive(model, num_samples=1000) y_pred = predictive(rng_key, X)["obs"] If you also have posterior samples, you can sample from the posterior predictive:: predictive = Predictive(model, posterior_samples=posterior_samples) y_pred = predictive(rng_key, X)["obs"] See docstrings for :class:`~numpyro.infer.svi.SVI` and :class:`~numpyro.infer.mcmc.MCMCKernel` to see example code of this in context. """ def __init__( self, model: Callable, posterior_samples: Optional[Dict] = None, *, guide: Optional[Callable] = None, params: Optional[Dict] = None, num_samples: Optional[int] = None, return_sites: Optional[List[str]] = None, infer_discrete: bool = False, parallel: bool = False, batch_ndims: Optional[int] = None, ): if posterior_samples is None and num_samples is None: raise ValueError( "Either posterior_samples or num_samples must be specified." ) batch_ndims = ( batch_ndims if batch_ndims is not None else 1 if guide is None else 0 ) posterior_samples = {} if posterior_samples is None else posterior_samples prototype_site = batch_shape = batch_size = None for name, sample in posterior_samples.items(): if batch_shape is not None and sample.shape[:batch_ndims] != batch_shape: raise ValueError( f"Batch shapes at site {name} and {prototype_site} " f"should be the same, but got " f"{sample.shape[:batch_ndims]} and {batch_shape}" ) else: prototype_site = name batch_shape = sample.shape[:batch_ndims] batch_size = int( if (num_samples is not None) and (num_samples != batch_size): warnings.warn( "Sample's batch dimension size {} is different from the " "provided {} num_samples argument. Defaulting to {}.".format( batch_size, num_samples, batch_size ), UserWarning, stacklevel=find_stack_level(), ) num_samples = batch_size if num_samples is None: raise ValueError( "No sample sites in posterior samples to infer `num_samples`." ) if batch_shape is None: batch_shape = (1,) * (batch_ndims - 1) + (num_samples,) if return_sites is not None: assert isinstance(return_sites, (list, tuple, set)) self.model = model self.posterior_samples = {} if posterior_samples is None else posterior_samples self.num_samples = num_samples = guide self.params = {} if params is None else params self.infer_discrete = infer_discrete self.return_sites = return_sites self.parallel = parallel self.batch_ndims = batch_ndims self._batch_shape = batch_shape def _call_with_params(self, rng_key, params, args, kwargs): posterior_samples = self.posterior_samples if is not None: rng_key, guide_rng_key = random.split(rng_key) # use return_sites='' as a special signal to return all sites guide = substitute(, params) posterior_samples = _predictive( guide_rng_key, guide, posterior_samples, self._batch_shape, return_sites="", parallel=self.parallel, model_args=args, model_kwargs=kwargs, ) model = substitute(self.model, self.params) return _predictive( rng_key, model, posterior_samples, self._batch_shape, return_sites=self.return_sites, infer_discrete=self.infer_discrete, parallel=self.parallel, model_args=args, model_kwargs=kwargs, ) def __call__(self, rng_key, *args, **kwargs): """ Returns dict of samples from the predictive distribution. By default, only sample sites not contained in `posterior_samples` are returned. This can be modified by changing the `return_sites` keyword argument of this :class:`Predictive` instance. :param jax.random.PRNGKey rng_key: random key to draw samples. :param args: model arguments. :param kwargs: model kwargs. """ if self.batch_ndims == 0 or self.params == {} or is None: return self._call_with_params(rng_key, self.params, args, kwargs) elif self.batch_ndims == 1: # batch over parameters batch_size = jnp.shape(tree_flatten(self.params)[0][0])[0] rng_keys = random.split(rng_key, batch_size) return jax.vmap( partial(self._call_with_params, args=args, kwargs=kwargs), in_axes=0, out_axes=1, )(rng_keys, self.params) else: raise NotImplementedError
[docs]def log_likelihood( model, posterior_samples, *args, parallel=False, batch_ndims=1, **kwargs ): """ (EXPERIMENTAL INTERFACE) Returns log likelihood at observation nodes of model, given samples of all latent variables. :param model: Python callable containing Pyro primitives. :param dict posterior_samples: dictionary of samples from the posterior. :param args: model arguments. :param batch_ndims: the number of batch dimensions in posterior samples. Some usages: + set `batch_ndims=0` to get log likelihoods for 1 single sample + set `batch_ndims=1` to get log likelihoods for `posterior_samples` with shapes `(num_samples x ...)` + set `batch_ndims=2` to get log likelihoods for `posterior_samples` with shapes `(num_chains x num_samples x ...)` :param kwargs: model kwargs. :return: dict of log likelihoods at observation sites. """ def single_loglik(samples): substituted_model = ( substitute(model, samples) if isinstance(samples, dict) else model ) model_trace = trace(substituted_model).get_trace(*args, **kwargs) return { name: site["fn"].log_prob(site["value"]) for name, site in model_trace.items() if site["type"] == "sample" and site["is_observed"] } prototype_site = batch_shape = None for name, sample in posterior_samples.items(): if batch_shape is not None and jnp.shape(sample)[:batch_ndims] != batch_shape: raise ValueError( f"Batch shapes at site {name} and {prototype_site} " f"should be the same, but got " f"{sample.shape[:batch_ndims]} and {batch_shape}" ) else: prototype_site = name batch_shape = jnp.shape(sample)[:batch_ndims] if batch_shape is None: # posterior_samples is an empty dict batch_shape = (1,) * batch_ndims posterior_samples = np.zeros(batch_shape) batch_size = int( chunk_size = batch_size if parallel else 1 return soft_vmap(single_loglik, posterior_samples, len(batch_shape), chunk_size)
@contextmanager def helpful_support_errors(site, raise_warnings=False): name = site["name"] support = getattr(site["fn"], "support", None) if isinstance(support, constraints.independent): support = support.base_constraint # Warnings if raise_warnings: if support is constraints.circular: msg = ( f"Continuous inference poorly handles circular sample site '{name}'. " + "Consider using VonMises distribution together with " + "a reparameterizer, e.g. " + f"numpyro.handlers.reparam(config={{'{name}': CircularReparam()}})." ) warnings.warn(msg, UserWarning, stacklevel=find_stack_level()) # Exceptions try: yield except NotImplementedError as e: support_name = repr(support).lower() if "integer" in support_name or "boolean" in support_name: # TODO: mention enumeration when it is supported in SVI raise ValueError( f"Continuous inference cannot handle discrete sample site '{name}'." ) if "sphere" in support_name: raise ValueError( f"Continuous inference cannot handle spherical sample site '{name}'. " "Consider using ProjectedNormal distribution together with " "a reparameterizer, e.g. " f"numpyro.handlers.reparam(config={{'{name}': ProjectedNormalReparam()}})." ) raise e from None