Interactive online version: Open In Colab

Truncated and folded distributions

This tutorial will cover how to work with truncated and folded distributions in NumPyro. It is assumed that you’re already familiar with the basics of NumPyro. To get the most out of this tutorial you’ll need some background in probability.

Table of contents

Setup

To run this notebook, we are going to need the following imports

[ ]:
!pip install -q git+https://github.com/pyro-ppl/numpyro.git
[2]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import poisson as sp_poisson

import jax
from jax import lax, random
import jax.numpy as jnp
from jax.scipy.special import ndtri
from jax.scipy.stats import norm, poisson

import numpyro
import numpyro.distributions as dist
from numpyro.distributions import (
    Distribution,
    FoldedDistribution,
    SoftLaplace,
    StudentT,
    TruncatedDistribution,
    TruncatedNormal,
    constraints,
)
from numpyro.distributions.util import promote_shapes
from numpyro.infer import MCMC, NUTS, DiscreteHMCGibbs, Predictive

numpyro.enable_x64()
RNG = random.PRNGKey(0)
PRIOR_RNG, MCMC_RNG, PRED_RNG = random.split(RNG, 3)
MCMC_KWARGS = dict(
    num_warmup=2000,
    num_samples=2000,
    num_chains=4,
    chain_method="sequential",
)

1. What are truncated distributions?

The support of a probability distribution is the set of values in the domain with non-zero probability. For example, the support of the normal distribution is the whole real line (even if the density gets very small as we move away from the mean, technically speaking, it is never quite zero). The support of the uniform distribution, as coded in jax.random.uniform with the default arguments, is the interval \(\left[0, 1)\right.\), because any value outside of that interval has zero probability. The support of the Poisson distribution is the set of non-negative integers, etc.

Truncating a distribution makes its support smaller so that any value outside our desired domain has zero probability. In practice, this can be useful for modelling situations in which certain biases are introduced during data collection. For example, some physical detectors only get triggered when the signal is above some minimum threshold, or sometimes the detectors fail if the signal exceeds a certain value. As a result, the observed values are constrained to be within a limited range of values, even though the true signal does not have the same constraints. See, for example, section 3.1 of Information Theory and Learning Algorithms by David Mackay. Naively, if \(S\) is the support of the original density \(p_Y(y)\), then by truncating to a new support \(T\subset S\) we are effectively defining a new random variable \(Z\) for which the density is

\[\begin{split}\begin{align} p_Z(z) \propto \begin{cases} p_Y(z) & \text{if $z$ is in $T$}\\ 0 & \text{if $z$ is outside $T$}\\ \end{cases} \end{align}\end{split}\]

The reason for writing a \(\propto\) (proportional to) sign instead of a strict equation is that, defined in the above way, the resulting function does not integrate to \(1\) and so it cannot be strictly considered a probability density. To make it into a probability density we need to re-distribute the truncated mass among the part of the distribution that remains. To do this, we simply re-weight every point by the same constant:

\[\begin{split}\begin{align} p_Z(z) = \begin{cases} \frac{1}{M}p_Y(z) & \text{if $z$ is in $T$}\\ 0 & \text{if $z$ is outside $T$}\\ \end{cases} \end{align}\end{split}\]

where \(M = \int_T p_Y(y)\mathrm{d}y\).

In practice, the truncation is often one-sided. This means that if, for example, the support before truncation is the interval \((a, b)\), then the support after truncation is of the form \((a, c)\) or \((c, b)\), with \(a < c < b\). The figure below illustrates a left-sided truncation at zero of a normal distribution \(N(1, 1)\).

truncated

The original distribution (left side) is truncated at the vertical dotted line. The truncated mass (orange region) is redistributed in the new support (right side image) so that the total area under the curve remains equal to 1 even after truncation. This method of re-weighting ensures that the density ratio between any two points, \(p(a)/p(b)\) remains the same before and after the reweighting is done (as long as the points are inside the new support, of course).

Note: Truncated data is different from censored data. Censoring also hides values that are outside some desired support but, contrary to truncated data, we know when a value has been censored. The typical example is the household scale which does not report values above 300 pounds. Censored data will not be covered in this tutorial.

2. What is a folded distribution?

Folding is achieved by taking the absolute value of a random variable, \(Z = \lvert Y \rvert\). This obviously modifies the support of the original distribution since negative values now have zero probability:

\[\begin{split}\begin{align} p_Z(z) = \begin{cases} p_Y(z) + p_Y(-z) & \text{if $z\ge 0$}\\ 0 & \text{if $z\lt 0$}\\ \end{cases} \end{align}\end{split}\]

The figure below illustrates a folded normal distribution \(N(1, 1)\).

folded

As you can see, the resulting distribution is different from the truncated case. In particular, the density ratio between points, \(p(a)/p(b)\), is in general not the same after folding. For some examples in which folding is relevant see references 3 and 4

If the original distribution is symmetric around zero, then folding and truncating at zero have the same effect.

3. Sampling from truncated and folded distributions

Truncated distributions

Usually, we already have a sampler for the pre-truncated distribution (e.g. np.random.normal). So, a seemingly simple way of generating samples from the truncated distribution would be to sample from the original distribution, and then discard the samples that are outside the desired support. For example, if we wanted samples from a normal distribution truncated to the support \((-\infty, 1)\), we’d simply do:

upper = 1
samples = np.random.normal(size=1000)
truncated_samples = samples[samples < upper]

This is called rejection sampling but it is not very efficient. If the region we truncated had a sufficiently high probability mass, then we’d be discarding a lot of samples and it might be a while before we accumulate sufficient samples for the truncated distribution. For example, the above snippet would only result in approximately 840 truncated samples even though we initially drew 1000. This can easily get a lot worse for other combinations of parameters. A more efficient approach is to use a method known as inverse transform sampling. In this method, we first sample from a uniform distribution in (0, 1) and then transform those samples with the inverse cumulative distribution of our truncated distribution. This method ensures that no samples are wasted in the process, though it does have the slight complication that we need to calculate the inverse CDF (ICDF) of our truncated distribution. This might sound too complicated at first but, with a bit of algebra, we can often calculate the truncated ICDF in terms of the untruncated ICDF. The untruncated ICDF for many distributions is already available.

Folded distributions

This case is a lot simpler. Since we already have a sampler for the pre-folded distribution, all we need to do is to take the absolute value of those samples:

samples = np.random.normal(size=1000)
folded_samples = np.abs(samples)

4. Ready to use truncated and folded distributions

The later sections in this tutorial will show you how to construct your own truncated and folded distributions, but you don’t have to reinvent the wheel. NumPyro has a bunch of truncated distributions already implemented.

Suppose, for example, that you want a normal distribution truncated on the right. For that purpose, we use the TruncatedNormal distribution. The parameters of this distribution are loc and scale, corresponding to the loc and scale of the untruncated normal, and low and/or high corresponding to the truncation points. Importantly, the low and high are keyword only arguments, only loc and scale are valid as positional arguments. This is how you can use this class in a model:

[3]:
def truncated_normal_model(num_observations, high, x=None):
    loc = numpyro.sample("loc", dist.Normal())
    scale = numpyro.sample("scale", dist.LogNormal())
    with numpyro.plate("observations", num_observations):
        numpyro.sample("x", TruncatedNormal(loc, scale, high=high), obs=x)

Let’s now check that we can use this model in a typical MCMC workflow.

Prior simulation

[4]:
high = 1.2
num_observations = 250
num_prior_samples = 100

prior = Predictive(truncated_normal_model, num_samples=num_prior_samples)
prior_samples = prior(PRIOR_RNG, num_observations, high)

Inference

To test our model, we run mcmc against some synthetic data. The synthetic data can be any arbitrary sample from the prior simulation.

[5]:
# -- select an arbitrary prior sample as true data
true_idx = 0
true_loc = prior_samples["loc"][true_idx]
true_scale = prior_samples["scale"][true_idx]
true_x = prior_samples["x"][true_idx]
[6]:
plt.hist(true_x.copy(), bins=20)
plt.axvline(high, linestyle=":", color="k")
plt.xlabel("x")
plt.show()
../_images/tutorials_truncated_distributions_12_0.png
[7]:
# --- Run MCMC and check estimates and diagnostics
mcmc = MCMC(NUTS(truncated_normal_model), **MCMC_KWARGS)
mcmc.run(MCMC_RNG, num_observations, high, true_x)
mcmc.print_summary()

# --- Compare to ground truth
print(f"True loc  : {true_loc:3.2}")
print(f"True scale: {true_scale:3.2}")
sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:02<00:00, 1909.24it/s, 1 steps of size 5.65e-01. acc. prob=0.93]
sample: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 10214.14it/s, 3 steps of size 5.16e-01. acc. prob=0.95]
sample: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 15102.95it/s, 1 steps of size 6.42e-01. acc. prob=0.90]
sample: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 16522.03it/s, 3 steps of size 6.39e-01. acc. prob=0.90]

                mean       std    median      5.0%     95.0%     n_eff     r_hat
       loc     -0.58      0.15     -0.59     -0.82     -0.35   2883.69      1.00
     scale      1.49      0.11      1.48      1.32      1.66   3037.78      1.00

Number of divergences: 0
True loc  : -0.56
True scale: 1.4

Removing the truncation

Once we have inferred the parameters of our model, a common task is to understand what the data would look like without the truncation. In this example, this is easily done by simply “pushing” the value of high to infinity.

[8]:
pred = Predictive(truncated_normal_model, posterior_samples=mcmc.get_samples())
pred_samples = pred(PRED_RNG, num_observations, high=float("inf"))

Let’s finally plot these samples and compare them to the original, observed data.

[9]:
# thin the samples to not saturate matplotlib
samples_thinned = pred_samples["x"].ravel()[::1000]
[10]:
f, axes = plt.subplots(1, 2, figsize=(15, 5), sharex=True)

axes[0].hist(
    samples_thinned.copy(), label="Untruncated posterior", bins=20, density=True
)
axes[0].set_title("Untruncated posterior")

vals, bins, _ = axes[1].hist(
    samples_thinned[samples_thinned < high].copy(),
    label="Tail of untruncated posterior",
    bins=10,
    density=True,
)
axes[1].hist(
    true_x.copy(), bins=bins, label="Observed, truncated data", density=True, alpha=0.5
)
axes[1].set_title("Comparison to observed data")

for ax in axes:
    ax.axvline(high, linestyle=":", color="k", label="Truncation point")
    ax.legend()

plt.show()
../_images/tutorials_truncated_distributions_19_0.png

The plot on the left shows data simulated from the posterior distribution with the truncation removed, so we are able to see how the data would look like if it were not truncated. To sense check this, we discard the simulated samples that are above the truncation point and make histogram of those and compare it to a histogram of the true data (right plot).

The TruncatedDistribution class

The source code for the TruncatedNormal in NumPyro uses a class called TruncatedDistribution which abstracts away the logic for sample and log_prob that we will discuss in the next sections. At the moment, though, this logic only works continuous, symmetric distributions with real support.

We can use this class to quickly construct other truncated distributions. For example, if we need a truncated SoftLaplace we can use the following pattern:

[11]:
def TruncatedSoftLaplace(
    loc=0.0, scale=1.0, *, low=None, high=None, validate_args=None
):
    return TruncatedDistribution(
        base_dist=SoftLaplace(loc, scale),
        low=low,
        high=high,
        validate_args=validate_args,
    )
[12]:
def truncated_soft_laplace_model(num_observations, high, x=None):
    loc = numpyro.sample("loc", dist.Normal())
    scale = numpyro.sample("scale", dist.LogNormal())
    with numpyro.plate("obs", num_observations):
        numpyro.sample("x", TruncatedSoftLaplace(loc, scale, high=high), obs=x)

And, as before, we check that we can use this model in the steps of a typical workflow:

[13]:
high = 2.3
num_observations = 200
num_prior_samples = 100

prior = Predictive(truncated_soft_laplace_model, num_samples=num_prior_samples)
prior_samples = prior(PRIOR_RNG, num_observations, high)

true_idx = 0
true_x = prior_samples["x"][true_idx]
true_loc = prior_samples["loc"][true_idx]
true_scale = prior_samples["scale"][true_idx]

mcmc = MCMC(
    NUTS(truncated_soft_laplace_model),
    **MCMC_KWARGS,
)

mcmc.run(
    MCMC_RNG,
    num_observations,
    high,
    true_x,
)

mcmc.print_summary()

print(f"True loc  : {true_loc:3.2}")
print(f"True scale: {true_scale:3.2}")
sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:02<00:00, 1745.70it/s, 1 steps of size 6.78e-01. acc. prob=0.93]
sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 9294.56it/s, 1 steps of size 7.02e-01. acc. prob=0.93]
sample: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 10412.30it/s, 1 steps of size 7.20e-01. acc. prob=0.92]
sample: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 10583.85it/s, 3 steps of size 7.01e-01. acc. prob=0.93]

                mean       std    median      5.0%     95.0%     n_eff     r_hat
       loc     -0.37      0.17     -0.38     -0.65     -0.10   4034.96      1.00
     scale      1.46      0.12      1.45      1.27      1.65   3618.77      1.00

Number of divergences: 0
True loc  : -0.56
True scale: 1.4

Important

The sample method of the TruncatedDistribution class relies on inverse-transform sampling. This has the implicit requirement that the base distribution should have an icdf method already available. If this is not the case, we will not be able to call the sample method on any instances of our distribution, nor use it with the Predictive class. However, the log_prob method only depends on the cdf method (which is more frequently available than the icdf). If the log_prob method is available, then we can use our distribution as prior/likelihood in a model.

The FoldedDistribution class

Similar to truncated distributions, NumPyro has the FoldedDistribution class to help you quickly construct folded distributions. Popular examples of folded distributions are the so-called “half-normal”, “half-student” or “half-cauchy”. As the name suggests, these distributions keep only (the positive) half of the distribution. Implicit in the name of these “half” distributions is that they are centered at zero before folding. But, of course, you can fold a distribution even if its not centered at zero. For instance, this is how you would define a folded student-t distribution.

[14]:
def FoldedStudentT(df, loc=0.0, scale=1.0):
    return FoldedDistribution(StudentT(df, loc=loc, scale=scale))
[15]:
def folded_student_model(num_observations, x=None):
    df = numpyro.sample("df", dist.Gamma(6, 2))
    loc = numpyro.sample("loc", dist.Normal())
    scale = numpyro.sample("scale", dist.LogNormal())
    with numpyro.plate("obs", num_observations):
        numpyro.sample("x", FoldedStudentT(df, loc, scale), obs=x)

And we check that we can use our distribution in a typical workflow:

[16]:
# --- prior sampling
num_observations = 500
num_prior_samples = 100
prior = Predictive(folded_student_model, num_samples=num_prior_samples)
prior_samples = prior(PRIOR_RNG, num_observations)


# --- choose any prior sample as the ground truth
true_idx = 0
true_df = prior_samples["df"][true_idx]
true_loc = prior_samples["loc"][true_idx]
true_scale = prior_samples["scale"][true_idx]
true_x = prior_samples["x"][true_idx]

# --- do inference with MCMC
mcmc = MCMC(
    NUTS(folded_student_model),
    **MCMC_KWARGS,
)
mcmc.run(MCMC_RNG, num_observations, true_x)

# --- Check diagostics
mcmc.print_summary()

# --- Compare to ground truth:
print(f"True df   : {true_df:3.2f}")
print(f"True loc  : {true_loc:3.2f}")
print(f"True scale: {true_scale:3.2f}")
sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:02<00:00, 1343.54it/s, 7 steps of size 3.51e-01. acc. prob=0.75]
sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:01<00:00, 3644.99it/s, 7 steps of size 3.56e-01. acc. prob=0.73]
sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:01<00:00, 3137.13it/s, 7 steps of size 2.62e-01. acc. prob=0.91]
sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:01<00:00, 3028.93it/s, 7 steps of size 1.85e-01. acc. prob=0.96]

                mean       std    median      5.0%     95.0%     n_eff     r_hat
        df      3.12      0.52      3.07      2.30      3.97   2057.60      1.00
       loc     -0.02      0.88     -0.03     -1.28      1.34    925.84      1.01
     scale      2.23      0.21      2.25      1.89      2.57   1677.38      1.00

Number of divergences: 33
True df   : 3.01
True loc  : 0.37
True scale: 2.41

5. Building your own truncated distribution

If the TruncatedDistribution and FoldedDistribution classes are not sufficient to solve your problem, you might want to look into writing your own truncated distribution from the ground up. This can be a tedious process, so this section will give you some guidance and examples to help you with it.

5.1 Recap of NumPyro distributions

A NumPyro distribution should subclass Distribution and implement a few basic ingredients:

Class attributes

The class attributes serve a few different purposes. Here we will mainly care about two: 1. arg_constraints: Impose some requirements on the parameters of the distribution. Errors are raised at instantiation time if the parameters passed do not satisfy the constraints. 2. support: It is used in some inference algorithms like MCMC and SVI with auto-guides, where we need to perform the algorithm in the unconstrained space. Knowing the support, we can automatically reparametrize things under the hood.

We’ll explain other class attributes as we go.

The __init__ method

This is where we define the parameters of the distribution. We also use jax and lax to promote the parameters to shapes that are valid for broadcasting. The __init__ method of the parent class is also required because that’s where the validation of our parameters is done.

The log_prob method

Implementing the log_prob method ensures that we can do inference. As the name suggests, this method returns the logarithm of the density evaluated at the argument.

The sample method

This method is used for drawing independent samples from our distribution. It is particularly useful for doing prior and posterior predictive checks. Note, in particular, that this method is not needed if you only need to use your distribution as prior in a model - the log_prob method will suffice.

The place-holder code for any of our implementations can be written as

class MyDistribution(Distribution):
    # class attributes
    arg_constraints = {}
    support = None
    def __init__(self):
        pass

    def log_prob(self, value):
        pass

    def sample(self, key, sample_shape=()):
        pass

5.2 Example: Right-truncated normal

We are going to modify a normal distribution so that its new support is of the form (-inf, high), with high a real number. This could be done with the TruncatedNormal distribution but, for the sake of illustration, we are not going to rely on it. We’ll call our distribution RightTruncatedNormal. Let’s write the skeleton code and then proceed to fill in the blanks.

class RightTruncatedNormal(Distribution):
    # <class attributes>
    def __init__(self):
        pass

    def log_prob(self, value):
        pass

    def sample(self, key, sample_shape=()):
        pass

Class attributes

Remember that a non-truncated normal distribution is specified in NumPyro by two parameters, loc and scale, which correspond to the mean and standard deviation. Looking at the source code for the Normal distribution we see the following lines:

arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
support = constraints.real
reparametrized_params = ["loc", "scale"]

The reparametrized_params attribute is used by variational inference algorithms when constructing gradient estimators. The parameters of many common distributions with continuous support (e.g. the Normal distribution) are reparameterizable, while the parameters of discrete distributions are not. Note that reparametrized_params is irrelevant for MCMC algorithms like HMC. See SVI Part III for more details.

We must adapt these attributes to our case by including the "high" parameter, but there are two issues we need to deal with:

  1. constraints.real is a bit too restrictive. We’d like jnp.inf to be a valid value for high (equivalent to no truncation), but at the moment infinity is not a valid real number. We deal with this situation by defining our own constraint. The source code for constraints.real is easy to imitate:

class _RightExtendedReal(constraints.Constraint):
    """
    Any number in the interval (-inf, inf].
    """
    def __call__(self, x):
        return (x == x) & (x != float("-inf"))

    def feasible_like(self, prototype):
        return jnp.zeros_like(prototype)

right_extended_real = _RightExtendedReal()
  1. support can no longer be a class attribute as it will depend on the value of high. So instead we implement it as a dependent property.

Our distribution then looks as follows:

class RightTruncatedNormal(Distribution):
    arg_constraints = {
        "loc": constraints.real,
        "scale": constraints.positive,
        "high": right_extended_real,
    }
    reparametrized_params = ["loc", "scale", "high"]

    # ...

    @constraints.dependent_property
    def support(self):
        return constraints.lower_than(self.high)

The __init__ method

Once again we take inspiration from the source code for the normal distribution. The key point is the use of lax and jax to check the shapes of the arguments passed and make sure that such shapes are consistent for broadcasting. We follow the same pattern for our use case – all we need to do is include the high parameter.

In the source implementation of Normal, both parameters loc and scale are given defaults so that one recovers a standard normal distribution if no arguments are specified. In the same spirit, we choose float("inf") as a default for high which would be equivalent to no truncation.

# ...
    def __init__(self, loc=0.0, scale=1.0, high=float("inf"), validate_args=None):
        batch_shape = lax.broadcast_shapes(
            jnp.shape(loc),
            jnp.shape(scale),
            jnp.shape(high),
        )
        self.loc, self.scale, self.high = promote_shapes(loc, scale, high)
        super().__init__(batch_shape, validate_args=validate_args)
# ...

The log_prob method

For a truncated distribution, the log density is given by

\[\begin{split}\begin{align} \log p_Z(z) = \begin{cases} \log p_Y(z) - \log M & \text{if $z$ is in $T$}\\ -\infty & \text{if $z$ is outside $T$}\\ \end{cases} \end{align}\end{split}\]

where, again, \(p_Z\) is the density of the truncated distribution, \(p_Y\) is the density before truncation, and \(M = \int_T p_Y(y)\mathrm{d}y\). For the specific case of truncating the normal distribution to the interval (-inf, high), the constant \(M\) is equal to the cumulative density evaluated at the truncation point. We can easily implement this log-density method because jax.scipy.stats already has a norm module that we can use.

# ...
    def log_prob(self, value):
        log_m = norm.logcdf(self.high, self.loc, self.scale)
        log_p = norm.logpdf(value, self.loc, self.scale)
        return jnp.where(value < self.high, log_p - log_m, -jnp.inf)
# ...

The sample method

To implement the sample method using inverse-transform sampling, we need to also implement the inverse cumulative distribution function. For this, we can use the ndtri function that lives inside jax.scipy.special. This function returns the inverse cdf for the standard normal distribution. We can do a bit of algebra to obtain the inverse cdf of the truncated, non-standard normal. First recall that if \(X\sim Normal(0, 1)\) and \(Y = \mu + \sigma X\), then \(Y\sim Normal(\mu, \sigma)\). Then if \(Z\) is the truncated \(Y\), its cumulative density is given by:

\[\begin{align} F_Z(y) &= \int_{-\infty}^{y}p_Z(r)dr\newline &= \frac{1}{M}\int_{-\infty}^{y}p_Y(s)ds \quad\text{if $y < high$} \newline &= \frac{1}{M}F_Y(y) \end{align}\]

And so its inverse is

\[\begin{align} F_Z^{-1}(u) = \left(\frac{1}{M}F_Y\right)^{-1}(u) = F_Y^{-1}(M u) = F_{\mu + \sigma X}^{-1}(Mu) = \mu + \sigma F_X^{-1}(Mu) \end{align}\]

The translation of the above math into code is

# ...
    def sample(self, key, sample_shape=()):
        shape = sample_shape + self.batch_shape
        minval = jnp.finfo(jnp.result_type(float)).tiny
        u = random.uniform(key, shape, minval=minval)
        return self.icdf(u)


    def icdf(self, u):
        m = norm.cdf(self.high, self.loc, self.scale)
        return self.loc + self.scale * ndtri(m * u)

With everything in place, the final implementation is as below.

[17]:
class _RightExtendedReal(constraints.Constraint):
    """
    Any number in the interval (-inf, inf].
    """

    def __call__(self, x):
        return (x == x) & (x != float("-inf"))

    def feasible_like(self, prototype):
        return jnp.zeros_like(prototype)


right_extended_real = _RightExtendedReal()


class RightTruncatedNormal(Distribution):
    """
    A truncated Normal distribution.
    :param numpy.ndarray loc: location parameter of the untruncated normal
    :param numpy.ndarray scale: scale parameter of the untruncated normal
    :param numpy.ndarray high: point at which the truncation happens
    """

    arg_constraints = {
        "loc": constraints.real,
        "scale": constraints.positive,
        "high": right_extended_real,
    }
    reparametrized_params = ["loc", "scale", "high"]

    def __init__(self, loc=0.0, scale=1.0, high=float("inf"), validate_args=True):
        batch_shape = lax.broadcast_shapes(
            jnp.shape(loc),
            jnp.shape(scale),
            jnp.shape(high),
        )
        self.loc, self.scale, self.high = promote_shapes(loc, scale, high)
        super().__init__(batch_shape, validate_args=validate_args)

    def log_prob(self, value):
        log_m = norm.logcdf(self.high, self.loc, self.scale)
        log_p = norm.logpdf(value, self.loc, self.scale)
        return jnp.where(value < self.high, log_p - log_m, -jnp.inf)

    def sample(self, key, sample_shape=()):
        shape = sample_shape + self.batch_shape
        minval = jnp.finfo(jnp.result_type(float)).tiny
        u = random.uniform(key, shape, minval=minval)
        return self.icdf(u)

    def icdf(self, u):
        m = norm.cdf(self.high, self.loc, self.scale)
        return self.loc + self.scale * ndtri(m * u)

    @constraints.dependent_property
    def support(self):
        return constraints.less_than(self.high)

Let’s try it out!

[18]:
def truncated_normal_model(num_observations, x=None):
    loc = numpyro.sample("loc", dist.Normal())
    scale = numpyro.sample("scale", dist.LogNormal())
    high = numpyro.sample("high", dist.Normal())
    with numpyro.plate("observations", num_observations):
        numpyro.sample("x", RightTruncatedNormal(loc, scale, high), obs=x)
[19]:
num_observations = 1000
num_prior_samples = 100
prior = Predictive(truncated_normal_model, num_samples=num_prior_samples)
prior_samples = prior(PRIOR_RNG, num_observations)

As before, we run mcmc against some synthetic data. We select any random sample from the prior as the ground truth:

[20]:
true_idx = 0
true_loc = prior_samples["loc"][true_idx]
true_scale = prior_samples["scale"][true_idx]
true_high = prior_samples["high"][true_idx]
true_x = prior_samples["x"][true_idx]
[21]:
plt.hist(true_x.copy())
plt.axvline(true_high, linestyle=":", color="k")
plt.xlabel("x")
plt.show()
../_images/tutorials_truncated_distributions_39_0.png

Run MCMC and check the estimates:

[22]:
mcmc = MCMC(NUTS(truncated_normal_model), **MCMC_KWARGS)
mcmc.run(MCMC_RNG, num_observations, true_x)
mcmc.print_summary()
sample: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:02<00:00, 1850.91it/s, 15 steps of size 8.88e-02. acc. prob=0.88]
sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 7434.51it/s, 5 steps of size 1.56e-01. acc. prob=0.78]
sample: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 7792.94it/s, 54 steps of size 5.41e-02. acc. prob=0.91]
sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 7404.07it/s, 9 steps of size 1.77e-01. acc. prob=0.78]

                mean       std    median      5.0%     95.0%     n_eff     r_hat
      high      0.88      0.01      0.88      0.88      0.89    590.13      1.01
       loc     -0.58      0.07     -0.58     -0.70     -0.46    671.04      1.01
     scale      1.40      0.05      1.40      1.32      1.48    678.30      1.01

Number of divergences: 6310

Compare estimates against the ground truth:

[23]:
print(f"True high : {true_high:3.2f}")
print(f"True loc  : {true_loc:3.2f}")
print(f"True scale: {true_scale:3.2f}")
True high : 0.88
True loc  : -0.56
True scale: 1.45

Note that, even though we can recover good estimates for the true values, we had a very high number of divergences. These divergences happen because the data can be outside of the support that we are allowing with our priors. To fix this, we can change the prior on high so that it depends on the observations:

[24]:
def truncated_normal_model_2(num_observations, x=None):
    loc = numpyro.sample("loc", dist.Normal())
    scale = numpyro.sample("scale", dist.LogNormal())
    if x is None:
        high = numpyro.sample("high", dist.Normal())
    else:
        # high is greater or equal to the max value in x:
        delta = numpyro.sample("delta", dist.HalfNormal())
        high = numpyro.deterministic("high", delta + x.max())

    with numpyro.plate("observations", num_observations):
        numpyro.sample("x", RightTruncatedNormal(loc, scale, high), obs=x)
[25]:
mcmc = MCMC(NUTS(truncated_normal_model_2), **MCMC_KWARGS)
mcmc.run(MCMC_RNG, num_observations, true_x)
mcmc.print_summary(exclude_deterministic=False)
sample: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:03<00:00, 1089.76it/s, 15 steps of size 4.85e-01. acc. prob=0.93]
sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 8802.95it/s, 7 steps of size 5.19e-01. acc. prob=0.92]
sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 8975.35it/s, 3 steps of size 5.72e-01. acc. prob=0.89]
sample: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 8471.94it/s, 15 steps of size 3.76e-01. acc. prob=0.96]

                mean       std    median      5.0%     95.0%     n_eff     r_hat
     delta      0.01      0.01      0.00      0.00      0.01   6104.22      1.00
      high      0.88      0.01      0.88      0.88      0.89   6104.22      1.00
       loc     -0.58      0.08     -0.58     -0.71     -0.46   3319.65      1.00
     scale      1.40      0.06      1.40      1.31      1.49   3377.38      1.00

Number of divergences: 0

And the divergences are gone.

In practice, we usually want to understand how the data would look like without the truncation. To do that in NumPyro, there is no need of writing a separate model, we can simply rely on the condition handler to push the truncation point to infinity:

[26]:
model_without_truncation = numpyro.handlers.condition(
    truncated_normal_model,
    {"high": float("inf")},
)
estimates = mcmc.get_samples().copy()
estimates.pop("high")  # Drop to make sure these are not used
pred = Predictive(
    model_without_truncation,
    posterior_samples=estimates,
)
pred_samples = pred(PRED_RNG, num_observations=1000)
[27]:
# thin the samples for a faster histogram
samples_thinned = pred_samples["x"].ravel()[::1000]
[28]:
f, axes = plt.subplots(1, 2, figsize=(15, 5))

axes[0].hist(
    samples_thinned.copy(), label="Untruncated posterior", bins=20, density=True
)
axes[0].axvline(true_high, linestyle=":", color="k", label="Truncation point")
axes[0].set_title("Untruncated posterior")
axes[0].legend()

axes[1].hist(
    samples_thinned[samples_thinned < true_high].copy(),
    label="Tail of untruncated posterior",
    bins=20,
    density=True,
)
axes[1].hist(true_x.copy(), label="Observed, truncated data", density=True, alpha=0.5)
axes[1].axvline(true_high, linestyle=":", color="k", label="Truncation point")
axes[1].set_title("Comparison to observed data")
axes[1].legend()
plt.show()
../_images/tutorials_truncated_distributions_51_0.png

5.3 Example: Left-truncated Poisson

As a final example, we now implement a left-truncated Poisson distribution. Note that a right-truncated Poisson could be reformulated as a particular case of a categorical distribution, so we focus on the less trivial case.

Class attributes

For a truncated Poisson we need two parameters, the rate of the original Poisson distribution and a low parameter to indicate the truncation point. As this is a discrete distribution, we need to clarify whether or not the truncation point is included in the support. In this tutorial, we’ll take the convention that the truncation point low is part of the support.

The low parameter has to be given a ‘non-negative integer’ constraint. As it is a discrete parameter, it will not be possible to do inference for this parameter using NUTS. This is likely not a problem since the truncation point is often known in advance. However, if we really must infer the low parameter, it is possible to do so with DiscreteHMCGibbs though one is limited to using priors with enumerate support.

Like in the case of the truncated normal, the support of this distribution will be defined as a property and not as a class attribute because it depends on the specific value of the low parameter.

class LeftTruncatedPoisson:
    arg_constraints = {
        "low": constraints.nonnegative_integer,
        "rate": constraints.positive,
    }

    # ...
    @constraints.dependent_property(is_discrete=True)
    def support(self):
        return constraints.integer_greater_than(self.low - 1)

The is_discrete argument passed in the dependent_property decorator is used to tell the inference algorithms which variables are discrete latent variables.

The __init__ method

Here we just follow the same pattern as in the previous example.

# ...
def __init__(self, rate=1.0, low=0, validate_args=None):
    batch_shape = lax.broadcast_shapes(
        jnp.shape(low), jnp.shape(rate)
    )
    self.low, self.rate = promote_shapes(low, rate)
    super().__init__(batch_shape, validate_args=validate_args)
# ...

The log_prob method

The logic is very similar to the truncated normal case. But this time we are truncating on the left, so the correct normalization is the complementary cumulative density:

\[\begin{align} M = \sum_{n=L}^{\infty} p_Y(n) = 1 - \sum_{n=0}^{L - 1} p_Y(n) = 1 - F_Y(L - 1) \end{align}\]

For the code, we can rely on the poisson module that lives inside jax.scipy.stats.

# ...
def log_prob(self, value):
    m = 1 - poisson.cdf(self.low - 1, self.rate)
    log_p = poisson.logpmf(value, self.rate)
    return jnp.where(value >= self.low, log_p - jnp.log(m), -jnp.inf)
# ...

The sample method

Inverse-transform sampling also works for discrete distributions. The “inverse” cdf of a discrete distribution being defined as:

\[\begin{align} F^{-1}(u) = \max\left\{n\in \mathbb{N} \rvert F(n) \lt u\right\} \end{align}\]

Or, in plain English, \(F^{-1}(u)\) is the highest number for which the cumulative density is less than \(u\). However, there’s currently no implementation of \(F^{-1}\) for the Poisson distribution in Jax (at least, at the moment of writing this tutorial). We have to rely on our own implementation. Fortunately, we can take advantage of the discrete nature of the distribution and easily implement a “brute-force” version that will work for most cases. The brute force approach consists of simply scanning all non-negative integers in order, one by one, until the value of the cumulative density exceeds the argument \(u\). The implicit requirement is that we need a way to evaluate the cumulative density for the truncated distribution, but we can calculate that:

\[\begin{align} F_Z(z) &= \sum_{n=0}^z p_z(n)\newline &= \frac{1}{M}\sum_{n=L}^z p_Y(n)\quad \text{assuming $z >= L$}\newline &= \frac{1}{M}\left(\sum_{n=0}^z p_Y(n) - \sum_{n=0}^{L-1}p_Y(n)\right)\newline &= \frac{1}{M}\left(F_Y(z) - F_Y (L-1)\right) \end{align}\]

And, of course, the value of \(F_Z(z)\) is equal to zero if \(z < L\). (As in the previous example, we are using \(Y\) to denote the original, un-truncated variable, and we are using \(Z\) to denote the truncated variable)

# ...
def sample(self, key, sample_shape=()):
    shape = sample_shape + self.batch_shape
    minval = jnp.finfo(jnp.result_type(float)).tiny
    u = random.uniform(key, shape, minval=minval)
    return self.icdf(u)

def icdf(self, u):
    def cond_fn(val):
        n, cdf = val
        return jnp.any(cdf < u)

    def body_fn(val):
        n, cdf = val
        n_new = jnp.where(cdf < u, n + 1, n)
        return n_new, self.cdf(n_new)

    low = self.low * jnp.ones_like(u)
    cdf = self.cdf(low)
    n, _ = lax.while_loop(cond_fn, body_fn, (low, cdf))
    return n.astype(jnp.result_type(int))

def cdf(self, value):
    m = 1 - poisson.cdf(self.low - 1, self.rate)
    f = poisson.cdf(value, self.rate) - poisson.cdf(self.low - 1, self.rate)
    return jnp.where(k >= self.low, f / m, 0)

A few comments with respect to the above implementation: * Even with double precision, if rate is much less than low, the above code will not work. Due to numerical limitations, one obtains that poisson.cdf(low - 1, rate) is equal (or very close) to 1.0. This makes it impossible to re-weight the distribution accurately because the normalization constant would be 0.0. * The brute-force icdf is of course very slow, particularly when rate is high. If you need faster sampling, one option would be to rely on a faster search algorithm. For example:

def icdf_faster(self, u):
    num_bins = 200 # Choose a reasonably large value
    bins = jnp.arange(num_bins)
    cdf = self.cdf(bins)
    indices = jnp.searchsorted(cdf, u)
    return bins[indices]

The obvious limitation here is that the number of bins has to be fixed a priori (jax does not allow for dynamically sized arrays). Another option would be to rely on an approximate implementation, as proposed in this article.

  • Yet another alternative for the icdf is to rely on scipy’s implementation and make use of Jax’s host_callback module. This feature allows you to use Python functions without having to code them in Jax. This means that we can simply make use of scipy’s implementation of the Poisson ICDF! From the last equation, we can write the truncated icdf as:

\[\begin{align} F_Z^{-1}(u) = F_Y^{-1}(Mu + F_Y(L-1)) \end{align}\]

And in python:

def scipy_truncated_poisson_icdf(args): # Note: all arguments are passed inside a tuple
    rate, low, u = args
    rate = np.asarray(rate)
    low = np.asarray(low)
    u = np.asarray(u)
    density = sp_poisson(rate)
    low_cdf = density.cdf(low - 1)
    normalizer = 1.0 - low_cdf
    x = normalizer * u + low_cdf
    return density.ppf(x)

In principle, it wouldn’t be possible to use the above function in our NumPyro distribution because it is not coded in Jax. The jax.experimental.host_callback.call function solves precisely that problem. The code below shows you how to use it, but keep in mind that this is currently an experimental feature so you should expect changes to the module. See the host_callback docs for more details.

# ...
def icdf_scipy(self, u):
    result_shape = jax.ShapeDtypeStruct(
        u.shape,
        jnp.result_type(float) # int type not currently supported
    )
    result = jax.experimental.host_callback.call(
        scipy_truncated_poisson_icdf,
        (self.rate, self.low, u),
        result_shape=result_shape
    )
    return result.astype(jnp.result_type(int))
# ...

Putting it all together, the implementation is as below:

[29]:
def scipy_truncated_poisson_icdf(args):  # Note: all arguments are passed inside a tuple
    rate, low, u = args
    rate = np.asarray(rate)
    low = np.asarray(low)
    u = np.asarray(u)
    density = sp_poisson(rate)
    low_cdf = density.cdf(low - 1)
    normalizer = 1.0 - low_cdf
    x = normalizer * u + low_cdf
    return density.ppf(x)


class LeftTruncatedPoisson(Distribution):
    """
    A truncated Poisson distribution.
    :param numpy.ndarray low: lower bound at which truncation happens
    :param numpy.ndarray rate: rate of the Poisson distribution.
    """

    arg_constraints = {
        "low": constraints.nonnegative_integer,
        "rate": constraints.positive,
    }

    def __init__(self, rate=1.0, low=0, validate_args=None):
        batch_shape = lax.broadcast_shapes(jnp.shape(low), jnp.shape(rate))
        self.low, self.rate = promote_shapes(low, rate)
        super().__init__(batch_shape, validate_args=validate_args)

    def log_prob(self, value):
        m = 1 - poisson.cdf(self.low - 1, self.rate)
        log_p = poisson.logpmf(value, self.rate)
        return jnp.where(value >= self.low, log_p - jnp.log(m), -jnp.inf)

    def sample(self, key, sample_shape=()):
        shape = sample_shape + self.batch_shape
        float_type = jnp.result_type(float)
        minval = jnp.finfo(float_type).tiny
        u = random.uniform(key, shape, minval=minval)
        # return self.icdf(u)        # Brute force
        # return self.icdf_faster(u) # For faster sampling.
        return self.icdf_scipy(u)  # Using `host_callback`

    def icdf(self, u):
        def cond_fn(val):
            n, cdf = val
            return jnp.any(cdf < u)

        def body_fn(val):
            n, cdf = val
            n_new = jnp.where(cdf < u, n + 1, n)
            return n_new, self.cdf(n_new)

        low = self.low * jnp.ones_like(u)
        cdf = self.cdf(low)
        n, _ = lax.while_loop(cond_fn, body_fn, (low, cdf))
        return n.astype(jnp.result_type(int))

    def icdf_faster(self, u):
        num_bins = 200  # Choose a reasonably large value
        bins = jnp.arange(num_bins)
        cdf = self.cdf(bins)
        indices = jnp.searchsorted(cdf, u)
        return bins[indices]

    def icdf_scipy(self, u):
        result_shape = jax.ShapeDtypeStruct(u.shape, jnp.result_type(float))
        result = jax.experimental.host_callback.call(
            scipy_truncated_poisson_icdf,
            (self.rate, self.low, u),
            result_shape=result_shape,
        )
        return result.astype(jnp.result_type(int))

    def cdf(self, value):
        m = 1 - poisson.cdf(self.low - 1, self.rate)
        f = poisson.cdf(value, self.rate) - poisson.cdf(self.low - 1, self.rate)
        return jnp.where(value >= self.low, f / m, 0)

    @constraints.dependent_property(is_discrete=True)
    def support(self):
        return constraints.integer_greater_than(self.low - 1)

Let’s try it out!

[30]:
def discrete_distplot(samples, ax=None, **kwargs):
    """
    Utility function for plotting the samples as a barplot.
    """
    x, y = np.unique(samples, return_counts=True)
    y = y / sum(y)
    if ax is None:
        ax = plt.gca()

    ax.bar(x, y, **kwargs)
    return ax
[31]:
def truncated_poisson_model(num_observations, x=None):
    low = numpyro.sample("low", dist.Categorical(0.2 * jnp.ones((5,))))
    rate = numpyro.sample("rate", dist.LogNormal(1, 1))
    with numpyro.plate("observations", num_observations):
        numpyro.sample("x", LeftTruncatedPoisson(rate, low), obs=x)

Prior samples

[32]:
# -- prior samples
num_observations = 1000
num_prior_samples = 100
prior = Predictive(truncated_poisson_model, num_samples=num_prior_samples)
prior_samples = prior(PRIOR_RNG, num_observations)

Inference

As in the case for the truncated normal, here it is better to replace the prior on the low parameter so that it is consistent with the observed data. We’d like to have a categorical prior on low (so that we can use DiscreteHMCGibbs) whose highest category is equal to the minimum value of x (so that prior and data are consistent). However, we have to be careful in the way we write such model because Jax does not allow for dynamically sized arrays. A simple way of coding this model is to simply specify the number of categories as an argument:

[33]:
def truncated_poisson_model(num_observations, x=None, k=5):
    zeros = jnp.zeros((k,))
    low = numpyro.sample("low", dist.Categorical(logits=zeros))
    rate = numpyro.sample("rate", dist.LogNormal(1, 1))
    with numpyro.plate("observations", num_observations):
        numpyro.sample("x", LeftTruncatedPoisson(rate, low), obs=x)
[34]:
# Take any prior sample as the true process.
true_idx = 6
true_low = prior_samples["low"][true_idx]
true_rate = prior_samples["rate"][true_idx]
true_x = prior_samples["x"][true_idx]
discrete_distplot(true_x.copy());
../_images/tutorials_truncated_distributions_62_0.png

To do inference, we set k = x.min() + 1. Note also the use of DiscreteHMCGibbs:

[35]:
mcmc = MCMC(DiscreteHMCGibbs(NUTS(truncated_poisson_model)), **MCMC_KWARGS)
mcmc.run(MCMC_RNG, num_observations, true_x, k=true_x.min() + 1)
mcmc.print_summary()
sample: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:04<00:00, 808.70it/s, 3 steps of size 9.58e-01. acc. prob=0.93]
sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 5916.30it/s, 3 steps of size 9.14e-01. acc. prob=0.93]
sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 5082.16it/s, 3 steps of size 9.91e-01. acc. prob=0.92]
sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 6511.68it/s, 3 steps of size 8.66e-01. acc. prob=0.94]

                mean       std    median      5.0%     95.0%     n_eff     r_hat
       low      4.13      2.43      4.00      0.00      7.00   7433.79      1.00
      rate     18.16      0.14     18.16     17.96     18.40   3074.46      1.00

[36]:
true_rate
[36]:
DeviceArray(18.2091848, dtype=float64)

As before, one needs to be extra careful when estimating the truncation point. If the truncation point is known is best to provide it.

[37]:
model_with_known_low = numpyro.handlers.condition(
    truncated_poisson_model, {"low": true_low}
)

And note we can use NUTS directly because there’s no need to infer any discrete parameters.

[38]:
mcmc = MCMC(
    NUTS(model_with_known_low),
    **MCMC_KWARGS,
)
[39]:
mcmc.run(MCMC_RNG, num_observations, true_x)
mcmc.print_summary()
sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:03<00:00, 1185.13it/s, 1 steps of size 9.18e-01. acc. prob=0.93]
sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 5786.32it/s, 3 steps of size 1.00e+00. acc. prob=0.92]
sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 5919.13it/s, 1 steps of size 8.62e-01. acc. prob=0.94]
sample: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 7562.36it/s, 3 steps of size 9.01e-01. acc. prob=0.93]

                mean       std    median      5.0%     95.0%     n_eff     r_hat
      rate     18.17      0.13     18.17     17.95     18.39   3406.81      1.00

Number of divergences: 0

Removing the truncation

[40]:
model_without_truncation = numpyro.handlers.condition(
    truncated_poisson_model,
    {"low": 0},
)
pred = Predictive(model_without_truncation, posterior_samples=mcmc.get_samples())
pred_samples = pred(PRED_RNG, num_observations)
thinned_samples = pred_samples["x"][::500]
[41]:
discrete_distplot(thinned_samples.copy());
../_images/tutorials_truncated_distributions_73_0.png