mlx_optimizers.Shampoo

Contents

mlx_optimizers.Shampoo#

class Shampoo(learning_rate: float | Callable[[array], array], momentum: float = 0.0, weight_decay: float = 0.0, update_freq: int = 1, eps: float = 0.0001)#

Preconditioned Stochastic Tensor Optimization (general tensor case) [1].

\[\begin{split}W_1 &= 0_{n_1 \times \dots \times n_k}; \forall i \in [k]: H_0^i = \epsilon I_{n_i}\\ H_t^i &= H_{t-1}^i + G_t^{(i)}\\ \tilde{G}_t &= \tilde{G}_t \times_i (H_t^i)^{-1/2k}\\ W_{t+1} &= W_t - \eta \tilde{G}_t\end{split}\]

[1] Gupta, Vineet, Tomer Koren, and Yoram Singer, 2018. Shampoo: Preconditioned stochastic tensor optimization. ICML 2018. https://arxiv.org/abs/1802.09568

Parameters:
  • learning_rate (float or callable) – learning rate \(\eta\).

  • momentum (float, optional) – momentum factor. Default: 0.00

  • weight_decay (float, optional) – weight decay factor. Default: 0.00

  • update_freq (int, optional) – frequency of updating the preconditioner. Default: 1

  • eps (float, optional) – term \(\epsilon\) added to the denominator to improve numerical stability. Default: 1e-6

../_images/rosenbrock_Shampoo.png

Methods

__init__(learning_rate[, momentum, ...])

apply_single(gradient, parameter, state)

Performs a single optimization step, updating \(m\) and \(v\)

init_single(parameter, state)

Initialize optimizer state