Pyro Primitives¶
param¶
-
param(name, init_value=None, **kwargs)[source]¶ Annotate the given site as an optimizable parameter for use with
jax.experimental.optimizers. For an example of how param statements can be used in inference algorithms, refer tosvi().Parameters: - name (str) – name of site.
- init_value (numpy.ndarray) – initial value specified by the user. Note that the onus of using this to initialize the optimizer is on the user / inference algorithm, since there is no global parameter store in NumPyro.
Returns: value for the parameter. Unless wrapped inside a handler like
substitute, this will simply return the initial value.
sample¶
-
sample(name, fn, obs=None, rng_key=None, sample_shape=())[source]¶ Returns a random sample from the stochastic function fn. This can have additional side effects when wrapped inside effect handlers like
substitute.Note
By design, sample primitive is meant to be used inside a NumPyro model. Then
seedhandler is used to inject a random state to fn. In those situations, rng_key keyword will take no effect.Parameters: - name (str) – name of the sample site
- fn – Python callable
- obs (numpy.ndarray) – observed value
- rng_key (jax.random.PRNGKey) – an optional random key for fn.
- sample_shape – Shape of samples to be drawn.
Returns: sample from the stochastic fn.
plate¶
-
class
plate(name, size, subsample_size=None, dim=None)[source]¶ Construct for annotating conditionally independent variables. Within a plate context manager, sample sites will be automatically broadcasted to the size of the plate. Additionally, a scale factor might be applied by certain inference algorithms if subsample_size is specified.
Parameters: - name (str) – Name of the plate.
- size (int) – Size of the plate.
- subsample_size (int) – Optional argument denoting the size of the mini-batch. This can be used to apply a scaling factor by inference algorithms. e.g. when computing ELBO using a mini-batch.
- dim (int) – Optional argument to specify which dimension in the tensor is used as the plate dim. If None (default), the leftmost available dim is allocated.
factor¶
-
factor(name, log_factor)[source]¶ Factor statement to add arbitrary log probability factor to a probabilistic model.
Parameters: - name (str) – Name of the trivial sample.
- log_factor (numpy.ndarray) – A possibly batched log probability factor.
module¶
-
module(name, nn, input_shape=None)[source]¶ Declare a
staxstyle neural network inside a model so that its parameters are registered for optimization viaparam()statements.Parameters: Returns: a apply_fn with bound parameters that takes an array as an input and returns the neural network transformed output array.