Optimizers¶
Optimizer classes defined here are light wrappers over the corresponding optimizers
sourced from jax.experimental.optimizers
with an interface that is better
suited for working with NumPyro inference algorithms.
Adam¶
-
class
Adam
(*args, **kwargs)[source]¶ Wrapper class for the JAX optimizer:
adam()
-
eval_and_update
(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as
Minimize
, the update is performed by reevaluating the function multiple times to get optimal parameters.Parameters: - fn – objective function.
- state – current optimizer state.
Returns: a pair of the output of objective function and the new optimizer state.
-
get_params
(state: Tuple[int, _OptState]) → _Params¶ Get current parameter values.
Parameters: state – current optimizer state. Returns: collection with current value for parameters.
-
init
(params: _Params) → Tuple[int, _OptState]¶ Initialize the optimizer with parameters designated to be optimized.
Parameters: params – a collection of numpy arrays. Returns: initial optimizer state.
-
update
(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Gradient update for the optimizer.
Parameters: - g – gradient information for parameters.
- state – current optimizer state.
Returns: new optimizer state after the update.
-
Adagrad¶
-
class
Adagrad
(*args, **kwargs)[source]¶ Wrapper class for the JAX optimizer:
adagrad()
-
eval_and_update
(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as
Minimize
, the update is performed by reevaluating the function multiple times to get optimal parameters.Parameters: - fn – objective function.
- state – current optimizer state.
Returns: a pair of the output of objective function and the new optimizer state.
-
get_params
(state: Tuple[int, _OptState]) → _Params¶ Get current parameter values.
Parameters: state – current optimizer state. Returns: collection with current value for parameters.
-
init
(params: _Params) → Tuple[int, _OptState]¶ Initialize the optimizer with parameters designated to be optimized.
Parameters: params – a collection of numpy arrays. Returns: initial optimizer state.
-
update
(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Gradient update for the optimizer.
Parameters: - g – gradient information for parameters.
- state – current optimizer state.
Returns: new optimizer state after the update.
-
ClippedAdam¶
-
class
ClippedAdam
(*args, clip_norm=10.0, **kwargs)[source]¶ Adam
optimizer with gradient clipping.Parameters: clip_norm (float) – All gradient values will be clipped between [-clip_norm, clip_norm]. Reference:
A Method for Stochastic Optimization, Diederik P. Kingma, Jimmy Ba https://arxiv.org/abs/1412.6980
-
eval_and_update
(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as
Minimize
, the update is performed by reevaluating the function multiple times to get optimal parameters.Parameters: - fn – objective function.
- state – current optimizer state.
Returns: a pair of the output of objective function and the new optimizer state.
-
get_params
(state: Tuple[int, _OptState]) → _Params¶ Get current parameter values.
Parameters: state – current optimizer state. Returns: collection with current value for parameters.
-
init
(params: _Params) → Tuple[int, _OptState]¶ Initialize the optimizer with parameters designated to be optimized.
Parameters: params – a collection of numpy arrays. Returns: initial optimizer state.
-
Minimize¶
-
class
Minimize
(method='BFGS', **kwargs)[source]¶ Wrapper class for the JAX minimizer:
minimize()
.Example:
>>> from numpy.testing import assert_allclose >>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.infer import SVI, Trace_ELBO >>> from numpyro.infer.autoguide import AutoLaplaceApproximation >>> def model(x, y): ... a = numpyro.sample("a", dist.Normal(0, 1)) ... b = numpyro.sample("b", dist.Normal(0, 1)) ... with numpyro.plate("N", y.shape[0]): ... numpyro.sample("obs", dist.Normal(a + b * x, 0.1), obs=y) >>> x = jnp.linspace(0, 10, 100) >>> y = 3 * x + 2 >>> optimizer = numpyro.optim.Minimize() >>> guide = AutoLaplaceApproximation(model) >>> svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) >>> init_state = svi.init(random.PRNGKey(0), x, y) >>> optimal_state, loss = svi.update(init_state, x, y) >>> params = svi.get_params(optimal_state) # get guide's parameters >>> quantiles = guide.quantiles(params, 0.5) # get means of posterior samples >>> assert_allclose(quantiles["a"], 2., atol=1e-3) >>> assert_allclose(quantiles["b"], 3., atol=1e-3)
-
eval_and_update
(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState][source]¶ Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as
Minimize
, the update is performed by reevaluating the function multiple times to get optimal parameters.Parameters: - fn – objective function.
- state – current optimizer state.
Returns: a pair of the output of objective function and the new optimizer state.
-
get_params
(state: Tuple[int, _OptState]) → _Params¶ Get current parameter values.
Parameters: state – current optimizer state. Returns: collection with current value for parameters.
-
init
(params: _Params) → Tuple[int, _OptState]¶ Initialize the optimizer with parameters designated to be optimized.
Parameters: params – a collection of numpy arrays. Returns: initial optimizer state.
-
update
(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Gradient update for the optimizer.
Parameters: - g – gradient information for parameters.
- state – current optimizer state.
Returns: new optimizer state after the update.
-
Momentum¶
-
class
Momentum
(*args, **kwargs)[source]¶ Wrapper class for the JAX optimizer:
momentum()
-
eval_and_update
(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as
Minimize
, the update is performed by reevaluating the function multiple times to get optimal parameters.Parameters: - fn – objective function.
- state – current optimizer state.
Returns: a pair of the output of objective function and the new optimizer state.
-
get_params
(state: Tuple[int, _OptState]) → _Params¶ Get current parameter values.
Parameters: state – current optimizer state. Returns: collection with current value for parameters.
-
init
(params: _Params) → Tuple[int, _OptState]¶ Initialize the optimizer with parameters designated to be optimized.
Parameters: params – a collection of numpy arrays. Returns: initial optimizer state.
-
update
(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Gradient update for the optimizer.
Parameters: - g – gradient information for parameters.
- state – current optimizer state.
Returns: new optimizer state after the update.
-
RMSProp¶
-
class
RMSProp
(*args, **kwargs)[source]¶ Wrapper class for the JAX optimizer:
rmsprop()
-
eval_and_update
(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as
Minimize
, the update is performed by reevaluating the function multiple times to get optimal parameters.Parameters: - fn – objective function.
- state – current optimizer state.
Returns: a pair of the output of objective function and the new optimizer state.
-
get_params
(state: Tuple[int, _OptState]) → _Params¶ Get current parameter values.
Parameters: state – current optimizer state. Returns: collection with current value for parameters.
-
init
(params: _Params) → Tuple[int, _OptState]¶ Initialize the optimizer with parameters designated to be optimized.
Parameters: params – a collection of numpy arrays. Returns: initial optimizer state.
-
update
(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Gradient update for the optimizer.
Parameters: - g – gradient information for parameters.
- state – current optimizer state.
Returns: new optimizer state after the update.
-
RMSPropMomentum¶
-
class
RMSPropMomentum
(*args, **kwargs)[source]¶ Wrapper class for the JAX optimizer:
rmsprop_momentum()
-
eval_and_update
(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as
Minimize
, the update is performed by reevaluating the function multiple times to get optimal parameters.Parameters: - fn – objective function.
- state – current optimizer state.
Returns: a pair of the output of objective function and the new optimizer state.
-
get_params
(state: Tuple[int, _OptState]) → _Params¶ Get current parameter values.
Parameters: state – current optimizer state. Returns: collection with current value for parameters.
-
init
(params: _Params) → Tuple[int, _OptState]¶ Initialize the optimizer with parameters designated to be optimized.
Parameters: params – a collection of numpy arrays. Returns: initial optimizer state.
-
update
(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Gradient update for the optimizer.
Parameters: - g – gradient information for parameters.
- state – current optimizer state.
Returns: new optimizer state after the update.
-
SGD¶
-
class
SGD
(*args, **kwargs)[source]¶ Wrapper class for the JAX optimizer:
sgd()
-
eval_and_update
(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as
Minimize
, the update is performed by reevaluating the function multiple times to get optimal parameters.Parameters: - fn – objective function.
- state – current optimizer state.
Returns: a pair of the output of objective function and the new optimizer state.
-
get_params
(state: Tuple[int, _OptState]) → _Params¶ Get current parameter values.
Parameters: state – current optimizer state. Returns: collection with current value for parameters.
-
init
(params: _Params) → Tuple[int, _OptState]¶ Initialize the optimizer with parameters designated to be optimized.
Parameters: params – a collection of numpy arrays. Returns: initial optimizer state.
-
update
(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Gradient update for the optimizer.
Parameters: - g – gradient information for parameters.
- state – current optimizer state.
Returns: new optimizer state after the update.
-
SM3¶
-
class
SM3
(*args, **kwargs)[source]¶ Wrapper class for the JAX optimizer:
sm3()
-
eval_and_update
(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as
Minimize
, the update is performed by reevaluating the function multiple times to get optimal parameters.Parameters: - fn – objective function.
- state – current optimizer state.
Returns: a pair of the output of objective function and the new optimizer state.
-
get_params
(state: Tuple[int, _OptState]) → _Params¶ Get current parameter values.
Parameters: state – current optimizer state. Returns: collection with current value for parameters.
-
init
(params: _Params) → Tuple[int, _OptState]¶ Initialize the optimizer with parameters designated to be optimized.
Parameters: params – a collection of numpy arrays. Returns: initial optimizer state.
-
update
(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Gradient update for the optimizer.
Parameters: - g – gradient information for parameters.
- state – current optimizer state.
Returns: new optimizer state after the update.
-