
Optimizer classes defined here are light wrappers over the corresponding optimizers sourced from jax.experimental.optimizers with an interface that is better suited for working with NumPyro inference algorithms.


class Adam(*args, **kwargs)[source]

Wrapper class for the JAX optimizer: adam()

eval_and_stable_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Like eval_and_update() but when the value of the objective function or the gradients are not finite, we will not update the input state and will set the objective output to nan.

  • fn – objective function.
  • state – current optimizer state.

a pair of the output of objective function and the new optimizer state.

eval_and_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as Minimize, the update is performed by reevaluating the function multiple times to get optimal parameters.

  • fn – objective function.
  • state – current optimizer state.

a pair of the output of objective function and the new optimizer state.

get_params(state: Tuple[int, _OptState]) → _Params

Get current parameter values.

Parameters:state – current optimizer state.
Returns:collection with current value for parameters.
init(params: _Params) → Tuple[int, _OptState]

Initialize the optimizer with parameters designated to be optimized.

Parameters:params – a collection of numpy arrays.
Returns:initial optimizer state.
update(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Gradient update for the optimizer.

  • g – gradient information for parameters.
  • state – current optimizer state.

new optimizer state after the update.


class Adagrad(*args, **kwargs)[source]

Wrapper class for the JAX optimizer: adagrad()

eval_and_stable_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Like eval_and_update() but when the value of the objective function or the gradients are not finite, we will not update the input state and will set the objective output to nan.

  • fn – objective function.
  • state – current optimizer state.

a pair of the output of objective function and the new optimizer state.

eval_and_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as Minimize, the update is performed by reevaluating the function multiple times to get optimal parameters.

  • fn – objective function.
  • state – current optimizer state.

a pair of the output of objective function and the new optimizer state.

get_params(state: Tuple[int, _OptState]) → _Params

Get current parameter values.

Parameters:state – current optimizer state.
Returns:collection with current value for parameters.
init(params: _Params) → Tuple[int, _OptState]

Initialize the optimizer with parameters designated to be optimized.

Parameters:params – a collection of numpy arrays.
Returns:initial optimizer state.
update(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Gradient update for the optimizer.

  • g – gradient information for parameters.
  • state – current optimizer state.

new optimizer state after the update.


class ClippedAdam(*args, clip_norm=10.0, **kwargs)[source]

Adam optimizer with gradient clipping.

Parameters:clip_norm (float) – All gradient values will be clipped between [-clip_norm, clip_norm].


A Method for Stochastic Optimization, Diederik P. Kingma, Jimmy Ba https://arxiv.org/abs/1412.6980

eval_and_stable_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Like eval_and_update() but when the value of the objective function or the gradients are not finite, we will not update the input state and will set the objective output to nan.

  • fn – objective function.
  • state – current optimizer state.

a pair of the output of objective function and the new optimizer state.

eval_and_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as Minimize, the update is performed by reevaluating the function multiple times to get optimal parameters.

  • fn – objective function.
  • state – current optimizer state.

a pair of the output of objective function and the new optimizer state.

get_params(state: Tuple[int, _OptState]) → _Params

Get current parameter values.

Parameters:state – current optimizer state.
Returns:collection with current value for parameters.
init(params: _Params) → Tuple[int, _OptState]

Initialize the optimizer with parameters designated to be optimized.

Parameters:params – a collection of numpy arrays.
Returns:initial optimizer state.
update(g, state)[source]

Gradient update for the optimizer.

  • g – gradient information for parameters.
  • state – current optimizer state.

new optimizer state after the update.


class Minimize(method='BFGS', **kwargs)[source]

Wrapper class for the JAX minimizer: minimize().


>>> from numpy.testing import assert_allclose
>>> from jax import random
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.infer import SVI, Trace_ELBO
>>> from numpyro.infer.autoguide import AutoLaplaceApproximation

>>> def model(x, y):
...     a = numpyro.sample("a", dist.Normal(0, 1))
...     b = numpyro.sample("b", dist.Normal(0, 1))
...     with numpyro.plate("N", y.shape[0]):
...         numpyro.sample("obs", dist.Normal(a + b * x, 0.1), obs=y)

>>> x = jnp.linspace(0, 10, 100)
>>> y = 3 * x + 2
>>> optimizer = numpyro.optim.Minimize()
>>> guide = AutoLaplaceApproximation(model)
>>> svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
>>> init_state = svi.init(random.PRNGKey(0), x, y)
>>> optimal_state, loss = svi.update(init_state, x, y)
>>> params = svi.get_params(optimal_state)  # get guide's parameters
>>> quantiles = guide.quantiles(params, 0.5)  # get means of posterior samples
>>> assert_allclose(quantiles["a"], 2., atol=1e-3)
>>> assert_allclose(quantiles["b"], 3., atol=1e-3)
eval_and_stable_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Like eval_and_update() but when the value of the objective function or the gradients are not finite, we will not update the input state and will set the objective output to nan.

  • fn – objective function.
  • state – current optimizer state.

a pair of the output of objective function and the new optimizer state.

eval_and_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState][source]

Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as Minimize, the update is performed by reevaluating the function multiple times to get optimal parameters.

  • fn – objective function.
  • state – current optimizer state.

a pair of the output of objective function and the new optimizer state.

get_params(state: Tuple[int, _OptState]) → _Params

Get current parameter values.

Parameters:state – current optimizer state.
Returns:collection with current value for parameters.
init(params: _Params) → Tuple[int, _OptState]

Initialize the optimizer with parameters designated to be optimized.

Parameters:params – a collection of numpy arrays.
Returns:initial optimizer state.
update(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Gradient update for the optimizer.

  • g – gradient information for parameters.
  • state – current optimizer state.

new optimizer state after the update.


class Momentum(*args, **kwargs)[source]

Wrapper class for the JAX optimizer: momentum()

eval_and_stable_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Like eval_and_update() but when the value of the objective function or the gradients are not finite, we will not update the input state and will set the objective output to nan.

  • fn – objective function.
  • state – current optimizer state.

a pair of the output of objective function and the new optimizer state.

eval_and_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as Minimize, the update is performed by reevaluating the function multiple times to get optimal parameters.

  • fn – objective function.
  • state – current optimizer state.

a pair of the output of objective function and the new optimizer state.

get_params(state: Tuple[int, _OptState]) → _Params

Get current parameter values.

Parameters:state – current optimizer state.
Returns:collection with current value for parameters.
init(params: _Params) → Tuple[int, _OptState]

Initialize the optimizer with parameters designated to be optimized.

Parameters:params – a collection of numpy arrays.
Returns:initial optimizer state.
update(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Gradient update for the optimizer.

  • g – gradient information for parameters.
  • state – current optimizer state.

new optimizer state after the update.


class RMSProp(*args, **kwargs)[source]

Wrapper class for the JAX optimizer: rmsprop()

eval_and_stable_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Like eval_and_update() but when the value of the objective function or the gradients are not finite, we will not update the input state and will set the objective output to nan.

  • fn – objective function.
  • state – current optimizer state.

a pair of the output of objective function and the new optimizer state.

eval_and_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as Minimize, the update is performed by reevaluating the function multiple times to get optimal parameters.

  • fn – objective function.
  • state – current optimizer state.

a pair of the output of objective function and the new optimizer state.

get_params(state: Tuple[int, _OptState]) → _Params

Get current parameter values.

Parameters:state – current optimizer state.
Returns:collection with current value for parameters.
init(params: _Params) → Tuple[int, _OptState]

Initialize the optimizer with parameters designated to be optimized.

Parameters:params – a collection of numpy arrays.
Returns:initial optimizer state.
update(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Gradient update for the optimizer.

  • g – gradient information for parameters.
  • state – current optimizer state.

new optimizer state after the update.


class RMSPropMomentum(*args, **kwargs)[source]

Wrapper class for the JAX optimizer: rmsprop_momentum()

eval_and_stable_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Like eval_and_update() but when the value of the objective function or the gradients are not finite, we will not update the input state and will set the objective output to nan.

  • fn – objective function.
  • state – current optimizer state.

a pair of the output of objective function and the new optimizer state.

eval_and_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as Minimize, the update is performed by reevaluating the function multiple times to get optimal parameters.

  • fn – objective function.
  • state – current optimizer state.

a pair of the output of objective function and the new optimizer state.

get_params(state: Tuple[int, _OptState]) → _Params

Get current parameter values.

Parameters:state – current optimizer state.
Returns:collection with current value for parameters.
init(params: _Params) → Tuple[int, _OptState]

Initialize the optimizer with parameters designated to be optimized.

Parameters:params – a collection of numpy arrays.
Returns:initial optimizer state.
update(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Gradient update for the optimizer.

  • g – gradient information for parameters.
  • state – current optimizer state.

new optimizer state after the update.


class SGD(*args, **kwargs)[source]

Wrapper class for the JAX optimizer: sgd()

eval_and_stable_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Like eval_and_update() but when the value of the objective function or the gradients are not finite, we will not update the input state and will set the objective output to nan.

  • fn – objective function.
  • state – current optimizer state.

a pair of the output of objective function and the new optimizer state.

eval_and_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as Minimize, the update is performed by reevaluating the function multiple times to get optimal parameters.

  • fn – objective function.
  • state – current optimizer state.

a pair of the output of objective function and the new optimizer state.

get_params(state: Tuple[int, _OptState]) → _Params

Get current parameter values.

Parameters:state – current optimizer state.
Returns:collection with current value for parameters.
init(params: _Params) → Tuple[int, _OptState]

Initialize the optimizer with parameters designated to be optimized.

Parameters:params – a collection of numpy arrays.
Returns:initial optimizer state.
update(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Gradient update for the optimizer.

  • g – gradient information for parameters.
  • state – current optimizer state.

new optimizer state after the update.


class SM3(*args, **kwargs)[source]

Wrapper class for the JAX optimizer: sm3()

eval_and_stable_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Like eval_and_update() but when the value of the objective function or the gradients are not finite, we will not update the input state and will set the objective output to nan.

  • fn – objective function.
  • state – current optimizer state.

a pair of the output of objective function and the new optimizer state.

eval_and_update(fn: Callable, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Performs an optimization step for the objective function fn. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as Minimize, the update is performed by reevaluating the function multiple times to get optimal parameters.

  • fn – objective function.
  • state – current optimizer state.

a pair of the output of objective function and the new optimizer state.

get_params(state: Tuple[int, _OptState]) → _Params

Get current parameter values.

Parameters:state – current optimizer state.
Returns:collection with current value for parameters.
init(params: _Params) → Tuple[int, _OptState]

Initialize the optimizer with parameters designated to be optimized.

Parameters:params – a collection of numpy arrays.
Returns:initial optimizer state.
update(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]

Gradient update for the optimizer.

  • g – gradient information for parameters.
  • state – current optimizer state.

new optimizer state after the update.