Source code for numpyro.contrib.stochastic_support.sdvi

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import namedtuple

import jax
import jax.numpy as jnp

from numpyro.contrib.stochastic_support.dcc import StochasticSupportInference
from numpyro.handlers import condition
from numpyro.infer import (
    SVI,
    Trace_ELBO,
    TraceEnum_ELBO,
    TraceGraph_ELBO,
    TraceMeanField_ELBO,
)
from numpyro.infer.autoguide import AutoNormal

SDVIResult = namedtuple("SDVIResult", ["guides", "slp_weights"])

VALID_ELBOS = (Trace_ELBO, TraceMeanField_ELBO, TraceEnum_ELBO, TraceGraph_ELBO)


[docs] class SDVI(StochasticSupportInference): """ Implements the Support Decomposition Variational Inference (SDVI) algorithm for models with stochastic support from [1]. This implementation creates a separate guide for each SLP, trains the guides separately, and then combines the guides by weighting them proportional to their ELBO estimates. **References:** 1. *Rethinking Variational Inference for Probabilistic Programs with Stochastic Support*, Tim Reichelt, Luke Ong, Tom Rainforth **Example:** .. code-block:: python def model(): model1 = numpyro.sample("model1", dist.Bernoulli(0.5), infer={"branching": True}) if model1 == 0: mean = numpyro.sample("a1", dist.Normal(0.0, 1.0)) else: mean = numpyro.sample("a2", dist.Normal(1.0, 1.0)) numpyro.sample("obs", dist.Normal(mean, 1.0), obs=0.2) sdvi = SDVI(model, numpyro.optim.Adam(step_size=0.001)) sdvi_result = sdvi.run(random.PRNGKey(0)) :param model: Python callable containing Pyro primitives :mod:`~numpyro.primitives`. :param optimizer: An instance of :class:`~numpyro.optim._NumpyroOptim`, a ``jax.example_libraries.optimizers.Optimizer`` or an Optax ``GradientTransformation``. Gets passed to :class:`~numpyro.infer.SVI`. :param int svi_num_steps: Number of steps to run SVI for each SLP. :param int combine_elbo_particles: Number of particles to estimate ELBO for computing SLP weights. :param guide_init: A constructor for the guide. This should be a callable that returns a :class:`~numpyro.infer.autoguide.AutoGuide` instance. Defaults to :class:`~numpyro.infer.autoguide.AutoNormal`. :param loss: ELBO loss for SVI. Defaults to :class:`~numpyro.infer.Trace_ELBO`. :param bool svi_progress_bar: Whether to use a progress bar for SVI. :param int num_slp_samples: Number of samples to draw from the prior to discover the straight-line programs (SLPs). :param int max_slps: Maximum number of SLPs to discover. DCC will not run inference on more than `max_slps`. """ def __init__( self, model, optimizer, svi_num_steps=1000, combine_elbo_particles=1000, guide_init=AutoNormal, loss=Trace_ELBO(), svi_progress_bar=False, num_slp_samples=1000, max_slps=124, ): self.guide_init = guide_init self.optimizer = optimizer self.svi_num_steps = svi_num_steps self.svi_progress_bar = svi_progress_bar if not isinstance(loss, VALID_ELBOS): err_str = ", ".join(x.__name__ for x in VALID_ELBOS) raise ValueError(f"loss must be an instance of: ({err_str})") self.loss = loss self.combine_elbo_particles = combine_elbo_particles super().__init__(model, num_slp_samples, max_slps) def _run_inference(self, rng_key, branching_trace, *args, **kwargs): """ Run SVI on a given SLP defined by its branching trace. """ slp_model = condition(self.model, branching_trace) guide = self.guide_init(slp_model) svi = SVI(slp_model, guide, self.optimizer, loss=self.loss) svi_result = svi.run( rng_key, self.svi_num_steps, *args, progress_bar=self.svi_progress_bar, **kwargs, ) return guide, svi_result.params def _combine_inferences(self, rng_key, guides, branching_traces, *args, **kwargs): """Weight each SLP proportional to its estimated ELBO.""" elbos = {} for bt, (guide, param_map) in guides.items(): slp_model = condition(self.model, branching_traces[bt]) elbos[bt] = -Trace_ELBO(num_particles=self.combine_elbo_particles).loss( rng_key, param_map, slp_model, guide, *args, **kwargs ) normalizer = jax.scipy.special.logsumexp(jnp.array(list(elbos.values()))) slp_weights = {k: jnp.exp(v - normalizer) for k, v in elbos.items()} return SDVIResult(guides, slp_weights)