# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from collections import namedtuple
from copy import deepcopy
from functools import partial
from jax import random
import jax.numpy as jnp
from jax.tree_util import register_pytree_node, tree_flatten, tree_unflatten
import numpyro
import numpyro.distributions as dist
from numpyro.primitives import mutable as numpyro_mutable
__all__ = [
"flax_module",
"haiku_module",
"random_flax_module",
"random_haiku_module",
]
[docs]def flax_module(
name, nn_module, *, input_shape=None, apply_rng=None, mutable=None, **kwargs
):
"""
Declare a :mod:`~flax` style neural network inside a
model so that its parameters are registered for optimization via
:func:`~numpyro.primitives.param` statements.
Given a flax ``nn_module``, in flax to evaluate the module with
a given set of parameters, we use: ``nn_module.apply(params, x)``.
In a NumPyro model, the pattern will be::
net = flax_module("net", nn_module)
y = net(x)
or with dropout layers::
net = flax_module("net", nn_module, apply_rng=["dropout"])
rng_key = numpyro.prng_key()
y = net(x, rngs={"dropout": rng_key})
:param str name: name of the module to be registered.
:param flax.linen.Module nn_module: a `flax` Module which has .init and .apply methods
:param tuple input_shape: shape of the input taken by the
neural network.
:param list apply_rng: A list to indicate which extra rng _kinds_ are needed for
``nn_module``. For example, when ``nn_module`` includes dropout layers, we
need to set ``apply_rng=["dropout"]``. Defaults to None, which means no extra
rng key is needed. Please see
`Flax Linen Intro <https://flax.readthedocs.io/en/latest/notebooks/linen_intro.html#Invoking-Modules>`_
for more information in how Flax deals with stochastic layers like dropout.
:param list mutable: A list to indicate mutable states of ``nn_module``. For example,
if your module has BatchNorm layer, we will need to define ``mutable=["batch_stats"]``.
See the above `Flax Linen Intro` tutorial for more information.
:param kwargs: optional keyword arguments to initialize flax neural network
as an alternative to `input_shape`
:return: a callable with bound parameters that takes an array
as an input and returns the neural network transformed output
array.
"""
try:
import flax # noqa: F401
except ImportError as e:
raise ImportError(
"Looking like you want to use flax to declare "
"nn modules. This is an experimental feature. "
"You need to install `flax` to be able to use this feature. "
"It can be installed with `pip install flax`."
) from e
module_key = name + "$params"
nn_params = numpyro.param(module_key)
if mutable:
nn_state = numpyro_mutable(name + "$state")
assert nn_state is None or isinstance(nn_state, dict)
assert (nn_state is None) == (nn_params is None)
if nn_params is None:
# feed in dummy data to init params
args = (jnp.ones(input_shape),) if input_shape is not None else ()
rng_key = numpyro.prng_key()
# split rng_key into a dict of rng_kind: rng_key
rngs = {}
if apply_rng:
assert isinstance(apply_rng, list)
for kind in apply_rng:
rng_key, subkey = random.split(rng_key)
rngs[kind] = subkey
rngs["params"] = rng_key
nn_vars = flax.core.unfreeze(nn_module.init(rngs, *args, **kwargs))
if "params" not in nn_vars:
raise ValueError(
"Your nn_module does not have any parameter. Currently, it is not"
" supported in NumPyro. Please make a github issue if you need"
" that feature."
)
nn_params = nn_vars["params"]
if mutable:
nn_state = {k: v for k, v in nn_vars.items() if k != "params"}
assert set(mutable) == set(nn_state)
numpyro_mutable(name + "$state", nn_state)
# make sure that nn_params keep the same order after unflatten
params_flat, tree_def = tree_flatten(nn_params)
nn_params = tree_unflatten(tree_def, params_flat)
numpyro.param(module_key, nn_params)
def apply_with_state(params, *args, **kwargs):
params = {"params": params, **nn_state}
out, new_state = nn_module.apply(params, mutable=mutable, *args, **kwargs)
nn_state.update(**new_state)
return out
def apply_without_state(params, *args, **kwargs):
return nn_module.apply({"params": params}, *args, **kwargs)
apply_fn = apply_with_state if mutable else apply_without_state
return partial(apply_fn, nn_params)
[docs]def haiku_module(name, nn_module, *, input_shape=None, apply_rng=False, **kwargs):
"""
Declare a :mod:`~haiku` style neural network inside a
model so that its parameters are registered for optimization via
:func:`~numpyro.primitives.param` statements.
Given a haiku ``nn_module``, in haiku to evaluate the module with
a given set of parameters, we use: ``nn_module.apply(params, None, x)``.
In a NumPyro model, the pattern will be::
net = haiku_module("net", nn_module)
y = net(x) # or y = net(rng_key, x)
or with dropout layers::
net = haiku_module("net", nn_module, apply_rng=True)
rng_key = numpyro.prng_key()
y = net(rng_key, x)
:param str name: name of the module to be registered.
:param nn_module: a `haiku` Module which has .init and .apply methods
:type nn_module: haiku.Transformed or haiku.TransformedWithState
:param tuple input_shape: shape of the input taken by the
neural network.
:param bool apply_rng: A flag to indicate if the returned callable requires
an rng argument (e.g. when ``nn_module`` includes dropout layers). Defaults
to False, which means no rng argument is needed. If this is True, the signature
of the returned callable ``nn = haiku_module(..., apply_rng=True)`` will be
``nn(rng_key, x)`` (rather than ``nn(x)``).
:param kwargs: optional keyword arguments to initialize flax neural network
as an alternative to `input_shape`
:return: a callable with bound parameters that takes an array
as an input and returns the neural network transformed output
array.
"""
try:
import haiku as hk # noqa: F401
except ImportError as e:
raise ImportError(
"Looking like you want to use haiku to declare "
"nn modules. This is an experimental feature. "
"You need to install `haiku` to be able to use this feature. "
"It can be installed with `pip install dm-haiku`."
) from e
if not apply_rng:
nn_module = hk.without_apply_rng(nn_module)
module_key = name + "$params"
nn_params = numpyro.param(module_key)
with_state = isinstance(nn_module, hk.TransformedWithState)
if with_state:
nn_state = numpyro_mutable(name + "$state")
assert nn_state is None or isinstance(nn_state, dict)
assert (nn_state is None) == (nn_params is None)
if nn_params is None:
args = (jnp.ones(input_shape),) if input_shape is not None else ()
# feed in dummy data to init params
rng_key = numpyro.prng_key()
if with_state:
nn_params, nn_state = nn_module.init(rng_key, *args, **kwargs)
nn_state = dict(nn_state)
numpyro_mutable(name + "$state", nn_state)
else:
nn_params = nn_module.init(rng_key, *args, **kwargs)
# haiku init returns an immutable dict
nn_params = hk.data_structures.to_mutable_dict(nn_params)
# we cast it to a mutable one to be able to set priors for parameters
# make sure that nn_params keep the same order after unflatten
params_flat, tree_def = tree_flatten(nn_params)
nn_params = tree_unflatten(tree_def, params_flat)
numpyro.param(module_key, nn_params)
def apply_with_state(params, *args, **kwargs):
out, new_state = nn_module.apply(params, nn_state, *args, **kwargs)
nn_state.update(**new_state)
return out
apply_fn = apply_with_state if with_state else nn_module.apply
return partial(apply_fn, nn_params)
# register an "empty" parameter which only stores its shape
# so that the optimizer can skip optimize this parameter, while
# it still provides shape information for priors
ParamShape = namedtuple("ParamShape", ["shape"])
register_pytree_node(
ParamShape, lambda x: ((None,), x.shape), lambda shape, x: ParamShape(shape)
)
def _update_params(params, new_params, prior, prefix=""):
"""
A helper to recursively set prior to new_params.
"""
for name, item in params.items():
flatten_name = ".".join([prefix, name]) if prefix else name
if isinstance(item, dict):
assert not isinstance(prior, dict) or flatten_name not in prior
new_item = new_params[name]
_update_params(item, new_item, prior, prefix=flatten_name)
elif (not isinstance(prior, dict)) or flatten_name in prior:
if isinstance(params[name], ParamShape):
param_shape = params[name].shape
else:
param_shape = jnp.shape(params[name])
params[name] = ParamShape(param_shape)
if isinstance(prior, dict):
d = prior[flatten_name]
elif callable(prior) and not isinstance(prior, dist.Distribution):
d = prior(flatten_name, param_shape)
else:
d = prior
param_batch_shape = param_shape[: len(param_shape) - d.event_dim]
# XXX: here we set all dimensions of prior to event dimensions.
new_params[name] = numpyro.sample(
flatten_name, d.expand(param_batch_shape).to_event()
)
[docs]def random_flax_module(
name, nn_module, prior, *, input_shape=None, apply_rng=None, mutable=None, **kwargs
):
"""
A primitive to place a prior over the parameters of the Flax module `nn_module`.
.. note::
Parameters of a Flax module are stored in a nested dict. For example,
the module `B` defined as follows::
class A(flax.linen.Module):
@flax.linen.compact
def __call__(self, x):
return nn.Dense(1, use_bias=False, name='dense')(x)
class B(flax.linen.Module):
@flax.linen.compact
def __call__(self, x):
return A(name='inner')(x)
has parameters `{'inner': {'dense': {'kernel': param_value}}}`. In the argument
`prior`, to specify `kernel` parameter, we join the path to it using dots:
`prior={"inner.dense.kernel": param_prior}`.
:param str name: name of NumPyro module
:param flax.linen.Module: the module to be registered with NumPyro
:param prior: a NumPyro distribution or a Python dict with parameter names as keys and
respective distributions as values. For example::
net = random_flax_module("net",
flax.linen.Dense(features=1),
prior={"bias": dist.Cauchy(), "kernel": dist.Normal()},
input_shape=(4,))
Alternatively, we can use a callable. For example the following are equivalent::
prior=(lambda name, shape: dist.Cauchy() if name == "bias" else dist.Normal())
prior={"bias": dist.Cauchy(), "kernel": dist.Normal()}
:type prior: dict, ~numpyro.distributions.Distribution or callable
:param tuple input_shape: shape of the input taken by the neural network.
:param list apply_rng: A list to indicate which extra rng _kinds_ are needed for
``nn_module``. For example, when ``nn_module`` includes dropout layers, we
need to set ``apply_rng=["dropout"]``. Defaults to None, which means no extra
rng key is needed. Please see
`Flax Linen Intro <https://flax.readthedocs.io/en/latest/notebooks/linen_intro.html#Invoking-Modules>`_
for more information in how Flax deals with stochastic layers like dropout.
:param list mutable: A list to indicate mutable states of ``nn_module``. For example,
if your module has BatchNorm layer, we will need to define ``mutable=["batch_stats"]``.
See the above `Flax Linen Intro` tutorial for more information.
:param kwargs: optional keyword arguments to initialize flax neural network
as an alternative to `input_shape`
:returns: a sampled module
**Example**
.. doctest::
# NB: this example is ported from https://github.com/ctallec/pyvarinf/blob/master/main_regression.ipynb
>>> import numpy as np; np.random.seed(0)
>>> import tqdm
>>> from flax import linen as nn
>>> from jax import jit, random
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.contrib.module import random_flax_module
>>> from numpyro.infer import Predictive, SVI, TraceMeanField_ELBO, autoguide, init_to_feasible
...
>>> class Net(nn.Module):
... n_units: int
...
... @nn.compact
... def __call__(self, x):
... x = nn.Dense(self.n_units)(x[..., None])
... x = nn.relu(x)
... x = nn.Dense(self.n_units)(x)
... x = nn.relu(x)
... mean = nn.Dense(1)(x)
... rho = nn.Dense(1)(x)
... return mean.squeeze(), rho.squeeze()
...
>>> def generate_data(n_samples):
... x = np.random.normal(size=n_samples)
... y = np.cos(x * 3) + np.random.normal(size=n_samples) * np.abs(x) / 2
... return x, y
...
>>> def model(x, y=None, batch_size=None):
... module = Net(n_units=32)
... net = random_flax_module("nn", module, dist.Normal(0, 0.1), input_shape=())
... with numpyro.plate("batch", x.shape[0], subsample_size=batch_size):
... batch_x = numpyro.subsample(x, event_dim=0)
... batch_y = numpyro.subsample(y, event_dim=0) if y is not None else None
... mean, rho = net(batch_x)
... sigma = nn.softplus(rho)
... numpyro.sample("obs", dist.Normal(mean, sigma), obs=batch_y)
...
>>> n_train_data = 5000
>>> x_train, y_train = generate_data(n_train_data)
>>> guide = autoguide.AutoNormal(model, init_loc_fn=init_to_feasible)
>>> svi = SVI(model, guide, numpyro.optim.Adam(5e-3), TraceMeanField_ELBO())
>>> n_iterations = 3000
>>> svi_result = svi.run(random.PRNGKey(0), n_iterations, x_train, y_train, batch_size=256)
>>> params, losses = svi_result.params, svi_result.losses
>>> n_test_data = 100
>>> x_test, y_test = generate_data(n_test_data)
>>> predictive = Predictive(model, guide=guide, params=params, num_samples=1000)
>>> y_pred = predictive(random.PRNGKey(1), x_test[:100])["obs"].copy()
>>> assert losses[-1] < 3000
>>> assert np.sqrt(np.mean(np.square(y_test - y_pred))) < 1
"""
nn = flax_module(
name,
nn_module,
input_shape=input_shape,
apply_rng=apply_rng,
mutable=mutable,
**kwargs
)
params = nn.args[0]
new_params = deepcopy(params)
with numpyro.handlers.scope(prefix=name):
_update_params(params, new_params, prior)
nn_new = partial(nn.func, new_params, *nn.args[1:], **nn.keywords)
return nn_new
[docs]def random_haiku_module(
name, nn_module, prior, *, input_shape=None, apply_rng=False, **kwargs
):
"""
A primitive to place a prior over the parameters of the Haiku module `nn_module`.
:param str name: name of NumPyro module
:param nn_module: the module to be registered with NumPyro
:type nn_module: haiku.Transformed or haiku.TransformedWithState
:param prior: a NumPyro distribution or a Python dict with parameter names as keys and
respective distributions as values. For example::
net = random_haiku_module("net",
haiku.transform(lambda x: hk.Linear(1)(x)),
prior={"linear.b": dist.Cauchy(), "linear.w": dist.Normal()},
input_shape=(4,))
Alternatively, we can use a callable. For example the following are equivalent::
prior=(lambda name, shape: dist.Cauchy() if name.startswith("b") else dist.Normal())
prior={"bias": dist.Cauchy(), "kernel": dist.Normal()}
:type prior: dict, ~numpyro.distributions.Distribution or callable
:param tuple input_shape: shape of the input taken by the neural network.
:param bool apply_rng: A flag to indicate if the returned callable requires
an rng argument (e.g. when ``nn_module`` includes dropout layers). Defaults
to False, which means no rng argument is needed. If this is True, the signature
of the returned callable ``nn = haiku_module(..., apply_rng=True)`` will be
``nn(rng_key, x)`` (rather than ``nn(x)``).
:param kwargs: optional keyword arguments to initialize flax neural network
as an alternative to `input_shape`
:returns: a sampled module
"""
nn = haiku_module(
name, nn_module, input_shape=input_shape, apply_rng=apply_rng, **kwargs
)
params = nn.args[0]
new_params = deepcopy(params)
with numpyro.handlers.scope(prefix=name):
_update_params(params, new_params, prior)
nn_new = partial(nn.func, new_params, *nn.args[1:], **nn.keywords)
return nn_new