Source code for numpyro.optim

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

"""
Optimizer classes defined here are light wrappers over the corresponding optimizers
sourced from :mod:`jax.experimental.optimizers` with an interface that is better
suited for working with NumPyro inference algorithms.
"""

from typing import Callable, Tuple, TypeVar

from jax.experimental import optimizers
import jax.numpy as jnp
from jax.tree_util import tree_map

__all__ = [
    'Adam',
    'Adagrad',
    'ClippedAdam',
    'Momentum',
    'RMSProp',
    'RMSPropMomentum',
    'SGD',
    'SM3',
]

_Params = TypeVar('_Params')
_OptState = TypeVar('_OptState')
_IterOptState = Tuple[int, _OptState]


class _NumpyroOptim(object):
    def __init__(self, optim_fn: Callable, *args, **kwargs) -> None:
        self.init_fn, self.update_fn, self.get_params_fn = optim_fn(*args, **kwargs)

    def init(self, params: _Params) -> _IterOptState:
        """
        Initialize the optimizer with parameters designated to be optimized.

        :param params: a collection of numpy arrays.
        :return: initial optimizer state.
        """
        opt_state = self.init_fn(params)
        return jnp.array(0), opt_state

    def update(self, g: _Params, state: _IterOptState) -> _IterOptState:
        """
        Gradient update for the optimizer.

        :param g: gradient information for parameters.
        :param state: current optimizer state.
        :return: new optimizer state after the update.
        """
        i, opt_state = state
        opt_state = self.update_fn(i, g, opt_state)
        return i + 1, opt_state

    def get_params(self, state: _IterOptState) -> _Params:
        """
        Get current parameter values.

        :param state: current optimizer state.
        :return: collection with current value for parameters.
        """
        _, opt_state = state
        return self.get_params_fn(opt_state)


def _add_doc(fn):
    def _wrapped(cls):
        cls.__doc__ = 'Wrapper class for the JAX optimizer: :func:`~jax.experimental.optimizers.{}`'\
            .format(fn.__name__)
        return cls

    return _wrapped


[docs]@_add_doc(optimizers.adam) class Adam(_NumpyroOptim): def __init__(self, *args, **kwargs): super(Adam, self).__init__(optimizers.adam, *args, **kwargs)
[docs]class ClippedAdam(_NumpyroOptim): """ :class:`~numpyro.optim.Adam` optimizer with gradient clipping. :param float clip_norm: All gradient values will be clipped between `[-clip_norm, clip_norm]`. **Reference:** `A Method for Stochastic Optimization`, Diederik P. Kingma, Jimmy Ba https://arxiv.org/abs/1412.6980 """ def __init__(self, *args, clip_norm=10., **kwargs): self.clip_norm = clip_norm super(ClippedAdam, self).__init__(optimizers.adam, *args, **kwargs)
[docs] def update(self, g, state): i, opt_state = state # clip norm g = tree_map(lambda g_: jnp.clip(g_, a_min=-self.clip_norm, a_max=self.clip_norm), g) opt_state = self.update_fn(i, g, opt_state) return i + 1, opt_state
[docs]@_add_doc(optimizers.adagrad) class Adagrad(_NumpyroOptim): def __init__(self, *args, **kwargs): super(Adagrad, self).__init__(optimizers.adagrad, *args, **kwargs)
[docs]@_add_doc(optimizers.momentum) class Momentum(_NumpyroOptim): def __init__(self, *args, **kwargs): super(Momentum, self).__init__(optimizers.momentum, *args, **kwargs)
[docs]@_add_doc(optimizers.rmsprop) class RMSProp(_NumpyroOptim): def __init__(self, *args, **kwargs): super(RMSProp, self).__init__(optimizers.rmsprop, *args, **kwargs)
[docs]@_add_doc(optimizers.rmsprop_momentum) class RMSPropMomentum(_NumpyroOptim): def __init__(self, *args, **kwargs): super(RMSPropMomentum, self).__init__(optimizers.rmsprop_momentum, *args, **kwargs)
[docs]@_add_doc(optimizers.sgd) class SGD(_NumpyroOptim): def __init__(self, *args, **kwargs): super(SGD, self).__init__(optimizers.sgd, *args, **kwargs)
[docs]@_add_doc(optimizers.sm3) class SM3(_NumpyroOptim): def __init__(self, *args, **kwargs): super(SM3, self).__init__(optimizers.sm3, *args, **kwargs)