Pyro Primitives

param(name, init_value)[source]

Annotate the given site as an optimizable parameter for use with jax.experimental.optimizers. For an example of how param statements can be used in inference algorithms, refer to svi().

Parameters:
  • name (str) – name of site.
  • init_value (numpy.ndarray) – initial value specified by the user. Note that the onus of using this to initialize the optimizer is on the user / inference algorithm, since there is no global parameter store in NumPyro.
Returns:

value for the parameter. Unless wrapped inside a handler like substitute, this will simply return the initial value.

sample(name, fn, obs=None)[source]

Returns a random sample from the stochastic function fn. This can have additional side effects when wrapped inside effect handlers like substitute.

Parameters:
  • name (str) – name of the sample site
  • fn – Python callable
  • obs (numpy.ndarray) – observed value
Returns:

sample from the stochastic fn.

Effect Handlers

This provides a small set of effect handlers in NumPyro that are modeled after Pyro’s poutine module. For a tutorial on effect handlers more generally, readers are encouraged to read Poutine: A Guide to Programming with Effect Handlers in Pyro. These simple effect handlers can be composed together or new ones added to enable implementation of custom inference utilities and algorithms.

Example

As an example, we are using seed, trace and substitute handlers to define the log_likelihood function below. We first create a logistic regression model and sample from the posterior distribution over the regression parameters using mcmc(). The log_likelihood function uses effect handlers to run the model by substituting sample sites with values from the posterior distribution and computes the log density for a single data point. The expected_log_likelihood function computes the log likelihood for each draw from the joint posterior and aggregates the results, but does so by using JAX’s auto-vectorize transform called vmap so that we do not need to loop over all the data points.

>>> N, D = 3000, 3
>>> def logistic_regression(data, labels):
...     coefs = sample('coefs', dist.Normal(np.zeros(D), np.ones(D)))
...     intercept = sample('intercept', dist.Normal(0., 10.))
...     logits = np.sum(coefs * data + intercept, axis=-1)
...     return sample('obs', dist.Bernoulli(logits=logits), obs=labels)

>>> data = random.normal(random.PRNGKey(0), (N, D))
>>> true_coefs = np.arange(1., D + 1.)
>>> logits = np.sum(true_coefs * data, axis=-1)
>>> labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

>>> init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2), logistic_regression, data, labels)
>>> num_warmup, num_samples = 1000, 1000
>>> samples = mcmc(num_warmup, num_samples, init_params,
...                potential_fn=potential_fn,
...                constrain_fn=constrain_fn)  
warmup: 100%|██████████| 1000/1000 [00:09<00:00, 109.40it/s, 1 steps of size 5.83e-01. acc. prob=0.79]
sample: 100%|██████████| 1000/1000 [00:00<00:00, 1252.39it/s, 1 steps of size 5.83e-01. acc. prob=0.85]


                   mean         sd       5.5%      94.5%      n_eff       Rhat
    coefs[0]       0.96       0.07       0.85       1.07     455.35       1.01
    coefs[1]       2.05       0.09       1.91       2.20     332.00       1.01
    coefs[2]       3.18       0.13       2.96       3.37     320.27       1.00
   intercept      -0.03       0.02      -0.06       0.00     402.53       1.00

>>> def log_likelihood(rng, params, model, *args, **kwargs):
...     model = substitute(seed(model, rng), params)
...     model_trace = trace(model).get_trace(*args, **kwargs)
...     obs_node = model_trace['obs']
...     return np.sum(obs_node['fn'].log_prob(obs_node['value']))

>>> def expected_log_likelihood(rng, params, model, *args, **kwargs):
...     n = list(params.values())[0].shape[0]
...     log_lk_fn = vmap(lambda rng, params: log_likelihood(rng, params, model, *args, **kwargs))
...     log_lk_vals = log_lk_fn(random.split(rng, n), params)
...     return logsumexp(log_lk_vals) - np.log(n)

>>> print(expected_log_likelihood(random.PRNGKey(2), samples, logistic_regression, data, labels))  
-876.172
class block(fn=None, hide_fn=<function block.<lambda>>)[source]

Bases: numpyro.handlers.Messenger

Given a callable fn, return another callable that selectively hides primitive sites where hide_fn returns True from other effect handlers on the stack.

Parameters:
  • fn – Python callable with NumPyro primitives.
  • hide_fn – function which when given a dictionary containing site-level metadata returns whether it should be blocked.

Example:

>>> def model():
...     a = sample('a', dist.Normal(0., 1.))
...     return sample('b', dist.Normal(a, 1.))

>>> model = seed(model, random.PRNGKey(0))
>>> block_all = block(model)
>>> block_a = block(model, lambda site: site['name'] == 'a')
>>> trace_block_all = trace(block_all).get_trace()
>>> assert not {'a', 'b'}.intersection(trace_block_all.keys())
>>> trace_block_a =  trace(block_a).get_trace()
>>> assert 'a' not in trace_block_a
>>> assert 'b' in trace_block_a
process_message(msg)[source]
class replay(fn, guide_trace)[source]

Bases: numpyro.handlers.Messenger

Given a callable fn and an execution trace guide_trace, return a callable which substitutes sample calls in fn with values from the corresponding site names in guide_trace.

Parameters:
  • fn – Python callable with NumPyro primitives.
  • guide_trace – an OrderedDict containing execution metadata.

Example

>>> def model():
...     sample('a', dist.Normal(0., 1.))

>>> exec_trace = trace(seed(model, random.PRNGKey(0))).get_trace()
>>> print(exec_trace['a']['value'])  
-0.20584235
>>> replayed_trace = trace(replay(model, exec_trace)).get_trace()
>>> print(exec_trace['a']['value'])  
-0.20584235
>>> assert replayed_trace['a']['value'] == exec_trace['a']['value']
process_message(msg)[source]
class seed(fn, rng)[source]

Bases: numpyro.handlers.Messenger

JAX uses a functional pseudo random number generator that requires passing in a seed PRNGKey() to every stochastic function. The seed handler allows us to initially seed a stochastic function with a PRNGKey(). Every call to the sample() primitive inside the function results in a splitting of this initial seed so that we use a fresh seed for each subsequent call without having to explicitly pass in a PRNGKey to each sample call.

process_message(msg)[source]
class substitute(fn=None, param_map=None)[source]

Bases: numpyro.handlers.Messenger

Given a callable fn and a dict param_map keyed by site names, return a callable which substitutes all primitive calls in fn with values from param_map whose key matches the site name. If the site name is not present in param_map, there is no side effect.

Parameters:
  • fn – Python callable with NumPyro primitives.
  • param_map (dict) – dictionary of numpy.ndarray values keyed by site names.

Example:

>>> def model():
...     sample('a', dist.Normal(0., 1.))

>>> model = seed(model, random.PRNGKey(0))
>>> exec_trace = trace(substitute(model, {'a': -1})).get_trace()
>>> assert exec_trace['a']['value'] == -1
process_message(msg)[source]
class trace(fn=None)[source]

Bases: numpyro.handlers.Messenger

Returns a handler that records the inputs and outputs at primitive calls inside fn.

Example

>>> def model():
...     sample('a', dist.Normal(0., 1.))

>>> exec_trace = trace(seed(model, random.PRNGKey(0))).get_trace()
>>> pp.pprint(exec_trace)  
OrderedDict([('a',
              {'args': (),
               'fn': <numpyro.distributions.continuous.Normal object at 0x7f9e689b1eb8>,
               'is_observed': False,
               'kwargs': {'random_state': DeviceArray([0, 0], dtype=uint32)},
               'name': 'a',
               'type': 'sample',
               'value': DeviceArray(-0.20584235, dtype=float32)})])
postprocess_message(msg)[source]
get_trace(*args, **kwargs)[source]

Run the wrapped callable and return the recorded trace.

Parameters:
  • *args – arguments to the callable.
  • **kwargs – keyword arguments to the callable.
Returns:

OrderedDict containing the execution trace.