mlx_optimizers.Kron

Contents

mlx_optimizers.Kron#

class Kron(learning_rate: float | Callable[[array], array], b1: float = 0.9, weight_decay: float = 0.0, precond_update_prob: float | Callable[[array], array] | None = None, max_size_triangular: int = 8192, min_ndim_triangular: int = 2, memory_save_mode: str | None = None, momentum_into_precond_update: bool = True)#

Kronecker-Factored Preconditioned Stochastic Gradient Descent [1].

PSGD is a second-order optimizer that uses Hessian- or whitening-based \((gg^T)\) preconditioners and Lie groups to improve convergence. Kron uses Kronecker-factored preconditioners for tensors of any number dimension.

[1] Xi-Lin Li, 2015. Preconditioned Stochastic Gradient Descent. https://arxiv.org/abs/1512.04202 lixilinx/psgd_torch

Parameters:
  • learning_rate (float or callable) – the learning rate.

  • b1 (float, optional) – coefficient used for computing running averages of the gradient. Default: 0.9

  • weight_decay (float, optional) – weight decay factor. Default: 0.0

  • precond_update_prob (float or callable, optional) – probability of updating the preconditioner. Default: None (flat exponential schedule)

  • max_size_triangular (int, optional) – maximum size for dim’s preconditioner to be triangular. Default: 8192

  • min_ndim_triangular (int, optional) – minimum number of dimensions a layer needs to have triangular preconditioners. Default: 2

  • memory_save_mode (str, optional) – (None, ‘one_diag’, or ‘all_diag’). None: set all preconditioners to be triangular, ‘one_diag’: sets the largest or last dim to be diagonal per layer, and ‘all_diag’: sets all preconditioners to be diagonal. Default: None

  • momentum_into_precond_update (bool, optional) – whether to use momentum in preconditioner update. Default: True

../_images/rosenbrock_Kron.png

Methods

__init__(learning_rate[, b1, weight_decay, ...])

apply_single(gradient, parameter, state)

Performs a single optimization step, updating \(m\) and \(v\)

init_single(parameter, state)

Initialize optimizer state