# Example: Nested Sampling for Gaussian Shells¶

This example illustrates the usage of the contrib class NestedSampler, which is a wrapper of jaxns library ([1]) to be used for NumPyro models.

Here we will replicate the Gaussian Shells demo at [2] and compare against NUTS sampler.

References:

```import argparse

import matplotlib.pyplot as plt

from jax import random
import jax.numpy as jnp

import numpyro
from numpyro.contrib.nested_sampling import NestedSampler
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, DiscreteHMCGibbs

class GaussianShell(dist.Distribution):
support = dist.constraints.real_vector

super().__init__(batch_shape=loc.shape[:-1], event_shape=loc.shape[-1:])

def sample(self, key, sample_shape=()):
return jnp.zeros(
sample_shape + self.shape()
)  # a dummy sample to initialize the samplers

def log_prob(self, value):
normalizer = (-0.5) * (jnp.log(2.0 * jnp.pi) + 2.0 * jnp.log(self.width))
d = jnp.linalg.norm(value - self.loc, axis=-1)
return normalizer - 0.5 * ((d - self.radius) / self.width) ** 2

def model(center1, center2, radius, width, enum=False):
z = numpyro.sample(
"z", dist.Bernoulli(0.5), infer={"enumerate": "parallel"} if enum else {}
)
x = numpyro.sample("x", dist.Uniform(-6.0, 6.0).expand([2]).to_event(1))
center = jnp.stack([center1, center2])[z]

def run_inference(args, data):
print("=== Performing Nested Sampling ===")
ns = NestedSampler(model)
ns.run(random.PRNGKey(0), **data, enum=args.enum)
ns.print_summary()
# samples obtained from nested sampler are weighted, so
# we need to provide random key to resample from those weighted samples
ns_samples = ns.get_samples(random.PRNGKey(1), num_samples=args.num_samples)

print("\n=== Performing MCMC Sampling ===")
if args.enum:
mcmc = MCMC(
NUTS(model), num_warmup=args.num_warmup, num_samples=args.num_samples
)
else:
mcmc = MCMC(
DiscreteHMCGibbs(NUTS(model)),
num_warmup=args.num_warmup,
num_samples=args.num_samples,
)
mcmc.run(random.PRNGKey(2), **data, enum=args.enum)
mcmc.print_summary()
mcmc_samples = mcmc.get_samples()

return ns_samples["x"], mcmc_samples["x"]

def main(args):
data = dict(
width=0.1,
center1=jnp.array([-3.5, 0.0]),
center2=jnp.array([3.5, 0.0]),
)
ns_samples, mcmc_samples = run_inference(args, data)

# plotting
fig, (ax1, ax2) = plt.subplots(
2, 1, sharex=True, figsize=(8, 8), constrained_layout=True
)

ax1.plot(mcmc_samples[:, 0], mcmc_samples[:, 1], "ro", alpha=0.2)
ax1.set(
xlim=(-6, 6),
ylim=(-2.5, 2.5),
ylabel="x[1]",
title="Gaussian-shell samples using NUTS",
)

ax2.plot(ns_samples[:, 0], ns_samples[:, 1], "ro", alpha=0.2)
ax2.set(
xlim=(-6, 6),
ylim=(-2.5, 2.5),
xlabel="x[0]",
ylabel="x[1]",
title="Gaussian-shell samples using Nested Sampler",
)

plt.savefig("gaussian_shells_plot.pdf")

if __name__ == "__main__":
assert numpyro.__version__.startswith("0.15.0")
parser = argparse.ArgumentParser(description="Nested sampler for Gaussian shells")
"--enum",
action="store_true",
default=False,
help="whether to enumerate over the discrete latent variable",
)
parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".')
args = parser.parse_args()

numpyro.set_platform(args.device)

main(args)
```

Gallery generated by Sphinx-Gallery