Optimizers¶
Optimizer classes defined here are light wrappers over the corresponding optimizers
sourced from jax.experimental.optimizers with an interface that is better
suited for working with NumPyro inference algorithms.
Adam¶
-
class
Adam(*args, **kwargs)[source]¶ Wrapper class for the JAX optimizer:
adam()-
get_params(state: Tuple[int, _OptState]) → _Params¶ Get current parameter values.
Parameters: state – current optimizer state. Returns: collection with current value for parameters.
-
init(params: _Params) → Tuple[int, _OptState]¶ Initialize the optimizer with parameters designated to be optimized.
Parameters: params – a collection of numpy arrays. Returns: initial optimizer state.
-
update(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Gradient update for the optimizer.
Parameters: - g – gradient information for parameters.
- state – current optimizer state.
Returns: new optimizer state after the update.
-
Adagrad¶
-
class
Adagrad(*args, **kwargs)[source]¶ Wrapper class for the JAX optimizer:
adagrad()-
get_params(state: Tuple[int, _OptState]) → _Params¶ Get current parameter values.
Parameters: state – current optimizer state. Returns: collection with current value for parameters.
-
init(params: _Params) → Tuple[int, _OptState]¶ Initialize the optimizer with parameters designated to be optimized.
Parameters: params – a collection of numpy arrays. Returns: initial optimizer state.
-
update(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Gradient update for the optimizer.
Parameters: - g – gradient information for parameters.
- state – current optimizer state.
Returns: new optimizer state after the update.
-
ClippedAdam¶
-
class
ClippedAdam(*args, clip_norm=10.0, **kwargs)[source]¶ Adamoptimizer with gradient clipping.Parameters: clip_norm (float) – All gradient values will be clipped between [-clip_norm, clip_norm]. Reference:
A Method for Stochastic Optimization, Diederik P. Kingma, Jimmy Ba https://arxiv.org/abs/1412.6980
-
get_params(state: Tuple[int, _OptState]) → _Params¶ Get current parameter values.
Parameters: state – current optimizer state. Returns: collection with current value for parameters.
-
init(params: _Params) → Tuple[int, _OptState]¶ Initialize the optimizer with parameters designated to be optimized.
Parameters: params – a collection of numpy arrays. Returns: initial optimizer state.
-
Momentum¶
-
class
Momentum(*args, **kwargs)[source]¶ Wrapper class for the JAX optimizer:
momentum()-
get_params(state: Tuple[int, _OptState]) → _Params¶ Get current parameter values.
Parameters: state – current optimizer state. Returns: collection with current value for parameters.
-
init(params: _Params) → Tuple[int, _OptState]¶ Initialize the optimizer with parameters designated to be optimized.
Parameters: params – a collection of numpy arrays. Returns: initial optimizer state.
-
update(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Gradient update for the optimizer.
Parameters: - g – gradient information for parameters.
- state – current optimizer state.
Returns: new optimizer state after the update.
-
RMSProp¶
-
class
RMSProp(*args, **kwargs)[source]¶ Wrapper class for the JAX optimizer:
rmsprop()-
get_params(state: Tuple[int, _OptState]) → _Params¶ Get current parameter values.
Parameters: state – current optimizer state. Returns: collection with current value for parameters.
-
init(params: _Params) → Tuple[int, _OptState]¶ Initialize the optimizer with parameters designated to be optimized.
Parameters: params – a collection of numpy arrays. Returns: initial optimizer state.
-
update(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Gradient update for the optimizer.
Parameters: - g – gradient information for parameters.
- state – current optimizer state.
Returns: new optimizer state after the update.
-
RMSPropMomentum¶
-
class
RMSPropMomentum(*args, **kwargs)[source]¶ Wrapper class for the JAX optimizer:
rmsprop_momentum()-
get_params(state: Tuple[int, _OptState]) → _Params¶ Get current parameter values.
Parameters: state – current optimizer state. Returns: collection with current value for parameters.
-
init(params: _Params) → Tuple[int, _OptState]¶ Initialize the optimizer with parameters designated to be optimized.
Parameters: params – a collection of numpy arrays. Returns: initial optimizer state.
-
update(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Gradient update for the optimizer.
Parameters: - g – gradient information for parameters.
- state – current optimizer state.
Returns: new optimizer state after the update.
-
SGD¶
-
class
SGD(*args, **kwargs)[source]¶ Wrapper class for the JAX optimizer:
sgd()-
get_params(state: Tuple[int, _OptState]) → _Params¶ Get current parameter values.
Parameters: state – current optimizer state. Returns: collection with current value for parameters.
-
init(params: _Params) → Tuple[int, _OptState]¶ Initialize the optimizer with parameters designated to be optimized.
Parameters: params – a collection of numpy arrays. Returns: initial optimizer state.
-
update(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Gradient update for the optimizer.
Parameters: - g – gradient information for parameters.
- state – current optimizer state.
Returns: new optimizer state after the update.
-
SM3¶
-
class
SM3(*args, **kwargs)[source]¶ Wrapper class for the JAX optimizer:
sm3()-
get_params(state: Tuple[int, _OptState]) → _Params¶ Get current parameter values.
Parameters: state – current optimizer state. Returns: collection with current value for parameters.
-
init(params: _Params) → Tuple[int, _OptState]¶ Initialize the optimizer with parameters designated to be optimized.
Parameters: params – a collection of numpy arrays. Returns: initial optimizer state.
-
update(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]¶ Gradient update for the optimizer.
Parameters: - g – gradient information for parameters.
- state – current optimizer state.
Returns: new optimizer state after the update.
-