mlx_optimizers.Lamb

Contents

mlx_optimizers.Lamb#

class Lamb(learning_rate: float | Callable[[array], array], betas: List[float] = [0.9, 0.999], weight_decay: float = 0.0, eps: float = 1e-08)#

Layerwise Adaptive Large Batch Optimization [1].

\[\begin{split}m_0 &= 0, v_0 = 0 \\ m_t &= \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ v_t &= \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 \\ mh_t &= m_t / (1 - \beta_1^t) \\ vh_t &= v_t / (1 - \beta_2^t) \\ r_t &= \frac{mh_t}{\sqrt{vh_t} + \epsilon} \\ \theta_{t+1} &= \theta_t - \eta \frac{\phi(\|\theta_t\|)}{\|r_t + \lambda \theta_t\|} \left(r_t + \lambda \theta_t\right)\end{split}\]

[1] You, Yang, et al., 2019. Large Batch Optimization for Deep Learning: Training BERT in 76 Minutes. https://arxiv.org/abs/1904.00962 v5 tensorflow/addons

Parameters:
  • learning_rate (float or callable) – learning rate \(\eta\).

  • betas (Tuple[float, float], optional) – coefficients \((\beta_1, \beta_2)\) used for computing running averages of the gradient and its square. Default: (0.9, 0.999)

  • weight_decay (float) – weight decay. Default: 0.0

  • eps (float, optional) – term \(\epsilon\) added to the denominator to improve numerical stability. Default: 1e-8

../_images/rosenbrock_Lamb.png

Methods

__init__(learning_rate[, betas, ...])

apply_single(gradient, parameter, state)

Performs a single optimization step, updating \(m\) and \(v\)

init_single(parameter, state)

Initialize optimizer state