Source code for numpyro.optim

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

Optimizer classes defined here are light wrappers over the corresponding optimizers
sourced from :mod:`jax.example_libraries.optimizers` with an interface that is better
suited for working with NumPyro inference algorithms.

from collections import namedtuple
from typing import Any, Callable, Tuple, TypeVar

import jax
from jax import lax, value_and_grad

from numpyro.util import _versiontuple

if _versiontuple(jax.__version__) >= (0, 2, 25):
    from jax.example_libraries import optimizers
    from jax.experimental import optimizers  # pytype: disable=import-error

from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
from jax.scipy.optimize import minimize
from jax.tree_util import register_pytree_node, tree_map

__all__ = [

_Params = TypeVar("_Params")
_OptState = TypeVar("_OptState")
_IterOptState = Tuple[int, _OptState]

class _NumPyroOptim(object):
    def __init__(self, optim_fn: Callable, *args, **kwargs) -> None:
        self.init_fn, self.update_fn, self.get_params_fn = optim_fn(*args, **kwargs)

    def init(self, params: _Params) -> _IterOptState:
        Initialize the optimizer with parameters designated to be optimized.

        :param params: a collection of numpy arrays.
        :return: initial optimizer state.
        opt_state = self.init_fn(params)
        return jnp.array(0), opt_state

    def update(self, g: _Params, state: _IterOptState) -> _IterOptState:
        Gradient update for the optimizer.

        :param g: gradient information for parameters.
        :param state: current optimizer state.
        :return: new optimizer state after the update.
        i, opt_state = state
        opt_state = self.update_fn(i, g, opt_state)
        return i + 1, opt_state

    def eval_and_update(self, fn: Callable[[Any], Tuple], state: _IterOptState):
        Performs an optimization step for the objective function `fn`.
        For most optimizers, the update is performed based on the gradient
        of the objective function w.r.t. the current state. However, for
        some optimizers such as :class:`Minimize`, the update is performed
        by reevaluating the function multiple times to get optimal

        :param fn: an objective function returning a pair where the first item
            is a scalar loss function to be differentiated and the second item
            is an auxiliary output.
        :param state: current optimizer state.
        :return: a pair of the output of objective function and the new optimizer state.
        params = self.get_params(state)
        (out, aux), grads = value_and_grad(fn, has_aux=True)(params)
        return (out, aux), self.update(grads, state)

    def eval_and_stable_update(self, fn: Callable[[Any], Tuple], state: _IterOptState):
        Like :meth:`eval_and_update` but when the value of the objective function
        or the gradients are not finite, we will not update the input `state`
        and will set the objective output to `nan`.

        :param fn: objective function.
        :param state: current optimizer state.
        :return: a pair of the output of objective function and the new optimizer state.
        params = self.get_params(state)
        (out, aux), grads = value_and_grad(fn, has_aux=True)(params)
        out, state = lax.cond(
            jnp.isfinite(out) & jnp.isfinite(ravel_pytree(grads)[0]).all(),
            lambda _: (out, self.update(grads, state)),
            lambda _: (jnp.nan, state),
        return (out, aux), state

    def get_params(self, state: _IterOptState) -> _Params:
        Get current parameter values.

        :param state: current optimizer state.
        :return: collection with current value for parameters.
        _, opt_state = state
        return self.get_params_fn(opt_state)

def _add_doc(fn):
    def _wrapped(cls):
        cls.__doc__ = "Wrapper class for the JAX optimizer: :func:`~jax.example_libraries.optimizers.{}`".format(
        return cls

    return _wrapped

[docs]@_add_doc(optimizers.adam) class Adam(_NumPyroOptim): def __init__(self, *args, **kwargs): super(Adam, self).__init__(optimizers.adam, *args, **kwargs)
[docs]class ClippedAdam(_NumPyroOptim): """ :class:`~numpyro.optim.Adam` optimizer with gradient clipping. :param float clip_norm: All gradient values will be clipped between `[-clip_norm, clip_norm]`. **Reference:** `A Method for Stochastic Optimization`, Diederik P. Kingma, Jimmy Ba """ def __init__(self, *args, clip_norm=10.0, **kwargs): self.clip_norm = clip_norm super(ClippedAdam, self).__init__(optimizers.adam, *args, **kwargs)
[docs] def update(self, g, state): i, opt_state = state # clip norm g = tree_map( lambda g_: jnp.clip(g_, a_min=-self.clip_norm, a_max=self.clip_norm), g ) opt_state = self.update_fn(i, g, opt_state) return i + 1, opt_state
[docs]@_add_doc(optimizers.adagrad) class Adagrad(_NumPyroOptim): def __init__(self, *args, **kwargs): super(Adagrad, self).__init__(optimizers.adagrad, *args, **kwargs)
[docs]@_add_doc(optimizers.momentum) class Momentum(_NumPyroOptim): def __init__(self, *args, **kwargs): super(Momentum, self).__init__(optimizers.momentum, *args, **kwargs)
[docs]@_add_doc(optimizers.rmsprop) class RMSProp(_NumPyroOptim): def __init__(self, *args, **kwargs): super(RMSProp, self).__init__(optimizers.rmsprop, *args, **kwargs)
[docs]@_add_doc(optimizers.rmsprop_momentum) class RMSPropMomentum(_NumPyroOptim): def __init__(self, *args, **kwargs): super(RMSPropMomentum, self).__init__( optimizers.rmsprop_momentum, *args, **kwargs )
[docs]@_add_doc(optimizers.sgd) class SGD(_NumPyroOptim): def __init__(self, *args, **kwargs): super(SGD, self).__init__(optimizers.sgd, *args, **kwargs)
[docs]@_add_doc(optimizers.sm3) class SM3(_NumPyroOptim): def __init__(self, *args, **kwargs): super(SM3, self).__init__(optimizers.sm3, *args, **kwargs)
# TODO: currently, jax.scipy.optimize.minimize only supports 1D input, # so we need to add the following mechanism to transform params to flat_params # and pass `unravel_fn` arround. # When arbitrary pytree is supported in JAX, we can just simply use # identity functions for `init_fn` and `get_params`. _MinimizeState = namedtuple("MinimizeState", ["flat_params", "unravel_fn"]) register_pytree_node( _MinimizeState, lambda state: ((state.flat_params,), (state.unravel_fn,)), lambda data, xs: _MinimizeState(xs[0], data[0]), ) def _minimize_wrapper(): def init_fn(params): flat_params, unravel_fn = ravel_pytree(params) return _MinimizeState(flat_params, unravel_fn) def update_fn(i, grad_tree, opt_state): # we don't use update_fn in Minimize, so let it do nothing return opt_state def get_params(opt_state): flat_params, unravel_fn = opt_state return unravel_fn(flat_params) return init_fn, update_fn, get_params
[docs]class Minimize(_NumPyroOptim): """ Wrapper class for the JAX minimizer: :func:`~jax.scipy.optimize.minimize`. .. warnings: This optimizer is intended to be used with static guides such as empty guides (maximum likelihood estimate), delta guides (MAP estimate), or :class:`~numpyro.infer.autoguide.AutoLaplaceApproximation`. Using this in stochastic setting is either expensive or hard to converge. **Example:** .. doctest:: >>> from numpy.testing import assert_allclose >>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.infer import SVI, Trace_ELBO >>> from numpyro.infer.autoguide import AutoLaplaceApproximation >>> def model(x, y): ... a = numpyro.sample("a", dist.Normal(0, 1)) ... b = numpyro.sample("b", dist.Normal(0, 1)) ... with numpyro.plate("N", y.shape[0]): ... numpyro.sample("obs", dist.Normal(a + b * x, 0.1), obs=y) >>> x = jnp.linspace(0, 10, 100) >>> y = 3 * x + 2 >>> optimizer = numpyro.optim.Minimize() >>> guide = AutoLaplaceApproximation(model) >>> svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) >>> init_state = svi.init(random.PRNGKey(0), x, y) >>> optimal_state, loss = svi.update(init_state, x, y) >>> params = svi.get_params(optimal_state) # get guide's parameters >>> quantiles = guide.quantiles(params, 0.5) # get means of posterior samples >>> assert_allclose(quantiles["a"], 2., atol=1e-3) >>> assert_allclose(quantiles["b"], 3., atol=1e-3) """ def __init__(self, method="BFGS", **kwargs): super().__init__(_minimize_wrapper) self._method = method self._kwargs = kwargs
[docs] def eval_and_update(self, fn: Callable[[Any], Tuple], state: _IterOptState): i, (flat_params, unravel_fn) = state def loss_fn(x): x = unravel_fn(x) out, aux = fn(x) if aux is not None: raise ValueError( "Minimize does not support models with mutable states." ) return out results = minimize( loss_fn, flat_params, (), method=self._method, **self._kwargs ) flat_params, out = results.x, state = (i + 1, _MinimizeState(flat_params, unravel_fn)) return (out, None), state
[docs]def optax_to_numpyro(transformation) -> _NumPyroOptim: """ This function produces a ``numpyro.optim._NumPyroOptim`` instance from an ``optax.GradientTransformation`` so that it can be used with ``numpyro.infer.svi.SVI``. It is a lightweight wrapper that recreates the ``(init_fn, update_fn, get_params_fn)`` interface defined by :mod:`jax.example_libraries.optimizers`. :param transformation: An ``optax.GradientTransformation`` instance to wrap. :return: An instance of ``numpyro.optim._NumPyroOptim`` wrapping the supplied Optax optimizer. """ import optax def init_fn(params): opt_state = transformation.init(params) return params, opt_state def update_fn(step, grads, state): params, opt_state = state updates, opt_state = transformation.update(grads, opt_state, params) updated_params = optax.apply_updates(params, updates) return updated_params, opt_state def get_params_fn(state): params, _ = state return params return _NumPyroOptim(lambda x, y, z: (x, y, z), init_fn, update_fn, get_params_fn)