Abstract
Deep learning models are usually trained with stochastic gradient descent-based algorithms, but these optimizers face inherent limitations, such as slow convergence and stringent assumptions for convergence. In particular, data heterogeneity arising from distributed settings poses significant challenges to their theoretical and numerical performance. Here we develop an algorithm called PISA (preconditioned inexact stochastic alternating direction method of multipliers). Grounded in rigorous theoretical guarantees, the algorithm converges under the sole assumption of Lipschitz continuity of the gradient on a bounded region, thereby removing the need for other conditions commonly imposed by stochastic methods. This capability enables the proposed algorithm to tackle the challenge of data heterogeneity effectively. Moreover, the algorithmic architecture enables scalable parallel computing and supports various preconditions, such as second-order information, second moment and orthogonalized momentum by Newton–Schulz iterations. Incorporating the last two preconditions in PISA yields two computationally efficient variants: SISA and NSISA. Comprehensive experimental evaluations for training or fine-tuning diverse deep models, including vision models, large language models, reinforcement learning models, generative adversarial networks and recurrent neural networks, demonstrate superior numerical performance of SISA and NSISA compared with various state-of-the-art optimizers.
Main
Deep models have led to extensive applications across a variety of industries. For instance, fine-tuning large language models (LLMs), such as Gemma1, GPT2 and Llama3, with user-specific data can significantly enhance the model performance. Diffusion models are increasingly used to generate personalized content, including images, audio and video, for entertainment. Vision models, such as ResNet4 and vision transformers5, are in high demand for tasks, such as image classification, segmentation and object detection on large-scale image datasets. However, the growing complexity and scale of these tasks present considerable challenges for widely used stochastic gradient descent (SGD)-based optimizers due to their inherent limitations. In this section, we review a subset of these methods, sufficient to motivate the development of our proposed approach.
SGD-based learning algorithms
It is known that the plain SGD-based algorithms, such as the original stochastic approximation approaches6,7, the parallelized SGD8 for machine learning and the newly developed federated averaging (FedAvg9,10) and local SGD (LocalSGD11) algorithms for federated learning, are frequently prone to high sensitivity to poor conditioning12 and slow convergence in high-dimensional and non-convex landscapes13. To overcome these drawbacks, SGD with momentum and adaptive learning rates has been proposed to enhance robustness and accelerate convergence. The former introduced first-order momentum to suppress the oscillation of SGD14, and the latter updated the learning rate iteratively based on historical information.
For instance, the adaptive gradient (AdaGrad15) interpolated the accumulated second-order moment of gradients to achieve the adaptive learning rate. The root mean-squared propagation (RMSProp16) exploited an exponential weighting technique to balance the distant historical information and the knowledge of the current second-order gradients, avoiding premature termination encountered by AdaGrad. The adaptive moment (Adam17) combined the first-order moment and adaptive learning rates, thereby exhibiting robustness to hyperparameters. The adaptive method setup-based gradient (AMSGrad18) took smaller learning rates than those of Adam by using a maximum value for normalizing the running average of the gradients, thereby fixing the convergence of Adam. Partially Adam (Padam19) used a similar strategy to AMSGrad with a difference in the rate power. Adam with dynamic bounds on learning rates (AdaBound20) designed a clipping mechanism on Adam-type learning rates by clipping the gradients larger than a threshold to avoid gradient explosion. Parallel restarted SGD (PRSGD21) simply benefited from a decayed power learning rate. A layerwise adaptive large batch optimization technique (Lamb22) performed per-dimension normalization with respect to the square root of the second moment used in Adam and set large batch sizes suggested by the layerwise adaptive rate scaling method (Lars23).
For some other SGD-based algorithms, one can refer to a per-dimension learning rate method based gradient descent (AdaDelta24), a variant of Adam based on the infinity norm (Adamax17), Adam with decoupled weight decay (AdamW25), a structure-aware preconditioning algorithm (Shampoo26), momentum orthogonalized by Newton–Schulz iterations (Muon27), Shampoo with Adam in the preconditioner’s eigenbasis (SOAP28), Adam with fewer learning rates (Adam-mini29) and those reviewed in a few studies30,31,32.
The aforementioned algorithms primarily leveraged the heavy ball acceleration technique to estimate the first-order and second-order moments of the gradient. An alternative approach to accelerating convergence is Nesterov acceleration, which has been theoretically proved to converge faster33,34. This has led to the development of several Nesterov-accelerated SGD algorithms. For example, Nesterov-accelerated Adam (NAdam35) integrated this technique into Adam to enhance convergence speed. More recently, an adaptive Nesterov momentum algorithm (Adan36) adopted Nesterov momentum estimation to refine the estimation of the first-order and second-order moments of the gradient, thereby improving adaptive learning-rate adjustments for faster convergence.
ADMM-based learning algorithms
The alternating direction method of multipliers (ADMM)37,38 is a promising framework for solving large-scale optimization problems because it decomposes complex problems into smaller and more manageable subproblems. This characteristic makes ADMM particularly well-suited for various challenges in distributed learning. It has demonstrated substantial potential in applications, such as image compressive sensing39, federated learning40,41,42, reinforcement learning43 and few-shot learning44. It is worth noting that there are many other distributed optimization frameworks, such as dual averaging45,46, push-sum47 and push–pull48. However, in the sequel, our focus is on the family of distributed algorithms, ADMM, as they are most directly relevant to the design and convergence analysis of our algorithm.
When applied to deep model training, early studies first relaxed optimization models before developing ADMM-based algorithms. Therefore, these methods can be classified as model-driven or model-specific approaches. They targeted the relaxations rather than the original problems, enabling better handling of the challenges associated with highly non-convex optimization landscapes. For instance, an ADMM-based algorithm in ref. 49 addressed a penalized formulation, avoiding pitfalls that hinder gradient-based methods in non-convex settings. Similarly, a deep learning ADMM algorithm50 and its variant51 were designed to enhance convergence by addressing penalization models as well. Moreover, the sigmoid-ADMM algorithm52 used the algorithmic gradient-free property to address saturation issues with sigmoid activation functions.
Despite their advantages, model-driven ADMM approaches face two critical issues. First, these algorithms exhibit high computational complexity because they require computing full gradients using the entire dataset and performing matrix inversions for weight updates, making them impractical for large-scale tasks. Furthermore, the reliance on the specific structure of optimization models limits their general applicability, as different deep models often necessitate distinct formulations.
To address the aforementioned limitations, data-driven ADMM algorithms have emerged as a promising alternative, aiming to reduce computational burdens and enhance adaptability to diverse tasks. For instance, the deterministic ADMM41,42, designed for distributed learning problems, achieved high accuracy but low computational efficiency due to the use of full dataset gradients. Stochastic ADMM (SADMM) methods tackled the inefficiency by replacing full gradients with stochastic approximations. Examples such as SADMM53 and distributed stochastic ADMM (PS-ADMM54) aimed to solve subproblems exactly and thus still involved high computational costs.
To further enhance convergence and reduce stochastic gradient variance, advanced techniques of variance reduction and Nesterov acceleration have been incorporated into SADMM. Representatives consist of stochastic average ADMM55, stochastic path-integrated differential estimator-based ADMM (SPIDER-ADMM56) and accelerated stochastic variance reduced gradient-based ADMM (ASVRG-ADMM) for convex57 and non-convex problems58. However, these enhanced methods still relied on access to full gradients for the aim of reducing stochastic gradient variance, posing challenges for large-scale applications.
Contributions
In this work, we propose a data-driven preconditioned inexact SADMM algorithm, termed PISA. It distinguishes itself from previous approaches and aims to reduce computational costs, relax convergence assumptions and enhance numerical performance. The key contributions are threefold.
A general algorithmic structure
The algorithm is based on a preconditioned inexact SADMM, combining simplicity with high generality to handle a wide range of deep learning applications. First, the use of preconditioning matrices allows us to incorporate various forms of useful information, such as first-moment and second-moment (yielding second-moment-based inexact SADMM (SISA), a variant of PISA), second-order information (for example, the Hessian) and orthogonalized momentum by Newton–Schulz iterations (leading to Newton–Schulz-based inexact SADMM (NSISA), another variant of PISA), into the updates, thereby enhancing the performance of the proposed algorithms. Moreover, the proposed algorithmic framework is inherently compatible with parallel computing, making it ideal for large-scale data settings.
Strong convergence theory under a sole assumption
PISA is proven to converge under a single assumption: the Lipschitz continuity of the gradient. Despite relying on stochastic gradients, it avoids many of the assumptions typically required by stochastic algorithms. As highlighted in Table 1, all algorithms, except for PISA and FedAvg10, have drawn identically and independently distributed (IID) samples to derive stochastic gradients for unbiased gradient estimation. However, real-world data are often heterogeneous (that is, non-IID), a phenomenon commonly referred to as statistical or data heterogeneity59,60,61,62, which poses significant challenges to the convergence of these algorithms for IID datasets.
In addition, the table also presents the algorithmic convexity of two types of convergence. Type I convergence, F(wT) − F* = B, refers to a rate B at which objective values {F(wT)} of generated sequence {wT} approach F*, where T is the number of iterations and F* is the limit of sequence {F(wT)} or a function value at a stationary point. For instance, B = O(1/T) indicates a sublinear convergence, while B = O(γT) with γ ∈ (0, 1) implies a linear rate. Type II convergence describes how quickly the length, {||∇F(wT)||}, of gradients diminishes, reflecting the stationarity of the iterates. Therefore, ASVRG-ADMM57 and PISA achieve the best Type I convergence rate. However, the former imposes several restrictive assumptions, while PISA requires a single assumption (that is, condition 12). We emphasize that condition 12 is closely related to the locally Lipschitz continuity of the gradient, which is a mild assumption. It is weaker than condition 3, and any twice-continuously differentiable function satisfies condition 12. More importantly, we eliminate the need for conditions on the boundedness of the stochastic gradient, the second moment of the stochastic gradient and variance. This makes PISA well-suited for addressing the challenges associated with data heterogeneity, an open problem in federated learning60,62.
High numerical performance for various applications
The effectiveness of PISA and its two variants, SISA and NSISA, is demonstrated through comparisons with many state-of-the-art optimizers using several deep models: vision models, LLMs, reinforcement learning models, generative adversarial networks (GANs) and recurrent neural networks, highlighting its great potential for extensive applications. In particular, numerical experiments on heterogeneous datasets corroborate our claim that PISA effectively addresses this challenge.
Results
This section evaluates the performance of SISA and NSISA in comparison with various established optimizers to process several deep models: vision models, LLMs, reinforcement learning models, recurrent neural networks and GANs. All LLM-targeted experiments are conducted on four NVIDIA H100-80GB graphical processing units, while the remaining experiments are run on a single such graphical processing unit. The source codes are available at https://github.com/Tracy-Wang7/PISA. Due to space limitations, details on hyperparameter setup, effects of key hyperparameters and more numerical evaluations, including experiments for reinforcement learning models and recurrent neural networks, are provided in Supplementary Section 1.
Data heterogeneity
To evaluate the effectiveness of the proposed algorithms in addressing the challenge of data heterogeneity, we design some experiments within a centralized federated learning framework. In this setting, m clients collaboratively learn a shared parameter under a central server, with each client i holding a local dataset \({{\mathcal{D}}}_{i}\). Note that our three algorithms are well-suited for centralized federated learning tasks. As shown in Algorithm 2, clients update their local parameters \(({\bf{w}}_{i}^{\ell +1},{\boldsymbol{\pi }}_{i}^{\ell +1})\) in parallel, and the server aggregates them to compute global parameter wℓ+2.
In the experiment, we focus on heterogeneous datasets \({{\mathcal{D}}}_{1},{{\mathcal{D}}}_{2},\ldots ,{{\mathcal{D}}}_{m}\). The heterogeneity may arise from label distribution skew, feature distribution skew or quantity skew61. Empirical evidence has shown that the label distribution skew presents a significantly greater challenge to current FL methods compared with the other two types. Therefore, we adopt this skew as the primary setting for evaluating data heterogeneity in our experiments.
To reflect the fact that most benchmark datasets contain 10 image classes, we set m = 10. Four datasets, MNIST, CIFAR-10, FMNIST and Adult, are used for the experiment. As shown in Table 2, the configuration ‘s label’ indicates that each client holds data containing s distinct label classes, where s = 1, 2, 3. For example, in the ‘2-label’ setting, \({{\mathcal{D}}}_{1}\) may include two labels 1 and 2, while \({{\mathcal{D}}}_{2}\) contains two labels 2 and 3. To ensure better performance, we used a mini-batch training strategy with multiple local updates before each aggregation step for the five baseline algorithms. In contrast, SISA keeps using one local update per aggregation step and still achieves competitive testing accuracy with far fewer local updates. Notably, on the MNIST dataset under the 1-label skew setting, the best accuracy achieved by other algorithms is 54.33%, whereas SISA reaches 94.97%, demonstrating a significant improvement.
Classification by vision models
We perform classification tasks on two well-known datasets, ImageNet63 and CIFAR-1064, and evaluate performance based on convergence speed and testing accuracy. Convergence performance is presented in Supplementary Section 1.3. We now focus on testing accuracy, following the experimental setup and hyperparameter configurations from ref. 65. In this experiment, a small batch size (for example, \(| {{\mathcal{D}}}_{i}| =128\) for every i ∈ [m]) is used during training. As SISA updates its hyperparameters frequently, it uses a different learning-rate scheduler from that in ref. 65. For fair comparison, each baseline optimizer is tested with both its default scheduler and the one used by SISA, and the best result is reported. On CIFAR-10, we train three models: VGG-1166, ResNet-3467 and DenseNet-12168. As shown in Table 3, SISA achieves the highest testing accuracy on ResNet-34 and DenseNet-121, and performs slightly worse than SGD-M on VGG-11. In addition, Fig. 1 illustrates the training loss and testing accuracy over training epochs for VGG-11 and ResNet-34, demonstrating the fast convergence and high accuracy of SISA. For ImageNet, we train a ResNet-18 model and report the testing accuracy in Table 3. SISA outperforms most adaptive methods but is slightly behind SGD-M and AdaBelief.
Training LLMs
In this subsection, we apply SISA and NSISA to train several GPT2 models69 and compare it with advanced optimizers such as AdamW, Muon, Shampoo, SOAP and Adam-mini that have been proved effective in training LLM. The comparison between SISA and AdamW is provided in Supplementary Section 1.4. We now compare NSISA with the other four optimizers, with the aim of tuning three GPT2 models: GPT2-Nano (125M), GPT2-Medium (330M) and GPT2-XL (1.5B). The first model follows the setup described in ref. 27, while the last two follow the configuration from ref. 29. We use the FineWeb dataset, a large collection of high-quality web text, to fine-tune GPT2 models, enabling them to generate more coherent and contextually relevant web-style content. The hyperparameters of NSISA remain consistent across the experiments for the three GPT2 models.
The comparison of different algorithms in terms of the memory overhead is given in Supplementary Section 1.4. The experimental results are presented in Fig. 2. We evaluate each optimizer’s training performance from two perspectives: the number of training tokens consumed and the wall-clock time required. To reduce wall-clock time, we implemented NSISA using parallel computation. From Fig. 2, as the number of GPT2 parameters increases from Nano to Medium to XL, the validation loss gap between NSISA and other baseline optimizers widens. In particular, for GPT2-XL, the last figure demonstrated evident advantage of NSISA in terms of the wall-clock time.
Generation tasks by GANs
GANs involve alternating updates between the generator and discriminator in a minimax game, making training notoriously unstable70. To assess optimizer stability, we compare SISA with adaptive methods such as Adam, RMSProp and AdaBelief, which are commonly recommended to enhance both stability and efficiency in training GANs65,71. We use two popular architectures, the Wasserstein GAN (WGAN72) and WGAN with gradient penalty (WGAN-GP71), with a small model and vanilla CNN generator on the CIFAR-10 dataset.
To evaluate performance, we compute the Fréchet inception distance (FID)73 every 10 training epochs, measuring the distance between 6,400 generated images and 60,000 real images. The FID score is a widely used metric for generative models, assessing both image quality and diversity. The lower FID values indicate better generated images. We follow the experimental setup from ref. 65 and run each optimizer five times independently. The corresponding mean and standard deviation of training FID over epochs are presented in Fig. 3. Clearly, SISA achieves the fastest convergence and the lowest FID. After training, 64,000 fake images are produced to compute the testing FID, and SISA outperforms the other three optimizers, demonstrating its superior stability and effectiveness in training GANs. Detailed testing FID scores can be found in Supplementary Section 1.7.
Conclusion
This paper introduces an ADMM-based algorithm that leverages stochastic gradients and solves subproblems inexactly, significantly accelerating computation. The preconditioning framework allows it to incorporate popular second-moment schemes, enhancing training performance. Theoretical guarantees, based solely on the Lipschitz continuity of the gradients, make the algorithm suitable for heterogeneous datasets, effectively addressing an open problem in stochastic optimization for distributed learning. The algorithm demonstrates superior generalization across a wide range of architectures, datasets and tasks, outperforming well-known deep learning optimizers.
Methods
Organization and notation
This section is organized as follows: in the next subsection, we introduce the main model and develop the proposed algorithm, PISA. The section ‘Convergence of PISA’ provides rigorous proofs of the convergence. The section ‘Second moment’ specifies the precondition by the second moment to derive a variation of PISA, termed SISA.
Let [m] ≔ {1, 2, …, m}, where ‘ ≔ ’ means ‘define’. The cardinality of a set \({\mathcal{D}}\) is written as \(| {\mathcal{D}}|\). For two vectors w and v, their inner product is denoted by 〈w, v〉 ≔ ∑iwivi. Let ∥ ⋅ ∥ be the Euclidean norm for vectors, namely \(\parallel {\bf{w}}\parallel =\sqrt{\langle {\bf{w}},{\bf{w}}\rangle }\), and the Spectral norm for matrices. A ball with a positive radius r is written as \({\mathbb{N}}(r):=\{{\bf{w}}:\parallel {\bf{w}}\parallel \le r\}\). A symmetric positive semi-definite matrix Q is written as Q ≽ 0. Then P ≽ Q means that P − Q ≽ 0. Denote the identity matrix by I and let 1 be the vector with all entries being 1. We write
Similar rules are also used for the definitions of Πℓ, Wℓ, Mℓ and σℓ.
Preconditioned inexact SADMM
We begin this subsection by introducing the mathematical optimization model for general distributed learning. Then we go through the development of the algorithm.
Model description
Suppose we are given a set of data as \({\mathcal{D}}:=\left\{{\mathbf{x}}_{t}:t\right.\)\(\left.=1,2,\ldots ,|{\mathcal{D}}|\right\}\), where xt is the tth sample. Let f(w; xt) be a function (such as neural networks) parameterized by w and sampled by xt. The total loss function on \({\mathcal{D}}\) is defined by \({\sum }_{{\bf{x}}_{t}\in {\mathcal{D}}}\,f({\bf{w}};{\bf{x}}_{t})/|{\mathcal{D}}|\). We then divide data \({\mathcal{D}}\) into m disjoint batches, namely, \({\mathcal{D}}={{\mathcal{D}}}_{1}\cup {{\mathcal{D}}}_{2}\cup \ldots \cup {{\mathcal{D}}}_{m}\) and \({{\mathcal{D}}}_{i}\cap {{\mathcal{D}}}_{{i}^{{\prime} }}={\rm{\varnothing }}\) for any two distinct i and \({i}^{{\prime} }\). Denote
Clearly, \({\sum }_{i=1}^{m}{\alpha }_{i}=1\). Now, we can rewrite the total loss as follows:
The task is to learn an optimal parameter to minimize the following regularized loss function:
where μ ≥ 0 is a penalty constant and ||w||2 is a regularization.
Main model
Throughout the paper, we focus on the following equivalent model of problem (2):
where λ ∈ [0, μ] and
In problem (3), m auxiliary variables wi are introduced in addition to the global parameter w. We emphasize that problems (2) and (3) are equivalent in terms of their optimal solutions but are expressed in different forms when λ ∈ [0, μ), and they are identical when λ = μ. Throughout this work, we assume that optimal function value F* is bounded from below, namely, F* > − ∞.
The algorithmic design
When using ADMM to solve problem (3), we need its associated augmented Lagrange function, which is defined as follows:
where σi > 0 and πi, i ∈ [m] are the Lagrange multipliers. Based on the above augmented Lagrange function, the conventional ADMM updates each variable in (w, W, Π) iteratively. However, we modified the framework as follows. Given initial point (w0, W0, Π0; σ0), the algorithm performs the following steps iteratively for ℓ = 0, 1, 2, …:
for each i ∈ [m], where ρi > 0, both scalar \({\sigma }_{i}^{\ell +1}\) and matrix \({Q}_{i}^{\ell +1}\succcurlyeq 0\) will be updated properly. Hereafter, superscripts ℓ and ℓ + 1 in \({\sigma }_{i}^{\ell }\) and \({\sigma }_{i}^{\ell +1}\) stand for the iteration number rather than the power. Here \({Q}_{i}^{\ell +1}\) is commonly referred to as an (adaptively) preconditioning matrix in preconditioned gradient methods26,74,75,76.
Remark 1
The primary distinction between algorithmic framework (equation (6a–c)) and conventional ADMM lies in the inclusion of a term \(\frac{{\rho }_{i}}{2}\langle {\mathbf{w}}_{i}-{\mathbf{w}}^{\ell +1},{Q}_{i}^{\ell +1}({\mathbf{w}}_{i}-{\mathbf{w}}^{\ell +1})\rangle\). This term enables the incorporation of various forms of useful information, such as second-moment, second-order information (for example, Hessian) and orthogonalized momentum by Newton–Schulz iterations, thereby enhancing the performance of the proposed algorithms; see the section ‘Second moment’ for more details.
One can check that subproblem (6a) admits a closed-form solution outlined in equation (8). For subproblem (6b), to accelerate the computational speed, we solve it inexactly by
Algorithm 1
PISA.
Divide \({\mathcal{D}}\) into m disjoint batches \(\{{{\mathcal{D}}}_{1},{{\mathcal{D}}}_{2},\ldots ,{{\mathcal{D}}}_{m}\}\) and calculate αi by equation (1).
Initialize \({\mathbf{w}}^{0}={\mathbf{w}}_{i}^{0}={\boldsymbol{\pi }}_{i}^{0}=0\), γi ∈ [3/4, 1), and \(({\sigma }_{i}^{0},{\eta }_{i},{\rho }_{i}) > 0\) for each i ∈ [m].
for ℓ = 0, 1, 2, … do
for i = 1, 2, …, m do
end
end
This update admits three advantages. First, it solves problem (6b) by a closed-form solution, namely, the second equation in (7), reducing the computational complexity. Second, we approximate Fi(w) using its first-order approximation at wℓ+1 rather than \({\mathbf{w}}_{i}^{\ell }\), which facilitates each batch parameter \({\mathbf{w}}_{i}^{\ell +1}\) to tend to wℓ+1 quickly, thereby accelerating the overall convergence. Finally, \(\nabla {F}_{i}({\mathbf{w}}^{\ell +1};{{\mathcal{B}}}_{i}^{\ell +1})\) serves as a stochastic approximation of true gradient \(\nabla {F}_{i}({\mathbf{w}}^{\ell +1})=\nabla {F}_{i}({\mathbf{w}}^{\ell +1};{{\mathcal{D}}}_{i})\), as defined by equation (4), where \({{\mathcal{B}}}_{i}^{\ell +1}\) is a random sample from \({{\mathcal{D}}}_{i}\). By using sub-batch datasets \(\{{{\mathcal{B}}}_{1}^{\ell +1},\ldots ,{{\mathcal{B}}}_{m}^{\ell +1}\}\) in every iteration, rather than full data \({\mathcal{D}}=\{{{\mathcal{D}}}_{1},\ldots ,{{\mathcal{D}}}_{m}\}\), the computational cost is significantly reduced. Overall, on the basis of these observations, we name our algorithm PISA, as described in Algorithm 1.
Another advantageous property of PISA is its ability to perform parallel computation, which stems from the parallelism used in solving subproblems in ADMM. At each iteration, m nodes (that is, i = 1, 2, ⋯ , m) update their parameters by equations (9)–(13) in parallel, thereby enabling the processing of large-scale datasets. Moreover, when specifying the preconditioning matrix, \({Q}_{i}^{\ell +1}\), as a diagonal matrix (as outlined in the section ‘Second moment’) and sampling \({{\mathcal{B}}}_{i}^{\ell +1}\) with small batch sizes, each node exhibits significantly low computational complexity, facilitating fast computation.
Convergence of PISA
In this subsection, we aim to establish the convergence property of Algorithm 1. To proceed with that, we first define a critical bound by
Lemma 1
εi(r) < ∞ for any given r ∈ (0, ∞) and any i ∈ [m].
The proof of Lemma 1 is given in Supplementary Section 2.3. One can observe that εi(r) = 0 for any r > 0 if we take the full batch data in each step, namely, choosing \({{\mathcal{B}}}_{i}^{\ell }={({{\mathcal{B}}}_{i}^{\ell })}^{{\prime} }={{\mathcal{D}}}_{i}\) for every i ∈ [m] and all ℓ ≥ 1. However for min-batch dataset \({{\mathcal{B}}}_{i}^{\ell }\subset {{\mathcal{D}}}_{i}\), this parameter is related to the bound of variance \({\mathbb{E}}{\left\Vert \nabla {F}_{i}(\mathbf{w};{{\mathcal{B}}}_{i})-\nabla {F}_{i}(\mathbf{w};{{\mathcal{D}}}_{i})\right\Vert }^{2}\), which is commonly assumed to be bounded for any w10,11,21,36,77. However, in the subsequent analysis, we can verify that both generated sequences {wℓ} and \(\{{\mathbf{w}}_{i}^{\ell }\}\) fall into a bounded region \({\mathbb{N}}(\delta )\) for any i ∈ [m] with δ defined as equation (16), thereby leading to a finitely bounded εi(δ) naturally, see Lemma 2. In other words, we no longer need to assume the boundedness of the variance, \({\mathbb{E}}{\left\Vert \nabla {F}_{i}(\mathbf{w};{{\mathcal{B}}}_{i})-\nabla {F}_{i}(\mathbf{w};{{\mathcal{D}}}_{i})\right\Vert }^{2}\) for any w. This assumption is known to be somewhat restrictive, particularly for non-IID or heterogeneous datasets. Therefore, the theorems we establish in the sequel effectively address this critical challenge60,62. Therefore, our algorithm demonstrates robust performance in settings with heterogeneous data.
Convergence analysis
To establish convergence, we need the assumption below. It assumes that function f has a Lipschitz continuous gradient on a bounded region, namely, the gradient is locally Lipschitz continuous. This is a relatively mild condition. Functions with (global) Lipschitz continuity and twice-continuously differentiable functions satisfy this condition. It is known that the Lipschitz continuity of the gradient is commonly referred to as L-smoothness. Therefore, our assumption can be regarded as L-smoothness on a bounded region, which is weaker than L-smoothness.
Assumption 1
For each \(t\in [| {\mathcal{D}}| ]\), gradient ∇ f( ⋅ ; xt) is Lipschitz continuous with a constant c(xt) > 0 on \({\mathbb{N}}(2\delta )\). Denote \({c}_{i}:=\mathop{\max }\limits_{{{\bf{x}}}_{t}\in {{\mathcal{D}}}_{i}}c({{\bf{x}}}_{t})\) and ri ≔ ci + μ − λ for each i ∈ [m].
First, given a constant σ > 0, we define a set
where \(\gamma :=\mathop{\max }\limits_{i\in [m]}{\gamma }_{i}\), based on which we further define
This indicates that any point (w, W) ∈ Ω satisfies \(\{\mathbf{w},{\mathbf{w}}_{1},\ldots ,{\mathbf{w}}_{m}\}\subseteq {\mathbb{N}}(\delta )\). Using this δ, we initialize \({{\boldsymbol{\sigma }}}^{0}:=({\sigma }_{1}^{0},{\sigma }_{2}^{0},\cdots \,,{\sigma }_{m}^{0})\) by
It is easy to see that Ω is a bounded set due to Fi being bounded from below. Therefore, δ is bounded and so is εi(2δ) due to Lemma 1. Hence, σ0 in equation (17) is a well-defined constant, namely, σ0 can be set as a finite positive number. For notational simplicity, hereafter, we denote
Our first result shows the descent property of a merit function associated with \({{\mathcal{L}}}^{\ell }\).
Lemma 2
Let {(wℓ, Wℓ, Πℓ)} be the sequence generated by Algorithm 1with σ0 chosen as equation (17). Then the following statements are valid under Assumption 1.
-
(1)
For any ℓ ≥ 0, sequence \(\{{\mathbf{w}}^{\ell },{\mathbf{w}}_{1}^{\ell },\ldots ,{\mathbf{w}}_{m}^{\ell }\}\subseteq {\mathbb{N}}(\delta )\).
-
(2)
For any ℓ ≥ 0,
$${\widetilde{{\mathcal{L}}}}^{\ell }-{\widetilde{{\mathcal{L}}}}^{\ell +1}\ge \displaystyle{\mathop{\sum }\limits_{i=1}^{m}}{\alpha }_{i}\left[\dfrac{{\sigma }_{i}^{\ell }+2\lambda }{4}{\left\Vert \Delta {\mathbf{w}}^{\ell +1}\right\Vert }^{2}+\dfrac{{\sigma }_{i}^{\ell }}{4}{\left\Vert \Delta {\mathbf{w}}_{i}^{\ell +1}\right\Vert }^{2}\right],$$(19)where \({\widetilde{{\mathcal{L}}}}^{\ell }\) is defined by
$${\mathop{L}\limits^{ \sim }}^{\ell }:={L}^{\ell }+\displaystyle{\mathop{\sum }\limits_{i=1}^{m}}{\alpha }_{i}\left[\dfrac{8}{{\sigma }_{i}^{\ell }}{\parallel {\rho }_{i}{Q}_{i}^{\ell }\Delta {\bar{\mathbf{w}}}_{i}^{\ell }\parallel }^{2}+\dfrac{{\gamma }_{i}^{\ell }}{16(1-{\gamma }_{i})}\right].$$(20)
The proof of Lemma 2 is given in Supplementary Section 2.6. This lemma is derived from a deterministic perspective. Such a success lies in considering the worst case of bound εi(2δ) (that is, taking all possible selections of \(\{{{\mathcal{B}}}_{1}^{\ell },\ldots ,{{\mathcal{B}}}_{m}^{\ell }\}\) into account). On the basis of the above key lemma, the following theorem shows the sequence convergence of the algorithm.
Theorem 1
Let {(wℓ, Wℓ, Πℓ)} be the sequence generated by Algorithm 1with σ0 chosen as equation (17). Then the following statements are valid under Assumption 1.
-
(1)
Sequences \(\{{{\mathcal{L}}}^{\ell }\}\) and \(\{{\widetilde{{\mathcal{L}}}}^{\ell }\}\) converge and for any i ∈ [m],
$$0=\mathop{\rm{lim}}\limits_{\ell \to \infty }\left\Vert \Delta {\mathbf{w}}^{\ell }\right\Vert =\mathop{\rm{lim}}\limits_{\ell \to \infty }\left\Vert \Delta {\mathbf{w}}_{i}^{\ell }\right\Vert =\mathop{\rm{lim}}\limits_{\ell \to \infty }\left\Vert \Delta {\bar{\mathbf{w}}}_{i}^{\ell }\right\Vert =\mathop{\rm{lim}}\limits_{\ell \to \infty }\left({\widetilde{{\mathcal{L}}}}^{\ell }-{{\mathcal{L}}}^{\ell }\right).$$(21) -
(2)
Sequence \(\{({\mathbf{w}}^{\ell },{{\mathit{W}}}^{\ell },{\mathbb{E}}{{{\Pi }}}^{\ell })\}\) converges.
The proof of Theorem 1 is given in Supplementary Section 2.7. To ensure the convergence results, initial value σ0 is selected according to equation (17), which involves a hyperparameter δ. If a lower bound \(\underline{F}\) of \(\mathop{\rm{min}}\limits_{\mathbf{w}}{\sum }_{i=1}^{m}{\alpha }_{i}{F}_{i}(\mathbf{w})\) is known, then an upper bound \(\overline{\delta }\) for δ can be estimated from equations (15) and (16) by substituting Fi(w) with \(\underline{F}\). In practice, particularly in deep learning, many widely used loss functions, such as mean-squared error and cross-entropy, yield non-negative values. This observation allows us to set the lower bound as \(\underline{F}=0\). Once \(\overline{\delta }\) is estimated, it can be used in equation (17) to select σ0, without affecting the convergence guarantees. However, it is worth emphasizing that equation (17) is a sufficient but not necessary condition. Therefore, in practice, it is not essential to enforce this condition strictly when initializing σ0 in numerical experiments.
Complexity analysis
Besides the convergence established above, the algorithm exhibits the following rate of convergence under the same assumption and parameter setup.
Theorem 2
Let {(wℓ, Wℓ, Πℓ)} be the sequence generated by Algorithm 1with σ0 chosen as equation (17). Let w∞ be the limit of sequence {wℓ}. Then there is a constant C1 > 0 such that
and a constant C2 > 0 such that
The proof of Theorem 2 is given in Supplementary Section 2.8. This theorem means that sequence \(\{({\mathbf{w}}^{\ell },{{\mathit{W}}}^{\ell },{\mathbb{E}}{{{\Pi }}}^{\ell })\}\) converges to its limit in a linear rate. To achieve such a result, we only assume Assumption 1 without imposing any other commonly used assumptions, such as those presented in Table 1.
Precondition specification
In this section, we explore the preconditioning matrix, namely, matrix \({Q}_{i}^{\ell }\). A simple and computationally efficient choice is to set \({Q}_{i}^{\ell }={{I}}\), which enables fast computation of updating \({\mathbf{w}}_{i}^{\ell +1}\) via equation (12). However, this choice is too simple to extract useful information about Fi. Therefore, several alternatives can be adopted to set \({Q}_{i}^{\ell }\).
Second-order information
Second-order optimization methods, such as Newton-type and trust region methods, are known to enhance numerical performance by leveraging second-order information, the (generalized) Hessian. For instance, if each function Fi is twice-continuously differentiable, then one can set
where \({\nabla }^{2}{F}_{i}({\mathbf{w}}^{\ell +1};{{\mathcal{B}}}_{i}^{\ell +1})\) represents the Hessian of \({F}_{i}(\cdot ;{{\mathcal{B}}}_{i}^{\ell +1})\) at wℓ+1. With this choice, subproblem (7) becomes closely related to second-order methods, and the update takes the form
This update corresponds to a Levenberg–Marquardt step78,79 or a regularized Newton step80,81 when \({\sigma }_{i}^{\ell +1} > 0\), and reduces to the classical Newton step if \({\sigma }_{i}^{\ell +1}=0\). While incorporating the Hessian can improve performance in terms of iteration count and solution quality, it often leads to significantly high computational complexity. To mitigate this, some other effective approaches exploit the second moment derived from historical updates to construct the preconditioning matrices.
Second moment
We note that the second moment to determine an adaptive learning rate enables the improvements of the learning performance of several popular algorithms, such as RMSProp16 and Adam17. Motivated by this, we specify preconditioning matrix by using the second moment as follows:
where Diag(m) is the diagonal matrix with the diagonal entries formed by m and \({\mathbf{m}}_{i}^{\ell +1}\) can be chosen flexibly as long as it satisfies that \(\parallel {\mathbf{m}}_{i}^{\ell +1}{\parallel }_{\infty }\le {\eta }_{i}^{2}\). Here, \(\parallel \mathbf{m}{\parallel }_{\infty }\) is the infinity norm of m. We can set \({\mathbf{m}}_{i}^{\ell +1}\) as follows
where \({\widetilde{\mathbf{m}}}_{i}^{\ell +1}\) can be updated by
where \({\widetilde{\mathbf{m}}}_{i}^{0}\) and \({\mathbf{n}}_{i}^{0}\) are given, βi ∈ (0, 1), and \({\beta }_{i}^{\ell }\) stands for power ℓ of βi. These three schemes resemble the ones used by AdaGrad15, RMSProp16 and Adam17, respectively. Putting equation (25) into Algorithm 1 gives rise to Algorithm 2. We term it SISA, an abbreviation for the second moment-based inexact SADMM. Compared with PISA in Algorithm 1, SISA admits three advantages.
-
(1)
It is capable of incorporating various schemes of the second moment, which may enhance the numerical performance of SISA significantly.
-
(2)
One can easily check that \({\eta }_{i}{{I}}\succcurlyeq {Q}_{i}^{\ell +1}\succcurlyeq 0\) for each batch i ∈ [m] and all ℓ ≥ 1. Therefore, equation (25) enables us to preserve the convergence property as follows.
Theorem 3
Let {(wℓ, Wℓ, Πℓ)} be the sequence generated by Algorithm 2with σ0 chosen as equation (17). Then under Assumption 1, all statements in Theorems 1and 2are valid.
-
(3)
Such a choice of \({Q}_{i}^{\ell +1}\) enables the fast computation compared with update \({\mathbf{w}}_{i}^{\ell +1}\) by equation (7). In fact, since operation u/v denotes element-wise division, the complexity of computing equation (28) is O(p), where p is the dimension of \({\mathbf{w}}_{i}^{\ell +1}\), whereas the complexity of computing equation (12) is O(p3).
Algorithm 2
SISA
Divide \({\mathcal{D}}\) into m disjoint batches \(\{{{\mathcal{D}}}_{1},{{\mathcal{D}}}_{2},\ldots ,{{\mathcal{D}}}_{m}\}\) and calculate αi by equation (1).
Initialize \({\mathbf{w}}^{0}={\mathbf{w}}_{i}^{0}={\boldsymbol{\pi }}_{i}^{0}=0\), γi ∈ [3/4, 1), and \(({\sigma }_{i}^{0},{\eta }_{i},{\rho }_{i}) > 0\) for each i ∈ [m].
for ℓ = 0, 1, 2, … do
for i = 1, 2, …, m do
end
end
Algorithm 3
NSISA
Divide \({\mathcal{D}}\) into m disjoint batches \(\{{{\mathcal{D}}}_{1},{{\mathcal{D}}}_{2},\ldots ,{{\mathcal{D}}}_{m}\}\) and calculate αi by equation (1).
Initialize \({\mathbf{w}}^{0}={\mathbf{w}}_{i}^{0}={\boldsymbol{\pi }}_{i}^{0}={\mathbf{b}}_{i}^{0}=0\), γi ∈ [3/4, 1), ϵi ∈ (0, 1) and \(({\sigma }_{i}^{0},{\rho }_{i},{\mu }_{i}) > 0\) for each i ∈ [m].
for ℓ = 0, 1, 2, … do
for i = 1, 2, …, m do
end
end
Orthogonalized momentum by Newton–Schulz iterations
Recently, the authors of ref. 27 proposed an algorithm called Muon, which orthogonalizes momentum using Newton–Schulz iterations. This approach has shown promising results in fine-tuning LLMs, outperforming many established optimizers. The underlying philosophy of Muon can also inform the design of the preconditioning matrix. Specifically, we consider the two-dimensional case, namely, the trainable variable w is a matrix. Then subproblem (7) in a vector form turns to
where vec(w) denotes the column-wise vectorization of matrix w. Now, initialize \({\mathbf{b}}_{i}^{0}\) for all i ∈ [m] and μ > 0, update momentum by \({\mathbf{b}}_{i}^{\ell +1}=\mu {\mathbf{b}}_{i}^{\ell }+{\mathbf{g}}_{i}^{\ell +1}\). Let \({\mathbf{b}}_{i}^{\ell +1}={U}_{i}^{\ell +1}{\Lambda }_{i}^{\ell +1}{({V}_{i}^{\ell +1})}^{{\rm{\top }}}\) be the singular value decomposition of \({\mathbf{b}}_{i}^{\ell +1}\), where \({\Lambda }_{i}^{\ell +1}\) is a diagonal matrix and \({U}_{i}^{\ell +1}\) and \({V}_{i}^{\ell +1}\) are two orthogonal matrices. Compute \({\mathbf{o}}_{i}^{\ell +1}={U}_{i}^{\ell +1}{({V}_{i}^{\ell +1})}^{{\rm{\top }}}\) and \({\mathbf{p}}_{i}^{\ell +1}\) by
where \({\mathbf{m}}_{i}^{\ell +1}\) can be the second moment (for example, \({\mathbf{m}}_{i}^{\ell +1}=({\boldsymbol{\pi }}_{i}^{\ell }+{\mathbf{o}}_{i}^{\ell +1})\odot ({\boldsymbol{\pi }}_{i}^{\ell }+{\mathbf{o}}_{i}^{\ell +1})\) is used in the numerical experiment), ϵi ∈ (0, 1) (here, \({\epsilon }_{i}^{\ell }\) stands for power ℓ of ϵi) and \({\mathbf{v}}_{i}^{\ell +1}\) is a matrix with (k, j)th element computed by \({({\mathbf{v}}_{i}^{\ell +1})}_{{kj}}=1\) if \({({\boldsymbol{\pi }}_{i}^{\ell }+{\mathbf{o}}_{i}^{\ell +1})}_{kj}=0\) and \({({\mathbf{v}}_{i}^{\ell +1})}_{kj}=0\) otherwise. Then we set the preconditioning matrix by
Substituting above choice into equation (30) derives equation (29). The idea of using equation (29) is inspired by ref. 27, where Newton–Schulz orthogonalization82,83 is used to efficiently approximate \({\mathbf{o}}_{i}^{\ell +1}\). Incorporating these steps into Algorithm 1 leads to Algorithm 3, which we refer to as NSISA. The implementation of NewtonSchulz(b) is provided in ref. 27. Below is the convergence result of NSISA.
Theorem 4
Let {(wℓ, Wℓ, Πℓ)} be the sequence generated by Algorithm 3with σ0 chosen as equation (17). If Assumption 1holds and \({({\mathbf{p}}_{i}^{\ell })}_{kj}\in (0,{\eta }_{i})\) for any (k, j) and ℓ ≥ 0, all statements in Theorems 1and 2are valid.
Code availability
All code is available via GitHub at https://github.com/Tracy-Wang7/PISA and via Code Ocean at https://doi.org/10.24433/CO.3435996.v1 (ref. 91).
References
Touvron, H. et al. Llama 2: open foundation and fine-tuned chat models. Preprint at https://arxiv.org/abs/2307.09288 (2023).
OpenAI et al. GPT-4 technical report. Preprint at https://arxiv.org/abs/2303.08774 (2024).
Gemma Team et al. Gemma: open models based on Gemini research and technology. Preprint at https://arxiv.org/abs/2403.08295 (2024).
He, K., Zhang, X., Ren, S. & Sun, J. Deep residual learning for image recognition. In Proc. IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 770–778 (IEEE, 2016).
Dosovitskiy, A. et al. An image is worth 16 × 16 words: transformers for image recognition at scale. In International Conference on Learning Representations (ICLR, 2021).
Robbins, H. & Monro, S. A stochastic approximation method. Ann. Math. Stat. 22, 400–407 (1951).
Chung, K. L. On a stochastic approximation method. Ann. Math. Stat. 25, 463–483 (1954).
Zinkevich, M., Weimer, M., Li, L. & Smola, A. Parallelized stochastic gradient descent. In Advances in Neural Information Processing Systems (NeurIPS) 2595–2603 (Curran Associates, Inc., 2010).
McMahan, B., Moore, E., Ramage, D., Hampson, S. & Arcas, B. A. Y. Communication-efficient learning of deep networks from decentralized data. In Proc. 20th International Conference on Artificial Intelligence and Statistics (eds Singh, A. & Zhu, J.) 1273–1282 (PMLR, 2017).
Li, X., Huang, K., Yang, W., Wang, S. & Zhang, Z. On the convergence of FedAvg on non-IID data. In International Conference on Learning Representations (ICLR, 2020).
Stich, S. U. Local SGD converges fast and communicates little. In International Conference on Learning Representations (ICLR, 2019).
Novak, R., Bahri, Y., Abolafia, D. A., Pennington, J. & Sohl-Dickstein, J. Sensitivity and generalization in neural networks: an empirical study. Preprint at arXiv:1802.08760 (2018).
Chen, X. et al. Symbolic discovery of optimization algorithms. Preprint at https://arxiv.org/abs/2302.06675 (2023).
Qian, N. On the momentum term in gradient descent learning algorithms. Neural Netw. 12, 145–151 (1999).
Duchi, J., Hazan, E. & Singer, Y. Adaptive subgradient methods for online learning and stochastic optimization. J. Mach. Learn. Res. 12, 2121–2159 (2011).
Tieleman, T. Lecture 6.5-rmsprop: divide the gradient by a running average of its recent magnitude. COURSERA: Neural Networks for Machine Learning 4, 26–31 (2012).
Kingma, D. P. Adam: a method for stochastic optimization. In International Conference on Learning Representations (ICLR, 2015).
Reddi, S. J., Kale, S. & Kumar, S. On the convergence of Adam and beyond. In International Conference on Learning Representations (ICLR, 2018).
Chen, J. et al. Closing the generalization gap of adaptive gradient methods in training deep neural networks. In International Joint Conferences on Artificial Intelligence (IJCAI) 3267–3275 (International Joint Conferences on Artificial Intelligence Organization, 2018).
Luo, L., Xiong, Y., Liu, Y. & Sun, X. Adaptive gradient methods with dynamic bound of learning rate. In International Conference on Learning Representations (ICLR, 2019).
Yu, H., Yang, S. & Zhu, S. Parallel restarted SGD with faster convergence and less communication: demystifying why model averaging works for deep learning. In Association for the Advancement of Artificial Intelligence (AAAI) 5693–5700 (AAAI, 2019).
You, Y. et al. Large batch optimization for deep learning: training BERT in 76 minutes. In International Conference on Learning Representations (ICLR, 2020).
You, Y., Gitman, I. & Ginsburg, B. Scaling SGD batch size to 32k for ImageNet training. Preprint at https://arxiv.org/abs/1708.03888v1?2 (2017).
Zeiler, M. D. Adadelta: an adaptive learning rate method. Preprint at https://arxiv.org/abs/1212.5701 (2012).
Loshchilov, I. Decoupled weight decay regularization. In International Conference on Learning Representations (ICLR, 2018).
Gupta, V., Koren, T. & Singer, Y. Shampoo: preconditioned stochastic tensor optimization. In International Conference on Machine Learning (ICML) 1842–1850 (PMLR, 2018).
Jordan, K. et al. Muon: an optimizer for hidden layers in neural networks. Keller Jordan blog https://kellerjordan.github.io/posts/muon/ (2024).
Vyas, N. et al. SOAP: improving and stabilizing Shampoo using Adam for language modeling. In International Conference on Learning Representations (ICLR, 2025).
Zhang, Y. et al. Adam-mini: use fewer learning rates to gain more. In International Conference on Learning Representations (ICLR, 2025).
Bottou, L. Large-scale machine learning with stochastic gradient descent. In International Conference on Computational Statistics 177–186 (Springer, 2010).
Ruder, S. An overview of gradient descent optimization algorithms. Preprint at https://arxiv.org/abs/1609.04747 (2016).
Bottou, L., Curtis, F. E. & Nocedal, J. Optimization methods for large-scale machine learning. SIAM Rev. 60, 223–311 (2018).
Nesterov, Y. A method for solving the convex programming problem with convergence rate o(1/k2). Dokl akad nauk Sssr 269, 543– 547 (1983).
Nesterov, Y. On an approach to the construction of optimal methods of minimization of smooth convex functions. Ekonomika i Mateaticheskie Metody 24, 509–517 (1988).
Dozat, T. Incorporating Nesterov momentum into Adam. In International Conference on Learning Representations 1–4 (ICLR, 2016).
Xie, X., Zhou, P., Li, H., Lin, Z. & Yan, S. Adan: adaptive Nesterov momentum algorithm for faster optimizing deep models. IEEE Trans. Pattern Anal. Mach. Intell. 46, 1–13 (2024).
Gabay, D. & Mercier, B. A dual algorithm for the solution of nonlinear variational problems via finite element approximation. Comput. Math. Appl. 2, 17–40 (1976).
Boyd, S. et al. Distributed optimization and statistical learning via the alternating direction method of multipliers. Found. Trends Mach. Learn. 3, 1–122 (2011).
Yang, Y., Sun, J., Li, H. & Xu, Z. ADMM-CSNet: a deep learning approach for image compressive sensing. IEEE Trans. Pattern Anal. Mach. Intell. 42, 521–538 (2018).
Zhang, X., Hong, M., Dhople, S., Yin, W. & Liu, Y. FedPD: a federated learning framework with adaptivity to non-IID data. IEEE Tran. Signal Process. 69, 6055–6070 (2021).
Zhou, S. & Li, G. Y. Federated learning via inexact ADMM. IEEE Trans. Pattern. Anal. Mach. Intell. 45, 9699–9708 (2023).
Zhou, S. & Li, G. Y. FedGiA: an efficient hybrid algorithm for federated learning. IEEE Trans. Signal Process. 71, 1493–1508 (2023).
Xu, K., Zhou, S. & Li, G. Y. Federated reinforcement learning for resource allocation in V2X networks. IEEE J. Sel. Topics Signal Process. 24, 2799–2803 (2024).
Wang, O., Zhou, S. & Li, G. Y. Frameworks on few-shot learning with applications in wireless communication. IEEE Trans. Signal Process. 73, 3857–3871 (2025).
Duchi, J. C., Agarwal, A. & Wainwright, M. J. Dual averaging for distributed optimization: Convergence analysis and network scaling. IEEE Trans. Automat. Contr. 57, 592–606 (2011).
Hosseini, S., Chapman, A. & Mesbahi, M. Online distributed optimization via dual averaging. In Conference on Decision and Control (CDC) 1484–1489 (IEEE, 2013).
Tsianos, K. I., Lawlor, S. & Rabbat, M. G. Push-sum distributed dual averaging for convex optimization. In Conference on Decision and Control (CDC) 5453–5458 (IEEE, 2012).
Pu, S., Shi, W., Xu, J. & Nedić, A. Push–pull gradient methods for distributed optimization in networks. IEEE Trans. Autom. Control. 66, 1–16 (2020).
Taylor, G. et al. Training neural networks without gradients: a scalable ADMM approach. In International Conference on Machine Learning (ICML) 2722–2731 (PMLR, 2016).
Wang, J., Yu, F., Chen, X. & Zhao, L. ADMM for efficient deep learning with global convergence. In Proc. ACM SIGKDD Int. Conf. Knowl. Discov. Data Min. 111–119 (2019).
Ebrahimi, Z., Batista, G. & Deghat, M. AA-DLADMM: an accelerated ADMM-based framework for training deep neural networks. Preprint at https://arxiv.org/abs/2401.03619 (2024).
Zeng, J., Lin, S.-B., Yao, Y. & Zhou, D.-X. On ADMM in deep learning: convergence and saturation-avoidance. J. Mach. Learn. Res. 22, 9024–9090 (2021).
Ouyang, H., He, N., Tran, L. & Gray, A. Stochastic alternating direction method of multipliers. In International Conference on Machine Learning (ICML) 80–88 (PMLR, 2013).
Ding, J. et al. Stochastic ADMM based distributed machine learning with differential privacy. In SecureComm 257–277 (Springer, 2019).
Zhong, W. & Kwok, J. Fast stochastic alternating direction method of multipliers. In International Conference on Machine Learning (ICML) 46–54 (PMLR, 2014).
Huang, F., Chen, S. & Huang, H. Faster stochastic alternating direction method of multipliers for nonconvex optimization. In International Conference on Machine Learning (ICML) 2839–2848 (PMLR, 2019).
Liu, Y., Shang, F. & Cheng, J. Accelerated variance reduced stochastic ADMM. In Association for the Advancement of Artificial Intelligence (AAAI) 2287–2293 (AAAI, 2017).
Zeng, Y., Wang, Z., Bai, J. & Shen, X. An accelerated stochastic ADMM for nonconvex and nonsmooth finite-sum optimization. Automatica 163, 111554 (2024).
Li, T. et al. Federated optimization in heterogeneous networks. MLSys 2, 429–450 (2020).
Kairouz, P. et al. Advances and open problems in federated learning. Found. Trends Mach. Learn. 14, 1–210 (2021).
Li, Q., Diao, Y., Chen, Q. & He, B. Federated learning on non-IID data silos: an experimental study. In International Conference on Data Engineering (ICDE) 965–978 (IEEE, 2022).
Ye, M., Fang, X., Du, B., Yuen, P. C. & Tao, D. Heterogeneous federated learning: state-of-the-art and research challenges. ACM Comput. Surv. 56, 1–44 (2023).
Krizhevsky, A., Sutskever, I. & Hinton, G. E. ImageNet classification with deep convolutional neural networks. In Advances in Neural Information Processing Systems (NeurIPS) 1097–1105 (Curran Associates, Inc., 2012).
Krizhevsky, A., Nair, V. & Hinton, G. CIFAR-10 (University of Toronto, 2010); https://www.cs.toronto.edu/~kriz/cifar.html
Zhuang, J. et al. AdaBelief optimizer: adapting stepsizes by the belief in observed gradients. In Advances in Neural Information Processing Systems (NeurIPS) 33, 18795–18806 (Curran Associates, Inc., 2020).
Simonyan, K. & Zisserman, A. Very deep convolutional networks for large-scale image recognition. In International Conference on Learning Representations (ICLR, 2014).
He, K., Zhang, X., Ren, S. & Sun, J. Identity mappings in deep residual networks. In European Conference on Computer Vision (ECCV) 630–645 (Springer, 2016).
Huang, G., Liu, Z., Van Der Maaten, L. & Weinberger, K. Q. Densely connected convolutional networks. In Proc. IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 4700–4708 (IEEE, 2017).
Radford, A. et al. Language models are unsupervised multitask learners. OpenAI Blog 1, 9 (2019).
Goodfellow, I. et al. Generative adversarial nets. In Advances in Neural Information Processing Systems (NeurIPS) 2672–2680 (Curran Associates, Inc., 2014).
Salimans, T. et al. Improved techniques for training GANs. In Advances in Neural Information Processing Systems (NeurIPS) 2234–2242 (Curran Associates, Inc., 2016).
Arjovsky, M., Chintala, S. & Bottou, L. Wasserstein generative adversarial networks. In International Conference on Machine Learning (ICML) 214–223 (PMLR, 2017).
Heusel, M., Ramsauer, H., Unterthiner, T., Nessler, B. & Hochreiter, S. GANs trained by a two time-scale update rule converge to a local Nash equilibrium. In Advances in Neural Information Processing Systems (NeurIPS) 6626–6637 (Curran Associates, Inc., 2017).
Li, X.-L. Preconditioned stochastic gradient descent. IEEE Trans. Neural Netw. Learn. Syst. 29, 1454–1466 (2017).
Agarwal, N. et al. Efficient full-matrix adaptive regularization. In International Conference on Machine Learning (ICML) 102–110 (PMLR, 2019).
Yong, H., Sun, Y. & Zhang, L. A general regret bound of preconditioned gradient method for DNN training. In Proc. IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 7866–7875 (IEEE, 2023).
Wang, J. & Joshi, G. Cooperative SGD: a unified framework for the design and analysis of local-update SGD algorithms. J. Mach. Learn. Res. 22, 1–50 (2021).
Levenberg, K. A method for the solution of certain non-linear problems in least squares. Q. Appl. Math. 2, 164–168 (1944).
Marquardt, D. W. An algorithm for least-squares estimation of nonlinear parameters. J. Soc. Ind. Appl. Math 11, 431–441 (1963).
Li, D.-H., Fukushima, M., Qi, L. & Yamashita, N. Regularized newton methods for convex minimization problems with singular solutions. Comput. Optim. Appl. 28, 131–147 (2004).
Polyak, R. A. Regularized Newton method for unconstrained convex optimization. Math. Program. 120, 125–145 (2009).
Kovarik, Z. Some iterative methods for improving orthonormality. SIAM J. Numer. Anal. 7, 386–389 (1970).
Björck, Å & Bowie, C. An iterative algorithm for computing the best estimate of an orthogonal matrix. SIAM J. Numer. Anal. 8, 358–364 (1971).
Deng, J. et al. ImageNet: a large-scale hierarchical image database. In Proc. IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 248–255 (IEEE, 2009).
Krizhevsky, A. & Hinton, G. Learning Multiple Layers of Features from Tiny Images. Technical Report TR-2009-0 (University of Toronto, 2009).
LeCun, Y., Bottou, L., Bengio, Y. & Haffner, P. Gradient-based learning applied to document recognition. Proc. IEEE 86, 2278–2324 (2002).
Xiao, H., Rasul, K. & Vollgraf, R. Fashion-MNIST: a novel image dataset for benchmarking machine learning algorithms. Preprint at arXiv:1708.07747 (2017).
Chang, C.-C. & Lin, C.-J. Libsvm: a library for support vector machines. ACM Trans. Intell. Syst. Technol. 2, 27:1–27:27 (2011).
Penedo, G. et al. The fineweb datasets: decanting the web for the finest text data at scale. In Advances in Neural Information Processing Systems (NeurIPS) 37, 30811–30849 (Curran Associates, Inc., 2024).
Marcus, M., Santorini, B. & Marcinkiewicz, M. A. Building a large annotated corpus of English: the Penn Treebank. Comput. Linguist. 19, 313–330 (1993).
Wang, O. Preconditioned inexact stochastic ADMM for deep models (source code). Code Ocean https://doi.org/10.24433/CO.3435996.v1 (2025).
Mukkamala, M. C. & Hein, M. Variants of RMSPprop and Adagrad with logarithmic regret bounds. In International Conference on Machine Learning (ICML) 2545–2553 (PMLR, 2017).
Zou, F., Shen, L., Jie, Z., Zhang, W. & Liu, W. A sufficient condition for convergences of Adam and RMSProp. In Proc. IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 11127–11135 (IEEE, 2019).
Reddi, S. J. et al. Adaptive federated optimization. In International Conference on Learning Representations (ICLR, 2021).
Liu, L. et al. On the variance of the adaptive learning rate and beyond. In International Conference on Learning Representations (ICLR, 2020).
Balles, L. & Hennig, P. Dissecting Adam: the sign, magnitude and variance of stochastic gradients. In International Conference on Machine Learning (ICML) 404–413 (PMLR, 2018).
Zaheer, M., Reddi, S., Sachan, D., Kale, S. & Kumar, S. Adaptive methods for nonconvex optimization. In Advances in Neural Information Processing Systems (NeurIPS) 31, 9815–9825 (Curran Associates, Inc., 2018).
Acknowledgements
S.Z. is partly supported by the National Key Research and Development Program of China (grant no. 2023YFA1011100), the Fundamental Research Funds for the Central Universities and the Talent Fund of Beijing Jiaotong University. Z.L. acknowledges support from the National Natural Science Foundation of China (grant no. 12271022) and National Key Research and the Development Program of China (grant no. 2024YFA1012901).
Author information
Authors and Affiliations
Contributions
S.Z. and O.W. conceived the research. S.Z. developed the algorithms, established all theoretical results, guided the experiments and wrote the paper. O.W. designed and carried out all experiments and prepared the results section. Z.L. checked the theoretical proofs, revised the paper and offered technical insights. Y.Z. and G.Y.L. contributed to discussions and assisted with article preparation.
Corresponding authors
Ethics declarations
Competing interests
The authors declare no competing interests.
Peer review
Peer review information
Nature Machine Intelligence thanks the anonymous reviewers for their contribution to the peer review of this work.
Additional information
Publisher’s note Springer Nature remains neutral with regard to jurisdictional claims in published maps and institutional affiliations.
Supplementary information
Supplementary Information
Supplementary Figs. 1–4, Discussion and Tables 1–5.
Rights and permissions
Open Access This article is licensed under a Creative Commons Attribution 4.0 International License, which permits use, sharing, adaptation, distribution and reproduction in any medium or format, as long as you give appropriate credit to the original author(s) and the source, provide a link to the Creative Commons licence, and indicate if changes were made. The images or other third party material in this article are included in the article’s Creative Commons licence, unless indicated otherwise in a credit line to the material. If material is not included in the article’s Creative Commons licence and your intended use is not permitted by statutory regulation or exceeds the permitted use, you will need to obtain permission directly from the copyright holder. To view a copy of this licence, visit http://creativecommons.org/licenses/by/4.0/.
About this article
Cite this article
Zhou, S., Wang, O., Luo, Z. et al. Preconditioned inexact stochastic ADMM for deep models. Nat Mach Intell (2026). https://doi.org/10.1038/s42256-026-01182-3
Received:
Accepted:
Published:
Version of record:
DOI: https://doi.org/10.1038/s42256-026-01182-3


