mlx_optimizers.Kron#
- class Kron(learning_rate: float | Callable[[array], array], b1: float = 0.9, weight_decay: float = 0.0, precond_update_prob: float | Callable[[array], array] | None = None, max_size_triangular: int = 8192, min_ndim_triangular: int = 2, memory_save_mode: str | None = None, momentum_into_precond_update: bool = True)#
Kronecker-Factored Preconditioned Stochastic Gradient Descent [1].
PSGD is a second-order optimizer that uses Hessian- or whitening-based \((gg^T)\) preconditioners and Lie groups to improve convergence. Kron uses Kronecker-factored preconditioners for tensors of any number dimension.
[1] Xi-Lin Li, 2015. Preconditioned Stochastic Gradient Descent. https://arxiv.org/abs/1512.04202 lixilinx/psgd_torch
- Parameters:
learning_rate (float or callable) – the learning rate.
b1 (float, optional) – coefficient used for computing running averages of the gradient. Default:
0.9
weight_decay (float, optional) – weight decay factor. Default:
0.0
precond_update_prob (float or callable, optional) – probability of updating the preconditioner. Default:
None
(flat exponential schedule)max_size_triangular (int, optional) – maximum size for dim’s preconditioner to be triangular. Default:
8192
min_ndim_triangular (int, optional) – minimum number of dimensions a layer needs to have triangular preconditioners. Default:
2
memory_save_mode (str, optional) – (None, ‘one_diag’, or ‘all_diag’). None: set all preconditioners to be triangular, ‘one_diag’: sets the largest or last dim to be diagonal per layer, and ‘all_diag’: sets all preconditioners to be diagonal. Default:
None
momentum_into_precond_update (bool, optional) – whether to use momentum in preconditioner update. Default:
True
Methods
__init__
(learning_rate[, b1, weight_decay, ...])apply_single
(gradient, parameter, state)Performs a single optimization step, updating \(m\) and \(v\)
init_single
(parameter, state)Initialize optimizer state