Contributed Code

Nested Sampling

class NestedSampler(model, *, num_live_points=1000, max_samples=100000, sampler_name='slice', depth=5, num_slices=5, termination_frac=0.01)[source]

Bases: object

(EXPERIMENTAL) A wrapper for jaxns , a nested sampling package based on JAX.

See reference [1] for details on the meaning of each parameter. Please consider citing this reference if you use the nested sampler in your research.


To enumerate over a discrete latent variable, you can add the keyword infer={“enumerate”: “parallel”} to the corresponding sample statement.


To improve the performance, please consider enabling x64 mode at the beginning of your NumPyro program numpyro.enable_x64().


  1. JAXNS: a high-performance nested sampling package based on JAX, Joshua G. Albert (
  • model (callable) – a call with NumPyro primitives
  • num_live_points (int) – the number of live points. As a rule-of-thumb, we should allocate around 50 live points per possible mode.
  • max_samples (int) – the maximum number of iterations and samples
  • sampler_name (str) – either “slice” (default value) or “multi_ellipsoid”
  • depth (int) – an integer which determines the maximum number of ellipsoids to construct via hierarchical splitting (typical range: 3 - 9, default to 5)
  • num_slices (int) – the number of slice sampling proposals at each sampling step (typical range: 1 - 5, default to 5)
  • termination_frac (float) – termination condition (typical range: 0.001 - 0.01) (default to 0.01).


>>> from jax import random
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.contrib.nested_sampling import NestedSampler

>>> true_coefs = jnp.array([1., 2., 3.])
>>> data = random.normal(random.PRNGKey(0), (2000, 3))
>>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(1))
>>> def model(data, labels):
...     coefs = numpyro.sample('coefs', dist.Normal(0, 1).expand([3]))
...     intercept = numpyro.sample('intercept', dist.Normal(0., 10.))
...     return numpyro.sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)),
...                           obs=labels)
>>> ns = NestedSampler(model)
>>>, data, labels)
>>> samples = ns.get_samples(random.PRNGKey(3), num_samples=1000)
>>> assert jnp.mean(jnp.abs(samples['intercept'])) < 0.05
>>> print(jnp.mean(samples['coefs'], axis=0))  
[0.93661342 1.95034876 2.86123884]
run(rng_key, *args, **kwargs)[source]

Run the nested samplers and collect weighted samples.

  • rng_key (random.PRNGKey) – Random number generator key to be used for the sampling.
  • args – The arguments needed by the model.
  • kwargs – The keyword arguments needed by the model.
get_samples(rng_key, num_samples)[source]

Draws samples from the weighted samples collected from the run.

  • rng_key (random.PRNGKey) – Random number generator key to be used to draw samples.
  • num_samples (int) – The number of samples.

a dict of posterior samples


Gets weighted samples and their corresponding log weights.


Print summary of the result. This is a wrapper of jaxns.utils.summary().


Plot diagnostics of the result. This is a wrapper of jaxns.plotting.plot_diagnostics() and jaxns.plotting.plot_cornerplot().