NumPyro documentation¶
Getting Started with NumPyro¶
Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
What is NumPyro?¶
NumPyro is a small probabilistic programming library that provides a NumPy backend for Pyro. We rely on JAX for automatic differentiation and JIT compilation to GPU / CPU. This is an alpha release under active development, so beware of brittleness, bugs, and changes to the API as the design evolves.
NumPyro is designed to be lightweight and focuses on providing a flexible substrate that users can build on:
Pyro Primitives: NumPyro programs can contain regular Python and NumPy code, in addition to Pyro primitives like
sample
andparam
. The model code should look very similar to Pyro except for some minor differences between PyTorch and Numpy’s API. See the example below.Inference algorithms: NumPyro currently supports Hamiltonian Monte Carlo, including an implementation of the No U-Turn Sampler. One of the motivations for NumPyro was to speed up Hamiltonian Monte Carlo by JIT compiling the verlet integrator that includes multiple gradient computations. With JAX, we can compose
jit
andgrad
to compile the entire integration step into an XLA optimized kernel. We also eliminate Python overhead by JIT compiling the entire tree building stage in NUTS (this is possible using Iterative NUTS). There is also a basic Variational Inference implementation for reparameterized distributions together with many flexible (auto)guides for Automatic Differentiation Variational Inference (ADVI).Distributions: The numpyro.distributions module provides distribution classes, constraints and bijective transforms. The distribution classes wrap over samplers implemented to work with JAX’s functional pseudo-random number generator. The design of the distributions module largely follows from PyTorch. A major subset of the API is implemented, and it contains most of the common distributions that exist in PyTorch. As a result, Pyro and PyTorch users can rely on the same API and batching semantics as in
torch.distributions
. In addition to distributions,constraints
andtransforms
are very useful when operating on distribution classes with bounded support.Effect handlers: Like Pyro, primitives like
sample
andparam
can be provided nonstandard interpretations using effect-handlers from the numpyro.handlers module, and these can be easily extended to implement custom inference algorithms and inference utilities.
A Simple Example - 8 Schools¶
Let us explore NumPyro using a simple example. We will use the eight schools example from Gelman et al., Bayesian Data Analysis: Sec. 5.5, 2003, which studies the effect of coaching on SAT performance in eight schools.
The data is given by:
>>> import numpy as np
>>> J = 8
>>> y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
>>> sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
, where y
are the treatment effects and sigma
the standard error. We build a hierarchical model for the study where we assume that the group-level parameters theta
for each school are sampled from a Normal distribution with unknown mean mu
and standard deviation tau
, while the observed data are in turn generated from a Normal distribution with mean and standard deviation given by theta
(true effect) and sigma
, respectively. This allows us to estimate the
population-level parameters mu
and tau
by pooling from all the observations, while still allowing for individual variation amongst the schools using the group-level theta
parameters.
>>> import numpyro
>>> import numpyro.distributions as dist
>>> # Eight Schools example
... def eight_schools(J, sigma, y=None):
... mu = numpyro.sample('mu', dist.Normal(0, 5))
... tau = numpyro.sample('tau', dist.HalfCauchy(5))
... with numpyro.plate('J', J):
... theta = numpyro.sample('theta', dist.Normal(mu, tau))
... numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)
Let us infer the values of the unknown parameters in our model by running MCMC using the No-U-Turn Sampler (NUTS). Note the usage of the extra_fields
argument in MCMC.run. By default, we only collect samples from the target (posterior) distribution when we run inference using MCMC
. However, collecting additional fields like potential energy or the acceptance probability of a sample can be easily achieved by using
the extra_fields
argument. For a list of possible fields that can be collected, see the HMCState object. In this example, we will additionally collect the potential_energy
for each sample.
>>> from jax import random
>>> from numpyro.infer import MCMC, NUTS
>>> nuts_kernel = NUTS(eight_schools)
>>> mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
>>> rng_key = random.PRNGKey(0)
>>> mcmc.run(rng_key, J, sigma, y=y, extra_fields=('potential_energy',))
We can print the summary of the MCMC run, and examine if we observed any divergences during inference. Additionally, since we collected the potential energy for each of the samples, we can easily compute the expected log joint density.
>>> mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat
mu 4.14 3.18 3.87 -0.76 9.50 115.42 1.01
tau 4.12 3.58 3.12 0.51 8.56 90.64 1.02
theta[0] 6.40 6.22 5.36 -2.54 15.27 176.75 1.00
theta[1] 4.96 5.04 4.49 -1.98 14.22 217.12 1.00
theta[2] 3.65 5.41 3.31 -3.47 13.77 247.64 1.00
theta[3] 4.47 5.29 4.00 -3.22 12.92 213.36 1.01
theta[4] 3.22 4.61 3.28 -3.72 10.93 242.14 1.01
theta[5] 3.89 4.99 3.71 -3.39 12.54 206.27 1.00
theta[6] 6.55 5.72 5.66 -1.43 15.78 124.57 1.00
theta[7] 4.81 5.95 4.19 -3.90 13.40 299.66 1.00
Number of divergences: 19
>>> pe = mcmc.get_extra_fields()['potential_energy']
>>> print('Expected log joint density: {:.2f}'.format(np.mean(-pe)))
Expected log joint density: -54.55
The values above 1 for the split Gelman Rubin diagnostic (r_hat
) indicates that the chain has not fully converged. The low value for the effective sample size (n_eff
), particularly for tau
, and the number of divergent transitions looks problematic. Fortunately, this is a common pathology that can be rectified by using a non-centered paramaterization for tau
in our model. This is straightforward
to do in NumPyro by using a TransformedDistribution instance together with a reparameterization effect handler. Let us rewrite the same model but instead of sampling theta
from a Normal(mu, tau)
, we will instead sample it from a base Normal(0, 1)
distribution that is transformed using an
AffineTransform. Note that by doing so, NumPyro runs HMC by generating samples theta_base
for the base Normal(0, 1)
distribution instead. We see that the resulting chain does not suffer from the same pathology — the Gelman Rubin diagnostic is 1 for all the parameters and the effective sample size looks quite good!
>>> from numpyro.infer.reparam import TransformReparam
>>> # Eight Schools example - Non-centered Reparametrization
... def eight_schools_noncentered(J, sigma, y=None):
... mu = numpyro.sample('mu', dist.Normal(0, 5))
... tau = numpyro.sample('tau', dist.HalfCauchy(5))
... with numpyro.plate('J', J):
... with numpyro.handlers.reparam(config={'theta': TransformReparam()}):
... theta = numpyro.sample(
... 'theta',
... dist.TransformedDistribution(dist.Normal(0., 1.),
... dist.transforms.AffineTransform(mu, tau)))
... numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)
>>> nuts_kernel = NUTS(eight_schools_noncentered)
>>> mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
>>> rng_key = random.PRNGKey(0)
>>> mcmc.run(rng_key, J, sigma, y=y, extra_fields=('potential_energy',))
>>> mcmc.print_summary(exclude_deterministic=False)
mean std median 5.0% 95.0% n_eff r_hat
mu 4.08 3.51 4.14 -1.69 9.71 720.43 1.00
tau 3.96 3.31 3.09 0.01 8.34 488.63 1.00
theta[0] 6.48 5.72 6.08 -2.53 14.96 801.59 1.00
theta[1] 4.95 5.10 4.91 -3.70 12.82 1183.06 1.00
theta[2] 3.65 5.58 3.72 -5.71 12.13 581.31 1.00
theta[3] 4.56 5.04 4.32 -3.14 12.92 1282.60 1.00
theta[4] 3.41 4.79 3.47 -4.16 10.79 801.25 1.00
theta[5] 3.58 4.80 3.78 -3.95 11.55 1101.33 1.00
theta[6] 6.31 5.17 5.75 -2.93 13.87 1081.11 1.00
theta[7] 4.81 5.38 4.61 -3.29 14.05 954.14 1.00
theta_base[0] 0.41 0.95 0.40 -1.09 1.95 851.45 1.00
theta_base[1] 0.15 0.95 0.20 -1.42 1.66 1568.11 1.00
theta_base[2] -0.08 0.98 -0.10 -1.68 1.54 1037.16 1.00
theta_base[3] 0.06 0.89 0.05 -1.42 1.47 1745.02 1.00
theta_base[4] -0.14 0.94 -0.16 -1.65 1.45 719.85 1.00
theta_base[5] -0.10 0.96 -0.14 -1.57 1.51 1128.45 1.00
theta_base[6] 0.38 0.95 0.42 -1.32 1.82 1026.50 1.00
theta_base[7] 0.10 0.97 0.10 -1.51 1.65 1190.98 1.00
Number of divergences: 0
>>> pe = mcmc.get_extra_fields()['potential_energy']
>>> # Compare with the earlier value
>>> print('Expected log joint density: {:.2f}'.format(np.mean(-pe)))
Expected log joint density: -46.09
Note that for the class of distributions with loc,scale
paramaters such as Normal
, Cauchy
, StudentT
, we also provide a LocScaleReparam reparameterizer to achieve the same purpose. The corresponding code will be
with numpyro.handlers.reparam(config={'theta': LocScaleReparam(centered=0)}):
theta = numpyro.sample('theta', dist.Normal(mu, tau))
Now, let us assume that we have a new school for which we have not observed any test scores, but we would like to generate predictions. NumPyro provides a Predictive class for such a purpose. Note that in the absence of any observed data, we simply use the population-level parameters to generate predictions. The Predictive
utility conditions the unobserved mu
and tau
sites to values drawn from the
posterior distribution from our last MCMC run, and runs the model forward to generate predictions.
>>> from numpyro.infer import Predictive
>>> # New School
... def new_school():
... mu = numpyro.sample('mu', dist.Normal(0, 5))
... tau = numpyro.sample('tau', dist.HalfCauchy(5))
... return numpyro.sample('obs', dist.Normal(mu, tau))
>>> predictive = Predictive(new_school, mcmc.get_samples())
>>> samples_predictive = predictive(random.PRNGKey(1))
>>> print(np.mean(samples_predictive['obs']))
3.9886456
More Examples¶
For some more examples on specifying models and doing inference in NumPyro:
Bayesian Regression in NumPyro - Start here to get acquainted with writing a simple model in NumPyro, MCMC inference API, effect handlers and writing custom inference utilities.
Time Series Forecasting - Illustrates how to convert for loops in the model to JAX’s
lax.scan
primitive for fast inference.Baseball example - Using NUTS for a simple hierarchical model. Compare this with the baseball example in Pyro.
Hidden Markov Model in NumPyro as compared to Stan.
Variational Autoencoder - As a simple example that uses Variational Inference with neural networks. Pyro implementation for comparison.
Gaussian Process - Provides a simple example to use NUTS to sample from the posterior over the hyper-parameters of a Gaussian Process.
Statistical Rethinking with NumPyro - Notebooks containing translation of the code in Richard McElreath’s Statistical Rethinking book second version, to NumPyro.
Other model examples can be found in the examples folder.
Pyro users will note that the API for model specification and inference is largely the same as Pyro, including the distributions API, by design. However, there are some important core differences (reflected in the internals) that users should be aware of. e.g. in NumPyro, there is no global parameter store or random state, to make it possible for us to leverage JAX’s JIT compilation. Also, users may need to write their models in a more functional style that works better with JAX. Refer to FAQs for a list of differences.
Installation¶
Limited Windows Support: Note that NumPyro is untested on Windows, and might require building jaxlib from source. See this JAX issue for more details. Alternatively, you can install Windows Subsystem for Linux and use NumPyro on it as on a Linux system. See also CUDA on Windows Subsystem for Linux and this forum post if you want to use GPUs on Windows.
To install NumPyro with the latest CPU version of JAX, you can use pip:
pip install numpyro
In case of compatibility issues arise during execution of the above command, you can instead force the installation of a known
compatible CPU version of JAX with
pip install numpyro[cpu]
To use NumPyro on the GPU, you need to install CUDA first and then use the following pip command:
# change `cuda111` to your CUDA version number, e.g. for CUDA 10.2 use `cuda102`
pip install numpyro[cuda111] -f https://storage.googleapis.com/jax-releases/jax_releases.html
If you need further guidance, please have a look at the JAX GPU installation instructions.
To run NumPyro on Cloud TPUs, you can look at some JAX on Cloud TPU examples.
For Cloud TPU VM, you need to setup the TPU backend as detailed in the Cloud TPU VM JAX Quickstart Guide.
After you have verified that the TPU backend is properly set up,
you can install NumPyro using the pip install numpyro
command.
Default Platform: JAX will use GPU by default if CUDA-supported
jaxlib
package is installed. You can use set_platform utilitynumpyro.set_platform("cpu")
to switch to CPU at the beginning of your program.
You can also install NumPyro from source:
git clone https://github.com/pyro-ppl/numpyro.git
# install jax/jaxlib first for CUDA support
pip install -e .[dev] # contains additional dependencies for NumPyro development
You can also install NumPyro with conda:
conda install -c conda-forge numpyro
Frequently Asked Questions¶
Unlike in Pyro,
numpyro.sample('x', dist.Normal(0, 1))
does not work. Why?
You are most likely using a numpyro.sample
statement outside an inference context. JAX does not have a global random state, and as such, distribution samplers need an explicit random number generator key (PRNGKey) to generate samples from. NumPyro’s inference algorithms use the seed handler to thread in a random number generator key, behind the scenes.
Your options are:
Call the distribution directly and provide a
PRNGKey
, e.g.dist.Normal(0, 1).sample(PRNGKey(0))
Provide the
rng_key
argument tonumpyro.sample
. e.g.numpyro.sample('x', dist.Normal(0, 1), rng_key=PRNGKey(0))
.Wrap the code in a
seed
handler, used either as a context manager or as a function that wraps over the original callable. e.g.```python with handlers.seed(rng_seed=0): # random.PRNGKey(0) is used x = numpyro.sample('x', dist.Beta(1, 1)) # uses a PRNGKey split from random.PRNGKey(0) y = numpyro.sample('y', dist.Bernoulli(x)) # uses different PRNGKey split from the last one ```
, or as a higher order function:
```python def fn(): x = numpyro.sample('x', dist.Beta(1, 1)) y = numpyro.sample('y', dist.Bernoulli(x)) return y print(handlers.seed(fn, rng_seed=0)()) ```
Can I use the same Pyro model for doing inference in NumPyro?
As you may have noticed from the examples, NumPyro supports all Pyro primitives like sample
, param
, plate
and module
, and effect handlers. Additionally, we have ensured that the distributions API is based on torch.distributions
, and the inference classes like SVI
and MCMC
have the same interface. This along with the similarity in the API for NumPy and PyTorch operations ensures that models containing
Pyro primitive statements can be used with either backend with some minor changes. Example of some differences along with the changes needed, are noted below:
Any
torch
operation in your model will need to be written in terms of the correspondingjax.numpy
operation. Additionally, not alltorch
operations have anumpy
counterpart (and vice-versa), and sometimes there are minor differences in the API.pyro.sample
statements outside an inference context will need to be wrapped in aseed
handler, as mentioned above.There is no global parameter store, and as such using
numpyro.param
outside an inference context will have no effect. To retrieve the optimized parameter values from SVI, use the SVI.get_params method. Note that you can still useparam
statements inside a model and NumPyro will use the substitute effect handler internally to substitute values from the optimizer when running the model in SVI.PyTorch neural network modules will need to rewritten as stax neural networks. See the VAE example for differences in syntax between the two backends.
JAX works best with functional code, particularly if we would like to leverage JIT compilation, which NumPyro does internally for many inference subroutines. As such, if your model has side-effects that are not visible to the JAX tracer, it may need to rewritten in a more functional style.
For most small models, changes required to run inference in NumPyro should be minor. Additionally, we are working on pyro-api which allows you to write the same code and dispatch it to multiple backends, including NumPyro. This will necessarily be more restrictive, but has the advantage of being backend agnostic. See the documentation for an example, and let us know your feedback.
How can I contribute to the project?
Thanks for your interest in the project! You can take a look at beginner friendly issues that are marked with the good first issue tag on Github. Also, please feel to reach out to us on the forum.
Future / Ongoing Work¶
In the near term, we plan to work on the following. Please open new issues for feature requests and enhancements:
Improving robustness of inference on different models, profiling and performance tuning.
Supporting more functionality as part of the pyro-api generic modeling interface.
More inference algorithms, particularly those that require second order derivaties or use HMC.
Integration with Funsor to support inference algorithms with delayed sampling.
Other areas motivated by Pyro’s research goals and application focus, and interest from the community.
Citing NumPyro¶
The motivating ideas behind NumPyro and a description of Iterative NUTS can be found in this paper that appeared in NeurIPS 2019 Program Transformations for Machine Learning Workshop.
If you use NumPyro, please consider citing:
@article{phan2019composable,
title={Composable Effects for Flexible and Accelerated Probabilistic Programming in NumPyro},
author={Phan, Du and Pradhan, Neeraj and Jankowiak, Martin},
journal={arXiv preprint arXiv:1912.11554},
year={2019}
}
as well as
@article{bingham2018pyro,
author = {Bingham, Eli and Chen, Jonathan P. and Jankowiak, Martin and Obermeyer, Fritz and
Pradhan, Neeraj and Karaletsos, Theofanis and Singh, Rohit and Szerlip, Paul and
Horsfall, Paul and Goodman, Noah D.},
title = {{Pyro: Deep Universal Probabilistic Programming}},
journal = {arXiv preprint arXiv:1810.09538},
year = {2018}
}
Pyro Primitives¶
param¶
- param(name, init_value=None, **kwargs)[source]¶
Annotate the given site as an optimizable parameter for use with
jax.experimental.optimizers
. For an example of how param statements can be used in inference algorithms, refer toSVI
.- Parameters
name (str) – name of site.
init_value (numpy.ndarray or callable) – initial value specified by the user or a lazy callable that accepts a JAX random PRNGKey and returns an array. Note that the onus of using this to initialize the optimizer is on the user inference algorithm, since there is no global parameter store in NumPyro.
constraint (numpyro.distributions.constraints.Constraint) – NumPyro constraint, defaults to
constraints.real
.event_dim (int) – (optional) number of rightmost dimensions unrelated to batching. Dimension to the left of this will be considered batch dimensions; if the param statement is inside a subsampled plate, then corresponding batch dimensions of the parameter will be correspondingly subsampled. If unspecified, all dimensions will be considered event dims and no subsampling will be performed.
- Returns
value for the parameter. Unless wrapped inside a handler like
substitute
, this will simply return the initial value.
sample¶
- sample(name, fn, obs=None, rng_key=None, sample_shape=(), infer=None, obs_mask=None)[source]¶
Returns a random sample from the stochastic function fn. This can have additional side effects when wrapped inside effect handlers like
substitute
.Note
By design, sample primitive is meant to be used inside a NumPyro model. Then
seed
handler is used to inject a random state to fn. In those situations, rng_key keyword will take no effect.- Parameters
name (str) – name of the sample site.
fn – a stochastic function that returns a sample.
obs (numpy.ndarray) – observed value
rng_key (jax.random.PRNGKey) – an optional random key for fn.
sample_shape – Shape of samples to be drawn.
infer (dict) – an optional dictionary containing additional information for inference algorithms. For example, if fn is a discrete distribution, setting infer={‘enumerate’: ‘parallel’} to tell MCMC marginalize this discrete latent site.
obs_mask (numpy.ndarray) – Optional boolean array mask of shape broadcastable with
fn.batch_shape
. If provided, events with mask=True will be conditioned onobs
and remaining events will be imputed by sampling. This introduces a latent sample site namedname + "_unobserved"
which should be used by guides.
- Returns
sample from the stochastic fn.
plate¶
- class plate(name, size, subsample_size=None, dim=None)[source]¶
Construct for annotating conditionally independent variables. Within a plate context manager, sample sites will be automatically broadcasted to the size of the plate. Additionally, a scale factor might be applied by certain inference algorithms if subsample_size is specified.
Note
This can be used to subsample minibatches of data:
with plate("data", len(data), subsample_size=100) as ind: batch = data[ind] assert len(batch) == 100
- Parameters
name (str) – Name of the plate.
size (int) – Size of the plate.
subsample_size (int) – Optional argument denoting the size of the mini-batch. This can be used to apply a scaling factor by inference algorithms. e.g. when computing ELBO using a mini-batch.
dim (int) – Optional argument to specify which dimension in the tensor is used as the plate dim. If None (default), the rightmost available dim is allocated.
plate_stack¶
subsample¶
- subsample(data, event_dim)[source]¶
EXPERIMENTAL Subsampling statement to subsample data based on enclosing
plate
s.This is typically called on arguments to
model()
when subsampling is performed automatically byplate
s by passingsubsample_size
kwarg. For example the following are equivalent:# Version 1. using indexing def model(data): with numpyro.plate("data", len(data), subsample_size=10, dim=-data.dim()) as ind: data = data[ind] # ... # Version 2. using numpyro.subsample() def model(data): with numpyro.plate("data", len(data), subsample_size=10, dim=-data.dim()): data = numpyro.subsample(data, event_dim=0) # ...
- Parameters
data (numpy.ndarray) – A tensor of batched data.
event_dim (int) – The event dimension of the data tensor. Dimensions to the left are considered batch dimensions.
- Returns
A subsampled version of
data
- Return type
deterministic¶
- deterministic(name, value)[source]¶
Used to designate deterministic sites in the model. Note that most effect handlers will not operate on deterministic sites (except
trace()
), so deterministic sites should be side-effect free. The use case for deterministic nodes is to record any values in the model execution trace.- Parameters
name (str) – name of the deterministic site.
value (numpy.ndarray) – deterministic value to record in the trace.
prng_key¶
factor¶
- factor(name, log_factor)[source]¶
Factor statement to add arbitrary log probability factor to a probabilistic model.
- Parameters
name (str) – Name of the trivial sample.
log_factor (numpy.ndarray) – A possibly batched log probability factor.
get_mask¶
- get_mask()[source]¶
Records the effects of enclosing
handlers.mask
handlers. This is useful for avoiding expensivenumpyro.factor()
computations during prediction, when the log density need not be computed, e.g.:def model(): # ... if numpyro.get_mask() is not False: log_density = my_expensive_computation() numpyro.factor("foo", log_density) # ...
- Returns
The mask.
- Return type
None, bool, or numpy.ndarray
module¶
- module(name, nn, input_shape=None)[source]¶
Declare a
stax
style neural network inside a model so that its parameters are registered for optimization viaparam()
statements.- Parameters
- Returns
a apply_fn with bound parameters that takes an array as an input and returns the neural network transformed output array.
flax_module¶
- flax_module(name, nn_module, *, input_shape=None, apply_rng=None, mutable=None, **kwargs)[source]¶
Declare a
flax
style neural network inside a model so that its parameters are registered for optimization viaparam()
statements.Given a flax
nn_module
, in flax to evaluate the module with a given set of parameters, we use:nn_module.apply(params, x)
. In a NumPyro model, the pattern will be:net = flax_module("net", nn_module) y = net(x)
or with dropout layers:
net = flax_module("net", nn_module, apply_rng=["dropout"]) rng_key = numpyro.prng_key() y = net(x, rngs={"dropout": rng_key})
- Parameters
name (str) – name of the module to be registered.
nn_module (flax.linen.Module) – a flax Module which has .init and .apply methods
input_shape (tuple) – shape of the input taken by the neural network.
apply_rng (list) – A list to indicate which extra rng _kinds_ are needed for
nn_module
. For example, whennn_module
includes dropout layers, we need to setapply_rng=["dropout"]
. Defaults to None, which means no extra rng key is needed. Please see Flax Linen Intro for more information in how Flax deals with stochastic layers like dropout.mutable (list) – A list to indicate mutable states of
nn_module
. For example, if your module has BatchNorm layer, we will need to definemutable=["batch_stats"]
. See the above Flax Linen Intro tutorial for more information.kwargs – optional keyword arguments to initialize flax neural network as an alternative to input_shape
- Returns
a callable with bound parameters that takes an array as an input and returns the neural network transformed output array.
haiku_module¶
- haiku_module(name, nn_module, *, input_shape=None, apply_rng=False, **kwargs)[source]¶
Declare a
haiku
style neural network inside a model so that its parameters are registered for optimization viaparam()
statements.Given a haiku
nn_module
, in haiku to evaluate the module with a given set of parameters, we use:nn_module.apply(params, None, x)
. In a NumPyro model, the pattern will be:net = haiku_module("net", nn_module) y = net(x) # or y = net(rng_key, x)
or with dropout layers:
net = haiku_module("net", nn_module, apply_rng=True) rng_key = numpyro.prng_key() y = net(rng_key, x)
- Parameters
name (str) – name of the module to be registered.
nn_module (haiku.Transformed or haiku.TransformedWithState) – a haiku Module which has .init and .apply methods
input_shape (tuple) – shape of the input taken by the neural network.
apply_rng (bool) – A flag to indicate if the returned callable requires an rng argument (e.g. when
nn_module
includes dropout layers). Defaults to False, which means no rng argument is needed. If this is True, the signature of the returned callablenn = haiku_module(..., apply_rng=True)
will benn(rng_key, x)
(rather thannn(x)
).kwargs – optional keyword arguments to initialize flax neural network as an alternative to input_shape
- Returns
a callable with bound parameters that takes an array as an input and returns the neural network transformed output array.
random_flax_module¶
- random_flax_module(name, nn_module, prior, *, input_shape=None, apply_rng=None, mutable=None, **kwargs)[source]¶
A primitive to place a prior over the parameters of the Flax module nn_module.
Note
Parameters of a Flax module are stored in a nested dict. For example, the module B defined as follows:
class A(flax.linen.Module): @flax.linen.compact def __call__(self, x): return nn.Dense(1, use_bias=False, name='dense')(x) class B(flax.linen.Module): @flax.linen.compact def __call__(self, x): return A(name='inner')(x)
has parameters {‘inner’: {‘dense’: {‘kernel’: param_value}}}. In the argument prior, to specify kernel parameter, we join the path to it using dots: prior={“inner.dense.kernel”: param_prior}.
- Parameters
name (str) – name of NumPyro module
flax.linen.Module – the module to be registered with NumPyro
prior (dict or Distribution) –
a NumPyro distribution or a Python dict with parameter names as keys and respective distributions as values. For example:
net = random_flax_module("net", flax.linen.Dense(features=1), prior={"bias": dist.Cauchy(), "kernel": dist.Normal()}, input_shape=(4,))
input_shape (tuple) – shape of the input taken by the neural network.
apply_rng (list) –
A list to indicate which extra rng _kinds_ are needed for
nn_module
. For example, whennn_module
includes dropout layers, we need to setapply_rng=["dropout"]
. Defaults to None, which means no extra rng key is needed. Please see Flax Linen Intro for more information in how Flax deals with stochastic layers like dropout.mutable (list) – A list to indicate mutable states of
nn_module
. For example, if your module has BatchNorm layer, we will need to definemutable=["batch_stats"]
. See the above Flax Linen Intro tutorial for more information.kwargs – optional keyword arguments to initialize flax neural network as an alternative to input_shape
- Returns
a sampled module
Example
# NB: this example is ported from https://github.com/ctallec/pyvarinf/blob/master/main_regression.ipynb >>> import numpy as np; np.random.seed(0) >>> import tqdm >>> from flax import linen as nn >>> from jax import jit, random >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.contrib.module import random_flax_module >>> from numpyro.infer import Predictive, SVI, TraceMeanField_ELBO, autoguide, init_to_feasible ... >>> class Net(nn.Module): ... n_units: int ... ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(self.n_units)(x[..., None]) ... x = nn.relu(x) ... x = nn.Dense(self.n_units)(x) ... x = nn.relu(x) ... mean = nn.Dense(1)(x) ... rho = nn.Dense(1)(x) ... return mean.squeeze(), rho.squeeze() ... >>> def generate_data(n_samples): ... x = np.random.normal(size=n_samples) ... y = np.cos(x * 3) + np.random.normal(size=n_samples) * np.abs(x) / 2 ... return x, y ... >>> def model(x, y=None, batch_size=None): ... module = Net(n_units=32) ... net = random_flax_module("nn", module, dist.Normal(0, 0.1), input_shape=()) ... with numpyro.plate("batch", x.shape[0], subsample_size=batch_size): ... batch_x = numpyro.subsample(x, event_dim=0) ... batch_y = numpyro.subsample(y, event_dim=0) if y is not None else None ... mean, rho = net(batch_x) ... sigma = nn.softplus(rho) ... numpyro.sample("obs", dist.Normal(mean, sigma), obs=batch_y) ... >>> n_train_data = 5000 >>> x_train, y_train = generate_data(n_train_data) >>> guide = autoguide.AutoNormal(model, init_loc_fn=init_to_feasible) >>> svi = SVI(model, guide, numpyro.optim.Adam(5e-3), TraceMeanField_ELBO()) >>> n_iterations = 3000 >>> svi_result = svi.run(random.PRNGKey(0), n_iterations, x_train, y_train, batch_size=256) >>> params, losses = svi_result.params, svi_result.losses >>> n_test_data = 100 >>> x_test, y_test = generate_data(n_test_data) >>> predictive = Predictive(model, guide=guide, params=params, num_samples=1000) >>> y_pred = predictive(random.PRNGKey(1), x_test[:100])["obs"].copy() >>> assert losses[-1] < 3000 >>> assert np.sqrt(np.mean(np.square(y_test - y_pred))) < 1
random_haiku_module¶
- random_haiku_module(name, nn_module, prior, *, input_shape=None, apply_rng=False, **kwargs)[source]¶
A primitive to place a prior over the parameters of the Haiku module nn_module.
- Parameters
name (str) – name of NumPyro module
nn_module (haiku.Transformed or haiku.TransformedWithState) – the module to be registered with NumPyro
prior (dict or Distribution) –
a NumPyro distribution or a Python dict with parameter names as keys and respective distributions as values. For example:
net = random_haiku_module("net", haiku.transform(lambda x: hk.Linear(1)(x)), prior={"linear.b": dist.Cauchy(), "linear.w": dist.Normal()}, input_shape=(4,))
input_shape (tuple) – shape of the input taken by the neural network.
apply_rng (bool) – A flag to indicate if the returned callable requires an rng argument (e.g. when
nn_module
includes dropout layers). Defaults to False, which means no rng argument is needed. If this is True, the signature of the returned callablenn = haiku_module(..., apply_rng=True)
will benn(rng_key, x)
(rather thannn(x)
).kwargs – optional keyword arguments to initialize flax neural network as an alternative to input_shape
- Returns
a sampled module
scan¶
- scan(f, init, xs, length=None, reverse=False, history=1)[source]¶
This primitive scans a function over the leading array axes of xs while carrying along state. See
jax.lax.scan()
for more information.Usage:
>>> import numpy as np >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.contrib.control_flow import scan >>> >>> def gaussian_hmm(y=None, T=10): ... def transition(x_prev, y_curr): ... x_curr = numpyro.sample('x', dist.Normal(x_prev, 1)) ... y_curr = numpyro.sample('y', dist.Normal(x_curr, 1), obs=y_curr) ... return x_curr, (x_curr, y_curr) ... ... x0 = numpyro.sample('x_0', dist.Normal(0, 1)) ... _, (x, y) = scan(transition, x0, y, length=T) ... return (x, y) >>> >>> # here we do some quick tests >>> with numpyro.handlers.seed(rng_seed=0): ... x, y = gaussian_hmm(np.arange(10.)) >>> assert x.shape == (10,) and y.shape == (10,) >>> assert np.all(y == np.arange(10)) >>> >>> with numpyro.handlers.seed(rng_seed=0): # generative ... x, y = gaussian_hmm() >>> assert x.shape == (10,) and y.shape == (10,)
Warning
This is an experimental utility function that allows users to use JAX control flow with NumPyro’s effect handlers. Currently, sample and deterministic sites within the scan body f are supported. If you notice that any effect handlers or distributions are unsupported, please file an issue.
Note
It is ambiguous to align scan dimension inside a plate context. So the following pattern won’t be supported
with numpyro.plate('N', 10): last, ys = scan(f, init, xs)
All plate statements should be put inside f. For example, the corresponding working code is
def g(*args, **kwargs): with numpyro.plate('N', 10): return f(*arg, **kwargs) last, ys = scan(g, init, xs)
Note
Nested scan is currently not supported.
Note
We can scan over discrete latent variables in f. The joint density is evaluated using parallel-scan (reference [1]) over time dimension, which reduces parallel complexity to O(log(length)).
A
trace
of scan with discrete latent variables will contain the following sites:init sites: those sites belong to the first history traces of f. Sites at the i-th trace will have name prefixed with ‘_PREV_’ * (2 * history - 1 - i).
scanned sites: those sites collect the values of the remaining scan loop over f. An addition time dimension _time_foo will be added to those sites, where foo is the name of the first site appeared in f.
Not all transition functions f are supported. All of the restrictions from Pyro’s enumeration tutorial [2] still apply here. In addition, there should not have any site outside of scan depend on the first output of scan (the last carry value).
References
Temporal Parallelization of Bayesian Smoothers, Simo Sarkka, Angel F. Garcia-Fernandez (https://arxiv.org/abs/1905.13002)
Inference with Discrete Latent Variables (http://pyro.ai/examples/enumeration.html#Dependencies-among-plates)
- Parameters
f (callable) – a function to be scanned.
init – the initial carrying state
xs – the values over which we scan along the leading axis. This can be any JAX pytree (e.g. list/dict of arrays).
length – optional value specifying the length of xs but can be used when xs is an empty pytree (e.g. None)
reverse (bool) – optional boolean specifying whether to run the scan iteration forward (the default) or in reverse
history (int) – The number of previous contexts visible from the current context. Defaults to 1. If zero, this is similar to
numpyro.plate
.
- Returns
output of scan, quoted from
jax.lax.scan()
docs: “pair of type (c, [b]) where the first element represents the final loop carry value and the second element represents the stacked outputs of the second output of f when scanned over the leading axis of the inputs”.
cond¶
- cond(pred, true_fun, false_fun, operand)[source]¶
This primitive conditionally applies
true_fun
orfalse_fun
. Seejax.lax.cond()
for more information.Usage:
>>> import numpyro >>> import numpyro.distributions as dist >>> from jax import random >>> from numpyro.contrib.control_flow import cond >>> from numpyro.infer import SVI, Trace_ELBO >>> >>> def model(): ... def true_fun(_): ... return numpyro.sample("x", dist.Normal(20.0)) ... ... def false_fun(_): ... return numpyro.sample("x", dist.Normal(0.0)) ... ... cluster = numpyro.sample("cluster", dist.Normal()) ... return cond(cluster > 0, true_fun, false_fun, None) >>> >>> def guide(): ... m1 = numpyro.param("m1", 10.0) ... s1 = numpyro.param("s1", 0.1, constraint=dist.constraints.positive) ... m2 = numpyro.param("m2", 10.0) ... s2 = numpyro.param("s2", 0.1, constraint=dist.constraints.positive) ... ... def true_fun(_): ... return numpyro.sample("x", dist.Normal(m1, s1)) ... ... def false_fun(_): ... return numpyro.sample("x", dist.Normal(m2, s2)) ... ... cluster = numpyro.sample("cluster", dist.Normal()) ... return cond(cluster > 0, true_fun, false_fun, None) >>> >>> svi = SVI(model, guide, numpyro.optim.Adam(1e-2), Trace_ELBO(num_particles=100)) >>> svi_result = svi.run(random.PRNGKey(0), num_steps=2500)
Warning
This is an experimental utility function that allows users to use JAX control flow with NumPyro’s effect handlers. Currently, sample and deterministic sites within true_fun and false_fun are supported. If you notice that any effect handlers or distributions are unsupported, please file an issue.
Warning
The
cond
primitive does not currently support enumeration and can not be used inside anumpyro.plate
context.Note
All
sample
sites must belong to the same distribution class. For example the following is not supportedcond( True, lambda _: numpyro.sample("x", dist.Normal()), lambda _: numpyro.sample("x", dist.Laplace()), None, )
- Parameters
pred (bool) – Boolean scalar type indicating which branch function to apply
true_fun (callable) – A function to be applied if
pred
is true.false_fun (callable) – A function to be applied if
pred
is false.operand – Operand input to either branch depending on
pred
. This can be any JAX PyTree (e.g. list / dict of arrays).
- Returns
Output of the applied branch function.
Distributions¶
Base Distribution¶
Distribution¶
- class Distribution(*args, **kwargs)[source]¶
Bases:
object
Base class for probability distributions in NumPyro. The design largely follows from
torch.distributions
.- Parameters
batch_shape – The batch shape for the distribution. This designates independent (possibly non-identical) dimensions of a sample from the distribution. This is fixed for a distribution instance and is inferred from the shape of the distribution parameters.
event_shape – The event shape for the distribution. This designates the dependent dimensions of a sample from the distribution. These are collapsed when we evaluate the log probability density of a batch of samples using .log_prob.
validate_args – Whether to enable validation of distribution parameters and arguments to .log_prob method.
As an example:
>>> import jax.numpy as jnp >>> import numpyro.distributions as dist >>> d = dist.Dirichlet(jnp.ones((2, 3, 4))) >>> d.batch_shape (2, 3) >>> d.event_shape (4,)
- arg_constraints = {}¶
- support = None¶
- has_enumerate_support = False¶
- reparametrized_params = []¶
- property batch_shape¶
Returns the shape over which the distribution parameters are batched.
- Returns
batch shape of the distribution.
- Return type
- property event_shape¶
Returns the shape of a single sample from the distribution without batching.
- Returns
event shape of the distribution.
- Return type
- property has_rsample¶
- shape(sample_shape=())[source]¶
The tensor shape of samples from this distribution.
Samples are of shape:
d.shape(sample_shape) == sample_shape + d.batch_shape + d.event_shape
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- sample_with_intermediates(key, sample_shape=())[source]¶
Same as
sample
except that any intermediate computations are returned (useful for TransformedDistribution).- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(value)[source]¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
- to_event(reinterpreted_batch_ndims=None)[source]¶
Interpret the rightmost reinterpreted_batch_ndims batch dimensions as dependent event dimensions.
- Parameters
reinterpreted_batch_ndims – Number of rightmost batch dims to interpret as event dims.
- Returns
An instance of Independent distribution.
- Return type
- enumerate_support(expand=True)[source]¶
Returns an array with shape len(support) x batch_shape containing all values in the support.
- expand(batch_shape)[source]¶
Returns a new
ExpandedDistribution
instance with batch dimensions expanded to batch_shape.- Parameters
batch_shape (tuple) – batch shape to expand to.
- Returns
an instance of ExpandedDistribution.
- Return type
- expand_by(sample_shape)[source]¶
Expands a distribution by adding
sample_shape
to the left side of itsbatch_shape
. To expand internal dims ofself.batch_shape
from 1 to something larger, useexpand()
instead.- Parameters
sample_shape (tuple) – The size of the iid batch to be drawn from the distribution.
- Returns
An expanded version of this distribution.
- Return type
- mask(mask)[source]¶
Masks a distribution by a boolean or boolean-valued array that is broadcastable to the distributions
Distribution.batch_shape
.- Parameters
mask (bool or jnp.ndarray) – A boolean or boolean valued array (True includes a site, False excludes a site).
- Returns
A masked copy of this distribution.
- Return type
Example:
>>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.distributions import constraints >>> from numpyro.infer import SVI, Trace_ELBO >>> def model(data, m): ... f = numpyro.sample("latent_fairness", dist.Beta(1, 1)) ... with numpyro.plate("N", data.shape[0]): ... # only take into account the values selected by the mask ... masked_dist = dist.Bernoulli(f).mask(m) ... numpyro.sample("obs", masked_dist, obs=data) >>> def guide(data, m): ... alpha_q = numpyro.param("alpha_q", 5., constraint=constraints.positive) ... beta_q = numpyro.param("beta_q", 5., constraint=constraints.positive) ... numpyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q)) >>> data = jnp.concatenate([jnp.ones(5), jnp.zeros(5)]) >>> # select values equal to one >>> masked_array = jnp.where(data == 1, True, False) >>> optimizer = numpyro.optim.Adam(step_size=0.05) >>> svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) >>> svi_result = svi.run(random.PRNGKey(0), 300, data, masked_array) >>> params = svi_result.params >>> # inferred_mean is closer to 1 >>> inferred_mean = params["alpha_q"] / (params["alpha_q"] + params["beta_q"])
- classmethod infer_shapes(*args, **kwargs)[source]¶
Infers
batch_shape
andevent_shape
given shapes of args to__init__()
.Note
This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.
- Parameters
*args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
**kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
- Returns
A pair
(batch_shape, event_shape)
of the shapes of a distribution that would be created with input args of the given shapes.- Return type
- cdf(value)[source]¶
The cummulative distribution function of this distribution.
- Parameters
value – samples from this distribution.
- Returns
output of the cummulative distribution function evaluated at value.
- icdf(q)[source]¶
The inverse cumulative distribution function of this distribution.
- Parameters
q – quantile values, should belong to [0, 1].
- Returns
the samples whose cdf values equals to q.
- property is_discrete¶
ExpandedDistribution¶
- class ExpandedDistribution(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
- arg_constraints = {}¶
- property has_enumerate_support¶
bool(x) -> bool
Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.
- property has_rsample¶
- property support¶
- sample_with_intermediates(key, sample_shape=())[source]¶
Same as
sample
except that any intermediate computations are returned (useful for TransformedDistribution).- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(value)[source]¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- enumerate_support(expand=True)[source]¶
Returns an array with shape len(support) x batch_shape containing all values in the support.
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
FoldedDistribution¶
- class FoldedDistribution(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.TransformedDistribution
Equivalent to
TransformedDistribution(base_dist, AbsTransform())
, but additionally supportslog_prob()
.- Parameters
base_dist (Distribution) – A univariate distribution to reflect.
- support = <numpyro.distributions.constraints._GreaterThan object>¶
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
ImproperUniform¶
- class ImproperUniform(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
A helper distribution with zero
log_prob()
over the support domain.Note
sample method is not implemented for this distribution. In autoguide and mcmc, initial parameters for improper sites are derived from init_to_uniform or init_to_value strategies.
Usage:
>>> from numpyro import sample >>> from numpyro.distributions import ImproperUniform, Normal, constraints >>> >>> def model(): ... # ordered vector with length 10 ... x = sample('x', ImproperUniform(constraints.ordered_vector, (), event_shape=(10,))) ... ... # real matrix with shape (3, 4) ... y = sample('y', ImproperUniform(constraints.real, (), event_shape=(3, 4))) ... ... # a shape-(6, 8) batch of length-5 vectors greater than 3 ... z = sample('z', ImproperUniform(constraints.greater_than(3), (6, 8), event_shape=(5,)))
If you want to set improper prior over all values greater than a, where a is another random variable, you might use
>>> def model(): ... a = sample('a', Normal(0, 1)) ... x = sample('x', ImproperUniform(constraints.greater_than(a), (), event_shape=()))
or if you want to reparameterize it
>>> from numpyro.distributions import TransformedDistribution, transforms >>> from numpyro.handlers import reparam >>> from numpyro.infer.reparam import TransformReparam >>> >>> def model(): ... a = sample('a', Normal(0, 1)) ... with reparam(config={'x': TransformReparam()}): ... x = sample('x', ... TransformedDistribution(ImproperUniform(constraints.positive, (), ()), ... transforms.AffineTransform(a, 1)))
- Parameters
support (Constraint) – the support of this distribution.
batch_shape (tuple) – batch shape of this distribution. It is usually safe to set batch_shape=().
event_shape (tuple) – event shape of this distribution.
- arg_constraints = {}¶
- support = <numpyro.distributions.constraints._Dependent object>¶
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
Independent¶
- class Independent(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
Reinterprets batch dimensions of a distribution as event dims by shifting the batch-event dim boundary further to the left.
From a practical standpoint, this is useful when changing the result of
log_prob()
. For example, a univariate Normal distribution can be interpreted as a multivariate Normal with diagonal covariance:>>> import numpyro.distributions as dist >>> normal = dist.Normal(jnp.zeros(3), jnp.ones(3)) >>> [normal.batch_shape, normal.event_shape] [(3,), ()] >>> diag_normal = dist.Independent(normal, 1) >>> [diag_normal.batch_shape, diag_normal.event_shape] [(), (3,)]
- Parameters
base_distribution (numpyro.distribution.Distribution) – a distribution instance.
reinterpreted_batch_ndims (int) – the number of batch dims to reinterpret as event dims.
- arg_constraints = {}¶
- property support¶
- property has_enumerate_support¶
bool(x) -> bool
Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.
- property reparameterized_params¶
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
- property has_rsample¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(value)[source]¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- expand(batch_shape)[source]¶
Returns a new
ExpandedDistribution
instance with batch dimensions expanded to batch_shape.- Parameters
batch_shape (tuple) – batch shape to expand to.
- Returns
an instance of ExpandedDistribution.
- Return type
MaskedDistribution¶
- class MaskedDistribution(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
Masks a distribution by a boolean array that is broadcastable to the distribution’s
Distribution.batch_shape
. In the special casemask is False
, computation oflog_prob()
, is skipped, and constant zero values are returned instead.- Parameters
mask (jnp.ndarray or bool) – A boolean or boolean-valued array.
- arg_constraints = {}¶
- property has_enumerate_support¶
bool(x) -> bool
Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.
- property has_rsample¶
- property support¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(value)[source]¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- enumerate_support(expand=True)[source]¶
Returns an array with shape len(support) x batch_shape containing all values in the support.
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
TransformedDistribution¶
- class TransformedDistribution(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
Returns a distribution instance obtained as a result of applying a sequence of transforms to a base distribution. For an example, see
LogNormal
andHalfNormal
.- Parameters
base_distribution – the base distribution over which to apply transforms.
transforms – a single transform or a list of transforms.
validate_args – Whether to enable validation of distribution parameters and arguments to .log_prob method.
- arg_constraints = {}¶
- property has_rsample¶
- property support¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- sample_with_intermediates(key, sample_shape=())[source]¶
Same as
sample
except that any intermediate computations are returned (useful for TransformedDistribution).- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
Delta¶
- class Delta(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
- arg_constraints = {'log_density': <numpyro.distributions.constraints._Real object>, 'v': <numpyro.distributions.constraints._Dependent object>}¶
- reparametrized_params = ['v', 'log_density']¶
- property support¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
Unit¶
- class Unit(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
Trivial nonnormalized distribution representing the unit type.
The unit type has a single value with no data, i.e.
value.size == 0
.This is used for
numpyro.factor()
statements.- arg_constraints = {'log_factor': <numpyro.distributions.constraints._Real object>}¶
- support = <numpyro.distributions.constraints._Real object>¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
Continuous Distributions¶
Beta¶
- class Beta(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
- arg_constraints = {'concentration0': <numpyro.distributions.constraints._GreaterThan object>, 'concentration1': <numpyro.distributions.constraints._GreaterThan object>}¶
- reparametrized_params = ['concentration1', 'concentration0']¶
- support = <numpyro.distributions.constraints._Interval object>¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
BetaProportion¶
- class BetaProportion(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.continuous.Beta
The BetaProportion distribution is a reparameterization of the conventional Beta distribution in terms of a the variate mean and a precision parameter.
- Reference:
- Beta regression for modelling rates and proportion, Ferrari Silvia, and
Francisco Cribari-Neto. Journal of Applied Statistics 31.7 (2004): 799-815.
- arg_constraints = {'concentration': <numpyro.distributions.constraints._GreaterThan object>, 'mean': <numpyro.distributions.constraints._Interval object>}¶
- reparametrized_params = ['mean', 'concentration']¶
- support = <numpyro.distributions.constraints._Interval object>¶
Cauchy¶
- class Cauchy(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
- arg_constraints = {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
- support = <numpyro.distributions.constraints._Real object>¶
- reparametrized_params = ['loc', 'scale']¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
Chi2¶
- class Chi2(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.continuous.Gamma
- arg_constraints = {'df': <numpyro.distributions.constraints._GreaterThan object>}¶
- reparametrized_params = ['df']¶
Dirichlet¶
- class Dirichlet(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
- arg_constraints = {'concentration': <numpyro.distributions.constraints._IndependentConstraint object>}¶
- reparametrized_params = ['concentration']¶
- support = <numpyro.distributions.constraints._Simplex object>¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
- static infer_shapes(concentration)[source]¶
Infers
batch_shape
andevent_shape
given shapes of args to__init__()
.Note
This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.
- Parameters
*args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
**kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
- Returns
A pair
(batch_shape, event_shape)
of the shapes of a distribution that would be created with input args of the given shapes.- Return type
Exponential¶
- class Exponential(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
- reparametrized_params = ['rate']¶
- arg_constraints = {'rate': <numpyro.distributions.constraints._GreaterThan object>}¶
- support = <numpyro.distributions.constraints._GreaterThan object>¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
Gamma¶
- class Gamma(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
- arg_constraints = {'concentration': <numpyro.distributions.constraints._GreaterThan object>, 'rate': <numpyro.distributions.constraints._GreaterThan object>}¶
- support = <numpyro.distributions.constraints._GreaterThan object>¶
- reparametrized_params = ['concentration', 'rate']¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
Gumbel¶
- class Gumbel(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
- arg_constraints = {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
- support = <numpyro.distributions.constraints._Real object>¶
- reparametrized_params = ['loc', 'scale']¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
GaussianRandomWalk¶
- class GaussianRandomWalk(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
- arg_constraints = {'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
- support = <numpyro.distributions.constraints._IndependentConstraint object>¶
- reparametrized_params = ['scale']¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
HalfCauchy¶
- class HalfCauchy(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
- reparametrized_params = ['scale']¶
- support = <numpyro.distributions.constraints._GreaterThan object>¶
- arg_constraints = {'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- cdf(value)[source]¶
The cummulative distribution function of this distribution.
- Parameters
value – samples from this distribution.
- Returns
output of the cummulative distribution function evaluated at value.
- icdf(q)[source]¶
The inverse cumulative distribution function of this distribution.
- Parameters
q – quantile values, should belong to [0, 1].
- Returns
the samples whose cdf values equals to q.
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
HalfNormal¶
- class HalfNormal(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
- reparametrized_params = ['scale']¶
- support = <numpyro.distributions.constraints._GreaterThan object>¶
- arg_constraints = {'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- cdf(value)[source]¶
The cummulative distribution function of this distribution.
- Parameters
value – samples from this distribution.
- Returns
output of the cummulative distribution function evaluated at value.
- icdf(q)[source]¶
The inverse cumulative distribution function of this distribution.
- Parameters
q – quantile values, should belong to [0, 1].
- Returns
the samples whose cdf values equals to q.
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
InverseGamma¶
- class InverseGamma(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.TransformedDistribution
Note
We keep the same notation rate as in Pyro but it plays the role of scale parameter of InverseGamma in literatures (e.g. wikipedia: https://en.wikipedia.org/wiki/Inverse-gamma_distribution)
- arg_constraints = {'concentration': <numpyro.distributions.constraints._GreaterThan object>, 'rate': <numpyro.distributions.constraints._GreaterThan object>}¶
- reparametrized_params = ['concentration', 'rate']¶
- support = <numpyro.distributions.constraints._GreaterThan object>¶
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
Laplace¶
- class Laplace(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
- arg_constraints = {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
- support = <numpyro.distributions.constraints._Real object>¶
- reparametrized_params = ['loc', 'scale']¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
LKJ¶
- class LKJ(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.TransformedDistribution
LKJ distribution for correlation matrices. The distribution is controlled by
concentration
parameter \(\eta\) to make the probability of the correlation matrix \(M\) propotional to \(\det(M)^{\eta - 1}\). Because of that, whenconcentration == 1
, we have a uniform distribution over correlation matrices.When
concentration > 1
, the distribution favors samples with large large determinent. This is useful when we know a priori that the underlying variables are not correlated.When
concentration < 1
, the distribution favors samples with small determinent. This is useful when we know a priori that some underlying variables are correlated.Sample code for using LKJ in the context of multivariate normal sample:
def model(y): # y has dimension N x d d = y.shape[1] N = y.shape[0] # Vector of variances for each of the d variables theta = numpyro.sample("theta", dist.HalfCauchy(jnp.ones(d))) concentration = jnp.ones(1) # Implies a uniform distribution over correlation matrices corr_mat = numpyro.sample("corr_mat", dist.LKJ(d, concentration)) sigma = jnp.sqrt(theta) # we can also use a faster formula `cov_mat = jnp.outer(theta, theta) * corr_mat` cov_mat = jnp.matmul(jnp.matmul(jnp.diag(sigma), corr_mat), jnp.diag(sigma)) # Vector of expectations mu = jnp.zeros(d) with numpyro.plate("observations", N): obs = numpyro.sample("obs", dist.MultivariateNormal(mu, covariance_matrix=cov_mat), obs=y) return obs
- Parameters
dimension (int) – dimension of the matrices
concentration (ndarray) – concentration/shape parameter of the distribution (often referred to as eta)
sample_method (str) – Either “cvine” or “onion”. Both methods are proposed in [1] and offer the same distribution over correlation matrices. But they are different in how to generate samples. Defaults to “onion”.
References
[1] Generating random correlation matrices based on vines and extended onion method, Daniel Lewandowski, Dorota Kurowicka, Harry Joe
- arg_constraints = {'concentration': <numpyro.distributions.constraints._GreaterThan object>}¶
- reparametrized_params = ['concentration']¶
- support = <numpyro.distributions.constraints._CorrMatrix object>¶
- property mean¶
Mean of the distribution.
LKJCholesky¶
- class LKJCholesky(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
LKJ distribution for lower Cholesky factors of correlation matrices. The distribution is controlled by
concentration
parameter \(\eta\) to make the probability of the correlation matrix \(M\) generated from a Cholesky factor propotional to \(\det(M)^{\eta - 1}\). Because of that, whenconcentration == 1
, we have a uniform distribution over Cholesky factors of correlation matrices.When
concentration > 1
, the distribution favors samples with large diagonal entries (hence large determinent). This is useful when we know a priori that the underlying variables are not correlated.When
concentration < 1
, the distribution favors samples with small diagonal entries (hence small determinent). This is useful when we know a priori that some underlying variables are correlated.Sample code for using LKJCholesky in the context of multivariate normal sample:
def model(y): # y has dimension N x d d = y.shape[1] N = y.shape[0] # Vector of variances for each of the d variables theta = numpyro.sample("theta", dist.HalfCauchy(jnp.ones(d))) # Lower cholesky factor of a correlation matrix concentration = jnp.ones(1) # Implies a uniform distribution over correlation matrices L_omega = numpyro.sample("L_omega", dist.LKJCholesky(d, concentration)) # Lower cholesky factor of the covariance matrix sigma = jnp.sqrt(theta) # we can also use a faster formula `L_Omega = sigma[..., None] * L_omega` L_Omega = jnp.matmul(jnp.diag(sigma), L_omega) # Vector of expectations mu = jnp.zeros(d) with numpyro.plate("observations", N): obs = numpyro.sample("obs", dist.MultivariateNormal(mu, scale_tril=L_Omega), obs=y) return obs
- Parameters
dimension (int) – dimension of the matrices
concentration (ndarray) – concentration/shape parameter of the distribution (often referred to as eta)
sample_method (str) – Either “cvine” or “onion”. Both methods are proposed in [1] and offer the same distribution over correlation matrices. But they are different in how to generate samples. Defaults to “onion”.
References
[1] Generating random correlation matrices based on vines and extended onion method, Daniel Lewandowski, Dorota Kurowicka, Harry Joe
- arg_constraints = {'concentration': <numpyro.distributions.constraints._GreaterThan object>}¶
- reparametrized_params = ['concentration']¶
- support = <numpyro.distributions.constraints._CorrCholesky object>¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
LogNormal¶
- class LogNormal(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.TransformedDistribution
- arg_constraints = {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
- support = <numpyro.distributions.constraints._GreaterThan object>¶
- reparametrized_params = ['loc', 'scale']¶
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
Logistic¶
- class Logistic(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
- arg_constraints = {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
- support = <numpyro.distributions.constraints._Real object>¶
- reparametrized_params = ['loc', 'scale']¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
MultivariateNormal¶
- class MultivariateNormal(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
- arg_constraints = {'covariance_matrix': <numpyro.distributions.constraints._PositiveDefinite object>, 'loc': <numpyro.distributions.constraints._IndependentConstraint object>, 'precision_matrix': <numpyro.distributions.constraints._PositiveDefinite object>, 'scale_tril': <numpyro.distributions.constraints._LowerCholesky object>}¶
- support = <numpyro.distributions.constraints._IndependentConstraint object>¶
- reparametrized_params = ['loc', 'covariance_matrix', 'precision_matrix', 'scale_tril']¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
- static infer_shapes(loc=(), covariance_matrix=None, precision_matrix=None, scale_tril=None)[source]¶
Infers
batch_shape
andevent_shape
given shapes of args to__init__()
.Note
This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.
- Parameters
*args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
**kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
- Returns
A pair
(batch_shape, event_shape)
of the shapes of a distribution that would be created with input args of the given shapes.- Return type
LowRankMultivariateNormal¶
- class LowRankMultivariateNormal(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
- arg_constraints = {'cov_diag': <numpyro.distributions.constraints._IndependentConstraint object>, 'cov_factor': <numpyro.distributions.constraints._IndependentConstraint object>, 'loc': <numpyro.distributions.constraints._IndependentConstraint object>}¶
- support = <numpyro.distributions.constraints._IndependentConstraint object>¶
- reparametrized_params = ['loc', 'cov_factor', 'cov_diag']¶
- property mean¶
Mean of the distribution.
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- static infer_shapes(loc, cov_factor, cov_diag)[source]¶
Infers
batch_shape
andevent_shape
given shapes of args to__init__()
.Note
This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.
- Parameters
*args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
**kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
- Returns
A pair
(batch_shape, event_shape)
of the shapes of a distribution that would be created with input args of the given shapes.- Return type
Normal¶
- class Normal(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
- arg_constraints = {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
- support = <numpyro.distributions.constraints._Real object>¶
- reparametrized_params = ['loc', 'scale']¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- cdf(value)[source]¶
The cummulative distribution function of this distribution.
- Parameters
value – samples from this distribution.
- Returns
output of the cummulative distribution function evaluated at value.
- icdf(q)[source]¶
The inverse cumulative distribution function of this distribution.
- Parameters
q – quantile values, should belong to [0, 1].
- Returns
the samples whose cdf values equals to q.
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
Pareto¶
- class Pareto(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.TransformedDistribution
- arg_constraints = {'alpha': <numpyro.distributions.constraints._GreaterThan object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
- reparametrized_params = ['scale', 'alpha']¶
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
- property support¶
- cdf(value)[source]¶
The cummulative distribution function of this distribution.
- Parameters
value – samples from this distribution.
- Returns
output of the cummulative distribution function evaluated at value.
SoftLaplace¶
- class SoftLaplace(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
Smooth distribution with Laplace-like tail behavior.
This distribution corresponds to the log-convex density:
z = (value - loc) / scale log_prob = log(2 / pi) - log(scale) - logaddexp(z, -z)
Like the Laplace density, this density has the heaviest possible tails (asymptotically) while still being log-convex. Unlike the Laplace distribution, this distribution is infinitely differentiable everywhere, and is thus suitable for HMC and Laplace approximation.
- Parameters
loc – Location parameter.
scale – Scale parameter.
- arg_constraints = {'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
- support = <numpyro.distributions.constraints._Real object>¶
- reparametrized_params = ['loc', 'scale']¶
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- cdf(value)[source]¶
The cummulative distribution function of this distribution.
- Parameters
value – samples from this distribution.
- Returns
output of the cummulative distribution function evaluated at value.
- icdf(value)[source]¶
The inverse cumulative distribution function of this distribution.
- Parameters
q – quantile values, should belong to [0, 1].
- Returns
the samples whose cdf values equals to q.
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
StudentT¶
- class StudentT(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
- arg_constraints = {'df': <numpyro.distributions.constraints._GreaterThan object>, 'loc': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
- support = <numpyro.distributions.constraints._Real object>¶
- reparametrized_params = ['df', 'loc', 'scale']¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
Uniform¶
- class Uniform(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
- arg_constraints = {'high': <numpyro.distributions.constraints._Dependent object>, 'low': <numpyro.distributions.constraints._Dependent object>}¶
- reparametrized_params = ['low', 'high']¶
- property support¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- cdf(value)[source]¶
The cummulative distribution function of this distribution.
- Parameters
value – samples from this distribution.
- Returns
output of the cummulative distribution function evaluated at value.
- icdf(value)[source]¶
The inverse cumulative distribution function of this distribution.
- Parameters
q – quantile values, should belong to [0, 1].
- Returns
the samples whose cdf values equals to q.
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
- static infer_shapes(low=(), high=())[source]¶
Infers
batch_shape
andevent_shape
given shapes of args to__init__()
.Note
This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.
- Parameters
*args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
**kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
- Returns
A pair
(batch_shape, event_shape)
of the shapes of a distribution that would be created with input args of the given shapes.- Return type
Weibull¶
- class Weibull(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
- arg_constraints = {'concentration': <numpyro.distributions.constraints._GreaterThan object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
- support = <numpyro.distributions.constraints._GreaterThan object>¶
- reparametrized_params = ['scale', 'concentration']¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- cdf(value)[source]¶
The cummulative distribution function of this distribution.
- Parameters
value – samples from this distribution.
- Returns
output of the cummulative distribution function evaluated at value.
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
Discrete Distributions¶
BernoulliLogits¶
- class BernoulliLogits(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
- arg_constraints = {'logits': <numpyro.distributions.constraints._Real object>}¶
- support = <numpyro.distributions.constraints._Boolean object>¶
- has_enumerate_support = True¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
BernoulliProbs¶
- class BernoulliProbs(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
- arg_constraints = {'probs': <numpyro.distributions.constraints._Interval object>}¶
- support = <numpyro.distributions.constraints._Boolean object>¶
- has_enumerate_support = True¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
BetaBinomial¶
- class BetaBinomial(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
Compound distribution comprising of a beta-binomial pair. The probability of success (
probs
for theBinomial
distribution) is unknown and randomly drawn from aBeta
distribution prior to a certain number of Bernoulli trials given bytotal_count
.- Parameters
concentration1 (numpy.ndarray) – 1st concentration parameter (alpha) for the Beta distribution.
concentration0 (numpy.ndarray) – 2nd concentration parameter (beta) for the Beta distribution.
total_count (numpy.ndarray) – number of Bernoulli trials.
- arg_constraints = {'concentration0': <numpyro.distributions.constraints._GreaterThan object>, 'concentration1': <numpyro.distributions.constraints._GreaterThan object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}¶
- has_enumerate_support = True¶
- enumerate_support(expand=True)¶
Returns an array with shape len(support) x batch_shape containing all values in the support.
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
- property support¶
BinomialLogits¶
- class BinomialLogits(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
- arg_constraints = {'logits': <numpyro.distributions.constraints._Real object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}¶
- has_enumerate_support = True¶
- enumerate_support(expand=True)¶
Returns an array with shape len(support) x batch_shape containing all values in the support.
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
- property support¶
BinomialProbs¶
- class BinomialProbs(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
- arg_constraints = {'probs': <numpyro.distributions.constraints._Interval object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}¶
- has_enumerate_support = True¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
- property support¶
CategoricalLogits¶
- class CategoricalLogits(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
- arg_constraints = {'logits': <numpyro.distributions.constraints._IndependentConstraint object>}¶
- has_enumerate_support = True¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
- property support¶
CategoricalProbs¶
- class CategoricalProbs(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
- arg_constraints = {'probs': <numpyro.distributions.constraints._Simplex object>}¶
- has_enumerate_support = True¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
- property support¶
DirichletMultinomial¶
- class DirichletMultinomial(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
Compound distribution comprising of a dirichlet-multinomial pair. The probability of classes (
probs
for theMultinomial
distribution) is unknown and randomly drawn from aDirichlet
distribution prior to a certain number of Categorical trials given bytotal_count
.- Parameters
concentration (numpy.ndarray) – concentration parameter (alpha) for the Dirichlet distribution.
total_count (numpy.ndarray) – number of Categorical trials.
- arg_constraints = {'concentration': <numpyro.distributions.constraints._IndependentConstraint object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
- property support¶
- static infer_shapes(concentration, total_count=())[source]¶
Infers
batch_shape
andevent_shape
given shapes of args to__init__()
.Note
This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.
- Parameters
*args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
**kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
- Returns
A pair
(batch_shape, event_shape)
of the shapes of a distribution that would be created with input args of the given shapes.- Return type
GammaPoisson¶
- class GammaPoisson(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
Compound distribution comprising of a gamma-poisson pair, also referred to as a gamma-poisson mixture. The
rate
parameter for thePoisson
distribution is unknown and randomly drawn from aGamma
distribution.- Parameters
concentration (numpy.ndarray) – shape parameter (alpha) of the Gamma distribution.
rate (numpy.ndarray) – rate parameter (beta) for the Gamma distribution.
- arg_constraints = {'concentration': <numpyro.distributions.constraints._GreaterThan object>, 'rate': <numpyro.distributions.constraints._GreaterThan object>}¶
- support = <numpyro.distributions.constraints._IntegerGreaterThan object>¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
GeometricLogits¶
- class GeometricLogits(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
- arg_constraints = {'logits': <numpyro.distributions.constraints._Real object>}¶
- support = <numpyro.distributions.constraints._IntegerGreaterThan object>¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
GeometricProbs¶
- class GeometricProbs(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
- arg_constraints = {'probs': <numpyro.distributions.constraints._Interval object>}¶
- support = <numpyro.distributions.constraints._IntegerGreaterThan object>¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
MultinomialLogits¶
- class MultinomialLogits(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
- arg_constraints = {'logits': <numpyro.distributions.constraints._IndependentConstraint object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
- property support¶
- static infer_shapes(logits, total_count)[source]¶
Infers
batch_shape
andevent_shape
given shapes of args to__init__()
.Note
This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.
- Parameters
*args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
**kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
- Returns
A pair
(batch_shape, event_shape)
of the shapes of a distribution that would be created with input args of the given shapes.- Return type
MultinomialProbs¶
- class MultinomialProbs(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
- arg_constraints = {'probs': <numpyro.distributions.constraints._Simplex object>, 'total_count': <numpyro.distributions.constraints._IntegerGreaterThan object>}¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
- property support¶
- static infer_shapes(probs, total_count)[source]¶
Infers
batch_shape
andevent_shape
given shapes of args to__init__()
.Note
This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.
- Parameters
*args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
**kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
- Returns
A pair
(batch_shape, event_shape)
of the shapes of a distribution that would be created with input args of the given shapes.- Return type
OrderedLogistic¶
- class OrderedLogistic(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.discrete.CategoricalProbs
A categorical distribution with ordered outcomes.
References:
Stan Functions Reference, v2.20 section 12.6, Stan Development Team
- Parameters
predictor (numpy.ndarray) – prediction in real domain; typically this is output of a linear model.
cutpoints (numpy.ndarray) – positions in real domain to separate categories.
- arg_constraints = {'cutpoints': <numpyro.distributions.constraints._OrderedVector object>, 'predictor': <numpyro.distributions.constraints._Real object>}¶
- static infer_shapes(predictor, cutpoints)[source]¶
Infers
batch_shape
andevent_shape
given shapes of args to__init__()
.Note
This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.
- Parameters
*args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
**kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
- Returns
A pair
(batch_shape, event_shape)
of the shapes of a distribution that would be created with input args of the given shapes.- Return type
NegativeBinomial¶
NegativeBinomialLogits¶
- class NegativeBinomialLogits(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.conjugate.GammaPoisson
- arg_constraints = {'logits': <numpyro.distributions.constraints._Real object>, 'total_count': <numpyro.distributions.constraints._GreaterThan object>}¶
- support = <numpyro.distributions.constraints._IntegerGreaterThan object>¶
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
NegativeBinomialProbs¶
- class NegativeBinomialProbs(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.conjugate.GammaPoisson
- arg_constraints = {'probs': <numpyro.distributions.constraints._Interval object>, 'total_count': <numpyro.distributions.constraints._GreaterThan object>}¶
- support = <numpyro.distributions.constraints._IntegerGreaterThan object>¶
NegativeBinomial2¶
- class NegativeBinomial2(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.conjugate.GammaPoisson
Another parameterization of GammaPoisson with rate is replaced by mean.
- arg_constraints = {'concentration': <numpyro.distributions.constraints._GreaterThan object>, 'mean': <numpyro.distributions.constraints._GreaterThan object>}¶
- support = <numpyro.distributions.constraints._IntegerGreaterThan object>¶
Poisson¶
- class Poisson(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
Creates a Poisson distribution parameterized by rate, the rate parameter.
Samples are nonnegative integers, with a pmf given by
\[\mathrm{rate}^k \frac{e^{-\mathrm{rate}}}{k!}\]- Parameters
rate (numpy.ndarray) – The rate parameter
is_sparse (bool) – Whether to assume value is mostly zero when computing
log_prob()
, which can speed up computation when data is sparse.
- arg_constraints = {'rate': <numpyro.distributions.constraints._GreaterThan object>}¶
- support = <numpyro.distributions.constraints._IntegerGreaterThan object>¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
PRNGIdentity¶
- class PRNGIdentity(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
Distribution over
PRNGKey()
. This can be used to draw a batch ofPRNGKey()
using theseed
handler. Only sample method is supported.- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
ZeroInflatedDistribution¶
- ZeroInflatedDistribution(base_dist, *, gate=None, gate_logits=None, validate_args=None)[source]¶
Generic Zero Inflated distribution.
- Parameters
base_dist (Distribution) – the base distribution.
gate (numpy.ndarray) – probability of extra zeros given via a Bernoulli distribution.
gate_logits (numpy.ndarray) – logits of extra zeros given via a Bernoulli distribution.
ZeroInflatedPoisson¶
- class ZeroInflatedPoisson(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.discrete.ZeroInflatedProbs
A Zero Inflated Poisson distribution.
- Parameters
gate (numpy.ndarray) – probability of extra zeros.
rate (numpy.ndarray) – rate of Poisson distribution.
- arg_constraints = {'gate': <numpyro.distributions.constraints._Interval object>, 'rate': <numpyro.distributions.constraints._GreaterThan object>}¶
- support = <numpyro.distributions.constraints._IntegerGreaterThan object>¶
Mixture Distributions¶
MixtureSameFamily¶
- class MixtureSameFamily(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
Marginalized Finite Mixture distribution of vectorized components.
The components being a vectorized distribution implies that all components are from the same family, represented by a single Distribution object.
- Parameters
mixing_distribution (numpyro.distribution.Distribution) – The mixing distribution to select the components. Needs to be a categorical.
component_distribution (numpyro.distribution.Distribution) – Vectorized component distribution.
As an example:
>>> import jax >>> import jax.numpy as jnp >>> import numpyro.distributions as dist >>> mixing_dist = dist.Categorical(probs=jnp.ones(3) / 3.) >>> component_dist = dist.Normal(loc=jnp.zeros(3), scale=jnp.ones(3)) >>> mixture = dist.MixtureSameFamily(mixing_dist, component_dist) >>> mixture.sample(jax.random.PRNGKey(42)).shape ()
- property mixture_size¶
Returns the number of distributions in the mixture
- Returns
number of mixtures.
- Return type
- property mixing_distribution¶
Returns the mixing distribution
- Returns
Categorical distribution
- Return type
Categorical
- property mixture_dim¶
- property component_distribution¶
Return the vectorized distribution of components being mixed.
- Returns
Component distribution
- Return type
- property support¶
- property is_discrete¶
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
- cdf(samples)[source]¶
The cumulative distribution function of this mixture distribution.
- Parameters
value – samples from this distribution.
- Returns
output of the cummulative distribution function evaluated at value.
- Raises
NotImplementedError if the component distribution does not implement the cdf method.
- sample_with_intermediates(key, sample_shape=())[source]¶
Same as
sample
except that the sampled mixture components are also returned.
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
Directional Distributions¶
ProjectedNormal¶
- class ProjectedNormal(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
Projected isotropic normal distribution of arbitrary dimension.
This distribution over directional data is qualitatively similar to the von Mises and von Mises-Fisher distributions, but permits tractable variational inference via reparametrized gradients.
To use this distribution with autoguides and HMC, use
handlers.reparam
with aProjectedNormalReparam
reparametrizer in the model, e.g.:@handlers.reparam(config={"direction": ProjectedNormalReparam()}) def model(): direction = numpyro.sample("direction", ProjectedNormal(zeros(3))) ...
Note
This implements
log_prob()
only for dimensions {2,3}.- [1] D. Hernandez-Stumpfhauser, F.J. Breidt, M.J. van der Woerd (2017)
“The General Projected Normal Distribution of Arbitrary Dimension: Modeling and Bayesian Inference” https://projecteuclid.org/euclid.ba/1453211962
- arg_constraints = {'concentration': <numpyro.distributions.constraints._IndependentConstraint object>}¶
- reparametrized_params = ['concentration']¶
- support = <numpyro.distributions.constraints._Sphere object>¶
- property mean¶
Note this is the mean in the sense of a centroid in the submanifold that minimizes expected squared geodesic distance.
- property mode¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(value)[source]¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- static infer_shapes(concentration)[source]¶
Infers
batch_shape
andevent_shape
given shapes of args to__init__()
.Note
This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.
- Parameters
*args – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.
**kwargs – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.
- Returns
A pair
(batch_shape, event_shape)
of the shapes of a distribution that would be created with input args of the given shapes.- Return type
SineBivariateVonMises¶
- class SineBivariateVonMises(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
Unimodal distribution of two dependent angles on the 2-torus (S^1 ⨂ S^1) given by
\[C^{-1}\exp(\kappa_1\cos(x_1-\mu_1) + \kappa_2\cos(x_2 -\mu_2) + \rho\sin(x_1 - \mu_1)\sin(x_2 - \mu_2))\]and
\[C = (2\pi)^2 \sum_{i=0} {2i \choose i} \left(\frac{\rho^2}{4\kappa_1\kappa_2}\right)^i I_i(\kappa_1)I_i(\kappa_2),\]where \(I_i(\cdot)\) is the modified bessel function of first kind, mu’s are the locations of the distribution, kappa’s are the concentration and rho gives the correlation between angles \(x_1\) and \(x_2\). This distribution is helpful for modeling coupled angles such as torsion angles in peptide chains.
To infer parameters, use
NUTS
orHMC
with priors that avoid parameterizations where the distribution becomes bimodal; see note below.Note
Sample efficiency drops as
\[\frac{\rho}{\kappa_1\kappa_2} \rightarrow 1\]because the distribution becomes increasingly bimodal.
Note
The correlation and weighted_correlation params are mutually exclusive.
Note
In the context of
SVI
, this distribution can be used as a likelihood but not for latent variables.- ** References: **
Probabilistic model for two dependent circular variables Singh, H., Hnizdo, V., and Demchuck, E. (2002)
- Parameters
phi_loc (np.ndarray) – location of first angle
psi_loc (np.ndarray) – location of second angle
phi_concentration (np.ndarray) – concentration of first angle
psi_concentration (np.ndarray) – concentration of second angle
correlation (np.ndarray) – correlation between the two angles
weighted_correlation (np.ndarray) – set correlation to weigthed_corr * sqrt(phi_conc*psi_conc) to avoid bimodality (see note).
- arg_constraints = {'correlation': <numpyro.distributions.constraints._Real object>, 'phi_concentration': <numpyro.distributions.constraints._GreaterThan object>, 'phi_loc': <numpyro.distributions.constraints._Interval object>, 'psi_concentration': <numpyro.distributions.constraints._GreaterThan object>, 'psi_loc': <numpyro.distributions.constraints._Interval object>}¶
- support = <numpyro.distributions.constraints._IndependentConstraint object>¶
- max_sample_iter = 1000¶
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- sample(key, sample_shape=())[source]¶
- ** References: **
A New Unified Approach for the Simulation of aWide Class of Directional Distributions John T. Kent, Asaad M. Ganeiber & Kanti V. Mardia (2018)
- property mean¶
Computes circular mean of distribution. Note: same as location when mapped to support [-pi, pi]
SineSkewed¶
- class SineSkewed(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
Sine-skewing [1] is a procedure for producing a distribution that breaks pointwise symmetry on a torus distribution. The new distribution is called the Sine Skewed X distribution, where X is the name of the (symmetric) base distribution. Torus distributions are distributions with support on products of circles (i.e., ⨂^d S^1 where S^1=[-pi,pi) ). So, a 0-torus is a point, the 1-torus is a circle, and the 2-torus is commonly associated with the donut shape.
The sine skewed X distribution is parameterized by a weight parameter for each dimension of the event of X. For example with a von Mises distribution over a circle (1-torus), the sine skewed von Mises distribution has one skew parameter. The skewness parameters can be inferred using
HMC
orNUTS
. For example, the following will produce a prior over skewness for the 2-torus,:@numpyro.handlers.reparam(config={'phi_loc': CircularReparam(), 'psi_loc': CircularReparam()}) def model(obs): # Sine priors phi_loc = numpyro.sample('phi_loc', VonMises(pi, 2.)) psi_loc = numpyro.sample('psi_loc', VonMises(-pi / 2, 2.)) phi_conc = numpyro.sample('phi_conc', Beta(1., 1.)) psi_conc = numpyro.sample('psi_conc', Beta(1., 1.)) corr_scale = numpyro.sample('corr_scale', Beta(2., 5.)) # Skewing prior ball_trans = L1BallTransform() skewness = numpyro.sample('skew_phi', Normal(0, 0.5).expand((2,))) skewness = ball_trans(skewness) # constraint sum |skewness_i| <= 1 with numpyro.plate('obs_plate'): sine = SineBivariateVonMises(phi_loc=phi_loc, psi_loc=psi_loc, phi_concentration=70 * phi_conc, psi_concentration=70 * psi_conc, weighted_correlation=corr_scale) return numpyro.sample('phi_psi', SineSkewed(sine, skewness), obs=obs)
To ensure the skewing does not alter the normalization constant of the (sine bivariate von Mises) base distribution the skewness parameters are constraint. The constraint requires the sum of the absolute values of skewness to be less than or equal to one. We can use the
L1BallTransform
to achieve this.In the context of
SVI
, this distribution can freely be used as a likelihood, but use as latent variables it will lead to slow inference for 2 and higher dim toruses. This is because the base_dist cannot be reparameterized.Note
An event in the base distribution must be on a d-torus, so the event_shape must be (d,).
Note
For the skewness parameter, it must hold that the sum of the absolute value of its weights for an event must be less than or equal to one. See eq. 2.1 in [1].
- ** References: **
- Sine-skewed toroidal distributions and their application in protein bioinformatics
Ameijeiras-Alonso, J., Ley, C. (2019)
- Parameters
base_dist (numpyro.distributions.Distribution) – base density on a d-dimensional torus. Supported base distributions include: 1D
VonMises
,SineBivariateVonMises
, 1DProjectedNormal
, andUniform
(-pi, pi).skewness (jax.numpy.array) – skewness of the distribution.
- arg_constraints = {'skewness': <numpyro.distributions.constraints._L1Ball object>}¶
- support = <numpyro.distributions.constraints._IndependentConstraint object>¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(value)[source]¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- property mean¶
Mean of the base distribution
VonMises¶
- class VonMises(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
The von Mises distribution, also known as the circular normal distribution.
This distribution is supported by a circular constraint from -pi to +pi. By default, the circular support behaves like
constraints.interval(-math.pi, math.pi)
. To avoid issues at the boundaries of this interval during sampling, you should reparameterize this distribution usinghandlers.reparam
with aCircularReparam
reparametrizer in the model, e.g.:@handlers.reparam(config={"direction": CircularReparam()}) def model(): direction = numpyro.sample("direction", VonMises(0.0, 4.0)) ...
- arg_constraints = {'concentration': <numpyro.distributions.constraints._GreaterThan object>, 'loc': <numpyro.distributions.constraints._Real object>}¶
- reparametrized_params = ['loc']¶
- support = <numpyro.distributions.constraints._Interval object>¶
- sample(key, sample_shape=())[source]¶
Generate sample from von Mises distribution
- Parameters
key – random number generator key
sample_shape – shape of samples
- Returns
samples from von Mises
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
- property mean¶
Computes circular mean of distribution. NOTE: same as location when mapped to support [-pi, pi]
- property variance¶
Computes circular variance of distribution
Truncated Distributions¶
LeftTruncatedDistribution¶
- class LeftTruncatedDistribution(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
- arg_constraints = {'low': <numpyro.distributions.constraints._Real object>}¶
- reparametrized_params = ['low']¶
- supported_types = (<class 'numpyro.distributions.continuous.Cauchy'>, <class 'numpyro.distributions.continuous.Laplace'>, <class 'numpyro.distributions.continuous.Logistic'>, <class 'numpyro.distributions.continuous.Normal'>, <class 'numpyro.distributions.continuous.SoftLaplace'>, <class 'numpyro.distributions.continuous.StudentT'>)¶
- property support¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
RightTruncatedDistribution¶
- class RightTruncatedDistribution(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
- arg_constraints = {'high': <numpyro.distributions.constraints._Real object>}¶
- reparametrized_params = ['high']¶
- supported_types = (<class 'numpyro.distributions.continuous.Cauchy'>, <class 'numpyro.distributions.continuous.Laplace'>, <class 'numpyro.distributions.continuous.Logistic'>, <class 'numpyro.distributions.continuous.Normal'>, <class 'numpyro.distributions.continuous.SoftLaplace'>, <class 'numpyro.distributions.continuous.StudentT'>)¶
- property support¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
TruncatedCauchy¶
- class TruncatedCauchy(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.truncated.LeftTruncatedDistribution
- arg_constraints = {'loc': <numpyro.distributions.constraints._Real object>, 'low': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
- reparametrized_params = ['low', 'loc', 'scale']¶
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
TruncatedDistribution¶
- TruncatedDistribution(base_dist, low=None, high=None, validate_args=None)[source]¶
A function to generate a truncated distribution.
- Parameters
base_dist – The base distribution to be truncated. This should be a univariate distribution. Currently, only the following distributions are supported: Cauchy, Laplace, Logistic, Normal, and StudentT.
low – the value which is used to truncate the base distribution from below. Setting this parameter to None to not truncate from below.
high – the value which is used to truncate the base distribution from above. Setting this parameter to None to not truncate from above.
TruncatedNormal¶
- class TruncatedNormal(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.truncated.LeftTruncatedDistribution
- arg_constraints = {'loc': <numpyro.distributions.constraints._Real object>, 'low': <numpyro.distributions.constraints._Real object>, 'scale': <numpyro.distributions.constraints._GreaterThan object>}¶
- reparametrized_params = ['low', 'loc', 'scale']¶
- property mean¶
Mean of the distribution.
- property variance¶
Variance of the distribution.
TruncatedPolyaGamma¶
- class TruncatedPolyaGamma(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
- truncation_point = 2.5¶
- num_log_prob_terms = 7¶
- num_gamma_variates = 8¶
- arg_constraints = {}¶
- support = <numpyro.distributions.constraints._Interval object>¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
TwoSidedTruncatedDistribution¶
- class TwoSidedTruncatedDistribution(*args, **kwargs)[source]¶
Bases:
numpyro.distributions.distribution.Distribution
- arg_constraints = {'high': <numpyro.distributions.constraints._Dependent object>, 'low': <numpyro.distributions.constraints._Dependent object>}¶
- reparametrized_params = ['low', 'high']¶
- supported_types = (<class 'numpyro.distributions.continuous.Cauchy'>, <class 'numpyro.distributions.continuous.Laplace'>, <class 'numpyro.distributions.continuous.Logistic'>, <class 'numpyro.distributions.continuous.Normal'>, <class 'numpyro.distributions.continuous.SoftLaplace'>, <class 'numpyro.distributions.continuous.StudentT'>)¶
- property support¶
- sample(key, sample_shape=())[source]¶
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.
- Parameters
key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
sample_shape (tuple) – the sample shape for the distribution.
- Returns
an array of shape sample_shape + batch_shape + event_shape
- Return type
- log_prob(*args, **kwargs)¶
Evaluates the log probability density for a batch of samples given by value.
- Parameters
value – A batch of samples from the distribution.
- Returns
an array with shape value.shape[:-self.event_shape]
- Return type
TensorFlow Distributions¶
Thin wrappers around TensorFlow Probability (TFP) distributions. For details on the TFP distribution interface, see its Distribution docs.
Constraints¶
Constraint¶
- class Constraint[source]¶
Bases:
object
Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.
- is_discrete = False¶
- event_dim = 0¶
dependent¶
- dependent = <numpyro.distributions.constraints._Dependent object>¶
Placeholder for variables whose support depends on other variables. These variables obey no simple coordinate-wise constraints.
- Parameters
is_discrete (bool) – Optional value of
.is_discrete
in case this can be computed statically. If not provided, access to the.is_discrete
attribute will raise a NotImplementedError.event_dim (int) – Optional value of
.event_dim
in case this can be computed statically. If not provided, access to the.event_dim
attribute will raise a NotImplementedError.
greater_than¶
- greater_than(lower_bound)¶
Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.
integer_interval¶
- integer_interval(lower_bound, upper_bound)¶
Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.
integer_greater_than¶
- integer_greater_than(lower_bound)¶
Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.
interval¶
- interval(lower_bound, upper_bound)¶
Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.
less_than¶
- less_than(upper_bound)¶
Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.
multinomial¶
- multinomial(upper_bound)¶
Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.
nonnegative_integer¶
- nonnegative_integer = <numpyro.distributions.constraints._IntegerGreaterThan object>¶
positive_definite¶
- positive_definite = <numpyro.distributions.constraints._PositiveDefinite object>¶
positive_integer¶
- positive_integer = <numpyro.distributions.constraints._IntegerGreaterThan object>¶
positive_ordered_vector¶
- positive_ordered_vector = <numpyro.distributions.constraints._PositiveOrderedVector object>¶
Constrains to a positive real-valued tensor where the elements are monotonically increasing along the event_shape dimension.
real_vector¶
- real_vector = <numpyro.distributions.constraints._IndependentConstraint object>¶
Wraps a constraint by aggregating over
reinterpreted_batch_ndims
-many dims incheck()
, so that an event is valid only if all its independent entries are valid.
scaled_unit_lower_cholesky¶
- scaled_unit_lower_cholesky = <numpyro.distributions.constraints._ScaledUnitLowerCholesky object>¶
softplus_positive¶
- softplus_positive = <numpyro.distributions.constraints._SoftplusPositive object>¶
softplus_lower_cholesky¶
- softplus_lower_cholesky = <numpyro.distributions.constraints._SoftplusLowerCholesky object>¶
Transforms¶
Transform¶
- class Transform[source]¶
Bases:
object
- domain = <numpyro.distributions.constraints._Real object>¶
- codomain = <numpyro.distributions.constraints._Real object>¶
- property event_dim¶
- property inv¶
AbsTransform¶
- class AbsTransform[source]¶
Bases:
numpyro.distributions.transforms.Transform
- domain = <numpyro.distributions.constraints._Real object>¶
- codomain = <numpyro.distributions.constraints._GreaterThan object>¶
AffineTransform¶
- class AffineTransform(loc, scale, domain=<numpyro.distributions.constraints._Real object>)[source]¶
Bases:
numpyro.distributions.transforms.Transform
Note
When scale is a JAX tracer, we always assume that scale > 0 when calculating codomain.
- property codomain¶
CholeskyTransform¶
- class CholeskyTransform[source]¶
Bases:
numpyro.distributions.transforms.Transform
Transform via the mapping \(y = cholesky(x)\), where x is a positive definite matrix.
- domain = <numpyro.distributions.constraints._PositiveDefinite object>¶
- codomain = <numpyro.distributions.constraints._LowerCholesky object>¶
ComposeTransform¶
CorrCholeskyTransform¶
- class CorrCholeskyTransform[source]¶
Bases:
numpyro.distributions.transforms.Transform
Transforms a uncontrained real vector \(x\) with length \(D*(D-1)/2\) into the Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower triangular matrix with positive diagonals and unit Euclidean norm for each row. The transform is processed as follows:
First we convert \(x\) into a lower triangular matrix with the following order:
\[\begin{split}\begin{bmatrix} 1 & 0 & 0 & 0 \\ x_0 & 1 & 0 & 0 \\ x_1 & x_2 & 1 & 0 \\ x_3 & x_4 & x_5 & 1 \end{bmatrix}\end{split}\]2. For each row \(X_i\) of the lower triangular part, we apply a signed version of class
StickBreakingTransform
to transform \(X_i\) into a unit Euclidean length vector using the following steps:Scales into the interval \((-1, 1)\) domain: \(r_i = \tanh(X_i)\).
Transforms into an unsigned domain: \(z_i = r_i^2\).
Applies \(s_i = StickBreakingTransform(z_i)\).
Transforms back into signed domain: \(y_i = (sign(r_i), 1) * \sqrt{s_i}\).
- domain = <numpyro.distributions.constraints._IndependentConstraint object>¶
- codomain = <numpyro.distributions.constraints._CorrCholesky object>¶
CorrMatrixCholeskyTransform¶
- class CorrMatrixCholeskyTransform[source]¶
Bases:
numpyro.distributions.transforms.CholeskyTransform
Transform via the mapping \(y = cholesky(x)\), where x is a correlation matrix.
- domain = <numpyro.distributions.constraints._CorrMatrix object>¶
- codomain = <numpyro.distributions.constraints._CorrCholesky object>¶
ExpTransform¶
InvCholeskyTransform¶
- class InvCholeskyTransform(domain=<numpyro.distributions.constraints._LowerCholesky object>)[source]¶
Bases:
numpyro.distributions.transforms.Transform
Transform via the mapping \(y = x @ x.T\), where x is a lower triangular matrix with positive diagonal.
- property codomain¶
L1BallTransform¶
- class L1BallTransform[source]¶
Bases:
numpyro.distributions.transforms.Transform
Transforms a uncontrained real vector \(x\) into the unit L1 ball.
- domain = <numpyro.distributions.constraints._IndependentConstraint object>¶
- codomain = <numpyro.distributions.constraints._L1Ball object>¶
LowerCholeskyAffine¶
- class LowerCholeskyAffine(loc, scale_tril)[source]¶
Bases:
numpyro.distributions.transforms.Transform
Transform via the mapping \(y = loc + scale\_tril\ @\ x\).
- Parameters
loc – a real vector.
scale_tril – a lower triangular matrix with positive diagonal.
- domain = <numpyro.distributions.constraints._IndependentConstraint object>¶
- codomain = <numpyro.distributions.constraints._IndependentConstraint object>¶
LowerCholeskyTransform¶
- class LowerCholeskyTransform[source]¶
Bases:
numpyro.distributions.transforms.Transform
Transform a real vector to a lower triangular cholesky factor, where the strictly lower triangular submatrix is unconstrained and the diagonal is parameterized with an exponential transform.
- domain = <numpyro.distributions.constraints._IndependentConstraint object>¶
- codomain = <numpyro.distributions.constraints._LowerCholesky object>¶
OrderedTransform¶
- class OrderedTransform[source]¶
Bases:
numpyro.distributions.transforms.Transform
Transform a real vector to an ordered vector.
References:
Stan Reference Manual v2.20, section 10.6, Stan Development Team
- domain = <numpyro.distributions.constraints._IndependentConstraint object>¶
- codomain = <numpyro.distributions.constraints._OrderedVector object>¶
PermuteTransform¶
- class PermuteTransform(permutation)[source]¶
Bases:
numpyro.distributions.transforms.Transform
- domain = <numpyro.distributions.constraints._IndependentConstraint object>¶
- codomain = <numpyro.distributions.constraints._IndependentConstraint object>¶
PowerTransform¶
- class PowerTransform(exponent)[source]¶
Bases:
numpyro.distributions.transforms.Transform
- domain = <numpyro.distributions.constraints._GreaterThan object>¶
- codomain = <numpyro.distributions.constraints._GreaterThan object>¶
ScaledUnitLowerCholeskyTransform¶
- class ScaledUnitLowerCholeskyTransform[source]¶
Bases:
numpyro.distributions.transforms.LowerCholeskyTransform
Like LowerCholeskyTransform this Transform transforms a real vector to a lower triangular cholesky factor. However it does so via a decomposition
\(y = loc + unit\_scale\_tril\ @\ scale\_diag\ @\ x\).
where \(unit\_scale\_tril\) has ones along the diagonal and \(scale\_diag\) is a diagonal matrix with all positive entries that is parameterized with a softplus transform.
- domain = <numpyro.distributions.constraints._IndependentConstraint object>¶
- codomain = <numpyro.distributions.constraints._ScaledUnitLowerCholesky object>¶
SigmoidTransform¶
SimplexToOrderedTransform¶
- class SimplexToOrderedTransform(anchor_point=0.0)[source]¶
Bases:
numpyro.distributions.transforms.Transform
Transform a simplex into an ordered vector (via difference in Logistic CDF between cutpoints) Used in [1] to induce a prior on latent cutpoints via transforming ordered category probabilities.
- Parameters
anchor_point – Anchor point is a nuisance parameter to improve the identifiability of the transform. For simplicity, we assume it is a scalar value, but it is broadcastable x.shape[:-1]. For more details please refer to Section 2.2 in [1]
References:
Ordinal Regression Case Study, section 2.2, M. Betancourt, https://betanalpha.github.io/assets/case_studies/ordinal_regression.html
- domain = <numpyro.distributions.constraints._Simplex object>¶
- codomain = <numpyro.distributions.constraints._OrderedVector object>¶
SoftplusLowerCholeskyTransform¶
- class SoftplusLowerCholeskyTransform[source]¶
Bases:
numpyro.distributions.transforms.Transform
Transform from unconstrained vector to lower-triangular matrices with nonnegative diagonal entries. This is useful for parameterizing positive definite matrices in terms of their Cholesky factorization.
- domain = <numpyro.distributions.constraints._IndependentConstraint object>¶
- codomain = <numpyro.distributions.constraints._SoftplusLowerCholesky object>¶
SoftplusTransform¶
- class SoftplusTransform[source]¶
Bases:
numpyro.distributions.transforms.Transform
Transform from unconstrained space to positive domain via softplus \(y = \log(1 + \exp(x))\). The inverse is computed as \(x = \log(\exp(y) - 1)\).
- domain = <numpyro.distributions.constraints._Real object>¶
- codomain = <numpyro.distributions.constraints._SoftplusPositive object>¶
StickBreakingTransform¶
- class StickBreakingTransform[source]¶
Bases:
numpyro.distributions.transforms.Transform
- domain = <numpyro.distributions.constraints._IndependentConstraint object>¶
- codomain = <numpyro.distributions.constraints._Simplex object>¶
Flows¶
InverseAutoregressiveTransform¶
- class InverseAutoregressiveTransform(autoregressive_nn, log_scale_min_clip=- 5.0, log_scale_max_clip=3.0)[source]¶
Bases:
numpyro.distributions.transforms.Transform
An implementation of Inverse Autoregressive Flow, using Eq (10) from Kingma et al., 2016,
\(\mathbf{y} = \mu_t + \sigma_t\odot\mathbf{x}\)
where \(\mathbf{x}\) are the inputs, \(\mathbf{y}\) are the outputs, \(\mu_t,\sigma_t\) are calculated from an autoregressive network on \(\mathbf{x}\), and \(\sigma_t>0\).
References
Improving Variational Inference with Inverse Autoregressive Flow [arXiv:1606.04934], Diederik P. Kingma, Tim Salimans, Rafal Jozefowicz, Xi Chen, Ilya Sutskever, Max Welling
- domain = <numpyro.distributions.constraints._IndependentConstraint object>¶
- codomain = <numpyro.distributions.constraints._IndependentConstraint object>¶
- log_abs_det_jacobian(x, y, intermediates=None)[source]¶
Calculates the elementwise determinant of the log jacobian.
- Parameters
x (numpy.ndarray) – the input to the transform
y (numpy.ndarray) – the output of the transform
BlockNeuralAutoregressiveTransform¶
- class BlockNeuralAutoregressiveTransform(bn_arn)[source]¶
Bases:
numpyro.distributions.transforms.Transform
An implementation of Block Neural Autoregressive flow.
References
Block Neural Autoregressive Flow, Nicola De Cao, Ivan Titov, Wilker Aziz
- domain = <numpyro.distributions.constraints._IndependentConstraint object>¶
- codomain = <numpyro.distributions.constraints._IndependentConstraint object>¶
- log_abs_det_jacobian(x, y, intermediates=None)[source]¶
Calculates the elementwise determinant of the log jacobian.
- Parameters
x (numpy.ndarray) – the input to the transform
y (numpy.ndarray) – the output of the transform
Inference¶
Markov Chain Monte Carlo (MCMC)¶
- class MCMC(sampler, *, num_warmup, num_samples, num_chains=1, thinning=1, postprocess_fn=None, chain_method='parallel', progress_bar=True, jit_model_args=False)[source]¶
Bases:
object
Provides access to Markov Chain Monte Carlo inference algorithms in NumPyro.
Note
chain_method is an experimental arg, which might be removed in a future version.
Note
Setting progress_bar=False will improve the speed for many cases. But it might require more memory than the other option.
Note
If setting num_chains greater than 1 in a Jupyter Notebook, then you will need to have installed ipywidgets in the environment from which you launced Jupyter in order for the progress bars to render correctly. If you are using Jupyter Notebook or Jupyter Lab, please also install the corresponding extension package like widgetsnbextension or jupyterlab_widgets.
- Parameters
sampler (MCMCKernel) – an instance of
MCMCKernel
that determines the sampler for running MCMC. Currently, onlyHMC
andNUTS
are available.num_warmup (int) – Number of warmup steps.
num_samples (int) – Number of samples to generate from the Markov chain.
thinning (int) – Positive integer that controls the fraction of post-warmup samples that are retained. For example if thinning is 2 then every other sample is retained. Defaults to 1, i.e. no thinning.
num_chains (int) – Number of MCMC chains to run. By default, chains will be run in parallel using
jax.pmap()
. If there are not enough devices available, chains will be run in sequence.postprocess_fn – Post-processing callable - used to convert a collection of unconstrained sample values returned from the sampler to constrained values that lie within the support of the sample sites. Additionally, this is used to return values at deterministic sites in the model.
chain_method (str) – One of ‘parallel’ (default), ‘sequential’, ‘vectorized’. The method ‘parallel’ is used to execute the drawing process in parallel on XLA devices (CPUs/GPUs/TPUs), If there are not enough devices for ‘parallel’, we fall back to ‘sequential’ method to draw chains sequentially. ‘vectorized’ method is an experimental feature which vectorizes the drawing method, hence allowing us to collect samples in parallel on a single device.
progress_bar (bool) – Whether to enable progress bar updates. Defaults to
True
.jit_model_args (bool) – If set to True, this will compile the potential energy computation as a function of model arguments. As such, calling MCMC.run again on a same sized but different dataset will not result in additional compilation cost. Note that currently, this does not take effect for the case
num_chains > 1
andchain_method == 'parallel'
.
- property post_warmup_state¶
The state before the sampling phase. If this attribute is not None,
run()
will skip the warmup phase and start with the state specified in this attribute.Note
This attribute can be used to sequentially draw MCMC samples. For example,
mcmc = MCMC(NUTS(model), num_warmup=100, num_samples=100) mcmc.run(random.PRNGKey(0)) first_100_samples = mcmc.get_samples() mcmc.post_warmup_state = mcmc.last_state mcmc.run(mcmc.post_warmup_state.rng_key) # or mcmc.run(random.PRNGKey(1)) second_100_samples = mcmc.get_samples()
- property last_state¶
The final MCMC state at the end of the sampling phase.
- warmup(rng_key, *args, extra_fields=(), collect_warmup=False, init_params=None, **kwargs)[source]¶
Run the MCMC warmup adaptation phase. After this call, self.warmup_state will be set and the
run()
method will skip the warmup adaptation phase. To run warmup again for the new data, it is required to runwarmup()
again.- Parameters
rng_key (random.PRNGKey) – Random number generator key to be used for the sampling.
args – Arguments to be provided to the
numpyro.infer.mcmc.MCMCKernel.init()
method. These are typically the arguments needed by the model.extra_fields (tuple or list) – Extra fields (aside from
default_fields()
) from the state object (e.g.numpyro.infer.hmc.HMCState
for HMC) to collect during the MCMC run.collect_warmup (bool) – Whether to collect samples from the warmup phase. Defaults to False.
init_params – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
kwargs – Keyword arguments to be provided to the
numpyro.infer.mcmc.MCMCKernel.init()
method. These are typically the keyword arguments needed by the model.
- run(rng_key, *args, extra_fields=(), init_params=None, **kwargs)[source]¶
Run the MCMC samplers and collect samples.
- Parameters
rng_key (random.PRNGKey) – Random number generator key to be used for the sampling. For multi-chains, a batch of num_chains keys can be supplied. If rng_key does not have batch_size, it will be split in to a batch of num_chains keys.
args – Arguments to be provided to the
numpyro.infer.mcmc.MCMCKernel.init()
method. These are typically the arguments needed by the model.extra_fields (tuple or list of str) – Extra fields (aside from “z”, “diverging”) to be collected during the MCMC run. Note that subfields can be accessed using dots, e.g. “adapt_state.step_size” can be used to collect step sizes at each step.
init_params – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
kwargs – Keyword arguments to be provided to the
numpyro.infer.mcmc.MCMCKernel.init()
method. These are typically the keyword arguments needed by the model.
Note
jax allows python code to continue even when the compiled code has not finished yet. This can cause troubles when trying to profile the code for speed. See https://jax.readthedocs.io/en/latest/async_dispatch.html and https://jax.readthedocs.io/en/latest/profiling.html for pointers on profiling jax programs.
- get_samples(group_by_chain=False)[source]¶
Get samples from the MCMC run.
- Parameters
group_by_chain (bool) – Whether to preserve the chain dimension. If True, all samples will have num_chains as the size of their leading dimension.
- Returns
Samples having the same data type as init_params. The data type is a dict keyed on site names if a model containing Pyro primitives is used, but can be any
jaxlib.pytree()
, more generally (e.g. when defining a potential_fn for HMC that takes list args).
Example:
You can then pass those samples to
Predictive
:posterior_samples = mcmc.get_samples() predictive = Predictive(model, posterior_samples=posterior_samples) samples = predictive(rng_key1, *model_args, **model_kwargs)
MCMC Kernels¶
MCMCKernel¶
- class MCMCKernel[source]¶
Bases:
abc.ABC
Defines the interface for the Markov transition kernel that is used for
MCMC
inference.Example:
>>> from collections import namedtuple >>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.infer import MCMC >>> MHState = namedtuple("MHState", ["u", "rng_key"]) >>> class MetropolisHastings(numpyro.infer.mcmc.MCMCKernel): ... sample_field = "u" ... ... def __init__(self, potential_fn, step_size=0.1): ... self.potential_fn = potential_fn ... self.step_size = step_size ... ... def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): ... return MHState(init_params, rng_key) ... ... def sample(self, state, model_args, model_kwargs): ... u, rng_key = state ... rng_key, key_proposal, key_accept = random.split(rng_key, 3) ... u_proposal = dist.Normal(u, self.step_size).sample(key_proposal) ... accept_prob = jnp.exp(self.potential_fn(u) - self.potential_fn(u_proposal)) ... u_new = jnp.where(dist.Uniform().sample(key_accept) < accept_prob, u_proposal, u) ... return MHState(u_new, rng_key) >>> def f(x): ... return ((x - 2) ** 2).sum() >>> kernel = MetropolisHastings(f) >>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000) >>> mcmc.run(random.PRNGKey(0), init_params=jnp.array([1., 2.])) >>> posterior_samples = mcmc.get_samples() >>> mcmc.print_summary()
- postprocess_fn(model_args, model_kwargs)[source]¶
Get a function that transforms unconstrained values at sample sites to values constrained to the site’s support, in addition to returning deterministic sites in the model.
- Parameters
model_args – Arguments to the model.
model_kwargs – Keyword arguments to the model.
- abstract init(rng_key, num_warmup, init_params, model_args, model_kwargs)[source]¶
Initialize the MCMCKernel and return an initial state to begin sampling from.
- Parameters
rng_key (random.PRNGKey) – Random number generator key to initialize the kernel.
num_warmup (int) – Number of warmup steps. This can be useful when doing adaptation during warmup.
init_params (tuple) – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
model_args – Arguments provided to the model.
model_kwargs – Keyword arguments provided to the model.
- Returns
The initial state representing the state of the kernel. This can be any class that is registered as a pytree.
- abstract sample(state, model_args, model_kwargs)[source]¶
Given the current state, return the next state using the given transition kernel.
- property sample_field¶
The attribute of the state object passed to
sample()
that denotes the MCMC sample. This is used bypostprocess_fn()
and for reporting results inMCMC.print_summary()
.
- property default_fields¶
The attributes of the state object to be collected by default during the MCMC run (when
MCMC.run()
is called).
BarkerMH¶
- class BarkerMH(model=None, potential_fn=None, step_size=1.0, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.4, init_strategy=<function init_to_uniform>)[source]¶
Bases:
numpyro.infer.mcmc.MCMCKernel
This is a gradient-based MCMC algorithm of Metropolis-Hastings type that uses a skew-symmetric proposal distribution that depends on the gradient of the potential (the Barker proposal; see reference [1]). In particular the proposal distribution is skewed in the direction of the gradient at the current sample.
We expect this algorithm to be particularly effective for low to moderate dimensional models, where it may be competitive with HMC and NUTS.
Note
We recommend to use this kernel with progress_bar=False in
MCMC
to reduce JAX’s dispatch overhead.References:
The Barker proposal: combining robustness and efficiency in gradient-based MCMC. Samuel Livingstone, Giacomo Zanella.
- Parameters
model – Python callable containing Pyro
primitives
. If model is provided, potential_fn will be inferred using the model.potential_fn – Python callable that computes the potential energy given input parameters. The input parameters to potential_fn can be any python collection type, provided that init_params argument to
init()
has the same type.step_size (float) – (Initial) step size to use in the Barker proposal.
adapt_step_size (bool) – Whether to adapt the step size during warm-up. Defaults to
adapt_step_size==True
.adapt_mass_matrix (bool) – Whether to adapt the mass matrix during warm-up. Defaults to
adapt_mass_matrix==True
.dense_mass (bool) – Whether to use a dense (i.e. full-rank) or diagonal mass matrix. (defaults to
dense_mass=False
).target_accept_prob (float) – The target acceptance probability that is used to guide step size adapation. Defaults to
target_accept_prob=0.4
.init_strategy (callable) – a per-site initialization function. See Initialization Strategies section for available functions.
Example
>>> import jax >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.infer import MCMC, BarkerMH >>> def model(): ... x = numpyro.sample("x", dist.Normal().expand([10])) ... numpyro.sample("obs", dist.Normal(x, 1.0), obs=jnp.ones(10)) >>> >>> kernel = BarkerMH(model) >>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, progress_bar=True) >>> mcmc.run(jax.random.PRNGKey(0)) >>> mcmc.print_summary()
- property model¶
- property sample_field¶
The attribute of the state object passed to
sample()
that denotes the MCMC sample. This is used bypostprocess_fn()
and for reporting results inMCMC.print_summary()
.
- get_diagnostics_str(state)[source]¶
Given the current state, returns the diagnostics string to be added to progress bar for diagnostics purpose.
- init(rng_key, num_warmup, init_params, model_args, model_kwargs)[source]¶
Initialize the MCMCKernel and return an initial state to begin sampling from.
- Parameters
rng_key (random.PRNGKey) – Random number generator key to initialize the kernel.
num_warmup (int) – Number of warmup steps. This can be useful when doing adaptation during warmup.
init_params (tuple) – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
model_args – Arguments provided to the model.
model_kwargs – Keyword arguments provided to the model.
- Returns
The initial state representing the state of the kernel. This can be any class that is registered as a pytree.
- postprocess_fn(args, kwargs)[source]¶
Get a function that transforms unconstrained values at sample sites to values constrained to the site’s support, in addition to returning deterministic sites in the model.
- Parameters
model_args – Arguments to the model.
model_kwargs – Keyword arguments to the model.
HMC¶
- class HMC(model=None, potential_fn=None, kinetic_fn=None, step_size=1.0, inverse_mass_matrix=None, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, trajectory_length=6.283185307179586, init_strategy=<function init_to_uniform>, find_heuristic_step_size=False, forward_mode_differentiation=False, regularize_mass_matrix=True)[source]¶
Bases:
numpyro.infer.mcmc.MCMCKernel
Hamiltonian Monte Carlo inference, using fixed trajectory length, with provision for step size and mass matrix adaptation.
Note
Until the kernel is used in an MCMC run, postprocess_fn will return the identity function.
Note
The default init strategy
init_to_uniform
might not be a good strategy for some models. You might want to try other init strategies likeinit_to_median
.References:
MCMC Using Hamiltonian Dynamics, Radford M. Neal
- Parameters
model – Python callable containing Pyro
primitives
. If model is provided, potential_fn will be inferred using the model.potential_fn – Python callable that computes the potential energy given input parameters. The input parameters to potential_fn can be any python collection type, provided that init_params argument to
init()
has the same type.kinetic_fn – Python callable that returns the kinetic energy given inverse mass matrix and momentum. If not provided, the default is euclidean kinetic energy.
step_size (float) – Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1.
inverse_mass_matrix (numpy.ndarray or dict) – Initial value for inverse mass matrix. This may be adapted during warmup if adapt_mass_matrix = True. If no value is specified, then it is initialized to the identity matrix. For a potential_fn with general JAX pytree parameters, the order of entries of the mass matrix is the order of the flattened version of pytree parameters obtained with jax.tree_flatten, which is a bit ambiguous (see more at https://jax.readthedocs.io/en/latest/pytrees.html). If model is not None, here we can specify a structured block mass matrix as a dictionary, where keys are tuple of site names and values are the corresponding block of the mass matrix. For more information about structured mass matrix, see dense_mass argument.
adapt_step_size (bool) – A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme.
adapt_mass_matrix (bool) – A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme.
This flag controls whether mass matrix is dense (i.e. full-rank) or diagonal (defaults to
dense_mass=False
). To specify a structured mass matrix, users can provide a list of tuples of site names. Each tuple represents a block in the joint mass matrix. For example, assuming that the model has latent variables “x”, “y”, “z” (where each variable can be multi-dimensional), possible specifications and corresponding mass matrix structures are as follows:dense_mass=[(“x”, “y”)]: use a dense mass matrix for the joint (x, y) and a diagonal mass matrix for z
dense_mass=[] (equivalent to dense_mass=False): use a diagonal mass matrix for the joint (x, y, z)
dense_mass=[(“x”, “y”, “z”)] (equivalent to full_mass=True): use a dense mass matrix for the joint (x, y, z)
dense_mass=[(“x”,), (“y”,), (“z”)]: use dense mass matrices for each of x, y, and z (i.e. block-diagonal with 3 blocks)
target_accept_prob (float) – Target acceptance probability for step size adaptation using Dual Averaging. Increasing this value will lead to a smaller step size, hence the sampling will be slower but more robust. Defaults to 0.8.
trajectory_length (float) – Length of a MCMC trajectory for HMC. Default value is \(2\pi\).
init_strategy (callable) – a per-site initialization function. See Initialization Strategies section for available functions.
find_heuristic_step_size (bool) – whether or not to use a heuristic function to adjust the step size at the beginning of each adaptation window. Defaults to False.
forward_mode_differentiation (bool) – whether to use forward-mode differentiation or reverse-mode differentiation. By default, we use reverse mode but the forward mode can be useful in some cases to improve the performance. In addition, some control flow utility on JAX such as jax.lax.while_loop or jax.lax.fori_loop only supports forward-mode differentiation. See JAX’s The Autodiff Cookbook for more information.
regularize_mass_matrix (bool) – whether or not to regularize the estimated mass matrix for numerical stability during warmup phase. Defaults to True. This flag does not take effect if
adapt_mass_matrix == False
.
- property model¶
- property sample_field¶
The attribute of the state object passed to
sample()
that denotes the MCMC sample. This is used bypostprocess_fn()
and for reporting results inMCMC.print_summary()
.
- property default_fields¶
The attributes of the state object to be collected by default during the MCMC run (when
MCMC.run()
is called).
- get_diagnostics_str(state)[source]¶
Given the current state, returns the diagnostics string to be added to progress bar for diagnostics purpose.
- init(rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={})[source]¶
Initialize the MCMCKernel and return an initial state to begin sampling from.
- Parameters
rng_key (random.PRNGKey) – Random number generator key to initialize the kernel.
num_warmup (int) – Number of warmup steps. This can be useful when doing adaptation during warmup.
init_params (tuple) – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
model_args – Arguments provided to the model.
model_kwargs – Keyword arguments provided to the model.
- Returns
The initial state representing the state of the kernel. This can be any class that is registered as a pytree.
- postprocess_fn(args, kwargs)[source]¶
Get a function that transforms unconstrained values at sample sites to values constrained to the site’s support, in addition to returning deterministic sites in the model.
- Parameters
model_args – Arguments to the model.
model_kwargs – Keyword arguments to the model.
- sample(state, model_args, model_kwargs)[source]¶
Run HMC from the given
HMCState
and return the resultingHMCState
.- Parameters
state (HMCState) – Represents the current state.
model_args – Arguments provided to the model.
model_kwargs – Keyword arguments provided to the model.
- Returns
Next state after running HMC.
NUTS¶
- class NUTS(model=None, potential_fn=None, kinetic_fn=None, step_size=1.0, inverse_mass_matrix=None, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8, trajectory_length=None, max_tree_depth=10, init_strategy=<function init_to_uniform>, find_heuristic_step_size=False, forward_mode_differentiation=False, regularize_mass_matrix=True)[source]¶
Bases:
numpyro.infer.hmc.HMC
Hamiltonian Monte Carlo inference, using the No U-Turn Sampler (NUTS) with adaptive path length and mass matrix adaptation.
Note
Until the kernel is used in an MCMC run, postprocess_fn will return the identity function.
Note
The default init strategy
init_to_uniform
might not be a good strategy for some models. You might want to try other init strategies likeinit_to_median
.References:
MCMC Using Hamiltonian Dynamics, Radford M. Neal
The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo, Matthew D. Hoffman, and Andrew Gelman.
A Conceptual Introduction to Hamiltonian Monte Carlo`, Michael Betancourt
- Parameters
model – Python callable containing Pyro
primitives
. If model is provided, potential_fn will be inferred using the model.potential_fn – Python callable that computes the potential energy given input parameters. The input parameters to potential_fn can be any python collection type, provided that init_params argument to init_kernel has the same type.
kinetic_fn – Python callable that returns the kinetic energy given inverse mass matrix and momentum. If not provided, the default is euclidean kinetic energy.
step_size (float) – Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1.
inverse_mass_matrix (numpy.ndarray or dict) – Initial value for inverse mass matrix. This may be adapted during warmup if adapt_mass_matrix = True. If no value is specified, then it is initialized to the identity matrix. For a potential_fn with general JAX pytree parameters, the order of entries of the mass matrix is the order of the flattened version of pytree parameters obtained with jax.tree_flatten, which is a bit ambiguous (see more at https://jax.readthedocs.io/en/latest/pytrees.html). If model is not None, here we can specify a structured block mass matrix as a dictionary, where keys are tuple of site names and values are the corresponding block of the mass matrix. For more information about structured mass matrix, see dense_mass argument.
adapt_step_size (bool) – A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme.
adapt_mass_matrix (bool) – A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme.
This flag controls whether mass matrix is dense (i.e. full-rank) or diagonal (defaults to
dense_mass=False
). To specify a structured mass matrix, users can provide a list of tuples of site names. Each tuple represents a block in the joint mass matrix. For example, assuming that the model has latent variables “x”, “y”, “z” (where each variable can be multi-dimensional), possible specifications and corresponding mass matrix structures are as follows:dense_mass=[(“x”, “y”)]: use a dense mass matrix for the joint (x, y) and a diagonal mass matrix for z
dense_mass=[] (equivalent to dense_mass=False): use a diagonal mass matrix for the joint (x, y, z)
dense_mass=[(“x”, “y”, “z”)] (equivalent to full_mass=True): use a dense mass matrix for the joint (x, y, z)
dense_mass=[(“x”,), (“y”,), (“z”)]: use dense mass matrices for each of x, y, and z (i.e. block-diagonal with 3 blocks)
target_accept_prob (float) – Target acceptance probability for step size adaptation using Dual Averaging. Increasing this value will lead to a smaller step size, hence the sampling will be slower but more robust. Defaults to 0.8.
trajectory_length (float) – Length of a MCMC trajectory for HMC. This arg has no effect in NUTS sampler.
max_tree_depth (int) – Max depth of the binary tree created during the doubling scheme of NUTS sampler. Defaults to 10. This argument also accepts a tuple of integers (d1, d2), where d1 is the max tree depth during warmup phase and d2 is the max tree depth during post warmup phase.
init_strategy (callable) – a per-site initialization function. See Initialization Strategies section for available functions.
find_heuristic_step_size (bool) – whether or not to use a heuristic function to adjust the step size at the beginning of each adaptation window. Defaults to False.
forward_mode_differentiation (bool) –
whether to use forward-mode differentiation or reverse-mode differentiation. By default, we use reverse mode but the forward mode can be useful in some cases to improve the performance. In addition, some control flow utility on JAX such as jax.lax.while_loop or jax.lax.fori_loop only supports forward-mode differentiation. See JAX’s The Autodiff Cookbook for more information.
HMCGibbs¶
- class HMCGibbs(inner_kernel, gibbs_fn, gibbs_sites)[source]¶
Bases:
numpyro.infer.mcmc.MCMCKernel
[EXPERIMENTAL INTERFACE]
HMC-within-Gibbs. This inference algorithm allows the user to combine general purpose gradient-based inference (HMC or NUTS) with custom Gibbs samplers.
Note that it is the user’s responsibility to provide a correct implementation of gibbs_fn that samples from the corresponding posterior conditional.
- Parameters
gibbs_fn – A Python callable that returns a dictionary of Gibbs samples conditioned on the HMC sites. Must include an argument rng_key that should be used for all sampling. Must also include arguments hmc_sites and gibbs_sites, each of which is a dictionary with keys that are site names and values that are sample values. Note that a given gibbs_fn may not need make use of all these sample values.
gibbs_sites (list) – a list of site names for the latent variables that are covered by the Gibbs sampler.
Example
>>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.infer import MCMC, NUTS, HMCGibbs ... >>> def model(): ... x = numpyro.sample("x", dist.Normal(0.0, 2.0)) ... y = numpyro.sample("y", dist.Normal(0.0, 2.0)) ... numpyro.sample("obs", dist.Normal(x + y, 1.0), obs=jnp.array([1.0])) ... >>> def gibbs_fn(rng_key, gibbs_sites, hmc_sites): ... y = hmc_sites['y'] ... new_x = dist.Normal(0.8 * (1-y), jnp.sqrt(0.8)).sample(rng_key) ... return {'x': new_x} ... >>> hmc_kernel = NUTS(model) >>> kernel = HMCGibbs(hmc_kernel, gibbs_fn=gibbs_fn, gibbs_sites=['x']) >>> mcmc = MCMC(kernel, num_warmup=100, num_samples=100, progress_bar=False) >>> mcmc.run(random.PRNGKey(0)) >>> mcmc.print_summary()
- sample_field = 'z'¶
- property model¶
- get_diagnostics_str(state)[source]¶
Given the current state, returns the diagnostics string to be added to progress bar for diagnostics purpose.
- postprocess_fn(args, kwargs)[source]¶
Get a function that transforms unconstrained values at sample sites to values constrained to the site’s support, in addition to returning deterministic sites in the model.
- Parameters
model_args – Arguments to the model.
model_kwargs – Keyword arguments to the model.
- init(rng_key, num_warmup, init_params, model_args, model_kwargs)[source]¶
Initialize the MCMCKernel and return an initial state to begin sampling from.
- Parameters
rng_key (random.PRNGKey) – Random number generator key to initialize the kernel.
num_warmup (int) – Number of warmup steps. This can be useful when doing adaptation during warmup.
init_params (tuple) – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
model_args – Arguments provided to the model.
model_kwargs – Keyword arguments provided to the model.
- Returns
The initial state representing the state of the kernel. This can be any class that is registered as a pytree.
DiscreteHMCGibbs¶
- class DiscreteHMCGibbs(inner_kernel, *, random_walk=False, modified=False)[source]¶
Bases:
numpyro.infer.hmc_gibbs.HMCGibbs
[EXPERIMENTAL INTERFACE]
A subclass of
HMCGibbs
which performs Metropolis updates for discrete latent sites.Note
The site update order is randomly permuted at each step.
Note
This class supports enumeration of discrete latent variables. To marginalize out a discrete latent site, we can specify infer={‘enumerate’: ‘parallel’} keyword in its corresponding
sample()
statement.- Parameters
random_walk (bool) – If False, Gibbs sampling will be used to draw a sample from the conditional p(gibbs_site | remaining sites). Otherwise, a sample will be drawn uniformly from the domain of gibbs_site. Defaults to False.
modified (bool) – whether to use a modified proposal, as suggested in reference [1], which always proposes a new state for the current Gibbs site. Defaults to False. The modified scheme appears in the literature under the name “modified Gibbs sampler” or “Metropolised Gibbs sampler”.
References:
Peskun’s theorem and a modified discrete-state Gibbs sampler, Liu, J. S. (1996)
Example
>>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.infer import DiscreteHMCGibbs, MCMC, NUTS ... >>> def model(probs, locs): ... c = numpyro.sample("c", dist.Categorical(probs)) ... numpyro.sample("x", dist.Normal(locs[c], 0.5)) ... >>> probs = jnp.array([0.15, 0.3, 0.3, 0.25]) >>> locs = jnp.array([-2, 0, 2, 4]) >>> kernel = DiscreteHMCGibbs(NUTS(model), modified=True) >>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=100000, progress_bar=False) >>> mcmc.run(random.PRNGKey(0), probs, locs) >>> mcmc.print_summary() >>> samples = mcmc.get_samples()["x"] >>> assert abs(jnp.mean(samples) - 1.3) < 0.1 >>> assert abs(jnp.var(samples) - 4.36) < 0.5
- init(rng_key, num_warmup, init_params, model_args, model_kwargs)[source]¶
Initialize the MCMCKernel and return an initial state to begin sampling from.
- Parameters
rng_key (random.PRNGKey) – Random number generator key to initialize the kernel.
num_warmup (int) – Number of warmup steps. This can be useful when doing adaptation during warmup.
init_params (tuple) – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
model_args – Arguments provided to the model.
model_kwargs – Keyword arguments provided to the model.
- Returns
The initial state representing the state of the kernel. This can be any class that is registered as a pytree.
MixedHMC¶
- class MixedHMC(inner_kernel, *, num_discrete_updates=None, random_walk=False, modified=False)[source]¶
Bases:
numpyro.infer.hmc_gibbs.DiscreteHMCGibbs
Implementation of Mixed Hamiltonian Monte Carlo (reference [1]).
Note
The number of discrete sites to update at each MCMC iteration (n_D in reference [1]) is fixed at value 1.
References
Mixed Hamiltonian Monte Carlo for Mixed Discrete and Continuous Variables, Guangyao Zhou (2020)
Peskun’s theorem and a modified discrete-state Gibbs sampler, Liu, J. S. (1996)
- Parameters
inner_kernel – A
HMC
kernel.num_discrete_updates (int) – Number of times to update discrete variables. Defaults to the number of discrete latent variables.
random_walk (bool) – If False, Gibbs sampling will be used to draw a sample from the conditional p(gibbs_site | remaining sites), where gibbs_site is one of the discrete sample sites in the model. Otherwise, a sample will be drawn uniformly from the domain of gibbs_site. Defaults to False.
modified (bool) – whether to use a modified proposal, as suggested in reference [2], which always proposes a new state for the current Gibbs site (i.e. discrete site). Defaults to False. The modified scheme appears in the literature under the name “modified Gibbs sampler” or “Metropolised Gibbs sampler”.
Example
>>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.infer import HMC, MCMC, MixedHMC ... >>> def model(probs, locs): ... c = numpyro.sample("c", dist