Pyro Primitives


param(name, init_value=None, **kwargs)[source]

Annotate the given site as an optimizable parameter for use with jax.experimental.optimizers. For an example of how param statements can be used in inference algorithms, refer to svi().

  • name (str) – name of site.
  • init_value (numpy.ndarray) – initial value specified by the user. Note that the onus of using this to initialize the optimizer is on the user / inference algorithm, since there is no global parameter store in NumPyro.

value for the parameter. Unless wrapped inside a handler like substitute, this will simply return the initial value.


sample(name, fn, obs=None, rng_key=None, sample_shape=(), infer=None)[source]

Returns a random sample from the stochastic function fn. This can have additional side effects when wrapped inside effect handlers like substitute.


By design, sample primitive is meant to be used inside a NumPyro model. Then seed handler is used to inject a random state to fn. In those situations, rng_key keyword will take no effect.

  • name (str) – name of the sample site.
  • fn – a stochastic function that returns a sample.
  • obs (numpy.ndarray) – observed value
  • rng_key (jax.random.PRNGKey) – an optional random key for fn.
  • sample_shape – Shape of samples to be drawn.
  • infer (dict) – an optional dictionary containing additional information for inference algorithms. For example, if fn is a discrete distribution, setting infer={‘enumerate’: ‘parallel’} to tell MCMC marginalize this discrete latent site.

sample from the stochastic fn.


class plate(name, size, subsample_size=None, dim=None)[source]

Construct for annotating conditionally independent variables. Within a plate context manager, sample sites will be automatically broadcasted to the size of the plate. Additionally, a scale factor might be applied by certain inference algorithms if subsample_size is specified.

  • name (str) – Name of the plate.
  • size (int) – Size of the plate.
  • subsample_size (int) – Optional argument denoting the size of the mini-batch. This can be used to apply a scaling factor by inference algorithms. e.g. when computing ELBO using a mini-batch.
  • dim (int) – Optional argument to specify which dimension in the tensor is used as the plate dim. If None (default), the leftmost available dim is allocated.


plate_stack(prefix, sizes, rightmost_dim=-1)[source]

Create a contiguous stack of plate s with dimensions:

rightmost_dim - len(sizes), ..., rightmost_dim
  • prefix (str) – Name prefix for plates.
  • sizes (iterable) – An iterable of plate sizes.
  • rightmost_dim (int) – The rightmost dim, counting from the right.


deterministic(name, value)[source]

Used to designate deterministic sites in the model. Note that most effect handlers will not operate on deterministic sites (except trace()), so deterministic sites should be side-effect free. The use case for deterministic nodes is to record any values in the model execution trace.

  • name (str) – name of the deterministic site.
  • value (numpy.ndarray) – deterministic value to record in the trace.


factor(name, log_factor)[source]

Factor statement to add arbitrary log probability factor to a probabilistic model.

  • name (str) – Name of the trivial sample.
  • log_factor (numpy.ndarray) – A possibly batched log probability factor.


module(name, nn, input_shape=None)[source]

Declare a stax style neural network inside a model so that its parameters are registered for optimization via param() statements.

  • name (str) – name of the module to be registered.
  • nn (tuple) – a tuple of (init_fn, apply_fn) obtained by a stax constructor function.
  • input_shape (tuple) – shape of the input taken by the neural network.

a apply_fn with bound parameters that takes an array as an input and returns the neural network transformed output array.


scan(f, init, xs, length=None, reverse=False)[source]

This primitive scans a function over the leading array axes of xs while carrying along state. See jax.lax.scan() for more information.


>>> import numpy as np
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.contrib.control_flow import scan
>>> def gaussian_hmm(y=None, T=10):
...     def transition(x_prev, y_curr):
...         x_curr = numpyro.sample('x', dist.Normal(x_prev, 1))
...         y_curr = numpyro.sample('y', dist.Normal(x_curr, 1), obs=y_curr)
...         return x_curr, (x_curr, y_curr)
...     x0 = numpyro.sample('x_0', dist.Normal(0, 1))
...     _, (x, y) = scan(transition, x0, y, length=T)
...     return (x, y)
>>> # here we do some quick tests
>>> with numpyro.handlers.seed(rng_seed=0):
...     x, y = gaussian_hmm(np.arange(10.))
>>> assert x.shape == (10,) and y.shape == (10,)
>>> assert np.all(y == np.arange(10))
>>> with numpyro.handlers.seed(rng_seed=0):  # generative
...     x, y = gaussian_hmm()
>>> assert x.shape == (10,) and y.shape == (10,)


This is an experimental utility function that allows users to use JAX control flow with NumPyro’s effect handlers. Currently, sample and deterministic sites within the scan body f are supported. If you notice that any effect handlers or distributions are unsupported, please file an issue.


It is ambiguous to align scan dimension inside a plate context. So the following pattern won’t be supported

with numpyro.plate('N', 10):
    last, ys = scan(f, init, xs)

All plate statements should be put inside f. For example, the corresponding working code is

def g(*args, **kwargs):
    with numpyro.plate('N', 10):
        return f(*arg, **kwargs)

last, ys = scan(g, init, xs)


Nested scan is currently not supported.


We can scan over discrete latent variables in f. The joint density is evaluated using parallel-scan (reference [1]) over time dimension, which reduces parallel complexity to O(log(length)).

Currently, only the equivalence to markov(history_size=1) is supported. A trace of scan with discrete latent variables will contain the following sites:

  • init sites: those sites belong to the first trace of f. Each of
    them will have name prefixed with _init/.
  • scanned sites: those sites collect the values of the remaining scan
    loop over f. An addition time dimension _time_foo will be added to those sites, where foo is the name of the first site appeared in f.

Not all transition functions f are supported. All of the restrictions from Pyro’s enumeration tutorial [2] still apply here. In addition, there should not have any site outside of scan depend on the first output of scan (the last carry value).

** References **

  1. Temporal Parallelization of Bayesian Smoothers, Simo Sarkka, Angel F. Garcia-Fernandez (
  2. Inference with Discrete Latent Variables (
  • f (callable) – a function to be scanned.
  • init – the initial carrying state
  • xs – the values over which we scan along the leading axis. This can be any JAX pytree (e.g. list/dict of arrays).
  • length – optional value specifying the length of xs but can be used when xs is an empty pytree (e.g. None)
  • reverse (bool) – optional boolean specifying whether to run the scan iteration forward (the default) or in reverse

output of scan, quoted from jax.lax.scan() docs: “pair of type (c, [b]) where the first element represents the final loop carry value and the second element represents the stacked outputs of the second output of f when scanned over the leading axis of the inputs”.