mlx_optimizers.Shampoo#
- class Shampoo(learning_rate: float | Callable[[array], array], momentum: float = 0.0, weight_decay: float = 0.0, update_freq: int = 1, eps: float = 0.0001)#
Preconditioned Stochastic Tensor Optimization (general tensor case) [1].
\[\begin{split}W_1 &= 0_{n_1 \times \dots \times n_k}; \forall i \in [k]: H_0^i = \epsilon I_{n_i}\\ H_t^i &= H_{t-1}^i + G_t^{(i)}\\ \tilde{G}_t &= \tilde{G}_t \times_i (H_t^i)^{-1/2k}\\ W_{t+1} &= W_t - \eta \tilde{G}_t\end{split}\][1] Gupta, Vineet, Tomer Koren, and Yoram Singer, 2018. Shampoo: Preconditioned stochastic tensor optimization. ICML 2018. https://arxiv.org/abs/1802.09568
- Parameters:
learning_rate (float or callable) – learning rate \(\eta\).
momentum (float, optional) – momentum factor. Default:
0.00
weight_decay (float, optional) – weight decay factor. Default:
0.00
update_freq (int, optional) – frequency of updating the preconditioner. Default:
1
eps (float, optional) – term \(\epsilon\) added to the denominator to improve numerical stability. Default:
1e-6
Methods
__init__
(learning_rate[, momentum, ...])apply_single
(gradient, parameter, state)Performs a single optimization step, updating \(m\) and \(v\)
init_single
(parameter, state)Initialize optimizer state