Pyro Primitives¶
param¶

param
(name, init_value=None, **kwargs)[source]¶ Annotate the given site as an optimizable parameter for use with
jax.experimental.optimizers
. For an example of how param statements can be used in inference algorithms, refer tosvi()
.Parameters:  name (str) – name of site.
 init_value (numpy.ndarray) – initial value specified by the user. Note that the onus of using this to initialize the optimizer is on the user / inference algorithm, since there is no global parameter store in NumPyro.
Returns: value for the parameter. Unless wrapped inside a handler like
substitute
, this will simply return the initial value.
sample¶

sample
(name, fn, obs=None, rng_key=None, sample_shape=())[source]¶ Returns a random sample from the stochastic function fn. This can have additional side effects when wrapped inside effect handlers like
substitute
.Note
By design, sample primitive is meant to be used inside a NumPyro model. Then
seed
handler is used to inject a random state to fn. In those situations, rng_key keyword will take no effect.Parameters:  name (str) – name of the sample site
 fn – Python callable
 obs (numpy.ndarray) – observed value
 rng_key (jax.random.PRNGKey) – an optional random key for fn.
 sample_shape – Shape of samples to be drawn.
Returns: sample from the stochastic fn.
plate¶

class
plate
(name, size, subsample_size=None, dim=None)[source]¶ Construct for annotating conditionally independent variables. Within a plate context manager, sample sites will be automatically broadcasted to the size of the plate. Additionally, a scale factor might be applied by certain inference algorithms if subsample_size is specified.
Parameters:  name (str) – Name of the plate.
 size (int) – Size of the plate.
 subsample_size (int) – Optional argument denoting the size of the minibatch. This can be used to apply a scaling factor by inference algorithms. e.g. when computing ELBO using a minibatch.
 dim (int) – Optional argument to specify which dimension in the tensor is used as the plate dim. If None (default), the leftmost available dim is allocated.
factor¶

factor
(name, log_factor)[source]¶ Factor statement to add arbitrary log probability factor to a probabilistic model.
Parameters:  name (str) – Name of the trivial sample.
 log_factor (numpy.ndarray) – A possibly batched log probability factor.
module¶

module
(name, nn, input_shape=None)[source]¶ Declare a
stax
style neural network inside a model so that its parameters are registered for optimization viaparam()
statements.Parameters: Returns: a apply_fn with bound parameters that takes an array as an input and returns the neural network transformed output array.