mlx_optimizers.Muon

Contents

mlx_optimizers.Muon#

class Muon(learning_rate: float | ~typing.Callable[[~mlx.core.array], ~mlx.core.array] = 0.02, momentum: float = 0.95, nesterov: bool = True, backend: str = 'newtonschulz5', backend_steps: int = 5, alternate_optimizer: ~mlx.optimizers.optimizers.Optimizer = <mlx.optimizers.optimizers.AdamW object>)#

MomentUm Orthogonalized by Newton-schulz [1].

\[\begin{split}m_t &= \mu m_{t-1} + g_t \\ g_t &= \mu m_t + g_t \text{ if nesterov} \\ O_t &= \text{orthogonalize}(g_t) \\ \theta_{t} &= \theta_{t-1} - \eta (O_t + \lambda \theta_{t-1})\end{split}\]

[1] Keller Jordan, 2024. KellerJordan/Muon

Parameters:
  • learning_rate (float or callable) – The learning rate \(\eta\). Default: 0.02

  • momentum (float, optional) – The momentum strength \(\mu\). Default: 0.95

  • nesterov (bool, optional) – Enables Nesterov momentum. Default: True

  • backend (str, optional) – The orthogonalization backend. Default: "newtonschulz5"

  • backend_steps (int, optional) – The number of steps for orthogonalization. Default: 5

  • alternate_optimizer (Optimizer, optional) – The alternate optimizer to use when the parameter is not a 2D tensor. Default: AdamW(0.001)

../_images/rosenbrock_Muon.png

Methods

__init__([learning_rate, momentum, ...])

apply_single(gradient, parameter, state)

Apply Muon optimization update with Newton-Schulz orthogonalization.

init_single(parameter, state)

To be extended by the children classes to implement each optimizer's state initialization.