mlx_optimizers.Lamb#
- class Lamb(learning_rate: float | Callable[[array], array], betas: List[float] = [0.9, 0.999], weight_decay: float = 0.0, eps: float = 1e-08)#
Layerwise Adaptive Large Batch Optimization [1].
\[\begin{split}m_0 &= 0, v_0 = 0 \\ m_t &= \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ v_t &= \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 \\ mh_t &= m_t / (1 - \beta_1^t) \\ vh_t &= v_t / (1 - \beta_2^t) \\ r_t &= \frac{mh_t}{\sqrt{vh_t} + \epsilon} \\ \theta_{t+1} &= \theta_t - \eta \frac{\phi(\|\theta_t\|)}{\|r_t + \lambda \theta_t\|} \left(r_t + \lambda \theta_t\right)\end{split}\][1] You, Yang, et al., 2019. Large Batch Optimization for Deep Learning: Training BERT in 76 Minutes. https://arxiv.org/abs/1904.00962 v5 tensorflow/addons
- Parameters:
learning_rate (float or callable) – learning rate \(\eta\).
betas (Tuple[float, float], optional) – coefficients \((\beta_1, \beta_2)\) used for computing running averages of the gradient and its square. Default:
(0.9, 0.999)
weight_decay (float) – weight decay. Default:
0.0
eps (float, optional) – term \(\epsilon\) added to the denominator to improve numerical stability. Default:
1e-8
Methods
__init__
(learning_rate[, betas, ...])apply_single
(gradient, parameter, state)Performs a single optimization step, updating \(m\) and \(v\)
init_single
(parameter, state)Initialize optimizer state