Optimizers
Optimizer classes defined here are light wrappers over the corresponding optimizers
sourced from jax.example_libraries.optimizers
with an interface that is better
suited for working with NumPyro inference algorithms.
Adam
- class Adam(*args, **kwargs)[source]
Wrapper class for the JAX optimizer:
adam()
- eval_and_stable_update(fn: Callable[[Any], tuple], state: tuple[int, _OptState], forward_mode_differentiation: bool = False)
Like
eval_and_update()
but when the value of the objective function or the gradients are not finite, we will not update the input state and will set the objective output to nan.- Parameters:
fn – objective function.
state – current optimizer state.
forward_mode_differentiation – boolean flag indicating whether to use forward mode differentiation.
- Returns:
a pair of the output of objective function and the new optimizer state.
- eval_and_update(fn: Callable[[Any], tuple], state: tuple[int, _OptState], forward_mode_differentiation: bool = False)
Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as
Minimize
, the update is performed by reevaluating the function multiple times to get optimal parameters.- Parameters:
fn – an objective function returning a pair where the first item is a scalar loss function to be differentiated and the second item is an auxiliary output.
state – current optimizer state.
forward_mode_differentiation – boolean flag indicating whether to use forward mode differentiation.
- Returns:
a pair of the output of objective function and the new optimizer state.
- get_params(state: tuple[int, _OptState]) _Params
Get current parameter values.
- Parameters:
state – current optimizer state.
- Returns:
collection with current value for parameters.
Adagrad
- class Adagrad(*args, **kwargs)[source]
Wrapper class for the JAX optimizer:
adagrad()
- eval_and_stable_update(fn: Callable[[Any], tuple], state: tuple[int, _OptState], forward_mode_differentiation: bool = False)
Like
eval_and_update()
but when the value of the objective function or the gradients are not finite, we will not update the input state and will set the objective output to nan.- Parameters:
fn – objective function.
state – current optimizer state.
forward_mode_differentiation – boolean flag indicating whether to use forward mode differentiation.
- Returns:
a pair of the output of objective function and the new optimizer state.
- eval_and_update(fn: Callable[[Any], tuple], state: tuple[int, _OptState], forward_mode_differentiation: bool = False)
Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as
Minimize
, the update is performed by reevaluating the function multiple times to get optimal parameters.- Parameters:
fn – an objective function returning a pair where the first item is a scalar loss function to be differentiated and the second item is an auxiliary output.
state – current optimizer state.
forward_mode_differentiation – boolean flag indicating whether to use forward mode differentiation.
- Returns:
a pair of the output of objective function and the new optimizer state.
- get_params(state: tuple[int, _OptState]) _Params
Get current parameter values.
- Parameters:
state – current optimizer state.
- Returns:
collection with current value for parameters.
ClippedAdam
- class ClippedAdam(*args, clip_norm=10.0, **kwargs)[source]
Adam
optimizer with gradient clipping.- Parameters:
clip_norm (float) – All gradient values will be clipped between [-clip_norm, clip_norm].
Reference:
A Method for Stochastic Optimization, Diederik P. Kingma, Jimmy Ba https://arxiv.org/abs/1412.6980
- eval_and_stable_update(fn: Callable[[Any], tuple], state: tuple[int, _OptState], forward_mode_differentiation: bool = False)
Like
eval_and_update()
but when the value of the objective function or the gradients are not finite, we will not update the input state and will set the objective output to nan.- Parameters:
fn – objective function.
state – current optimizer state.
forward_mode_differentiation – boolean flag indicating whether to use forward mode differentiation.
- Returns:
a pair of the output of objective function and the new optimizer state.
- eval_and_update(fn: Callable[[Any], tuple], state: tuple[int, _OptState], forward_mode_differentiation: bool = False)
Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as
Minimize
, the update is performed by reevaluating the function multiple times to get optimal parameters.- Parameters:
fn – an objective function returning a pair where the first item is a scalar loss function to be differentiated and the second item is an auxiliary output.
state – current optimizer state.
forward_mode_differentiation – boolean flag indicating whether to use forward mode differentiation.
- Returns:
a pair of the output of objective function and the new optimizer state.
- get_params(state: tuple[int, _OptState]) _Params
Get current parameter values.
- Parameters:
state – current optimizer state.
- Returns:
collection with current value for parameters.
Minimize
- class Minimize(method='BFGS', **kwargs)[source]
Wrapper class for the JAX minimizer:
minimize()
.Example:
>>> from numpy.testing import assert_allclose >>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.infer import SVI, Trace_ELBO >>> from numpyro.infer.autoguide import AutoLaplaceApproximation >>> def model(x, y): ... a = numpyro.sample("a", dist.Normal(0, 1)) ... b = numpyro.sample("b", dist.Normal(0, 1)) ... with numpyro.plate("N", y.shape[0]): ... numpyro.sample("obs", dist.Normal(a + b * x, 0.1), obs=y) >>> x = jnp.linspace(0, 10, 100) >>> y = 3 * x + 2 >>> optimizer = numpyro.optim.Minimize() >>> guide = AutoLaplaceApproximation(model) >>> svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) >>> init_state = svi.init(random.PRNGKey(0), x, y) >>> optimal_state, loss = svi.update(init_state, x, y) >>> params = svi.get_params(optimal_state) # get guide's parameters >>> quantiles = guide.quantiles(params, 0.5) # get means of posterior samples >>> assert_allclose(quantiles["a"], 2., atol=1e-3) >>> assert_allclose(quantiles["b"], 3., atol=1e-3)
- eval_and_stable_update(fn: Callable[[Any], tuple], state: tuple[int, _OptState], forward_mode_differentiation: bool = False)
Like
eval_and_update()
but when the value of the objective function or the gradients are not finite, we will not update the input state and will set the objective output to nan.- Parameters:
fn – objective function.
state – current optimizer state.
forward_mode_differentiation – boolean flag indicating whether to use forward mode differentiation.
- Returns:
a pair of the output of objective function and the new optimizer state.
- eval_and_update(fn: Callable[[Any], tuple], state: tuple[int, _OptState], forward_mode_differentiation=False)[source]
Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as
Minimize
, the update is performed by reevaluating the function multiple times to get optimal parameters.- Parameters:
fn – an objective function returning a pair where the first item is a scalar loss function to be differentiated and the second item is an auxiliary output.
state – current optimizer state.
forward_mode_differentiation – boolean flag indicating whether to use forward mode differentiation.
- Returns:
a pair of the output of objective function and the new optimizer state.
- get_params(state: tuple[int, _OptState]) _Params
Get current parameter values.
- Parameters:
state – current optimizer state.
- Returns:
collection with current value for parameters.
Momentum
- class Momentum(*args, **kwargs)[source]
Wrapper class for the JAX optimizer:
momentum()
- eval_and_stable_update(fn: Callable[[Any], tuple], state: tuple[int, _OptState], forward_mode_differentiation: bool = False)
Like
eval_and_update()
but when the value of the objective function or the gradients are not finite, we will not update the input state and will set the objective output to nan.- Parameters:
fn – objective function.
state – current optimizer state.
forward_mode_differentiation – boolean flag indicating whether to use forward mode differentiation.
- Returns:
a pair of the output of objective function and the new optimizer state.
- eval_and_update(fn: Callable[[Any], tuple], state: tuple[int, _OptState], forward_mode_differentiation: bool = False)
Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as
Minimize
, the update is performed by reevaluating the function multiple times to get optimal parameters.- Parameters:
fn – an objective function returning a pair where the first item is a scalar loss function to be differentiated and the second item is an auxiliary output.
state – current optimizer state.
forward_mode_differentiation – boolean flag indicating whether to use forward mode differentiation.
- Returns:
a pair of the output of objective function and the new optimizer state.
- get_params(state: tuple[int, _OptState]) _Params
Get current parameter values.
- Parameters:
state – current optimizer state.
- Returns:
collection with current value for parameters.
RMSProp
- class RMSProp(*args, **kwargs)[source]
Wrapper class for the JAX optimizer:
rmsprop()
- eval_and_stable_update(fn: Callable[[Any], tuple], state: tuple[int, _OptState], forward_mode_differentiation: bool = False)
Like
eval_and_update()
but when the value of the objective function or the gradients are not finite, we will not update the input state and will set the objective output to nan.- Parameters:
fn – objective function.
state – current optimizer state.
forward_mode_differentiation – boolean flag indicating whether to use forward mode differentiation.
- Returns:
a pair of the output of objective function and the new optimizer state.
- eval_and_update(fn: Callable[[Any], tuple], state: tuple[int, _OptState], forward_mode_differentiation: bool = False)
Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as
Minimize
, the update is performed by reevaluating the function multiple times to get optimal parameters.- Parameters:
fn – an objective function returning a pair where the first item is a scalar loss function to be differentiated and the second item is an auxiliary output.
state – current optimizer state.
forward_mode_differentiation – boolean flag indicating whether to use forward mode differentiation.
- Returns:
a pair of the output of objective function and the new optimizer state.
- get_params(state: tuple[int, _OptState]) _Params
Get current parameter values.
- Parameters:
state – current optimizer state.
- Returns:
collection with current value for parameters.
RMSPropMomentum
- class RMSPropMomentum(*args, **kwargs)[source]
Wrapper class for the JAX optimizer:
rmsprop_momentum()
- eval_and_stable_update(fn: Callable[[Any], tuple], state: tuple[int, _OptState], forward_mode_differentiation: bool = False)
Like
eval_and_update()
but when the value of the objective function or the gradients are not finite, we will not update the input state and will set the objective output to nan.- Parameters:
fn – objective function.
state – current optimizer state.
forward_mode_differentiation – boolean flag indicating whether to use forward mode differentiation.
- Returns:
a pair of the output of objective function and the new optimizer state.
- eval_and_update(fn: Callable[[Any], tuple], state: tuple[int, _OptState], forward_mode_differentiation: bool = False)
Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as
Minimize
, the update is performed by reevaluating the function multiple times to get optimal parameters.- Parameters:
fn – an objective function returning a pair where the first item is a scalar loss function to be differentiated and the second item is an auxiliary output.
state – current optimizer state.
forward_mode_differentiation – boolean flag indicating whether to use forward mode differentiation.
- Returns:
a pair of the output of objective function and the new optimizer state.
- get_params(state: tuple[int, _OptState]) _Params
Get current parameter values.
- Parameters:
state – current optimizer state.
- Returns:
collection with current value for parameters.
SGD
- class SGD(*args, **kwargs)[source]
Wrapper class for the JAX optimizer:
sgd()
- eval_and_stable_update(fn: Callable[[Any], tuple], state: tuple[int, _OptState], forward_mode_differentiation: bool = False)
Like
eval_and_update()
but when the value of the objective function or the gradients are not finite, we will not update the input state and will set the objective output to nan.- Parameters:
fn – objective function.
state – current optimizer state.
forward_mode_differentiation – boolean flag indicating whether to use forward mode differentiation.
- Returns:
a pair of the output of objective function and the new optimizer state.
- eval_and_update(fn: Callable[[Any], tuple], state: tuple[int, _OptState], forward_mode_differentiation: bool = False)
Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as
Minimize
, the update is performed by reevaluating the function multiple times to get optimal parameters.- Parameters:
fn – an objective function returning a pair where the first item is a scalar loss function to be differentiated and the second item is an auxiliary output.
state – current optimizer state.
forward_mode_differentiation – boolean flag indicating whether to use forward mode differentiation.
- Returns:
a pair of the output of objective function and the new optimizer state.
- get_params(state: tuple[int, _OptState]) _Params
Get current parameter values.
- Parameters:
state – current optimizer state.
- Returns:
collection with current value for parameters.
SM3
- class SM3(*args, **kwargs)[source]
Wrapper class for the JAX optimizer:
sm3()
- eval_and_stable_update(fn: Callable[[Any], tuple], state: tuple[int, _OptState], forward_mode_differentiation: bool = False)
Like
eval_and_update()
but when the value of the objective function or the gradients are not finite, we will not update the input state and will set the objective output to nan.- Parameters:
fn – objective function.
state – current optimizer state.
forward_mode_differentiation – boolean flag indicating whether to use forward mode differentiation.
- Returns:
a pair of the output of objective function and the new optimizer state.
- eval_and_update(fn: Callable[[Any], tuple], state: tuple[int, _OptState], forward_mode_differentiation: bool = False)
Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as
Minimize
, the update is performed by reevaluating the function multiple times to get optimal parameters.- Parameters:
fn – an objective function returning a pair where the first item is a scalar loss function to be differentiated and the second item is an auxiliary output.
state – current optimizer state.
forward_mode_differentiation – boolean flag indicating whether to use forward mode differentiation.
- Returns:
a pair of the output of objective function and the new optimizer state.
- get_params(state: tuple[int, _OptState]) _Params
Get current parameter values.
- Parameters:
state – current optimizer state.
- Returns:
collection with current value for parameters.
Optax support
- optax_to_numpyro(transformation) _NumPyroOptim [source]
This function produces a
numpyro.optim._NumPyroOptim
instance from anoptax.GradientTransformation
so that it can be used withnumpyro.infer.svi.SVI
. It is a lightweight wrapper that recreates the(init_fn, update_fn, get_params_fn)
interface defined byjax.example_libraries.optimizers
.- Parameters:
transformation – An
optax.GradientTransformation
instance to wrap.- Returns:
An instance of
numpyro.optim._NumPyroOptim
wrapping the supplied Optax optimizer.