Pyro Primitives¶
param¶
- param(name, init_value=None, **kwargs)[source]¶
Annotate the given site as an optimizable parameter for use with
jax.example_libraries.optimizers
. For an example of how param statements can be used in inference algorithms, refer toSVI
.- Parameters:
name (str) – name of site.
init_value (jnp.ndarray or callable) – initial value specified by the user or a lazy callable that accepts a JAX random PRNGKey and returns an array. Note that the onus of using this to initialize the optimizer is on the user inference algorithm, since there is no global parameter store in NumPyro.
constraint (numpyro.distributions.constraints.Constraint) – NumPyro constraint, defaults to
constraints.real
.event_dim (int) – (optional) number of rightmost dimensions unrelated to batching. Dimension to the left of this will be considered batch dimensions; if the param statement is inside a subsampled plate, then corresponding batch dimensions of the parameter will be correspondingly subsampled. If unspecified, all dimensions will be considered event dims and no subsampling will be performed.
- Returns:
value for the parameter. Unless wrapped inside a handler like
substitute
, this will simply return the initial value.
sample¶
- sample(name, fn, obs=None, rng_key=None, sample_shape=(), infer=None, obs_mask=None)[source]¶
Returns a random sample from the stochastic function fn. This can have additional side effects when wrapped inside effect handlers like
substitute
.Note
By design, sample primitive is meant to be used inside a NumPyro model. Then
seed
handler is used to inject a random state to fn. In those situations, rng_key keyword will take no effect.- Parameters:
name (str) – name of the sample site.
fn – a stochastic function that returns a sample.
obs (jnp.ndarray) – observed value
rng_key (jax.random.PRNGKey) – an optional random key for fn.
sample_shape – Shape of samples to be drawn.
infer (dict) – an optional dictionary containing additional information for inference algorithms. For example, if fn is a discrete distribution, setting infer={‘enumerate’: ‘parallel’} to tell MCMC marginalize this discrete latent site.
obs_mask (jnp.ndarray) – Optional boolean array mask of shape broadcastable with
fn.batch_shape
. If provided, events with mask=True will be conditioned onobs
and remaining events will be imputed by sampling. This introduces a latent sample site namedname + "_unobserved"
which should be used by guides in SVI. Note that this argument is not intended to be used with MCMC.
- Returns:
sample from the stochastic fn.
plate¶
- class plate(name, size, subsample_size=None, dim=None)[source]¶
Construct for annotating conditionally independent variables. Within a plate context manager, sample sites will be automatically broadcasted to the size of the plate. Additionally, a scale factor might be applied by certain inference algorithms if subsample_size is specified.
Note
This can be used to subsample minibatches of data:
with plate("data", len(data), subsample_size=100) as ind: batch = data[ind] assert len(batch) == 100
- Parameters:
name (str) – Name of the plate.
size (int) – Size of the plate.
subsample_size (int) – Optional argument denoting the size of the mini-batch. This can be used to apply a scaling factor by inference algorithms. e.g. when computing ELBO using a mini-batch.
dim (int) – Optional argument to specify which dimension in the tensor is used as the plate dim. If None (default), the rightmost available dim is allocated.
plate_stack¶
subsample¶
- subsample(data, event_dim)[source]¶
EXPERIMENTAL Subsampling statement to subsample data based on enclosing
plate
s.This is typically called on arguments to
model()
when subsampling is performed automatically byplate
s by passingsubsample_size
kwarg. For example the following are equivalent:# Version 1. using indexing def model(data): with numpyro.plate("data", len(data), subsample_size=10, dim=-data.dim()) as ind: data = data[ind] # ... # Version 2. using numpyro.subsample() def model(data): with numpyro.plate("data", len(data), subsample_size=10, dim=-data.dim()): data = numpyro.subsample(data, event_dim=0) # ...
- Parameters:
data (jnp.ndarray) – A tensor of batched data.
event_dim (int) – The event dimension of the data tensor. Dimensions to the left are considered batch dimensions.
- Returns:
A subsampled version of
data
- Return type:
ndarray
deterministic¶
- deterministic(name, value)[source]¶
Used to designate deterministic sites in the model. Note that most effect handlers will not operate on deterministic sites (except
trace()
), so deterministic sites should be side-effect free. The use case for deterministic nodes is to record any values in the model execution trace.- Parameters:
name (str) – name of the deterministic site.
value (jnp.ndarray) – deterministic value to record in the trace.
prng_key¶
factor¶
get_mask¶
- get_mask()[source]¶
Records the effects of enclosing
handlers.mask
handlers. This is useful for avoiding expensivenumpyro.factor()
computations during prediction, when the log density need not be computed, e.g.:def model(): # ... if numpyro.get_mask() is not False: log_density = my_expensive_computation() numpyro.factor("foo", log_density) # ...
- Returns:
The mask.
- Return type:
None, bool, or jnp.ndarray
module¶
- module(name, nn, input_shape=None)[source]¶
Declare a
stax
style neural network inside a model so that its parameters are registered for optimization viaparam()
statements.- Parameters:
- Returns:
a apply_fn with bound parameters that takes an array as an input and returns the neural network transformed output array.
flax_module¶
- flax_module(name, nn_module, *args, input_shape=None, apply_rng=None, mutable=None, **kwargs)[source]¶
Declare a
flax
style neural network inside a model so that its parameters are registered for optimization viaparam()
statements.Given a flax
nn_module
, in flax to evaluate the module with a given set of parameters, we use:nn_module.apply(params, x)
. In a NumPyro model, the pattern will be:net = flax_module("net", nn_module) y = net(x)
or with dropout layers:
net = flax_module("net", nn_module, apply_rng=["dropout"]) rng_key = numpyro.prng_key() y = net(x, rngs={"dropout": rng_key})
- Parameters:
name (str) – name of the module to be registered.
nn_module (flax.linen.Module) – a flax Module which has .init and .apply methods
args – optional arguments to initialize flax neural network as an alternative to input_shape
input_shape (tuple) – shape of the input taken by the neural network.
apply_rng (list) – A list to indicate which extra rng _kinds_ are needed for
nn_module
. For example, whennn_module
includes dropout layers, we need to setapply_rng=["dropout"]
. Defaults to None, which means no extra rng key is needed. Please see Flax Linen Intro for more information in how Flax deals with stochastic layers like dropout.mutable (list) – A list to indicate mutable states of
nn_module
. For example, if your module has BatchNorm layer, we will need to definemutable=["batch_stats"]
. See the above Flax Linen Intro tutorial for more information.kwargs – optional keyword arguments to initialize flax neural network as an alternative to input_shape
- Returns:
a callable with bound parameters that takes an array as an input and returns the neural network transformed output array.
haiku_module¶
- haiku_module(name, nn_module, *args, input_shape=None, apply_rng=False, **kwargs)[source]¶
Declare a
haiku
style neural network inside a model so that its parameters are registered for optimization viaparam()
statements.Given a haiku
nn_module
, in haiku to evaluate the module with a given set of parameters, we use:nn_module.apply(params, None, x)
. In a NumPyro model, the pattern will be:net = haiku_module("net", nn_module) y = net(x) # or y = net(rng_key, x)
or with dropout layers:
net = haiku_module("net", nn_module, apply_rng=True) rng_key = numpyro.prng_key() y = net(rng_key, x)
- Parameters:
name (str) – name of the module to be registered.
nn_module (haiku.Transformed or haiku.TransformedWithState) – a haiku Module which has .init and .apply methods
args – optional arguments to initialize flax neural network as an alternative to input_shape
input_shape (tuple) – shape of the input taken by the neural network.
apply_rng (bool) – A flag to indicate if the returned callable requires an rng argument (e.g. when
nn_module
includes dropout layers). Defaults to False, which means no rng argument is needed. If this is True, the signature of the returned callablenn = haiku_module(..., apply_rng=True)
will benn(rng_key, x)
(rather thannn(x)
).kwargs – optional keyword arguments to initialize flax neural network as an alternative to input_shape
- Returns:
a callable with bound parameters that takes an array as an input and returns the neural network transformed output array.
random_flax_module¶
- random_flax_module(name, nn_module, prior, *args, input_shape=None, apply_rng=None, mutable=None, **kwargs)[source]¶
A primitive to place a prior over the parameters of the Flax module nn_module.
Note
Parameters of a Flax module are stored in a nested dict. For example, the module B defined as follows:
class A(flax.linen.Module): @flax.linen.compact def __call__(self, x): return nn.Dense(1, use_bias=False, name='dense')(x) class B(flax.linen.Module): @flax.linen.compact def __call__(self, x): return A(name='inner')(x)
has parameters {‘inner’: {‘dense’: {‘kernel’: param_value}}}. In the argument prior, to specify kernel parameter, we join the path to it using dots: prior={“inner.dense.kernel”: param_prior}.
- Parameters:
name (str) – name of NumPyro module
flax.linen.Module – the module to be registered with NumPyro
prior (dict, Distribution or callable) –
a NumPyro distribution or a Python dict with parameter names as keys and respective distributions as values. For example:
net = random_flax_module("net", flax.linen.Dense(features=1), prior={"bias": dist.Cauchy(), "kernel": dist.Normal()}, input_shape=(4,))
Alternatively, we can use a callable. For example the following are equivalent:
prior=(lambda name, shape: dist.Cauchy() if name == "bias" else dist.Normal()) prior={"bias": dist.Cauchy(), "kernel": dist.Normal()}
args – optional arguments to initialize flax neural network as an alternative to input_shape
input_shape (tuple) – shape of the input taken by the neural network.
apply_rng (list) –
A list to indicate which extra rng _kinds_ are needed for
nn_module
. For example, whennn_module
includes dropout layers, we need to setapply_rng=["dropout"]
. Defaults to None, which means no extra rng key is needed. Please see Flax Linen Intro for more information in how Flax deals with stochastic layers like dropout.mutable (list) – A list to indicate mutable states of
nn_module
. For example, if your module has BatchNorm layer, we will need to definemutable=["batch_stats"]
. See the above Flax Linen Intro tutorial for more information.kwargs – optional keyword arguments to initialize flax neural network as an alternative to input_shape
- Returns:
a sampled module
Example
# NB: this example is ported from https://github.com/ctallec/pyvarinf/blob/master/main_regression.ipynb >>> import numpy as np; np.random.seed(0) >>> import tqdm >>> from flax import linen as nn >>> from jax import jit, random >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.contrib.module import random_flax_module >>> from numpyro.infer import Predictive, SVI, TraceMeanField_ELBO, autoguide, init_to_feasible ... >>> class Net(nn.Module): ... n_units: int ... ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(self.n_units)(x[..., None]) ... x = nn.relu(x) ... x = nn.Dense(self.n_units)(x) ... x = nn.relu(x) ... mean = nn.Dense(1)(x) ... rho = nn.Dense(1)(x) ... return mean.squeeze(), rho.squeeze() ... >>> def generate_data(n_samples): ... x = np.random.normal(size=n_samples) ... y = np.cos(x * 3) + np.random.normal(size=n_samples) * np.abs(x) / 2 ... return x, y ... >>> def model(x, y=None, batch_size=None): ... module = Net(n_units=32) ... net = random_flax_module("nn", module, dist.Normal(0, 0.1), input_shape=()) ... with numpyro.plate("batch", x.shape[0], subsample_size=batch_size): ... batch_x = numpyro.subsample(x, event_dim=0) ... batch_y = numpyro.subsample(y, event_dim=0) if y is not None else None ... mean, rho = net(batch_x) ... sigma = nn.softplus(rho) ... numpyro.sample("obs", dist.Normal(mean, sigma), obs=batch_y) ... >>> n_train_data = 5000 >>> x_train, y_train = generate_data(n_train_data) >>> guide = autoguide.AutoNormal(model, init_loc_fn=init_to_feasible) >>> svi = SVI(model, guide, numpyro.optim.Adam(5e-3), TraceMeanField_ELBO()) >>> n_iterations = 3000 >>> svi_result = svi.run(random.PRNGKey(0), n_iterations, x_train, y_train, batch_size=256) >>> params, losses = svi_result.params, svi_result.losses >>> n_test_data = 100 >>> x_test, y_test = generate_data(n_test_data) >>> predictive = Predictive(model, guide=guide, params=params, num_samples=1000) >>> y_pred = predictive(random.PRNGKey(1), x_test[:100])["obs"].copy() >>> assert losses[-1] < 3000 >>> assert np.sqrt(np.mean(np.square(y_test - y_pred))) < 1
random_haiku_module¶
- random_haiku_module(name, nn_module, prior, *args, input_shape=None, apply_rng=False, **kwargs)[source]¶
A primitive to place a prior over the parameters of the Haiku module nn_module.
- Parameters:
name (str) – name of NumPyro module
nn_module (haiku.Transformed or haiku.TransformedWithState) – the module to be registered with NumPyro
prior (dict, Distribution or callable) –
a NumPyro distribution or a Python dict with parameter names as keys and respective distributions as values. For example:
net = random_haiku_module("net", haiku.transform(lambda x: hk.Linear(1)(x)), prior={"linear.b": dist.Cauchy(), "linear.w": dist.Normal()}, input_shape=(4,))
Alternatively, we can use a callable. For example the following are equivalent:
prior=(lambda name, shape: dist.Cauchy() if name.startswith("b") else dist.Normal()) prior={"bias": dist.Cauchy(), "kernel": dist.Normal()}
args – optional arguments to initialize flax neural network as an alternative to input_shape
input_shape (tuple) – shape of the input taken by the neural network.
apply_rng (bool) – A flag to indicate if the returned callable requires an rng argument (e.g. when
nn_module
includes dropout layers). Defaults to False, which means no rng argument is needed. If this is True, the signature of the returned callablenn = haiku_module(..., apply_rng=True)
will benn(rng_key, x)
(rather thannn(x)
).kwargs – optional keyword arguments to initialize flax neural network as an alternative to input_shape
- Returns:
a sampled module
scan¶
- scan(f, init, xs, length=None, reverse=False, history=1)[source]¶
This primitive scans a function over the leading array axes of xs while carrying along state. See
jax.lax.scan()
for more information.Usage:
>>> import numpy as np >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.contrib.control_flow import scan >>> >>> def gaussian_hmm(y=None, T=10): ... def transition(x_prev, y_curr): ... x_curr = numpyro.sample('x', dist.Normal(x_prev, 1)) ... y_curr = numpyro.sample('y', dist.Normal(x_curr, 1), obs=y_curr) ... return x_curr, (x_curr, y_curr) ... ... x0 = numpyro.sample('x_0', dist.Normal(0, 1)) ... _, (x, y) = scan(transition, x0, y, length=T) ... return (x, y) >>> >>> # here we do some quick tests >>> with numpyro.handlers.seed(rng_seed=0): ... x, y = gaussian_hmm(np.arange(10.)) >>> assert x.shape == (10,) and y.shape == (10,) >>> assert np.all(y == np.arange(10)) >>> >>> with numpyro.handlers.seed(rng_seed=0): # generative ... x, y = gaussian_hmm() >>> assert x.shape == (10,) and y.shape == (10,)
Warning
This is an experimental utility function that allows users to use JAX control flow with NumPyro’s effect handlers. Currently, sample and deterministic sites within the scan body f are supported. If you notice that any effect handlers or distributions are unsupported, please file an issue.
Note
It is ambiguous to align scan dimension inside a plate context. So the following pattern won’t be supported
with numpyro.plate('N', 10): last, ys = scan(f, init, xs)
All plate statements should be put inside f. For example, the corresponding working code is
def g(*args, **kwargs): with numpyro.plate('N', 10): return f(*arg, **kwargs) last, ys = scan(g, init, xs)
Note
Nested scan is currently not supported.
Note
We can scan over discrete latent variables in f. The joint density is evaluated using parallel-scan (reference [1]) over time dimension, which reduces parallel complexity to O(log(length)).
A
trace
of scan with discrete latent variables will contain the following sites:init sites: those sites belong to the first history traces of f. Sites at the i-th trace will have name prefixed with ‘_PREV_’ * (2 * history - 1 - i).
scanned sites: those sites collect the values of the remaining scan loop over f. An addition time dimension _time_foo will be added to those sites, where foo is the name of the first site appeared in f.
Not all transition functions f are supported. All of the restrictions from Pyro’s enumeration tutorial [2] still apply here. In addition, there should not have any site outside of scan depend on the first output of scan (the last carry value).
References
Temporal Parallelization of Bayesian Smoothers, Simo Sarkka, Angel F. Garcia-Fernandez (https://arxiv.org/abs/1905.13002)
Inference with Discrete Latent Variables (http://pyro.ai/examples/enumeration.html#Dependencies-among-plates)
- Parameters:
f (callable) – a function to be scanned.
init – the initial carrying state
xs – the values over which we scan along the leading axis. This can be any JAX pytree (e.g. list/dict of arrays).
length – optional value specifying the length of xs but can be used when xs is an empty pytree (e.g. None)
reverse (bool) – optional boolean specifying whether to run the scan iteration forward (the default) or in reverse
history (int) – The number of previous contexts visible from the current context. Defaults to 1. If zero, this is similar to
numpyro.plate
.
- Returns:
output of scan, quoted from
jax.lax.scan()
docs: “pair of type (c, [b]) where the first element represents the final loop carry value and the second element represents the stacked outputs of the second output of f when scanned over the leading axis of the inputs”.
cond¶
- cond(pred, true_fun, false_fun, operand)[source]¶
This primitive conditionally applies
true_fun
orfalse_fun
. Seejax.lax.cond()
for more information.Usage:
>>> import numpyro >>> import numpyro.distributions as dist >>> from jax import random >>> from numpyro.contrib.control_flow import cond >>> from numpyro.infer import SVI, Trace_ELBO >>> >>> def model(): ... def true_fun(_): ... return numpyro.sample("x", dist.Normal(20.0)) ... ... def false_fun(_): ... return numpyro.sample("x", dist.Normal(0.0)) ... ... cluster = numpyro.sample("cluster", dist.Normal()) ... return cond(cluster > 0, true_fun, false_fun, None) >>> >>> def guide(): ... m1 = numpyro.param("m1", 10.0) ... s1 = numpyro.param("s1", 0.1, constraint=dist.constraints.positive) ... m2 = numpyro.param("m2", 10.0) ... s2 = numpyro.param("s2", 0.1, constraint=dist.constraints.positive) ... ... def true_fun(_): ... return numpyro.sample("x", dist.Normal(m1, s1)) ... ... def false_fun(_): ... return numpyro.sample("x", dist.Normal(m2, s2)) ... ... cluster = numpyro.sample("cluster", dist.Normal()) ... return cond(cluster > 0, true_fun, false_fun, None) >>> >>> svi = SVI(model, guide, numpyro.optim.Adam(1e-2), Trace_ELBO(num_particles=100)) >>> svi_result = svi.run(random.PRNGKey(0), num_steps=2500)
Warning
This is an experimental utility function that allows users to use JAX control flow with NumPyro’s effect handlers. Currently, sample and deterministic sites within true_fun and false_fun are supported. If you notice that any effect handlers or distributions are unsupported, please file an issue.
Warning
The
cond
primitive does not currently support enumeration and can not be used inside anumpyro.plate
context.Note
All
sample
sites must belong to the same distribution class. For example the following is not supportedcond( True, lambda _: numpyro.sample("x", dist.Normal()), lambda _: numpyro.sample("x", dist.Laplace()), None, )
- Parameters:
pred (bool) – Boolean scalar type indicating which branch function to apply
true_fun (callable) – A function to be applied if
pred
is true.false_fun (callable) – A function to be applied if
pred
is false.operand – Operand input to either branch depending on
pred
. This can be any JAX PyTree (e.g. list / dict of arrays).
- Returns:
Output of the applied branch function.