Main

Deep models have led to extensive applications across a variety of industries. For instance, fine-tuning large language models (LLMs), such as Gemma1, GPT2 and Llama3, with user-specific data can significantly enhance the model performance. Diffusion models are increasingly used to generate personalized content, including images, audio and video, for entertainment. Vision models, such as ResNet4 and vision transformers5, are in high demand for tasks, such as image classification, segmentation and object detection on large-scale image datasets. However, the growing complexity and scale of these tasks present considerable challenges for widely used stochastic gradient descent (SGD)-based optimizers due to their inherent limitations. In this section, we review a subset of these methods, sufficient to motivate the development of our proposed approach.

SGD-based learning algorithms

It is known that the plain SGD-based algorithms, such as the original stochastic approximation approaches6,7, the parallelized SGD8 for machine learning and the newly developed federated averaging (FedAvg9,10) and local SGD (LocalSGD11) algorithms for federated learning, are frequently prone to high sensitivity to poor conditioning12 and slow convergence in high-dimensional and non-convex landscapes13. To overcome these drawbacks, SGD with momentum and adaptive learning rates has been proposed to enhance robustness and accelerate convergence. The former introduced first-order momentum to suppress the oscillation of SGD14, and the latter updated the learning rate iteratively based on historical information.

For instance, the adaptive gradient (AdaGrad15) interpolated the accumulated second-order moment of gradients to achieve the adaptive learning rate. The root mean-squared propagation (RMSProp16) exploited an exponential weighting technique to balance the distant historical information and the knowledge of the current second-order gradients, avoiding premature termination encountered by AdaGrad. The adaptive moment (Adam17) combined the first-order moment and adaptive learning rates, thereby exhibiting robustness to hyperparameters. The adaptive method setup-based gradient (AMSGrad18) took smaller learning rates than those of Adam by using a maximum value for normalizing the running average of the gradients, thereby fixing the convergence of Adam. Partially Adam (Padam19) used a similar strategy to AMSGrad with a difference in the rate power. Adam with dynamic bounds on learning rates (AdaBound20) designed a clipping mechanism on Adam-type learning rates by clipping the gradients larger than a threshold to avoid gradient explosion. Parallel restarted SGD (PRSGD21) simply benefited from a decayed power learning rate. A layerwise adaptive large batch optimization technique (Lamb22) performed per-dimension normalization with respect to the square root of the second moment used in Adam and set large batch sizes suggested by the layerwise adaptive rate scaling method (Lars23).

For some other SGD-based algorithms, one can refer to a per-dimension learning rate method based gradient descent (AdaDelta24), a variant of Adam based on the infinity norm (Adamax17), Adam with decoupled weight decay (AdamW25), a structure-aware preconditioning algorithm (Shampoo26), momentum orthogonalized by Newton–Schulz iterations (Muon27), Shampoo with Adam in the preconditioner’s eigenbasis (SOAP28), Adam with fewer learning rates (Adam-mini29) and those reviewed in a few studies30,31,32.

The aforementioned algorithms primarily leveraged the heavy ball acceleration technique to estimate the first-order and second-order moments of the gradient. An alternative approach to accelerating convergence is Nesterov acceleration, which has been theoretically proved to converge faster33,34. This has led to the development of several Nesterov-accelerated SGD algorithms. For example, Nesterov-accelerated Adam (NAdam35) integrated this technique into Adam to enhance convergence speed. More recently, an adaptive Nesterov momentum algorithm (Adan36) adopted Nesterov momentum estimation to refine the estimation of the first-order and second-order moments of the gradient, thereby improving adaptive learning-rate adjustments for faster convergence.

ADMM-based learning algorithms

The alternating direction method of multipliers (ADMM)37,38 is a promising framework for solving large-scale optimization problems because it decomposes complex problems into smaller and more manageable subproblems. This characteristic makes ADMM particularly well-suited for various challenges in distributed learning. It has demonstrated substantial potential in applications, such as image compressive sensing39, federated learning40,41,42, reinforcement learning43 and few-shot learning44. It is worth noting that there are many other distributed optimization frameworks, such as dual averaging45,46, push-sum47 and push–pull48. However, in the sequel, our focus is on the family of distributed algorithms, ADMM, as they are most directly relevant to the design and convergence analysis of our algorithm.

When applied to deep model training, early studies first relaxed optimization models before developing ADMM-based algorithms. Therefore, these methods can be classified as model-driven or model-specific approaches. They targeted the relaxations rather than the original problems, enabling better handling of the challenges associated with highly non-convex optimization landscapes. For instance, an ADMM-based algorithm in ref. 49 addressed a penalized formulation, avoiding pitfalls that hinder gradient-based methods in non-convex settings. Similarly, a deep learning ADMM algorithm50 and its variant51 were designed to enhance convergence by addressing penalization models as well. Moreover, the sigmoid-ADMM algorithm52 used the algorithmic gradient-free property to address saturation issues with sigmoid activation functions.

Despite their advantages, model-driven ADMM approaches face two critical issues. First, these algorithms exhibit high computational complexity because they require computing full gradients using the entire dataset and performing matrix inversions for weight updates, making them impractical for large-scale tasks. Furthermore, the reliance on the specific structure of optimization models limits their general applicability, as different deep models often necessitate distinct formulations.

To address the aforementioned limitations, data-driven ADMM algorithms have emerged as a promising alternative, aiming to reduce computational burdens and enhance adaptability to diverse tasks. For instance, the deterministic ADMM41,42, designed for distributed learning problems, achieved high accuracy but low computational efficiency due to the use of full dataset gradients. Stochastic ADMM (SADMM) methods tackled the inefficiency by replacing full gradients with stochastic approximations. Examples such as SADMM53 and distributed stochastic ADMM (PS-ADMM54) aimed to solve subproblems exactly and thus still involved high computational costs.

To further enhance convergence and reduce stochastic gradient variance, advanced techniques of variance reduction and Nesterov acceleration have been incorporated into SADMM. Representatives consist of stochastic average ADMM55, stochastic path-integrated differential estimator-based ADMM (SPIDER-ADMM56) and accelerated stochastic variance reduced gradient-based ADMM (ASVRG-ADMM) for convex57 and non-convex problems58. However, these enhanced methods still relied on access to full gradients for the aim of reducing stochastic gradient variance, posing challenges for large-scale applications.

Contributions

In this work, we propose a data-driven preconditioned inexact SADMM algorithm, termed PISA. It distinguishes itself from previous approaches and aims to reduce computational costs, relax convergence assumptions and enhance numerical performance. The key contributions are threefold.

A general algorithmic structure

The algorithm is based on a preconditioned inexact SADMM, combining simplicity with high generality to handle a wide range of deep learning applications. First, the use of preconditioning matrices allows us to incorporate various forms of useful information, such as first-moment and second-moment (yielding second-moment-based inexact SADMM (SISA), a variant of PISA), second-order information (for example, the Hessian) and orthogonalized momentum by Newton–Schulz iterations (leading to Newton–Schulz-based inexact SADMM (NSISA), another variant of PISA), into the updates, thereby enhancing the performance of the proposed algorithms. Moreover, the proposed algorithmic framework is inherently compatible with parallel computing, making it ideal for large-scale data settings.

Strong convergence theory under a sole assumption

PISA is proven to converge under a single assumption: the Lipschitz continuity of the gradient. Despite relying on stochastic gradients, it avoids many of the assumptions typically required by stochastic algorithms. As highlighted in Table 1, all algorithms, except for PISA and FedAvg10, have drawn identically and independently distributed (IID) samples to derive stochastic gradients for unbiased gradient estimation. However, real-world data are often heterogeneous (that is, non-IID), a phenomenon commonly referred to as statistical or data heterogeneity59,60,61,62, which poses significant challenges to the convergence of these algorithms for IID datasets.

Table 1 Assumptions imposed by different stochastic algorithms for convergence

In addition, the table also presents the algorithmic convexity of two types of convergence. Type I convergence, F(wT) − F* = B, refers to a rate B at which objective values {F(wT)} of generated sequence {wT} approach F*, where T is the number of iterations and F* is the limit of sequence {F(wT)} or a function value at a stationary point. For instance, B = O(1/T) indicates a sublinear convergence, while B = O(γT) with γ (0, 1) implies a linear rate. Type II convergence describes how quickly the length, {||F(wT)||}, of gradients diminishes, reflecting the stationarity of the iterates. Therefore, ASVRG-ADMM57 and PISA achieve the best Type I convergence rate. However, the former imposes several restrictive assumptions, while PISA requires a single assumption (that is, condition 12). We emphasize that condition 12 is closely related to the locally Lipschitz continuity of the gradient, which is a mild assumption. It is weaker than condition 3, and any twice-continuously differentiable function satisfies condition 12. More importantly, we eliminate the need for conditions on the boundedness of the stochastic gradient, the second moment of the stochastic gradient and variance. This makes PISA well-suited for addressing the challenges associated with data heterogeneity, an open problem in federated learning60,62.

High numerical performance for various applications

The effectiveness of PISA and its two variants, SISA and NSISA, is demonstrated through comparisons with many state-of-the-art optimizers using several deep models: vision models, LLMs, reinforcement learning models, generative adversarial networks (GANs) and recurrent neural networks, highlighting its great potential for extensive applications. In particular, numerical experiments on heterogeneous datasets corroborate our claim that PISA effectively addresses this challenge.

Results

This section evaluates the performance of SISA and NSISA in comparison with various established optimizers to process several deep models: vision models, LLMs, reinforcement learning models, recurrent neural networks and GANs. All LLM-targeted experiments are conducted on four NVIDIA H100-80GB graphical processing units, while the remaining experiments are run on a single such graphical processing unit. The source codes are available at https://github.com/Tracy-Wang7/PISA. Due to space limitations, details on hyperparameter setup, effects of key hyperparameters and more numerical evaluations, including experiments for reinforcement learning models and recurrent neural networks, are provided in Supplementary Section 1.

Data heterogeneity

To evaluate the effectiveness of the proposed algorithms in addressing the challenge of data heterogeneity, we design some experiments within a centralized federated learning framework. In this setting, m clients collaboratively learn a shared parameter under a central server, with each client i holding a local dataset \({{\mathcal{D}}}_{i}\). Note that our three algorithms are well-suited for centralized federated learning tasks. As shown in Algorithm 2, clients update their local parameters \(({\bf{w}}_{i}^{\ell +1},{\boldsymbol{\pi }}_{i}^{\ell +1})\) in parallel, and the server aggregates them to compute global parameter w+2.

In the experiment, we focus on heterogeneous datasets \({{\mathcal{D}}}_{1},{{\mathcal{D}}}_{2},\ldots ,{{\mathcal{D}}}_{m}\). The heterogeneity may arise from label distribution skew, feature distribution skew or quantity skew61. Empirical evidence has shown that the label distribution skew presents a significantly greater challenge to current FL methods compared with the other two types. Therefore, we adopt this skew as the primary setting for evaluating data heterogeneity in our experiments.

To reflect the fact that most benchmark datasets contain 10 image classes, we set m = 10. Four datasets, MNIST, CIFAR-10, FMNIST and Adult, are used for the experiment. As shown in Table 2, the configuration ‘s label’ indicates that each client holds data containing s distinct label classes, where s = 1, 2, 3. For example, in the ‘2-label’ setting, \({{\mathcal{D}}}_{1}\) may include two labels 1 and 2, while \({{\mathcal{D}}}_{2}\) contains two labels 2 and 3. To ensure better performance, we used a mini-batch training strategy with multiple local updates before each aggregation step for the five baseline algorithms. In contrast, SISA keeps using one local update per aggregation step and still achieves competitive testing accuracy with far fewer local updates. Notably, on the MNIST dataset under the 1-label skew setting, the best accuracy achieved by other algorithms is 54.33%, whereas SISA reaches 94.97%, demonstrating a significant improvement.

Table 2 Testing accuracy under data heterogeneity with skewed label distributions
Table 3 Top 1 accuracy (%) obtained by different algorithms

Classification by vision models

We perform classification tasks on two well-known datasets, ImageNet63 and CIFAR-1064, and evaluate performance based on convergence speed and testing accuracy. Convergence performance is presented in Supplementary Section 1.3. We now focus on testing accuracy, following the experimental setup and hyperparameter configurations from ref. 65. In this experiment, a small batch size (for example, \(| {{\mathcal{D}}}_{i}| =128\) for every i [m]) is used during training. As SISA updates its hyperparameters frequently, it uses a different learning-rate scheduler from that in ref. 65. For fair comparison, each baseline optimizer is tested with both its default scheduler and the one used by SISA, and the best result is reported. On CIFAR-10, we train three models: VGG-1166, ResNet-3467 and DenseNet-12168. As shown in Table 3, SISA achieves the highest testing accuracy on ResNet-34 and DenseNet-121, and performs slightly worse than SGD-M on VGG-11. In addition, Fig. 1 illustrates the training loss and testing accuracy over training epochs for VGG-11 and ResNet-34, demonstrating the fast convergence and high accuracy of SISA. For ImageNet, we train a ResNet-18 model and report the testing accuracy in Table 3. SISA outperforms most adaptive methods but is slightly behind SGD-M and AdaBelief.

Fig. 1: Training loss and testing accuracy over epochs for vision models on CIFAR-10.
figure 1

The left panels present the training and the right panels present the testing performances over epochs for nine benchmarked algorithms on two vision models: VGG-11 (top panels) and ResNet-34 (bottom panels).

Training LLMs

In this subsection, we apply SISA and NSISA to train several GPT2 models69 and compare it with advanced optimizers such as AdamW, Muon, Shampoo, SOAP and Adam-mini that have been proved effective in training LLM. The comparison between SISA and AdamW is provided in Supplementary Section 1.4. We now compare NSISA with the other four optimizers, with the aim of tuning three GPT2 models: GPT2-Nano (125M), GPT2-Medium (330M) and GPT2-XL (1.5B). The first model follows the setup described in ref. 27, while the last two follow the configuration from ref. 29. We use the FineWeb dataset, a large collection of high-quality web text, to fine-tune GPT2 models, enabling them to generate more coherent and contextually relevant web-style content. The hyperparameters of NSISA remain consistent across the experiments for the three GPT2 models.

The comparison of different algorithms in terms of the memory overhead is given in Supplementary Section 1.4. The experimental results are presented in Fig. 2. We evaluate each optimizer’s training performance from two perspectives: the number of training tokens consumed and the wall-clock time required. To reduce wall-clock time, we implemented NSISA using parallel computation. From Fig. 2, as the number of GPT2 parameters increases from Nano to Medium to XL, the validation loss gap between NSISA and other baseline optimizers widens. In particular, for GPT2-XL, the last figure demonstrated evident advantage of NSISA in terms of the wall-clock time.

Fig. 2: Validation loss over tokens and wall-clock time (in seconds) for three GPT2 models.
figure 2

The left three panels present the validation losses over training tokens, and the right three panels present the wall-clock times for six benchmarked algorithms on three GPT2 models: GPT2-NANO (top panels), GPT2-MEDIUM (middle panels) and GPT2-XL (bottom panels).

Generation tasks by GANs

GANs involve alternating updates between the generator and discriminator in a minimax game, making training notoriously unstable70. To assess optimizer stability, we compare SISA with adaptive methods such as Adam, RMSProp and AdaBelief, which are commonly recommended to enhance both stability and efficiency in training GANs65,71. We use two popular architectures, the Wasserstein GAN (WGAN72) and WGAN with gradient penalty (WGAN-GP71), with a small model and vanilla CNN generator on the CIFAR-10 dataset.

To evaluate performance, we compute the Fréchet inception distance (FID)73 every 10 training epochs, measuring the distance between 6,400 generated images and 60,000 real images. The FID score is a widely used metric for generative models, assessing both image quality and diversity. The lower FID values indicate better generated images. We follow the experimental setup from ref. 65 and run each optimizer five times independently. The corresponding mean and standard deviation of training FID over epochs are presented in Fig. 3. Clearly, SISA achieves the fastest convergence and the lowest FID. After training, 64,000 fake images are produced to compute the testing FID, and SISA outperforms the other three optimizers, demonstrating its superior stability and effectiveness in training GANs. Detailed testing FID scores can be found in Supplementary Section 1.7.

Fig. 3: Training FID over epochs for GANs on Cifar-10.
figure 3

The left and right panels, respectively, present the training FID over training epochs for four benchmarked algorithms on the WGAN and WGAN_GP architectures.

Conclusion

This paper introduces an ADMM-based algorithm that leverages stochastic gradients and solves subproblems inexactly, significantly accelerating computation. The preconditioning framework allows it to incorporate popular second-moment schemes, enhancing training performance. Theoretical guarantees, based solely on the Lipschitz continuity of the gradients, make the algorithm suitable for heterogeneous datasets, effectively addressing an open problem in stochastic optimization for distributed learning. The algorithm demonstrates superior generalization across a wide range of architectures, datasets and tasks, outperforming well-known deep learning optimizers.

Methods

Organization and notation

This section is organized as follows: in the next subsection, we introduce the main model and develop the proposed algorithm, PISA. The section ‘Convergence of PISA’ provides rigorous proofs of the convergence. The section ‘Second moment’ specifies the precondition by the second moment to derive a variation of PISA, termed SISA.

Let [m] {1, 2, …, m}, where ‘ ’ means ‘define’. The cardinality of a set \({\mathcal{D}}\) is written as \(| {\mathcal{D}}|\). For two vectors w and v, their inner product is denoted by 〈w, viwivi. Let be the Euclidean norm for vectors, namely \(\parallel {\bf{w}}\parallel =\sqrt{\langle {\bf{w}},{\bf{w}}\rangle }\), and the Spectral norm for matrices. A ball with a positive radius r is written as \({\mathbb{N}}(r):=\{{\bf{w}}:\parallel {\bf{w}}\parallel \le r\}\). A symmetric positive semi-definite matrix Q is written as Q 0. Then P Q means that PQ 0. Denote the identity matrix by I and let 1 be the vector with all entries being 1. We write

$$\begin{array}{l}\begin{array}{l}{{\varPi }}\,=({\boldsymbol{\pi }}_{1},{\boldsymbol{\pi }}_{2},\ldots ,{\boldsymbol{\pi }}_{m}),\,\,\,{{W}}\,=({\bf{w}}_{1},{\bf{w}}_{2},\ldots ,{\bf{w}}_{m}),\\ {{M}}=({\bf{m}}_{1},{\bf{m}}_{2},\ldots ,{\bf{m}}_{m}),\,\,\,{\boldsymbol{\sigma }}=({\boldsymbol{\sigma }}_{1},{\boldsymbol{\sigma }}_{2},\ldots ,{\boldsymbol{\sigma }}_{m}).\end{array}\end{array}$$

Similar rules are also used for the definitions of Π, W, M and σ.

Preconditioned inexact SADMM

We begin this subsection by introducing the mathematical optimization model for general distributed learning. Then we go through the development of the algorithm.

Model description

Suppose we are given a set of data as \({\mathcal{D}}:=\left\{{\mathbf{x}}_{t}:t\right.\)\(\left.=1,2,\ldots ,|{\mathcal{D}}|\right\}\), where xt is the tth sample. Let f(w; xt) be a function (such as neural networks) parameterized by w and sampled by xt. The total loss function on \({\mathcal{D}}\) is defined by \({\sum }_{{\bf{x}}_{t}\in {\mathcal{D}}}\,f({\bf{w}};{\bf{x}}_{t})/|{\mathcal{D}}|\). We then divide data \({\mathcal{D}}\) into m disjoint batches, namely, \({\mathcal{D}}={{\mathcal{D}}}_{1}\cup {{\mathcal{D}}}_{2}\cup \ldots \cup {{\mathcal{D}}}_{m}\) and \({{\mathcal{D}}}_{i}\cap {{\mathcal{D}}}_{{i}^{{\prime} }}={\rm{\varnothing }}\) for any two distinct i and \({i}^{{\prime} }\). Denote

$$\displaystyle{{H}_{i}({\mathbf{w}};{\mathcal{D}}_{i}):=\frac{1}{|{\mathcal{D}}_{i}|}\mathop{\sum }\limits_{{\mathbf{x}}_{t}\in {\mathcal{D}}_{i}}f ({\mathbf{w}};{\mathbf{x}}_{t})\,\,\,{\mathrm{and}}\,\,\,{\alpha }_{i}:=\frac{|{\mathcal{D}}_{i}|}{|{\mathcal{D}}|}.}$$
(1)

Clearly, \({\sum }_{i=1}^{m}{\alpha }_{i}=1\). Now, we can rewrite the total loss as follows:

$$\displaystyle{\frac{1}{|{\mathcal{D}}|}\mathop{\sum }\limits_{{\mathbf{x}}_{t}\in {\mathcal{D}}}f({\mathbf{w}};{\mathbf{x}}_{t})=\frac{1}{|{\mathcal{D}}|} \mathop{\sum }\limits_{i=1}^{m}\mathop{\sum }\limits_{{\mathbf{x}}_{t}\in {{\mathcal{D}}}_{i}}f({\mathbf{w}};{{\mathbf{x}}_{t}})=\mathop{\sum }\limits_{i=1}^{m}{\alpha }_{i} {H}_{i}({\mathbf{w}};{{\mathcal{D}}}_{i}).}$$

The task is to learn an optimal parameter to minimize the following regularized loss function:

$${\mathop{\rm{min}}\limits_{\mathbf{w}}\,\mathop{\sum}\limits_{i=1}^{m}{\alpha}_{i}{H}_{i}({\mathbf{w}};{\mathcal{D}}_{i})+\frac{\mu}{2}\parallel {\mathbf{w}}{\parallel}^{2},}$$
(2)

where μ ≥ 0 is a penalty constant and ||w||2 is a regularization.

Main model

Throughout the paper, we focus on the following equivalent model of problem (2):

$${{F}^{\ast}:=\mathop{\rm{min}}\limits_{{\mathbf{w}},W}\,\mathop{\sum }\limits_{i=1}^{m}{\alpha }_{i}{F}_{i}({\mathbf{w}}_{i})+\frac{\lambda }{2}\parallel {\mathbf{w}}{\parallel }^{2},\,\,\,{\rm{s.t.}}\,\,{\mathbf{w}}_{i}={\mathbf{w}},\,i\in [m],}$$
(3)

where λ [0, μ] and

$$\begin{array}{rl}{F}_{i}({\mathbf{w}}) & :={F}_{i}({\mathbf{w}};{\mathcal{D}}_{i}):={H}_{i}({\mathbf{w}};{\mathcal{D}}_{i})+\dfrac{\mu -\lambda }{2}\parallel {\mathbf{w}}{\parallel }^{2},\\ F({\mathbf{w}}) & :=\displaystyle{\mathop{\sum }\limits_{i=1}^{m}}{\alpha }_{i}{F}_{i}({\mathbf{w}}),\,\,\,{F}_{\lambda }({\mathbf{w}}):=F({\mathbf{w}})+\dfrac{\lambda }{2}\parallel {\mathbf{w}}{\parallel }^{2}.\end{array}$$
(4)

In problem (3), m auxiliary variables wi are introduced in addition to the global parameter w. We emphasize that problems (2) and (3) are equivalent in terms of their optimal solutions but are expressed in different forms when λ [0, μ), and they are identical when λ = μ. Throughout this work, we assume that optimal function value F* is bounded from below, namely, F* > − .

The algorithmic design

When using ADMM to solve problem (3), we need its associated augmented Lagrange function, which is defined as follows:

$$\begin{array}{rcl}{\mathcal{L}}\left({\mathbf{w}},W,{{\varPi }};{\boldsymbol{\sigma }}\right) & := & \displaystyle{\mathop{\sum }\limits_{i=1}^{m}}{\alpha }_{i}{L}_{i}({\mathbf{w}},{\mathbf{w}}_{i},{\boldsymbol{\pi }}_{i};{\sigma }_{i})+\dfrac{\lambda }{2}\parallel \left\{{\mathbf{w}}{\parallel }^{2},\right.\\ {L}_{i}({\mathbf{w}},{\mathbf{w}}_{i},{\boldsymbol{\pi }}_{i};{\sigma }_{i}) & := & {F}_{i}({\mathbf{w}}_{i})+\langle {\boldsymbol{\pi }}_{i},{\mathbf{w}}_{i}-{\mathbf{w}}\rangle +\dfrac{{\sigma }_{i}}{2}\parallel {\mathbf{w}}_{i}-{\mathbf{w}}{\parallel }^{2},\end{array}$$
(5)

where σi > 0 and πi, i [m] are the Lagrange multipliers. Based on the above augmented Lagrange function, the conventional ADMM updates each variable in (w, W, Π) iteratively. However, we modified the framework as follows. Given initial point (w0, W0, Π0; σ0), the algorithm performs the following steps iteratively for = 0, 1, 2, …:

$${\mathbf{w}}^{\ell +1}={\rm{arg}} \mathop{\rm{min}} \limits_{\mathbf{w}}{\mathcal{L}}\left({\mathbf{w}},{W}^{\ell},{{{\Pi}}}^{\ell };{{\boldsymbol{\sigma}}}^{\ell}\right),$$
(6a)
$${\mathbf{w}}_{i}^{\ell +1}={\rm{arg}} \mathop {\rm{min}}\limits_{{\mathbf{w}}_{i}}{L}_{i}({\mathbf{w}}^{\ell +1},{\mathbf{w}}_{i},{\boldsymbol{\pi }}_{i}^{\ell };{\sigma }_{i}^{\ell +1})+\dfrac{{\rho }_{i}}{2}\langle {\mathbf{w}}_{i}-{\mathbf{w}}^{\ell +1},{Q}_{i}^{\ell +1}({\mathbf{w}}_{i}-{\mathbf{w}}^{\ell +1})\rangle ,$$
(6b)
$${\boldsymbol{\pi }}_{i}^{\ell +1}={\boldsymbol{\pi }}_{i}^{\ell }+{\sigma }_{i}^{\ell +1}({\mathbf{w}}_{i}^{\ell +1}-{\mathbf{w}}^{\ell +1}),$$
(6c)

for each i [m], where ρi > 0, both scalar \({\sigma }_{i}^{\ell +1}\) and matrix \({Q}_{i}^{\ell +1}\succcurlyeq 0\) will be updated properly. Hereafter, superscripts and + 1 in \({\sigma }_{i}^{\ell }\) and \({\sigma }_{i}^{\ell +1}\) stand for the iteration number rather than the power. Here \({Q}_{i}^{\ell +1}\) is commonly referred to as an (adaptively) preconditioning matrix in preconditioned gradient methods26,74,75,76.

Remark 1

The primary distinction between algorithmic framework (equation (6a–c)) and conventional ADMM lies in the inclusion of a term \(\frac{{\rho }_{i}}{2}\langle {\mathbf{w}}_{i}-{\mathbf{w}}^{\ell +1},{Q}_{i}^{\ell +1}({\mathbf{w}}_{i}-{\mathbf{w}}^{\ell +1})\rangle\). This term enables the incorporation of various forms of useful information, such as second-moment, second-order information (for example, Hessian) and orthogonalized momentum by Newton–Schulz iterations, thereby enhancing the performance of the proposed algorithms; see the section ‘Second moment’ for more details.

One can check that subproblem (6a) admits a closed-form solution outlined in equation (8). For subproblem (6b), to accelerate the computational speed, we solve it inexactly by

$$\begin{array}{rcl}{\mathbf{w}}_{i}^{\ell +1} & =&{\rm{arg}}\, \mathop{\rm{min}}\limits_{{\mathbf{w}}_{i}}\,\langle {\boldsymbol{\pi }}_{i}^{\ell },{\mathbf{w}}_{i}-{\mathbf{w}}^{\ell +1}\rangle +\dfrac{{\sigma }_{i}^{\ell +1}}{2}\parallel {\mathbf{w}}_{i}-{\mathbf{w}}^{\ell +1}{\parallel }^{2}\\ & +&{F}_{i}({\mathbf{w}}^{\ell +1})+\langle \nabla {F}_{i}({\mathbf{w}}^{\ell +1};{{\mathcal{B}}}_{i}^{\ell +1}),{\mathbf{w}}_{i}-{\mathbf{w}}^{\ell +1}\rangle \\ & +&\dfrac{{\rho }_{i}}{2}\left\langle {\mathbf{w}}_{i}-{\mathbf{w}}^{\ell +1},Q_{i}^{\ell +1}\left({\mathbf{w}}_{i}-{\mathbf{w}}^{\ell +1}\right)\right\rangle \\ & =&{\mathbf{w}}^{\ell +1}-{\left({\sigma }_{i}^{\ell +1}{I}+{\rho }_{i}{Q}_{i}^{\ell +1}\right)}^{-1}\left({\boldsymbol{\pi }}_{i}^{\ell }+\nabla {F}_{i}({\mathbf{w}}^{\ell +1};{{\mathcal{B}}}_{i}^{\ell +1})\right).\end{array}$$
(7)

Algorithm 1

PISA.

Divide \({\mathcal{D}}\) into m disjoint batches \(\{{{\mathcal{D}}}_{1},{{\mathcal{D}}}_{2},\ldots ,{{\mathcal{D}}}_{m}\}\) and calculate αi by equation (1).

Initialize \({\mathbf{w}}^{0}={\mathbf{w}}_{i}^{0}={\boldsymbol{\pi }}_{i}^{0}=0\), γi [3/4, 1), and \(({\sigma }_{i}^{0},{\eta }_{i},{\rho }_{i}) > 0\) for each i [m].

for = 0, 1, 2, … do

$$\mathbf{w}^{\ell +1}=\dfrac{{\sum }_{i=1}^{m}{\alpha }_{i}({\sigma }_{i}^{\ell }\mathbf{w}_{i}^{\ell }+\boldsymbol{\pi }_{i}^{\ell })}{{\sum }_{i=1}^{m}{\alpha }_{i}{\sigma }_{i}^{\ell }+\lambda }.$$
(8)

 for i = 1, 2, …, m do

$$\begin{array}{l}{\text{Randomly draw a mini-batch}}\,{\mathcal{B}}_{i}^{\ell +1}\subseteq{{D}}_{i}.\\{\text{Calculate}}\,{\mathbf{g}}_{i}^{\ell +1}={\rm{\nabla }}{F}_{i}({\mathbf{w}}^{\ell +1};{\mathcal{B}}_{i}^{\ell +1}).\end{array}$$
(9)
$${\rm{Choose}}\,{Q}_{i}^{\ell +1}\,{\rm{to}}\,{\rm{satisfy}}\,{\eta }_{i}I\succcurlyeq {Q}_{i}^{\ell +1}\succcurlyeq 0.$$
(10)
$${\sigma }_{i}^{\ell +1}={\sigma }_{i}^{\ell }/{\gamma }_{i}.$$
(11)
$${\mathbf{w}}_{i}^{\ell +1}={\mathbf{w}}^{\ell +1}-{({\sigma }_{i}^{\ell +1}{I}+{\rho }_{i}{Q}_{i}^{\ell +1})}^{-1}({\boldsymbol{\pi }}_{i}^{\ell }+{\mathbf{g}}_{i}^{\ell +1}).$$
(12)
$${\boldsymbol{\pi }}_{i}^{\ell +1}={\boldsymbol{\pi }}_{i}^{\ell }+{\sigma }_{i}^{\ell +1}(\mathbf{w}_{i}^{\ell +1}-\mathbf{w}^{\ell +1}).$$
(13)

 end

end

This update admits three advantages. First, it solves problem (6b) by a closed-form solution, namely, the second equation in (7), reducing the computational complexity. Second, we approximate Fi(w) using its first-order approximation at w+1 rather than \({\mathbf{w}}_{i}^{\ell }\), which facilitates each batch parameter \({\mathbf{w}}_{i}^{\ell +1}\) to tend to w+1 quickly, thereby accelerating the overall convergence. Finally, \(\nabla {F}_{i}({\mathbf{w}}^{\ell +1};{{\mathcal{B}}}_{i}^{\ell +1})\) serves as a stochastic approximation of true gradient \(\nabla {F}_{i}({\mathbf{w}}^{\ell +1})=\nabla {F}_{i}({\mathbf{w}}^{\ell +1};{{\mathcal{D}}}_{i})\), as defined by equation (4), where \({{\mathcal{B}}}_{i}^{\ell +1}\) is a random sample from \({{\mathcal{D}}}_{i}\). By using sub-batch datasets \(\{{{\mathcal{B}}}_{1}^{\ell +1},\ldots ,{{\mathcal{B}}}_{m}^{\ell +1}\}\) in every iteration, rather than full data \({\mathcal{D}}=\{{{\mathcal{D}}}_{1},\ldots ,{{\mathcal{D}}}_{m}\}\), the computational cost is significantly reduced. Overall, on the basis of these observations, we name our algorithm PISA, as described in Algorithm 1.

Another advantageous property of PISA is its ability to perform parallel computation, which stems from the parallelism used in solving subproblems in ADMM. At each iteration, m nodes (that is, i = 1, 2, , m) update their parameters by equations (9)–(13) in parallel, thereby enabling the processing of large-scale datasets. Moreover, when specifying the preconditioning matrix, \({Q}_{i}^{\ell +1}\), as a diagonal matrix (as outlined in the section ‘Second moment’) and sampling \({{\mathcal{B}}}_{i}^{\ell +1}\) with small batch sizes, each node exhibits significantly low computational complexity, facilitating fast computation.

Convergence of PISA

In this subsection, we aim to establish the convergence property of Algorithm 1. To proceed with that, we first define a critical bound by

$${\varepsilon }_{i}(r):=\mathop{\rm{sup}}\limits_{{{\mathcal{B}}}_{i},{{\mathcal{B}}}_{i}^{{\prime} }\subseteq {{\mathcal{D}}}_{i},\mathbf{w}\in {{{\mathbb{N}}}}(r)}64{\left\Vert \nabla {F}_{i}(\mathbf{w};{{\mathcal{B}}}_{i})-\nabla {F}_{i}(\mathbf{w};{{\mathcal{B}}}_{i}^{{\prime} })\right\Vert }^{2},\,\,\,\forall \,i\in [m].$$
(14)

Lemma 1

εi(r) < for any given r (0, ) and any i [m].

The proof of Lemma 1 is given in Supplementary Section 2.3. One can observe that εi(r) = 0 for any r > 0 if we take the full batch data in each step, namely, choosing \({{\mathcal{B}}}_{i}^{\ell }={({{\mathcal{B}}}_{i}^{\ell })}^{{\prime} }={{\mathcal{D}}}_{i}\) for every i [m] and all  ≥ 1. However for min-batch dataset \({{\mathcal{B}}}_{i}^{\ell }\subset {{\mathcal{D}}}_{i}\), this parameter is related to the bound of variance \({\mathbb{E}}{\left\Vert \nabla {F}_{i}(\mathbf{w};{{\mathcal{B}}}_{i})-\nabla {F}_{i}(\mathbf{w};{{\mathcal{D}}}_{i})\right\Vert }^{2}\), which is commonly assumed to be bounded for any w10,11,21,36,77. However, in the subsequent analysis, we can verify that both generated sequences {w} and \(\{{\mathbf{w}}_{i}^{\ell }\}\) fall into a bounded region \({\mathbb{N}}(\delta )\) for any i [m] with δ defined as equation (16), thereby leading to a finitely bounded εi(δ) naturally, see Lemma 2. In other words, we no longer need to assume the boundedness of the variance, \({\mathbb{E}}{\left\Vert \nabla {F}_{i}(\mathbf{w};{{\mathcal{B}}}_{i})-\nabla {F}_{i}(\mathbf{w};{{\mathcal{D}}}_{i})\right\Vert }^{2}\) for any w. This assumption is known to be somewhat restrictive, particularly for non-IID or heterogeneous datasets. Therefore, the theorems we establish in the sequel effectively address this critical challenge60,62. Therefore, our algorithm demonstrates robust performance in settings with heterogeneous data.

Convergence analysis

To establish convergence, we need the assumption below. It assumes that function f has a Lipschitz continuous gradient on a bounded region, namely, the gradient is locally Lipschitz continuous. This is a relatively mild condition. Functions with (global) Lipschitz continuity and twice-continuously differentiable functions satisfy this condition. It is known that the Lipschitz continuity of the gradient is commonly referred to as L-smoothness. Therefore, our assumption can be regarded as L-smoothness on a bounded region, which is weaker than L-smoothness.

Assumption 1

For each \(t\in [| {\mathcal{D}}| ]\), gradient f( ; xt) is Lipschitz continuous with a constant c(xt) > 0 on \({\mathbb{N}}(2\delta )\). Denote \({c}_{i}:=\mathop{\max }\limits_{{{\bf{x}}}_{t}\in {{\mathcal{D}}}_{i}}c({{\bf{x}}}_{t})\) and ri ci + μλ for each i [m].

First, given a constant σ > 0, we define a set

$$\Omega :=\left\{(\mathbf{w},{\mathit{W}}):\,\displaystyle{\mathop{\sum }\limits_{i=1}^{m}}{\alpha }_{i}\left({F}_{i}(\mathbf{w})+\dfrac{\lambda }{2}\parallel \mathbf{w}{\parallel }^{2}+\dfrac{\sigma }{2}\parallel {\mathbf{w}}_{i}-\mathbf{w}{\parallel }^{2}\right)\le F({\mathbf{w}}^{0})+\dfrac{1}{1-\gamma }\right\},$$
(15)

where \(\gamma :=\mathop{\max }\limits_{i\in [m]}{\gamma }_{i}\), based on which we further define

$$\delta :=\mathop{\rm{sup}}\limits_{(\mathbf{w},{\mathit{W}})\in \Omega }\left\{\parallel \mathbf{w}\parallel ,\parallel {\mathbf{w}}_{1}\parallel ,\parallel {\mathbf{w}}_{2}\parallel ,\ldots ,\parallel {\mathbf{w}}_{m}\parallel \right\}.$$
(16)

This indicates that any point (w, W) Ω satisfies \(\{\mathbf{w},{\mathbf{w}}_{1},\ldots ,{\mathbf{w}}_{m}\}\subseteq {\mathbb{N}}(\delta )\). Using this δ, we initialize \({{\boldsymbol{\sigma }}}^{0}:=({\sigma }_{1}^{0},{\sigma }_{2}^{0},\cdots \,,{\sigma }_{m}^{0})\) by

$${\sigma }^{0}:={\rm{min}} \{{\sigma }_{1}^{0},{\sigma }_{2}^{0},\cdots \,,{\sigma }_{m}^{0}\}\ge \,8\mathop{\rm{max}}\limits_{i\in [m]}\left\{\sigma ,\,{\rho }_{i}{\eta }_{i},\,{r}_{i},\,{\delta }^{-2},\,{\varepsilon }_{i}(2\delta )\right\}.$$
(17)

It is easy to see that Ω is a bounded set due to Fi being bounded from below. Therefore, δ is bounded and so is εi(2δ) due to Lemma 1. Hence, σ0 in equation (17) is a well-defined constant, namely, σ0 can be set as a finite positive number. For notational simplicity, hereafter, we denote

$$\begin{array}{llllll}\Delta {\mathbf{w}}^{\ell }:={\mathbf{w}}^{\ell }-{\mathbf{w}}^{\ell -1},\,&\Delta {\mathbf{w}}_{i}^{\ell }:={\mathbf{w}}_{i}^{\ell }-{\mathbf{w}}_{i}^{\ell -1},\\ \Delta {\boldsymbol{\pi }}_{i}^{\ell }:={\boldsymbol{\pi }}_{i}^{\ell }-{\boldsymbol{\pi }}_{i}^{\ell -1},\,&\Delta {\bar{\mathbf{w}}}_{i}^{\ell }:={\mathbf{w}}_{i}^{\ell }-{\mathbf{w}}^{\ell },\\ \Delta {\mathbf{g}}_{i}^{\ell }:={\mathbf{g}}_{i}^{\ell }-\nabla {F}_{i}({\mathbf{w}}^{\ell }),&{{\mathcal{L}}}^{\ell }:={\mathcal{L}}({\mathbf{w}}^{\ell },{{\mathit{W}}}^{\ell },{{\mathit{\Pi }}}^{\ell };{{\boldsymbol{\sigma }}}^{\ell }).\end{array}$$
(18)

Our first result shows the descent property of a merit function associated with \({{\mathcal{L}}}^{\ell }\).

Lemma 2

Let {(w, W, Π)} be the sequence generated by Algorithm 1with σ0 chosen as equation (17). Then the following statements are valid under Assumption 1.

  1. (1)

    For any  ≥ 0, sequence \(\{{\mathbf{w}}^{\ell },{\mathbf{w}}_{1}^{\ell },\ldots ,{\mathbf{w}}_{m}^{\ell }\}\subseteq {\mathbb{N}}(\delta )\).

  2. (2)

    For any  ≥ 0,

    $${\widetilde{{\mathcal{L}}}}^{\ell }-{\widetilde{{\mathcal{L}}}}^{\ell +1}\ge \displaystyle{\mathop{\sum }\limits_{i=1}^{m}}{\alpha }_{i}\left[\dfrac{{\sigma }_{i}^{\ell }+2\lambda }{4}{\left\Vert \Delta {\mathbf{w}}^{\ell +1}\right\Vert }^{2}+\dfrac{{\sigma }_{i}^{\ell }}{4}{\left\Vert \Delta {\mathbf{w}}_{i}^{\ell +1}\right\Vert }^{2}\right],$$
    (19)

    where \({\widetilde{{\mathcal{L}}}}^{\ell }\) is defined by

    $${\mathop{L}\limits^{ \sim }}^{\ell }:={L}^{\ell }+\displaystyle{\mathop{\sum }\limits_{i=1}^{m}}{\alpha }_{i}\left[\dfrac{8}{{\sigma }_{i}^{\ell }}{\parallel {\rho }_{i}{Q}_{i}^{\ell }\Delta {\bar{\mathbf{w}}}_{i}^{\ell }\parallel }^{2}+\dfrac{{\gamma }_{i}^{\ell }}{16(1-{\gamma }_{i})}\right].$$
    (20)

The proof of Lemma 2 is given in Supplementary Section 2.6. This lemma is derived from a deterministic perspective. Such a success lies in considering the worst case of bound εi(2δ) (that is, taking all possible selections of \(\{{{\mathcal{B}}}_{1}^{\ell },\ldots ,{{\mathcal{B}}}_{m}^{\ell }\}\) into account). On the basis of the above key lemma, the following theorem shows the sequence convergence of the algorithm.

Theorem 1

Let {(w, W, Π)} be the sequence generated by Algorithm 1with σ0 chosen as equation (17). Then the following statements are valid under Assumption 1.

  1. (1)

    Sequences \(\{{{\mathcal{L}}}^{\ell }\}\) and \(\{{\widetilde{{\mathcal{L}}}}^{\ell }\}\) converge and for any i [m],

    $$0=\mathop{\rm{lim}}\limits_{\ell \to \infty }\left\Vert \Delta {\mathbf{w}}^{\ell }\right\Vert =\mathop{\rm{lim}}\limits_{\ell \to \infty }\left\Vert \Delta {\mathbf{w}}_{i}^{\ell }\right\Vert =\mathop{\rm{lim}}\limits_{\ell \to \infty }\left\Vert \Delta {\bar{\mathbf{w}}}_{i}^{\ell }\right\Vert =\mathop{\rm{lim}}\limits_{\ell \to \infty }\left({\widetilde{{\mathcal{L}}}}^{\ell }-{{\mathcal{L}}}^{\ell }\right).$$
    (21)
  2. (2)

    Sequence \(\{({\mathbf{w}}^{\ell },{{\mathit{W}}}^{\ell },{\mathbb{E}}{{{\Pi }}}^{\ell })\}\) converges.

The proof of Theorem 1 is given in Supplementary Section 2.7. To ensure the convergence results, initial value σ0 is selected according to equation (17), which involves a hyperparameter δ. If a lower bound \(\underline{F}\) of \(\mathop{\rm{min}}\limits_{\mathbf{w}}{\sum }_{i=1}^{m}{\alpha }_{i}{F}_{i}(\mathbf{w})\) is known, then an upper bound \(\overline{\delta }\) for δ can be estimated from equations (15) and (16) by substituting Fi(w) with \(\underline{F}\). In practice, particularly in deep learning, many widely used loss functions, such as mean-squared error and cross-entropy, yield non-negative values. This observation allows us to set the lower bound as \(\underline{F}=0\). Once \(\overline{\delta }\) is estimated, it can be used in equation (17) to select σ0, without affecting the convergence guarantees. However, it is worth emphasizing that equation (17) is a sufficient but not necessary condition. Therefore, in practice, it is not essential to enforce this condition strictly when initializing σ0 in numerical experiments.

Complexity analysis

Besides the convergence established above, the algorithm exhibits the following rate of convergence under the same assumption and parameter setup.

Theorem 2

Let {(w, W, Π)} be the sequence generated by Algorithm 1with σ0 chosen as equation (17). Let w be the limit of sequence {w}. Then there is a constant C1 > 0 such that

$$\max \left\{\left\Vert {\mathbf{w}}^{\ell }-{\mathbf{w}}^{\infty }\right\Vert ,\,\left\Vert {\mathbf{w}}_{i}^{\ell }-{\mathbf{w}}^{\infty }\right\Vert ,\,\left\Vert {\mathbb{E}}{\boldsymbol{\pi }}_{i}^{\ell }+\nabla {F}_{i}({\mathbf{w}}^{\infty })\right\Vert ,\,\forall \,i\in [m]\right\}\le {C}_{1}{\gamma }^{\ell }$$
(22)

and a constant C2 > 0 such that

$$\max \left\{{F}_{\lambda }({\mathbf{w}}^{\ell }),\,{{\mathcal{L}}}^{\ell },\,{\widetilde{{\mathcal{L}}}}^{\ell }\right\}-{F}_{\lambda }({\mathbf{w}}^{\infty })\le {C}_{2}{\gamma }^{\ell }.$$
(23)

The proof of Theorem 2 is given in Supplementary Section 2.8. This theorem means that sequence \(\{({\mathbf{w}}^{\ell },{{\mathit{W}}}^{\ell },{\mathbb{E}}{{{\Pi }}}^{\ell })\}\) converges to its limit in a linear rate. To achieve such a result, we only assume Assumption 1 without imposing any other commonly used assumptions, such as those presented in Table 1.

Precondition specification

In this section, we explore the preconditioning matrix, namely, matrix \({Q}_{i}^{\ell }\). A simple and computationally efficient choice is to set \({Q}_{i}^{\ell }={{I}}\), which enables fast computation of updating \({\mathbf{w}}_{i}^{\ell +1}\) via equation (12). However, this choice is too simple to extract useful information about Fi. Therefore, several alternatives can be adopted to set \({Q}_{i}^{\ell }\).

Second-order information

Second-order optimization methods, such as Newton-type and trust region methods, are known to enhance numerical performance by leveraging second-order information, the (generalized) Hessian. For instance, if each function Fi is twice-continuously differentiable, then one can set

$${Q}_{i}^{\ell +1}={{\rm{\nabla }}}^{2}{F}_{i}({\mathbf{w}}^{\ell +1};{\mathcal{B}}_{i}^{\ell +1}),$$
(24)

where \({\nabla }^{2}{F}_{i}({\mathbf{w}}^{\ell +1};{{\mathcal{B}}}_{i}^{\ell +1})\) represents the Hessian of \({F}_{i}(\cdot ;{{\mathcal{B}}}_{i}^{\ell +1})\) at w+1. With this choice, subproblem (7) becomes closely related to second-order methods, and the update takes the form

$${\mathbf{w}}_{i}^{\ell +1}={\mathbf{w}}^{\ell +1}-{\left({\sigma }_{i}^{\ell +1}{I}+{\rho }_{i}{\nabla }^{2}{F}_{i}({\mathbf{w}}^{\ell +1};{{\mathcal{B}}}_{i}^{\ell +1})\right)}^{-1}\left({\boldsymbol{\pi }}_{i}^{\ell }+{\mathbf{g}}_{i}^{\ell +1}\right).$$

This update corresponds to a Levenberg–Marquardt step78,79 or a regularized Newton step80,81 when \({\sigma }_{i}^{\ell +1} > 0\), and reduces to the classical Newton step if \({\sigma }_{i}^{\ell +1}=0\). While incorporating the Hessian can improve performance in terms of iteration count and solution quality, it often leads to significantly high computational complexity. To mitigate this, some other effective approaches exploit the second moment derived from historical updates to construct the preconditioning matrices.

Second moment

We note that the second moment to determine an adaptive learning rate enables the improvements of the learning performance of several popular algorithms, such as RMSProp16 and Adam17. Motivated by this, we specify preconditioning matrix by using the second moment as follows:

$${Q}_{i}^{\ell +1}=\mathrm{Diag}\left(\sqrt{{\mathbf{m}}_{i}^{\ell +1}}\right),$$
(25)

where Diag(m) is the diagonal matrix with the diagonal entries formed by m and \({\mathbf{m}}_{i}^{\ell +1}\) can be chosen flexibly as long as it satisfies that \(\parallel {\mathbf{m}}_{i}^{\ell +1}{\parallel }_{\infty }\le {\eta }_{i}^{2}\). Here, \(\parallel \mathbf{m}{\parallel }_{\infty }\) is the infinity norm of m. We can set \({\mathbf{m}}_{i}^{\ell +1}\) as follows

$${\mathbf{m}}_{i}^{\ell +1}={{\text{min}}}\left\{{\mathop{\mathbf{m}}\limits^{ \sim }}_{i}^{\ell +1},\,{\eta }_{i}^{2}{\bf{1}}\right\},$$
(26)

where \({\widetilde{\mathbf{m}}}_{i}^{\ell +1}\) can be updated by

$$\begin{array}{rcl}\,{{\mathrm{Scheme}}\; {\mathrm{I}}\; :} & & {\widetilde{\mathbf{m}}}_{i}^{\ell +1}={\widetilde{\mathbf{m}}}_{i}^{\ell }+\left({\boldsymbol{\pi }}_{i}^{\ell }+{\mathbf{g}}_{i}^{\ell +1}\right)\odot \left({\boldsymbol{\pi }}_{i}^{\ell }+{\mathbf{g}}_{i}^{\ell +1}\right),\\ \,{{\mathrm{Scheme}}\; {\mathrm{II}}\; :} & & {\widetilde{\mathbf{m}}}_{i}^{\ell +1}={\beta }_{i}{\widetilde{\mathbf{m}}}_{i}^{\ell }+(1-{\beta }_{i})\left({\boldsymbol{\pi }}_{i}^{\ell }+{\mathbf{g}}_{i}^{\ell +1}\right)\odot \left({\boldsymbol{\pi }}_{i}^{\ell }+{\mathbf{g}}_{i}^{\ell +1}\right),\\ \,{{\mathrm{Scheme}}\; {\mathrm{III}}\; :} & & {\mathbf{n}}_{i}^{\ell +1}={\beta }_{i}{\mathbf{n}}_{i}^{\ell }+(1-{\beta }_{i})\left({\boldsymbol{\pi }}_{i}^{\ell }+{\mathbf{g}}_{i}^{\ell +1}\right)\odot \left({\boldsymbol{\pi }}_{i}^{\ell }+{\mathbf{g}}_{i}^{\ell +1}\right),\\ & & {\widetilde{\mathbf{m}}}_{i}^{\ell +1}={\mathbf{n}}_{i}^{\ell +1}/(1-{\beta }_{i}^{\ell +1}),\end{array}$$
(27)

where \({\widetilde{\mathbf{m}}}_{i}^{0}\) and \({\mathbf{n}}_{i}^{0}\) are given, βi (0, 1), and \({\beta }_{i}^{\ell }\) stands for power of βi. These three schemes resemble the ones used by AdaGrad15, RMSProp16 and Adam17, respectively. Putting equation (25) into Algorithm 1 gives rise to Algorithm 2. We term it SISA, an abbreviation for the second moment-based inexact SADMM. Compared with PISA in Algorithm 1, SISA admits three advantages.

  1. (1)

    It is capable of incorporating various schemes of the second moment, which may enhance the numerical performance of SISA significantly.

  2. (2)

    One can easily check that \({\eta }_{i}{{I}}\succcurlyeq {Q}_{i}^{\ell +1}\succcurlyeq 0\) for each batch i [m] and all  ≥ 1. Therefore, equation (25) enables us to preserve the convergence property as follows.

Theorem 3

Let {(w, W, Π)} be the sequence generated by Algorithm 2with σ0 chosen as equation (17). Then under Assumption 1, all statements in Theorems 1and 2are valid.

  1. (3)

    Such a choice of \({Q}_{i}^{\ell +1}\) enables the fast computation compared with update \({\mathbf{w}}_{i}^{\ell +1}\) by equation (7). In fact, since operation u/v denotes element-wise division, the complexity of computing equation (28) is O(p), where p is the dimension of \({\mathbf{w}}_{i}^{\ell +1}\), whereas the complexity of computing equation (12) is O(p3).

Algorithm 2

SISA

Divide \({\mathcal{D}}\) into m disjoint batches \(\{{{\mathcal{D}}}_{1},{{\mathcal{D}}}_{2},\ldots ,{{\mathcal{D}}}_{m}\}\) and calculate αi by equation (1).

Initialize \({\mathbf{w}}^{0}={\mathbf{w}}_{i}^{0}={\boldsymbol{\pi }}_{i}^{0}=0\), γi [3/4, 1), and \(({\sigma }_{i}^{0},{\eta }_{i},{\rho }_{i}) > 0\) for each i [m].

for = 0, 1, 2, … do

$${\mathbf{w}}^{\ell +1}=\dfrac{{\sum }_{i=1}^{m}{\alpha }_{i}\left({\sigma }_{i}^{\ell }{\mathbf{w}}_{i}^{\ell }+{\boldsymbol{\pi }}_{i}^{\ell }\right)}{{\sum }_{i=1}^{m}{\alpha }_{i}{\sigma }_{i}^{\ell }+\lambda }.$$

 for i = 1, 2, …, m do

$$\begin{array}{l}{\mathrm{Randomly}}\,{\mathrm{draw}}\,{\mathrm{a}}\,{\mathrm{mini}}{\mbox{-}}{\mathrm{batch}}\,{\mathcal{B}}_{i}^{{\ell} +1}\subseteq{\mathcal{D}}_{i}.\\{\mathrm{Compute}}\,{\mathbf{g}}_{i}^{{\ell} +1}=\nabla {F}_{i}({\mathbf{w}}^{{\ell} +1};{\mathcal{B}}_{i}^{\ell +1}).\\{\rm{Choose}}\, {{\mathbf{m}}_{i}^{\ell+1}}\,{{\mathrm{to}}\,{\mathrm{satisfy}}}\,{\|{\mathbf{m}}_{i}^{{\ell}+1}\|_{\infty}\leq {\eta}_{i}^{2}}.\\{\sigma}_{i}^{\ell +1} = {\sigma}_{i}^{\ell }/{\gamma}_{i}.\\ {\mathbf{w}}_{i}^{{\ell}+1} = {\mathbf{w}}^{\ell +1}-\dfrac{{\boldsymbol{\pi}}_{i}^{\ell }+{\mathbf{g}}_{i}^{{\ell} +1}}{{\sigma}_{i}^{\ell +1}+{\rho}_{i}\sqrt{{\mathbf{m}}_{i}^{\ell +1}}}.\\ {\boldsymbol{\pi}}_{i}^{\ell +1} = {\boldsymbol{\pi }}_{i}^{\ell}+{\sigma}_{i}^{\ell +1}({\mathbf{w}}_{i}^{\ell +1}-{\mathbf{w}}^{\ell +1}).\end{array}$$
(28)

 end

end

Algorithm 3

NSISA

Divide \({\mathcal{D}}\) into m disjoint batches \(\{{{\mathcal{D}}}_{1},{{\mathcal{D}}}_{2},\ldots ,{{\mathcal{D}}}_{m}\}\) and calculate αi by equation (1).

Initialize \({\mathbf{w}}^{0}={\mathbf{w}}_{i}^{0}={\boldsymbol{\pi }}_{i}^{0}={\mathbf{b}}_{i}^{0}=0\), γi [3/4, 1), ϵi (0, 1) and \(({\sigma }_{i}^{0},{\rho }_{i},{\mu }_{i}) > 0\) for each i [m].

for = 0, 1, 2, … do

$${\mathbf{w}}^{\ell +1}=\dfrac{{\sum }_{i=1}^{m}{\alpha }_{i}\left({\sigma }_{i}^{\ell }{\mathbf{w}}_{i}^{\ell }+{\boldsymbol{\pi }}_{i}^{\ell }\right)}{{\sum }_{i=1}^{m}{\alpha }_{i}{\sigma }_{i}^{\ell }+\lambda }.$$

 for i = 1, 2, …, m do

$$\begin{array}{l}{\mathrm{Randomly}}\,{\mathrm{draw}}\,{\mathrm{a}}\,{\mathrm{mini}}{\hbox{-}}{\mathrm{batch}}\,{{\mathcal{B}}}_{i}^{\ell +1}\subseteq{{\mathcal{D}}}_{i}.\\{\mathrm{Compute}}\,{\mathbf{g}}_{i}^{\ell +1}=\nabla {F}_{i}({\mathbf{w}}^{\ell +1};{{\mathcal{B}}}_{i}^{\ell +1}).\\{\mathbf{b}}_{i}^{\ell +1} = {\mu }_{i}{\mathbf{b}}_{i}^{\ell }+{\mathbf{g}}_{i}^{\ell +1}.\\ {\mathbf{o}}_{i}^{\ell +1} = {\mathtt{NewtonSchulz}}({\mathbf{b}}_{i}^{\ell +1}).\\ {\sigma }_{i}^{\ell +1} = {\sigma }_{i}^{\ell }/{\gamma }_{i}.\\ {\mathbf{w}}_{i}^{\ell +1} = {\mathbf{w}}^{\ell +1}-\dfrac{{\boldsymbol{\pi}}_{i}^{\ell }+{\mathbf{o}}_{i}^{\ell +1}+{\epsilon}_{i}^{\ell +1}{\mathbf{v}}_{i}^{\ell +1}}{{\sigma}_{i}^{\ell +1}+{\rho }_{i}\sqrt{{\mathbf{m}}_{i}^{\ell +1}}}.\\ {\boldsymbol{\pi}}_{i}^{\ell +1} = {\boldsymbol{\pi}}_{i}^{\ell }+{\sigma}_{i}^{\ell +1}({\mathbf{w}}_{i}^{\ell +1}-{\mathbf{w}}^{\ell +1}).\end{array}$$
(29)

 end

end

Orthogonalized momentum by Newton–Schulz iterations

Recently, the authors of ref. 27 proposed an algorithm called Muon, which orthogonalizes momentum using Newton–Schulz iterations. This approach has shown promising results in fine-tuning LLMs, outperforming many established optimizers. The underlying philosophy of Muon can also inform the design of the preconditioning matrix. Specifically, we consider the two-dimensional case, namely, the trainable variable w is a matrix. Then subproblem (7) in a vector form turns to

$${\rm{vec}}({\mathbf{w}}_{i}^{\ell +1})={\rm{vec}}(\mathbf{{w}}^{\ell +1})-{({\sigma }_{i}^{\ell +1}{I}+{\rho }_{i}{Q}_{i}^{\ell +1})}^{-1}({\rm{vec}}({\boldsymbol{\pi }}_{i}^{\ell })+{\rm{vec}}({\mathbf{g}}_{i}^{\ell +1})),$$
(30)

where vec(w) denotes the column-wise vectorization of matrix w. Now, initialize \({\mathbf{b}}_{i}^{0}\) for all i [m] and μ > 0, update momentum by \({\mathbf{b}}_{i}^{\ell +1}=\mu {\mathbf{b}}_{i}^{\ell }+{\mathbf{g}}_{i}^{\ell +1}\). Let \({\mathbf{b}}_{i}^{\ell +1}={U}_{i}^{\ell +1}{\Lambda }_{i}^{\ell +1}{({V}_{i}^{\ell +1})}^{{\rm{\top }}}\) be the singular value decomposition of \({\mathbf{b}}_{i}^{\ell +1}\), where \({\Lambda }_{i}^{\ell +1}\) is a diagonal matrix and \({U}_{i}^{\ell +1}\) and \({V}_{i}^{\ell +1}\) are two orthogonal matrices. Compute \({\mathbf{o}}_{i}^{\ell +1}={U}_{i}^{\ell +1}{({V}_{i}^{\ell +1})}^{{\rm{\top }}}\) and \({\mathbf{p}}_{i}^{\ell +1}\) by

$${\mathbf{p}}_{i}^{\ell }=\dfrac{({\sigma }_{i}^{\ell +1}+{\rho }_{i}\sqrt{{\mathbf{m}}_{i}^{\ell +1}})\odot ({\boldsymbol{\pi }}_{i}^{\ell }+{\mathbf{g}}_{i}^{\ell +1})}{{\rho }_{i}({\boldsymbol{\pi }}_{i}^{\ell }+{\mathbf{o}}_{i}^{\ell +1}+{\epsilon }_{i}^{\ell +1}{\mathbf{v}}_{i}^{\ell +1})}-\dfrac{{\sigma }_{i}^{\ell +1}}{{\rho }_{i}},$$

where \({\mathbf{m}}_{i}^{\ell +1}\) can be the second moment (for example, \({\mathbf{m}}_{i}^{\ell +1}=({\boldsymbol{\pi }}_{i}^{\ell }+{\mathbf{o}}_{i}^{\ell +1})\odot ({\boldsymbol{\pi }}_{i}^{\ell }+{\mathbf{o}}_{i}^{\ell +1})\) is used in the numerical experiment), ϵi (0, 1) (here, \({\epsilon }_{i}^{\ell }\) stands for power of ϵi) and \({\mathbf{v}}_{i}^{\ell +1}\) is a matrix with (k, j)th element computed by \({({\mathbf{v}}_{i}^{\ell +1})}_{{kj}}=1\) if \({({\boldsymbol{\pi }}_{i}^{\ell }+{\mathbf{o}}_{i}^{\ell +1})}_{kj}=0\) and \({({\mathbf{v}}_{i}^{\ell +1})}_{kj}=0\) otherwise. Then we set the preconditioning matrix by

$${{Q}}_{i}^{\ell +1}={\mathrm{Diag}}\,\,\,({\mathrm{vec}}({\mathbf{p}}_{i}^{\ell +1})).$$

Substituting above choice into equation (30) derives equation (29). The idea of using equation (29) is inspired by ref. 27, where Newton–Schulz orthogonalization82,83 is used to efficiently approximate \({\mathbf{o}}_{i}^{\ell +1}\). Incorporating these steps into Algorithm 1 leads to Algorithm 3, which we refer to as NSISA. The implementation of NewtonSchulz(b) is provided in ref. 27. Below is the convergence result of NSISA.

Theorem 4

Let {(w, W, Π)} be the sequence generated by Algorithm 3with σ0 chosen as equation (17). If Assumption 1holds and \({({\mathbf{p}}_{i}^{\ell })}_{kj}\in (0,{\eta }_{i})\) for any (k, j) and  ≥ 0, all statements in Theorems 1and 2are valid.