Contributed Code

Nested Sampling

class NestedSampler(model, *, constructor_kwargs=None, termination_kwargs=None)[source]

Bases: object

(EXPERIMENTAL) A wrapper for jaxns , a nested sampling package based on JAX.

See reference [1] for details on the meaning of each parameter. Please consider citing this reference if you use the nested sampler in your research.


To enumerate over a discrete latent variable, you can add the keyword infer={“enumerate”: “parallel”} to the corresponding sample statement.


To improve the performance, please consider enabling x64 mode at the beginning of your NumPyro program numpyro.enable_x64().


  1. JAXNS: a high-performance nested sampling package based on JAX, Joshua G. Albert (

  • model (callable) – a call with NumPyro primitives

  • constructor_kwargs (dict) – additional keyword arguments to construct an upstream jaxns.NestedSampler instance.

  • termination_kwargs (dict) – keyword arguments to terminate the sampler. Please refer to the upstream jaxns.NestedSampler.__call__() method.


>>> from jax import random
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.contrib.nested_sampling import NestedSampler

>>> true_coefs = jnp.array([1., 2., 3.])
>>> data = random.normal(random.PRNGKey(0), (2000, 3))
>>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(1))
>>> def model(data, labels):
...     coefs = numpyro.sample('coefs', dist.Normal(0, 1).expand([3]))
...     intercept = numpyro.sample('intercept', dist.Normal(0., 10.))
...     return numpyro.sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)),
...                           obs=labels)
>>> ns = NestedSampler(model)
>>>, data, labels)
>>> samples = ns.get_samples(random.PRNGKey(3), num_samples=1000)
>>> assert jnp.mean(jnp.abs(samples['intercept'])) < 0.05
>>> print(jnp.mean(samples['coefs'], axis=0))  
[0.93661342 1.95034876 2.86123884]
run(rng_key, *args, **kwargs)[source]

Run the nested samplers and collect weighted samples.

  • rng_key (random.PRNGKey) – Random number generator key to be used for the sampling.

  • args – The arguments needed by the model.

  • kwargs – The keyword arguments needed by the model.

get_samples(rng_key, num_samples)[source]

Draws samples from the weighted samples collected from the run.

  • rng_key (random.PRNGKey) – Random number generator key to be used to draw samples.

  • num_samples (int) – The number of samples.


a dict of posterior samples


Gets weighted samples and their corresponding log weights.


Print summary of the result. This is a wrapper of jaxns.utils.summary().


Plot diagnostics of the result. This is a wrapper of jaxns.plotting.plot_diagnostics() and jaxns.plotting.plot_cornerplot().

Stein Variational Inference

Stein Variational Inference (SteinVI) is a family of VI techniques for approximate Bayesian inference based on Stein’s method (see [1] for an overview). It is gaining popularity as it combines the scalability of traditional VI with the flexibility of non-parametric particle-based methods.

Stein variational gradient descent (SVGD) [2] is a recent SteinVI technique which uses iteratively moves a set of particles \(\{z_i\}_{i=1}^N\) to approximate a distribution p(z). SVGD is well suited for capturing correlations between latent variables as a particle-based method. The technique preserves the scalability of traditional VI approaches while offering the flexibility and modeling scope of methods such as Markov chain Monte Carlo (MCMC). SVGD is good at capturing multi-modality [3][4].

numpyro.contrib.einstein is a framework for particle-based inference using the Stein mixture algorithm. The framework works on Stein mixtures, a restricted mixture of guide programs parameterized by Stein particles. Similarly to how SVGD works, Stein mixtures can approximate model posteriors by moving the Stein particles according to the Stein forces. Because the Stein particles parameterize a guide, they capture a neighborhood rather than a single point.

numpyro.contrib.einstein mimics the interface from numpyro.infer.svi, so trying SteinVI requires minimal change to the code for existing models inferred with SVI. For primary usage, see the Bayesian neural network example.

The framework currently supports several kernels, including:

  • RBFKernel

  • LinearKernel

  • RandomFeatureKernel

  • MixtureKernel

  • GraphicalKernel

  • ProbabilityProductKernel

For example, usage see:


1. Stein’s Method Meets Statistics: A Review of Some Recent Developments (2021) Andreas Anastasiou, Alessandro Barp, François-Xavier Briol, Bruno Ebner, Robert E. Gaunt, Fatemeh Ghaderinezhad, Jackson Gorham, Arthur Gretton, Christophe Ley, Qiang Liu, Lester Mackey, Chris. J. Oates, Gesine Reinert, Yvik Swan.

2. Stein variational gradient descent: A general-purpose Bayesian inference algorithm (2016) Qiang Liu, Dilin Wang. NeurIPS

3. Nonlinear Stein Variational Gradient Descent for Learning Diversified Mixture Models (2019) Dilin Wang, Qiang Liu. PMLR

SteinVI Interface

class SteinVI(model:, guide:, optim: ~numpyro.optim._NumPyroOptim, kernel_fn: ~numpyro.contrib.einstein.stein_kernels.SteinKernel, num_stein_particles: int = 10, num_elbo_particles: int = 10, loss_temperature: float = 1.0, repulsion_temperature: float = 1.0, non_mixture_guide_params_fn:[[str], bool] = <function SteinVI.<lambda>>, enum=True, **static_kwargs)[source]

Variational inference with Stein mixtures.


>>> from jax import random
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.distributions import constraints
>>> from numpyro.contrib.einstein import MixtureGuidePredictive, SteinVI, RBFKernel

>>> def model(data):
...     f = numpyro.sample("latent_fairness", dist.Beta(10, 10))
...     with numpyro.plate("N", data.shape[0] if data is not None else 10):
...         numpyro.sample("obs", dist.Bernoulli(f), obs=data)

>>> def guide(data):
...     alpha_q = numpyro.param("alpha_q", 15., constraint=constraints.positive)
...     beta_q = numpyro.param("beta_q", lambda rng_key: random.exponential(rng_key),
...                            constraint=constraints.positive)
...     numpyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q))

>>> data = jnp.concatenate([jnp.ones(6), jnp.zeros(4)])
>>> optimizer = numpyro.optim.Adam(step_size=0.0005)
>>> stein = SteinVI(model, guide, optimizer, kernel_fn=RBFKernel())
>>> stein_result =, 2000, data)
>>> params = stein_result.params
>>> # use guide to make predictive
>>> predictive = MixtureGuidePredictive(model, guide, params, num_samples=1000, guide_sites=stein.guide_sites)
>>> samples = predictive(random.PRNGKey(1), data=None)
  • model (Callable) – Python callable with Pyro primitives for the model.

  • guide – Python callable with Pyro primitives for the guide (recognition network).

  • optim (_NumPyroOptim) – An instance of _NumpyroOptim.

  • kernel_fn (SteinKernel) – Function that produces a logarithm of the statistical kernel to use with Stein mixture inference.

  • num_stein_particles – Number of particles (i.e., mixture components) in the Stein mixture.

  • num_elbo_particles – Number of Monte Carlo draws used to approximate the attractive force gradient. (More particles give better gradient approximations)

  • loss_temperature (Float) – Scaling factor of the attractive force.

  • repulsion_temperature (Float) – Scaling factor of the repulsive force (Non-linear Stein)

  • non_mixture_guide_param_fn (Callable) – predicate on names of parameters in guide which should be optimized classically without Stein (E.g. parameters for large normal networks or other transformation)

  • static_kwargs – Static keyword arguments for the model / guide, i.e. arguments that remain constant during inference.

SteinVI Kernels

class RBFKernel(mode='norm', matrix_mode='norm_diag', bandwidth_factor:[[float], float] = <function RBFKernel.<lambda>>)[source]

Calculates the Gaussian RBF kernel function, from [1], \(k(x,y) = \exp(\frac{1}{h} \|x-y\|^2)\), where the bandwidth h is computed using the median heuristic \(h = \frac{1}{\log(n)} \text{med}(\|x-y\|)\).


  1. Stein Variational Gradient Descent by Liu and Wang

  • mode (str) – Either ‘norm’ (default) specifying to take the norm of each particle, ‘vector’ to return a component-wise kernel or ‘matrix’ to return a matrix-valued kernel

  • matrix_mode (str) – Either ‘norm_diag’ (default) for diagonal filled with the norm kernel or ‘vector_diag’ for diagonal of vector-valued kernel

  • bandwidth_factor – A multiplier to the bandwidth based on data size n (default 1/log(n))

class LinearKernel(mode='norm')[source]

Calculates the linear kernel \(k(x,y) = x \cdot y + 1\) from [1].


  1. Stein Variational Gradient Descent as Moment Matching by Liu and Wang

class RandomFeatureKernel(mode='norm', bandwidth_subset=None, bandwidth_factor:[[float], float] = <function RandomFeatureKernel.<lambda>>)[source]

Calculates the random kernel \(k(x,y)= 1/m\sum_{l=1}^{m}\phi(x,w_l)\phi(y,w_l)\) from [1].


  1. Stein Variational Gradient Descent as Moment Matching by Liu and Wang

  • bandwidth_subset – How many particles should be used to calculate the bandwidth? (default None, meaning all particles)

  • random_indices – The set of indices which to do random feature expansion on. (default None, meaning all indices)

  • bandwidth_factor – A multiplier to the bandwidth based on data size n (default 1/log(n))

class MixtureKernel(ws: list[float], kernel_fns: list[SteinKernel], mode='norm')[source]

Calculates a mixture of multiple kernels \(k(x,y) = \sum_i w_ik_i(x,y)\)


  1. Stein Variational Gradient Descent as Moment Matching by Liu and Wang

  • ws – Weight of each kernel in the mixture

  • kernel_fns – Different kernel functions to mix together

class GraphicalKernel(mode='matrix', local_kernel_fns: dict[str, ~numpyro.contrib.einstein.stein_kernels.SteinKernel] | None = None, default_kernel_fn: ~numpyro.contrib.einstein.stein_kernels.SteinKernel = <numpyro.contrib.einstein.stein_kernels.RBFKernel object>)[source]

Calculates graphical kernel \(k(x,y) = diag({K_l(x_l,y_l)})\) for local kernels \(K_l\) from [1][2].


  1. Stein Variational Message Passing for Continuous Graphical Models by Wang, Zheng, and Liu

  2. Stein Variational Gradient Descent with Matrix-Valued Kernels by Wang, Tang, Bajaj, and Liu

  • local_kernel_fns – A mapping between parameters and a choice of kernel function for that parameter (default to default_kernel_fn for each parameter)

  • default_kernel_fn – The default choice of kernel function when none is specified for a particular parameter

class ProbabilityProductKernel(guide, scale=1.0)[source]

Stochastic Support

class DCC(model, mcmc_kwargs, kernel_cls=<class 'numpyro.infer.hmc.NUTS'>, num_slp_samples=1000, max_slps=124, proposal_scale=1.0)[source]

Bases: object

Implements the Divide, Conquer, and Combine (DCC) algorithm for models with stochastic support from [1].


This implementation assumes that all stochastic branching is done based on the outcomes of discrete sampling sites that are annotated with infer={“branching”: True}. For example,

def model():
    model1 = numpyro.sample("model1", dist.Bernoulli(0.5), infer={"branching": True})
    if model1 == 0:
        mean = numpyro.sample("a1", dist.Normal(0.0, 1.0))
        mean = numpyro.sample("a2", dist.Normal(1.0, 1.0))
    numpyro.sample("obs", dist.Normal(mean, 1.0), obs=0.2)


  1. Divide, Conquer, and Combine: a New Inference Strategy for Probabilistic Programs with Stochastic Support, Yuan Zhou, Hongseok Yang, Yee Whye Teh, Tom Rainforth

  • model – Python callable containing Pyro primitives primitives.

  • mcmc_kwargs (dict) – Dictionary of arguments passed to MCMC.

  • kernel_cls (numpyro.infer.mcmc.MCMCKernel) – MCMC kernel class that is used for local inference. Defaults to NUTS.

  • num_slp_samples (int) – Number of samples to draw from the prior to discover the straight-line programs (SLPs).

  • max_slps (int) – Maximum number of SLPs to discover. DCC will not run inference on more than max_slps.

  • proposal_scale (float) – Scale parameter for the proposal distribution for estimating the normalization constant of an SLP.

run(rng_key, *args, **kwargs)[source]

Run DCC and collect samples for all SLPs.

  • rng_key (jax.random.PRNGKey) – Random number generator key.

  • args – Arguments to the model.

  • kwargs – Keyword arguments to the model.