mlx_optimizers.Muon#
- class Muon(learning_rate: float | ~typing.Callable[[~mlx.core.array], ~mlx.core.array] = 0.02, momentum: float = 0.95, nesterov: bool = True, backend: str = 'newtonschulz5', backend_steps: int = 5, alternate_optimizer: ~mlx.optimizers.optimizers.Optimizer = <mlx.optimizers.optimizers.AdamW object>)#
MomentUm Orthogonalized by Newton-schulz [1].
\[\begin{split}m_t &= \mu m_{t-1} + g_t \\ g_t &= \mu m_t + g_t \text{ if nesterov} \\ O_t &= \text{orthogonalize}(g_t) \\ \theta_{t} &= \theta_{t-1} - \eta (O_t + \lambda \theta_{t-1})\end{split}\][1] Keller Jordan, 2024. KellerJordan/Muon
- Parameters:
learning_rate (float or callable) – The learning rate \(\eta\). Default:
0.02
momentum (float, optional) – The momentum strength \(\mu\). Default:
0.95
nesterov (bool, optional) – Enables Nesterov momentum. Default:
True
backend (str, optional) – The orthogonalization backend. Default:
"newtonschulz5"
backend_steps (int, optional) – The number of steps for orthogonalization. Default:
5
alternate_optimizer (Optimizer, optional) – The alternate optimizer to use when the parameter is not a 2D tensor. Default:
AdamW(0.001)
Methods
__init__
([learning_rate, momentum, ...])apply_single
(gradient, parameter, state)Apply Muon optimization update with Newton-Schulz orthogonalization.
init_single
(parameter, state)To be extended by the children classes to implement each optimizer's state initialization.