# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from collections import namedtuple
from copy import deepcopy
from functools import partial
import jax
from jax import random
import jax.numpy as jnp
import jax.tree_util as jtu
import numpyro
import numpyro.distributions as dist
from numpyro.primitives import mutable as numpyro_mutable
__all__ = [
"flax_module",
"random_flax_module",
"nnx_module",
"random_nnx_module",
"eqx_module",
"random_eqx_module",
]
[docs]
def flax_module(
name, nn_module, *args, input_shape=None, apply_rng=None, mutable=None, **kwargs
):
"""
Declare a :mod:`~flax` style neural network inside a
model so that its parameters are registered for optimization via
:func:`~numpyro.primitives.param` statements.
Given a flax ``nn_module``, in flax to evaluate the module with
a given set of parameters, we use: ``nn_module.apply(params, x)``.
In a NumPyro model, the pattern will be::
net = flax_module("net", nn_module)
y = net(x)
or with dropout layers::
net = flax_module("net", nn_module, apply_rng=["dropout"])
rng_key = numpyro.prng_key()
y = net(x, rngs={"dropout": rng_key})
:param str name: name of the module to be registered.
:param flax.linen.Module nn_module: a `flax` Module which has .init and .apply methods
:param args: optional arguments to initialize flax neural network
as an alternative to `input_shape`
:param tuple input_shape: shape of the input taken by the
neural network.
:param list apply_rng: A list to indicate which extra rng _kinds_ are needed for
``nn_module``. For example, when ``nn_module`` includes dropout layers, we
need to set ``apply_rng=["dropout"]``. Defaults to None, which means no extra
rng key is needed. Please see
`Flax Linen Intro <https://flax.readthedocs.io/en/latest/notebooks/linen_intro.html#Invoking-Modules>`_
for more information in how Flax deals with stochastic layers like dropout.
:param list mutable: A list to indicate mutable states of ``nn_module``. For example,
if your module has BatchNorm layer, we will need to define ``mutable=["batch_stats"]``.
See the above `Flax Linen Intro` tutorial for more information.
:param kwargs: optional keyword arguments to initialize flax neural network
as an alternative to `input_shape`
:return: a callable with bound parameters that takes an array
as an input and returns the neural network transformed output
array.
"""
try:
import flax # noqa: F401
except ImportError as e:
raise ImportError(
"Looking like you want to use flax to declare "
"nn modules. This is an experimental feature. "
"You need to install `flax` to be able to use this feature. "
"It can be installed with `pip install flax`."
) from e
module_key = name + "$params"
nn_params = numpyro.param(module_key)
if mutable:
nn_state = numpyro_mutable(name + "$state")
assert nn_state is None or isinstance(nn_state, dict)
assert (nn_state is None) == (nn_params is None)
if nn_params is None:
# feed in dummy data to init params
args = (jnp.ones(input_shape),) if input_shape is not None else args
rng_key = numpyro.prng_key()
if rng_key is None:
rng_key = random.key(0)
# split rng_key into a dict of rng_kind: rng_key
rngs = {}
if apply_rng:
assert isinstance(apply_rng, list)
for kind in apply_rng:
rng_key, subkey = random.split(rng_key)
rngs[kind] = subkey
rngs["params"] = rng_key
nn_vars = flax.core.unfreeze(nn_module.init(rngs, *args, **kwargs))
if "params" not in nn_vars:
raise ValueError(
"Your nn_module does not have any parameter. Currently, it is not"
" supported in NumPyro. Please make a github issue if you need"
" that feature."
)
nn_params = nn_vars["params"]
if mutable:
nn_state = {k: v for k, v in nn_vars.items() if k != "params"}
assert set(mutable) == set(nn_state)
numpyro_mutable(name + "$state", nn_state)
# make sure that nn_params keep the same order after unflatten
params_flat, tree_def = jax.tree.flatten(nn_params)
nn_params = jax.tree.unflatten(tree_def, params_flat)
numpyro.param(module_key, nn_params)
def apply_with_state(params, *args, **kwargs):
params = {"params": params, **nn_state}
out, new_state = nn_module.apply(params, mutable=mutable, *args, **kwargs)
new_state = jax.lax.stop_gradient(new_state)
nn_state.update(**new_state)
return out
def apply_without_state(params, *args, **kwargs):
return nn_module.apply({"params": params}, *args, **kwargs)
apply_fn = apply_with_state if mutable else apply_without_state
return partial(apply_fn, nn_params)
# register an "empty" parameter which only stores its shape
# so that the optimizer can skip optimize this parameter, while
# it still provides shape information for priors
ParamShape = namedtuple("ParamShape", ["shape"])
jtu.register_pytree_node(
ParamShape, lambda x: ((None,), x.shape), lambda shape, x: ParamShape(shape)
)
def _update_params(params, new_params, prior, prefix=""):
"""
A helper to recursively set prior to new_params.
"""
for name, item in params.items():
flatten_name = ".".join([str(prefix), str(name)]) if prefix else str(name)
if isinstance(item, dict):
assert not isinstance(prior, dict) or flatten_name not in prior
new_item = new_params[name]
_update_params(item, new_item, prior, prefix=flatten_name)
elif (not isinstance(prior, dict)) or flatten_name in prior:
if isinstance(params[name], ParamShape):
param_shape = params[name].shape
else:
param_shape = jnp.shape(params[name])
params[name] = ParamShape(param_shape)
if isinstance(prior, dict):
d = prior[flatten_name]
elif callable(prior) and not isinstance(prior, dist.Distribution):
d = prior(flatten_name, param_shape)
else:
d = prior
param_batch_shape = param_shape[: len(param_shape) - d.event_dim]
# XXX: here we set all dimensions of prior to event dimensions.
new_params[name] = numpyro.sample(
flatten_name, d.expand(param_batch_shape).to_event()
)
[docs]
def random_flax_module(
name,
nn_module,
prior,
*args,
input_shape=None,
apply_rng=None,
mutable=None,
**kwargs,
):
"""
A primitive to place a prior over the parameters of the Flax module `nn_module`.
.. note::
Parameters of a Flax module are stored in a nested dict. For example,
the module `B` defined as follows::
class A(flax.linen.Module):
@flax.linen.compact
def __call__(self, x):
return nn.Dense(1, use_bias=False, name='dense')(x)
class B(flax.linen.Module):
@flax.linen.compact
def __call__(self, x):
return A(name='inner')(x)
has parameters `{'inner': {'dense': {'kernel': param_value}}}`. In the argument
`prior`, to specify `kernel` parameter, we join the path to it using dots:
`prior={"inner.dense.kernel": param_prior}`.
:param str name: name of NumPyro module
:param flax.linen.Module: the module to be registered with NumPyro
:param prior: a NumPyro distribution or a Python dict with parameter names as keys and
respective distributions as values. For example::
net = random_flax_module("net",
flax.linen.Dense(features=1),
prior={"bias": dist.Cauchy(), "kernel": dist.Normal()},
input_shape=(4,))
Alternatively, we can use a callable. For example the following are equivalent::
prior=(lambda name, shape: dist.Cauchy() if name == "bias" else dist.Normal())
prior={"bias": dist.Cauchy(), "kernel": dist.Normal()}
:type prior: dict, ~numpyro.distributions.Distribution or callable
:param args: optional arguments to initialize flax neural network
as an alternative to `input_shape`
:param tuple input_shape: shape of the input taken by the neural network.
:param list apply_rng: A list to indicate which extra rng _kinds_ are needed for
``nn_module``. For example, when ``nn_module`` includes dropout layers, we
need to set ``apply_rng=["dropout"]``. Defaults to None, which means no extra
rng key is needed. Please see
`Flax Linen Intro <https://flax.readthedocs.io/en/latest/notebooks/linen_intro.html#Invoking-Modules>`_
for more information in how Flax deals with stochastic layers like dropout.
:param list mutable: A list to indicate mutable states of ``nn_module``. For example,
if your module has BatchNorm layer, we will need to define ``mutable=["batch_stats"]``.
See the above `Flax Linen Intro` tutorial for more information.
:param kwargs: optional keyword arguments to initialize flax neural network
as an alternative to `input_shape`
:returns: a sampled module
**Example**
.. doctest::
# NB: this example is ported from https://github.com/ctallec/pyvarinf/blob/master/main_regression.ipynb
>>> import numpy as np; np.random.seed(0)
>>> import tqdm
>>> from flax import linen as nn
>>> from jax import jit, random
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.contrib.module import random_flax_module
>>> from numpyro.infer import Predictive, SVI, TraceMeanField_ELBO, autoguide, init_to_feasible
...
>>> class Net(nn.Module):
... n_units: int
...
... @nn.compact
... def __call__(self, x):
... x = nn.Dense(self.n_units)(x[..., None])
... x = nn.relu(x)
... x = nn.Dense(self.n_units)(x)
... x = nn.relu(x)
... mean = nn.Dense(1)(x)
... rho = nn.Dense(1)(x)
... return mean.squeeze(), rho.squeeze()
...
>>> def generate_data(n_samples):
... x = np.random.normal(size=n_samples)
... y = np.cos(x * 3) + np.random.normal(size=n_samples) * np.abs(x) / 2
... return x, y
...
>>> def model(x, y=None, batch_size=None):
... module = Net(n_units=32)
... net = random_flax_module("nn", module, dist.Normal(0, 0.1), input_shape=())
... with numpyro.plate("batch", x.shape[0], subsample_size=batch_size):
... batch_x = numpyro.subsample(x, event_dim=0)
... batch_y = numpyro.subsample(y, event_dim=0) if y is not None else None
... mean, rho = net(batch_x)
... sigma = nn.softplus(rho)
... numpyro.sample("obs", dist.Normal(mean, sigma), obs=batch_y)
...
>>> n_train_data = 5000
>>> x_train, y_train = generate_data(n_train_data)
>>> guide = autoguide.AutoNormal(model, init_loc_fn=init_to_feasible)
>>> svi = SVI(model, guide, numpyro.optim.Adam(5e-3), TraceMeanField_ELBO())
>>> n_iterations = 3000
>>> svi_result = svi.run(random.key(0), n_iterations, x_train, y_train, batch_size=256)
>>> params, losses = svi_result.params, svi_result.losses
>>> n_test_data = 100
>>> x_test, y_test = generate_data(n_test_data)
>>> predictive = Predictive(model, guide=guide, params=params, num_samples=1000)
>>> y_pred = predictive(random.key(1), x_test[:100])["obs"].copy()
>>> assert losses[-1] < 3000
>>> assert np.sqrt(np.mean(np.square(y_test - y_pred))) < 1
"""
nn = flax_module(
name,
nn_module,
*args,
input_shape=input_shape,
apply_rng=apply_rng,
mutable=mutable,
**kwargs,
)
params = nn.args[0]
new_params = deepcopy(params)
with numpyro.handlers.scope(prefix=name):
_update_params(params, new_params, prior)
nn_new = partial(nn.func, new_params, *nn.args[1:], **nn.keywords)
return nn_new
[docs]
def nnx_module(name, nn_module):
"""
Declare a :mod:`~flax.nnx` style neural network inside a
model so that its parameters are registered for optimization via
:func:`~numpyro.primitives.param` statements.
Given a flax NNX ``nn_module``, to evaluate the module, we directly call it.
In a NumPyro model, the pattern will be::
# Eager initialization outside the model
module = nn_module(...)
# Inside the model
net = nnx_module("net", module)
y = net(x)
:param str name: name of the module to be registered.
:param flax.nnx.Module nn_module: a pre-initialized `flax nnx` Module instance.
:return: a callable that takes an array as an input and returns
the neural network transformed output array.
"""
try:
from flax import nnx
except ImportError as e:
raise ImportError(
"Looking like you want to use flax.nnx to declare "
"nn modules. This is an experimental feature. "
"You need to install the latest version of `flax` to use this feature. "
"It can be installed with `pip install git+https://github.com/google/flax.git`."
) from e
graph_def, eager_params_state, eager_other_state = nnx.split(
nn_module, nnx.Param, nnx.Not(nnx.Param)
)
eager_params_state_dict = nnx.to_pure_dict(eager_params_state)
module_params = None
if eager_params_state:
module_params = numpyro.param(name + "$params")
if module_params is None:
module_params = numpyro.param(name + "$params", eager_params_state_dict)
eager_other_state_dict = nnx.to_pure_dict(eager_other_state)
mutable_holder = None
if eager_other_state_dict:
mutable_holder = numpyro_mutable(
name + "$state", {"state": eager_other_state_dict}
)
def apply_fn(params, *call_args, **call_kwargs):
params_state = eager_params_state
if params:
nnx.replace_by_pure_dict(params_state, params)
mutable_state = eager_other_state
if mutable_holder:
nnx.replace_by_pure_dict(mutable_state, mutable_holder["state"])
model = nnx.merge(graph_def, params_state, mutable_state, copy=True)
model_call = model(*call_args, **call_kwargs)
if mutable_holder:
_, _, new_mutable_state = nnx.split(model, nnx.Param, nnx.Not(nnx.Param))
new_mutable_state = jax.lax.stop_gradient(new_mutable_state)
mutable_holder["state"] = nnx.to_pure_dict(new_mutable_state)
return model_call
return partial(apply_fn, module_params)
[docs]
def random_nnx_module(
name,
nn_module,
prior,
scope_divider="/",
):
"""
A primitive to create a random :mod:`~flax.nnx` style neural network
which can be used in MCMC samplers. The parameters of the neural network
will be sampled from ``prior``.
:param str name: name of the module to be registered.
:param flax.nnx.Module nn_module: a pre-initialized `flax nnx` Module instance.
:param prior: a distribution or a dict of distributions or a callable.
If it is a distribution, all parameters will be sampled from the same
distribution. If it is a dict, it maps parameter names to distributions.
If it is a callable, it takes parameter name and parameter shape as
inputs and returns a distribution. For example::
class Linear(nnx.Module):
def __init__(self, din, dout, *, rngs):
self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))
self.b = nnx.Param(jnp.zeros((dout,)))
def __call__(self, x):
return x @ self.w + self.b
# Eager initialization
linear = Linear(din=4, dout=1, rngs=nnx.Rngs(params=random.key(0)))
net = random_nnx_module("net", linear, prior={"w": dist.Normal(), "b": dist.Cauchy()})
Alternatively, we can use a callable. For example the following are equivalent::
prior=(lambda name, shape: dist.Cauchy() if name.endswith("b") else dist.Normal())
prior={"w": dist.Normal(), "b": dist.Cauchy()}
:param str scope_divider: the divider to use for the nnx name in the scope effect handler. Defaults to "/".
:return: a callable that takes an array as an input and returns
the neural network transformed output array.
"""
nn = nnx_module(name, nn_module)
apply_fn = nn.func
params = nn.args[0]
other_args = nn.args[1:]
keywords = nn.keywords
new_params = deepcopy(params)
with numpyro.handlers.scope(prefix=name, divider=scope_divider):
_update_params(params, new_params, prior)
return partial(apply_fn, new_params, *other_args, **keywords)
[docs]
def eqx_module(name, nn_module):
"""
Declare an :mod:`equinox` style neural network inside a
model so that its parameters are registered for optimization via
:func:`~numpyro.primitives.param` statements.
Given an equinox ``nn_module``, to evaluate the module, we directly call it.
In a NumPyro model, the pattern will be::
# Eager initialization outside the model
module = nn_module(...)
# Inside the model
net = eqx_module("net", module)
y = jax.vmap(net)(x)
In the case of stateful computation, the pattern is the following::
# Eager initialization outside the model
module, eager_state = eqx.nn.make_with_state(nn_module)(...)
# Inside the model
net = eqx_module("net", module)
mutable_holder = numpyro_mutable("net$state", {"state": eager_state})
batched_net = jax.vmap(net, in_axes=(0,None), out_axes=(0,None), axis_name='batch')
y, new_state = batched_net(x, mutable_holder['state'])
mutable_holder['state'] = new_state
:param str name: name of the module to be registered.
:param eqx.Module nn_module: a pre-initialized `equinox` Module instance.
:return: a callable that takes an array as an input and returns
the neural network transformed output array.
"""
try:
import equinox as eqx
except ImportError as e:
raise ImportError(
"Looking like you want to use equinox to declare "
"nn modules. This is an experimental feature. "
"You need to install the latest version of `equinox` to use this feature. "
"It can be installed with `pip install git+https://github.com/patrick-kidger/equinox.git`."
) from e
params, static = eqx.partition(nn_module, filter_spec=eqx.is_inexact_array)
params = numpyro.param(name + "$params", lambda _: params)
nn_module = eqx.combine(params, static)
return nn_module
[docs]
def random_eqx_module(name, nn_module, prior, scope_divider="/"):
"""
A primitive to create a random :mod:`equinox` style neural network
which can be used in MCMC samplers. The parameters of the neural network
will be sampled from ``prior``.
For supplying a prior dictionary, the dictionary keys are based on their jax key path.
To see the jax key paths for all of the leaves in your pytree model, you can run:
key_paths = [jtu.keystr(path)[1:] for path, _ in jtu.tree_leaves_with_path(model_instance)]
:param str name: name of the module to be registered.
:param eqx.Module nn_module: a pre-initialized `equinox` Module instance.
:param prior: a distribution or a dict of distributions or a callable.
If it is a distribution, all parameters will be sampled from the same
distribution. If it is a dict, it maps parameter names to distributions.
If it is a callable, it takes parameter name and parameter shape as
inputs and returns a distribution. For example::
class Linear(eqx.Module):
weight: jax.Array
bias: jax.Array
def __init__(self, in_size, out_size, key):
wkey, bkey = jax.random.split(key)
self.weight = jax.random.normal(wkey, (out_size, in_size))
self.bias = jax.random.normal(bkey, (out_size,))
def __call__(self, x):
return self.weight @ x + self.bias
# Eager initialization
linear = Linear(in_features=3, out_features=1, key=random.key(0))
nn_priors = {"weight": dist.Normal(), "bias": dist.Cauchy()}
net = random_eqx_module("net", linear, prior=nn_priors)
Alternatively, we can use a callable. For example the following are equivalent::
prior=(lambda name, shape: dist.Cauchy() if name == 'bias' else dist.Normal())
prior={"weight": dist.Normal(), "bias": dist.Cauchy()}
:param str scope_divider: the divider to use for the nnx name in the scope effect handler. Defaults to "/".
:return: a callable that takes an array as an input and returns
the neural network transformed output array.
"""
try:
import equinox as eqx
except ImportError as e:
raise ImportError(
"Looking like you want to use equinox to declare "
"nn modules. This is an experimental feature. "
"You need to install the latest version of `equinox` to use this feature. "
"It can be installed with `pip install git+https://github.com/patrick-kidger/equinox.git`."
) from e
nn = eqx_module(name, nn_module)
params, static = eqx.partition(nn, filter_spec=eqx.is_inexact_array)
params_dict = eqx_to_dict(params)
new_params = deepcopy(params_dict)
with numpyro.handlers.scope(prefix=name, divider=scope_divider):
_update_params(params_dict, new_params, prior)
return eqx.combine(eqx_from_dict(new_params, tree=params), static)
def eqx_to_dict(tree):
out = {}
def to_dict_impl(path, leaf):
out[jtu.keystr(path)[1:]] = leaf
jax.tree.map_with_path(to_dict_impl, tree)
return out
def eqx_from_dict(data: dict, tree):
def from_dict_impl(path, _):
return data[jtu.keystr(path)[1:]]
return jax.tree.map_with_path(from_dict_impl, tree)