Pyro Primitives

param

param(name, init_value=None, **kwargs)[source]

Annotate the given site as an optimizable parameter for use with jax.experimental.optimizers. For an example of how param statements can be used in inference algorithms, refer to svi().

Parameters:
  • name (str) – name of site.
  • init_value (numpy.ndarray) – initial value specified by the user. Note that the onus of using this to initialize the optimizer is on the user / inference algorithm, since there is no global parameter store in NumPyro.
Returns:

value for the parameter. Unless wrapped inside a handler like substitute, this will simply return the initial value.

sample

sample(name, fn, obs=None, sample_shape=())[source]

Returns a random sample from the stochastic function fn. This can have additional side effects when wrapped inside effect handlers like substitute.

Parameters:
  • name (str) – name of the sample site
  • fn – Python callable
  • obs (numpy.ndarray) – observed value
  • sample_shape – Shape of samples to be drawn.
Returns:

sample from the stochastic fn.

module

module(name, nn, input_shape=None)[source]

Declare a stax style neural network inside a model so that its parameters are registered for optimization via param() statements.

Parameters:
  • name (str) – name of the module to be registered.
  • nn (tuple) – a tuple of (init_fn, apply_fn) obtained by a stax constructor function.
  • input_shape (tuple) – shape of the input taken by the neural network.
Returns:

a apply_fn with bound parameters that takes an array as an input and returns the neural network transformed output array.