Pyro Primitives¶
param¶
-
param(name, init_value=None, **kwargs)[source]¶ Annotate the given site as an optimizable parameter for use with
jax.experimental.optimizers. For an example of how param statements can be used in inference algorithms, refer tosvi().Parameters: - name (str) – name of site.
- init_value (numpy.ndarray) – initial value specified by the user. Note that the onus of using this to initialize the optimizer is on the user / inference algorithm, since there is no global parameter store in NumPyro.
- constraint (numpyro.distributions.constraints.Constraint) – NumPyro constraint, defaults to
constraints.real. - event_dim (int) – (optional) number of rightmost dimensions unrelated to batching. Dimension to the left of this will be considered batch dimensions; if the param statement is inside a subsampled plate, then corresponding batch dimensions of the parameter will be correspondingly subsampled. If unspecified, all dimensions will be considered event dims and no subsampling will be performed.
Returns: value for the parameter. Unless wrapped inside a handler like
substitute, this will simply return the initial value.
sample¶
-
sample(name, fn, obs=None, rng_key=None, sample_shape=(), infer=None)[source]¶ Returns a random sample from the stochastic function fn. This can have additional side effects when wrapped inside effect handlers like
substitute.Note
By design, sample primitive is meant to be used inside a NumPyro model. Then
seedhandler is used to inject a random state to fn. In those situations, rng_key keyword will take no effect.Parameters: - name (str) – name of the sample site.
- fn – a stochastic function that returns a sample.
- obs (numpy.ndarray) – observed value
- rng_key (jax.random.PRNGKey) – an optional random key for fn.
- sample_shape – Shape of samples to be drawn.
- infer (dict) – an optional dictionary containing additional information for inference algorithms. For example, if fn is a discrete distribution, setting infer={‘enumerate’: ‘parallel’} to tell MCMC marginalize this discrete latent site.
Returns: sample from the stochastic fn.
plate¶
-
class
plate(name, size, subsample_size=None, dim=None)[source]¶ Construct for annotating conditionally independent variables. Within a plate context manager, sample sites will be automatically broadcasted to the size of the plate. Additionally, a scale factor might be applied by certain inference algorithms if subsample_size is specified.
Note
This can be used to subsample minibatches of data:
with plate("data", len(data), subsample_size=100) as ind: batch = data[ind] assert len(batch) == 100
Parameters: - name (str) – Name of the plate.
- size (int) – Size of the plate.
- subsample_size (int) – Optional argument denoting the size of the mini-batch. This can be used to apply a scaling factor by inference algorithms. e.g. when computing ELBO using a mini-batch.
- dim (int) – Optional argument to specify which dimension in the tensor is used as the plate dim. If None (default), the leftmost available dim is allocated.
plate_stack¶
subsample¶
-
subsample(data, event_dim)[source]¶ EXPERIMENTAL Subsampling statement to subsample data based on enclosing
plates.This is typically called on arguments to
model()when subsampling is performed automatically byplates by passingsubsample_sizekwarg. For example the following are equivalent:# Version 1. using indexing def model(data): with numpyro.plate("data", len(data), subsample_size=10, dim=-data.dim()) as ind: data = data[ind] # ... # Version 2. using numpyro.subsample() def model(data): with numpyro.plate("data", len(data), subsample_size=10, dim=-data.dim()): data = numpyro.subsample(data, event_dim=0) # ...
Parameters: - data (numpy.ndarray) – A tensor of batched data.
- event_dim (int) – The event dimension of the data tensor. Dimensions to the left are considered batch dimensions.
Returns: A subsampled version of
dataReturn type:
deterministic¶
-
deterministic(name, value)[source]¶ Used to designate deterministic sites in the model. Note that most effect handlers will not operate on deterministic sites (except
trace()), so deterministic sites should be side-effect free. The use case for deterministic nodes is to record any values in the model execution trace.Parameters: - name (str) – name of the deterministic site.
- value (numpy.ndarray) – deterministic value to record in the trace.
factor¶
-
factor(name, log_factor)[source]¶ Factor statement to add arbitrary log probability factor to a probabilistic model.
Parameters: - name (str) – Name of the trivial sample.
- log_factor (numpy.ndarray) – A possibly batched log probability factor.
module¶
-
module(name, nn, input_shape=None)[source]¶ Declare a
staxstyle neural network inside a model so that its parameters are registered for optimization viaparam()statements.Parameters: Returns: a apply_fn with bound parameters that takes an array as an input and returns the neural network transformed output array.
flax_module¶
-
flax_module(name, nn_module, *, input_shape=None)[source]¶ Declare a
flaxstyle neural network inside a model so that its parameters are registered for optimization viaparam()statements.Parameters: Returns: a callable with bound parameters that takes an array as an input and returns the neural network transformed output array.
haiku_module¶
-
haiku_module(name, nn_module, *, input_shape=None)[source]¶ Declare a
haikustyle neural network inside a model so that its parameters are registered for optimization viaparam()statements.Parameters: Returns: a callable with bound parameters that takes an array as an input and returns the neural network transformed output array.
random_flax_module¶
-
random_flax_module(name, nn_module, prior, *, input_shape=None)[source]¶ A primitive to place a prior over the parameters of the Flax module nn_module.
Note
Parameters of a Flax module are stored in a nested dict. For example, the module B defined as follows:
class A(nn.Module): def apply(self, x): return nn.Dense(x, 1, bias=False, name='dense') class B(nn.Module): def apply(self, x): return A(x, name='inner')
has parameters {‘inner’: {‘dense’: {‘kernel’: param_value}}}. In the argument prior, to specify kernel parameter, we join the path to it using dots: prior={“inner.dense.kernel”: param_prior}.
Parameters: - name (str) – name of NumPyro module
- flax.nn.Module – the module to be registered with NumPyro
- prior –
a NumPyro distribution or a Python dict with parameter names as keys and respective distributions as values. For example:
net = random_flax_module("net", flax.nn.Dense.partial(features=1), prior={"bias": dist.Cauchy(), "kernel": dist.Normal()}, input_shape=(4,))
- input_shape (tuple) – shape of the input taken by the neural network.
Returns: a sampled module
Example
# NB: this example is ported from https://github.com/ctallec/pyvarinf/blob/master/main_regression.ipynb >>> import numpy as np; np.random.seed(0) >>> import tqdm >>> from flax import nn >>> from jax import jit, random >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.contrib.module import random_flax_module >>> from numpyro.infer import Predictive, SVI, TraceMeanField_ELBO, autoguide, init_to_feasible >>> >>> class Net(nn.Module): ... def apply(self, x, n_units): ... x = nn.Dense(x[..., None], features=n_units) ... x = nn.relu(x) ... x = nn.Dense(x, features=n_units) ... x = nn.relu(x) ... mean = nn.Dense(x, features=1) ... rho = nn.Dense(x, features=1) ... return mean.squeeze(), rho.squeeze() >>> >>> def generate_data(n_samples): ... x = np.random.normal(size=n_samples) ... y = np.cos(x * 3) + np.random.normal(size=n_samples) * np.abs(x) / 2 ... return x, y >>> >>> def model(x, y=None, batch_size=None): ... module = Net.partial(n_units=32) ... net = random_flax_module("nn", module, dist.Normal(0, 0.1), input_shape=()) ... with numpyro.plate("batch", x.shape[0], subsample_size=batch_size): ... batch_x = numpyro.subsample(x, event_dim=0) ... batch_y = numpyro.subsample(y, event_dim=0) if y is not None else None ... mean, rho = net(batch_x) ... sigma = nn.softplus(rho) ... numpyro.sample("obs", dist.Normal(mean, sigma), obs=batch_y) >>> >>> n_train_data = 5000 >>> x_train, y_train = generate_data(n_train_data) >>> guide = autoguide.AutoNormal(model, init_loc_fn=init_to_feasible) >>> svi = SVI(model, guide, numpyro.optim.Adam(5e-3), TraceMeanField_ELBO()) >>> >>> batch_size = 256 >>> n_iterations = 3000 >>> svi_state = svi.init(random.PRNGKey(0), x_train, y_train, batch_size=batch_size) >>> update_fn = jit(svi.update, static_argnums=(3,)) >>> for i in tqdm.trange(n_iterations): ... svi_state, loss = update_fn(svi_state, x_train, y_train, batch_size) >>> >>> params = svi.get_params(svi_state) >>> n_test_data = 100 >>> x_test, y_test = generate_data(n_test_data) >>> predictive = Predictive(model, guide=guide, params=params, num_samples=1000) >>> y_pred = predictive(random.PRNGKey(1), x_test[:100])["obs"].copy() >>> assert loss < 3000 >>> assert np.sqrt(np.mean(np.square(y_test - y_pred))) < 1
random_haiku_module¶
-
random_haiku_module(name, nn_module, prior, *, input_shape=None)[source]¶ A primitive to place a prior over the parameters of the Haiku module nn_module.
Parameters: - name (str) – name of NumPyro module
- haiku.Module – the module to be registered with NumPyro
- prior –
a NumPyro distribution or a Python dict with parameter names as keys and respective distributions as values. For example:
net = random_haiku_module("net", haiku.transform(lambda x: hk.Linear(1)(x)), prior={"linear.b": dist.Cauchy(), "linear.w": dist.Normal()}, input_shape=(4,))
- input_shape (tuple) – shape of the input taken by the neural network.
Returns: a sampled module
scan¶
-
scan(f, init, xs, length=None, reverse=False)[source]¶ This primitive scans a function over the leading array axes of xs while carrying along state. See
jax.lax.scan()for more information.Usage:
>>> import numpy as np >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.contrib.control_flow import scan >>> >>> def gaussian_hmm(y=None, T=10): ... def transition(x_prev, y_curr): ... x_curr = numpyro.sample('x', dist.Normal(x_prev, 1)) ... y_curr = numpyro.sample('y', dist.Normal(x_curr, 1), obs=y_curr) ... return x_curr, (x_curr, y_curr) ... ... x0 = numpyro.sample('x_0', dist.Normal(0, 1)) ... _, (x, y) = scan(transition, x0, y, length=T) ... return (x, y) >>> >>> # here we do some quick tests >>> with numpyro.handlers.seed(rng_seed=0): ... x, y = gaussian_hmm(np.arange(10.)) >>> assert x.shape == (10,) and y.shape == (10,) >>> assert np.all(y == np.arange(10)) >>> >>> with numpyro.handlers.seed(rng_seed=0): # generative ... x, y = gaussian_hmm() >>> assert x.shape == (10,) and y.shape == (10,)
Warning
This is an experimental utility function that allows users to use JAX control flow with NumPyro’s effect handlers. Currently, sample and deterministic sites within the scan body f are supported. If you notice that any effect handlers or distributions are unsupported, please file an issue.
Note
It is ambiguous to align scan dimension inside a plate context. So the following pattern won’t be supported
with numpyro.plate('N', 10): last, ys = scan(f, init, xs)
All plate statements should be put inside f. For example, the corresponding working code is
def g(*args, **kwargs): with numpyro.plate('N', 10): return f(*arg, **kwargs) last, ys = scan(g, init, xs)
Note
Nested scan is currently not supported.
Note
We can scan over discrete latent variables in f. The joint density is evaluated using parallel-scan (reference [1]) over time dimension, which reduces parallel complexity to O(log(length)).
Currently, only the equivalence to
markov(history_size=1)is supported. Atraceof scan with discrete latent variables will contain the following sites:- init sites: those sites belong to the first trace of f. Each of
- them will have name prefixed with _init/.
- scanned sites: those sites collect the values of the remaining scan
- loop over f. An addition time dimension _time_foo will be added to those sites, where foo is the name of the first site appeared in f.
Not all transition functions f are supported. All of the restrictions from Pyro’s enumeration tutorial [2] still apply here. In addition, there should not have any site outside of scan depend on the first output of scan (the last carry value).
** References **
- Temporal Parallelization of Bayesian Smoothers, Simo Sarkka, Angel F. Garcia-Fernandez (https://arxiv.org/abs/1905.13002)
- Inference with Discrete Latent Variables (http://pyro.ai/examples/enumeration.html#Dependencies-among-plates)
Parameters: - f (callable) – a function to be scanned.
- init – the initial carrying state
- xs – the values over which we scan along the leading axis. This can be any JAX pytree (e.g. list/dict of arrays).
- length – optional value specifying the length of xs but can be used when xs is an empty pytree (e.g. None)
- reverse (bool) – optional boolean specifying whether to run the scan iteration forward (the default) or in reverse
Returns: output of scan, quoted from
jax.lax.scan()docs: “pair of type (c, [b]) where the first element represents the final loop carry value and the second element represents the stacked outputs of the second output of f when scanned over the leading axis of the inputs”.