Source code for numpyro.optim

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

Optimizer classes defined here are light wrappers over the corresponding optimizers
sourced from :mod:`jax.example_libraries.optimizers` with an interface that is better
suited for working with NumPyro inference algorithms.

from collections import namedtuple
from import Callable
from typing import Any, TypeVar

from jax import lax, value_and_grad
from jax.example_libraries import optimizers
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
from jax.scipy.optimize import minimize
from jax.tree_util import register_pytree_node, tree_map

__all__ = [

_Params = TypeVar("_Params")
_OptState = TypeVar("_OptState")
_IterOptState = tuple[int, _OptState]

class _NumPyroOptim(object):
    def __init__(self, optim_fn: Callable, *args, **kwargs) -> None:
        self.init_fn, self.update_fn, self.get_params_fn = optim_fn(*args, **kwargs)

    def init(self, params: _Params) -> _IterOptState:
        Initialize the optimizer with parameters designated to be optimized.

        :param params: a collection of numpy arrays.
        :return: initial optimizer state.
        opt_state = self.init_fn(params)
        return jnp.array(0), opt_state

    def update(self, g: _Params, state: _IterOptState) -> _IterOptState:
        Gradient update for the optimizer.

        :param g: gradient information for parameters.
        :param state: current optimizer state.
        :return: new optimizer state after the update.
        i, opt_state = state
        opt_state = self.update_fn(i, g, opt_state)
        return i + 1, opt_state

    def eval_and_update(self, fn: Callable[[Any], tuple], state: _IterOptState):
        Performs an optimization step for the objective function `fn`.
        For most optimizers, the update is performed based on the gradient
        of the objective function w.r.t. the current state. However, for
        some optimizers such as :class:`Minimize`, the update is performed
        by reevaluating the function multiple times to get optimal

        :param fn: an objective function returning a pair where the first item
            is a scalar loss function to be differentiated and the second item
            is an auxiliary output.
        :param state: current optimizer state.
        :return: a pair of the output of objective function and the new optimizer state.
        params = self.get_params(state)
        (out, aux), grads = value_and_grad(fn, has_aux=True)(params)
        return (out, aux), self.update(grads, state)

    def eval_and_stable_update(self, fn: Callable[[Any], tuple], state: _IterOptState):
        Like :meth:`eval_and_update` but when the value of the objective function
        or the gradients are not finite, we will not update the input `state`
        and will set the objective output to `nan`.

        :param fn: objective function.
        :param state: current optimizer state.
        :return: a pair of the output of objective function and the new optimizer state.
        params = self.get_params(state)
        (out, aux), grads = value_and_grad(fn, has_aux=True)(params)
        out, state = lax.cond(
            jnp.isfinite(out) & jnp.isfinite(ravel_pytree(grads)[0]).all(),
            lambda _: (out, self.update(grads, state)),
            lambda _: (jnp.nan, state),
        return (out, aux), state

    def get_params(self, state: _IterOptState) -> _Params:
        Get current parameter values.

        :param state: current optimizer state.
        :return: collection with current value for parameters.
        _, opt_state = state
        return self.get_params_fn(opt_state)

def _add_doc(fn):
    def _wrapped(cls):
        cls.__doc__ = "Wrapper class for the JAX optimizer: :func:`~jax.example_libraries.optimizers.{}`".format(
        return cls

    return _wrapped

[docs]@_add_doc(optimizers.adam) class Adam(_NumPyroOptim): def __init__(self, *args, **kwargs): super(Adam, self).__init__(optimizers.adam, *args, **kwargs)
[docs]class ClippedAdam(_NumPyroOptim): """ :class:`~numpyro.optim.Adam` optimizer with gradient clipping. :param float clip_norm: All gradient values will be clipped between `[-clip_norm, clip_norm]`. **Reference:** `A Method for Stochastic Optimization`, Diederik P. Kingma, Jimmy Ba """ def __init__(self, *args, clip_norm=10.0, **kwargs): self.clip_norm = clip_norm super(ClippedAdam, self).__init__(optimizers.adam, *args, **kwargs)
[docs] def update(self, g, state): i, opt_state = state # clip norm g = tree_map( lambda g_: jnp.clip(g_, a_min=-self.clip_norm, a_max=self.clip_norm), g ) opt_state = self.update_fn(i, g, opt_state) return i + 1, opt_state
[docs]@_add_doc(optimizers.adagrad) class Adagrad(_NumPyroOptim): def __init__(self, *args, **kwargs): super(Adagrad, self).__init__(optimizers.adagrad, *args, **kwargs)
[docs]@_add_doc(optimizers.momentum) class Momentum(_NumPyroOptim): def __init__(self, *args, **kwargs): super(Momentum, self).__init__(optimizers.momentum, *args, **kwargs)
[docs]@_add_doc(optimizers.rmsprop) class RMSProp(_NumPyroOptim): def __init__(self, *args, **kwargs): super(RMSProp, self).__init__(optimizers.rmsprop, *args, **kwargs)
[docs]@_add_doc(optimizers.rmsprop_momentum) class RMSPropMomentum(_NumPyroOptim): def __init__(self, *args, **kwargs): super(RMSPropMomentum, self).__init__( optimizers.rmsprop_momentum, *args, **kwargs )
[docs]@_add_doc(optimizers.sgd) class SGD(_NumPyroOptim): def __init__(self, *args, **kwargs): super(SGD, self).__init__(optimizers.sgd, *args, **kwargs)
[docs]@_add_doc(optimizers.sm3) class SM3(_NumPyroOptim): def __init__(self, *args, **kwargs): super(SM3, self).__init__(optimizers.sm3, *args, **kwargs)
# TODO: currently, jax.scipy.optimize.minimize only supports 1D input, # so we need to add the following mechanism to transform params to flat_params # and pass `unravel_fn` arround. # When arbitrary pytree is supported in JAX, we can just simply use # identity functions for `init_fn` and `get_params`. _MinimizeState = namedtuple("MinimizeState", ["flat_params", "unravel_fn"]) register_pytree_node( _MinimizeState, lambda state: ((state.flat_params,), (state.unravel_fn,)), lambda data, xs: _MinimizeState(xs[0], data[0]), ) def _minimize_wrapper(): def init_fn(params): flat_params, unravel_fn = ravel_pytree(params) return _MinimizeState(flat_params, unravel_fn) def update_fn(i, grad_tree, opt_state): # we don't use update_fn in Minimize, so let it do nothing return opt_state def get_params(opt_state): flat_params, unravel_fn = opt_state return unravel_fn(flat_params) return init_fn, update_fn, get_params
[docs]class Minimize(_NumPyroOptim): """ Wrapper class for the JAX minimizer: :func:`~jax.scipy.optimize.minimize`. .. warnings: This optimizer is intended to be used with static guides such as empty guides (maximum likelihood estimate), delta guides (MAP estimate), or :class:`~numpyro.infer.autoguide.AutoLaplaceApproximation`. Using this in stochastic setting is either expensive or hard to converge. **Example:** .. doctest:: >>> from numpy.testing import assert_allclose >>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.infer import SVI, Trace_ELBO >>> from numpyro.infer.autoguide import AutoLaplaceApproximation >>> def model(x, y): ... a = numpyro.sample("a", dist.Normal(0, 1)) ... b = numpyro.sample("b", dist.Normal(0, 1)) ... with numpyro.plate("N", y.shape[0]): ... numpyro.sample("obs", dist.Normal(a + b * x, 0.1), obs=y) >>> x = jnp.linspace(0, 10, 100) >>> y = 3 * x + 2 >>> optimizer = numpyro.optim.Minimize() >>> guide = AutoLaplaceApproximation(model) >>> svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) >>> init_state = svi.init(random.PRNGKey(0), x, y) >>> optimal_state, loss = svi.update(init_state, x, y) >>> params = svi.get_params(optimal_state) # get guide's parameters >>> quantiles = guide.quantiles(params, 0.5) # get means of posterior samples >>> assert_allclose(quantiles["a"], 2., atol=1e-3) >>> assert_allclose(quantiles["b"], 3., atol=1e-3) """ def __init__(self, method="BFGS", **kwargs): super().__init__(_minimize_wrapper) self._method = method self._kwargs = kwargs
[docs] def eval_and_update(self, fn: Callable[[Any], tuple], state: _IterOptState): i, (flat_params, unravel_fn) = state def loss_fn(x): x = unravel_fn(x) out, aux = fn(x) if aux is not None: raise ValueError( "Minimize does not support models with mutable states." ) return out results = minimize( loss_fn, flat_params, (), method=self._method, **self._kwargs ) flat_params, out = results.x, state = (i + 1, _MinimizeState(flat_params, unravel_fn)) return (out, None), state
[docs]def optax_to_numpyro(transformation) -> _NumPyroOptim: """ This function produces a ``numpyro.optim._NumPyroOptim`` instance from an ``optax.GradientTransformation`` so that it can be used with ``numpyro.infer.svi.SVI``. It is a lightweight wrapper that recreates the ``(init_fn, update_fn, get_params_fn)`` interface defined by :mod:`jax.example_libraries.optimizers`. :param transformation: An ``optax.GradientTransformation`` instance to wrap. :return: An instance of ``numpyro.optim._NumPyroOptim`` wrapping the supplied Optax optimizer. """ import optax def init_fn(params): opt_state = transformation.init(params) return params, opt_state def update_fn(step, grads, state): params, opt_state = state updates, opt_state = transformation.update(grads, opt_state, params) updated_params = optax.apply_updates(params, updates) return updated_params, opt_state def get_params_fn(state): params, _ = state return params return _NumPyroOptim(lambda x, y, z: (x, y, z), init_fn, update_fn, get_params_fn)