Introduction

Deep Convolutional Neural Networks(DCNNs) have achieved groundbreaking advancements in domains such as computer vision, natural language processing, and speech recognition. In computer vision, DCNNs have achieved state-of-the-art performance in image classification, object detection, semantic segmentation, and image generation1,2,3,4,5,6. Recent research has also extended to medical image encryption, demonstrating the potential in enhancing security and efficiency7. In natural language processing, DCNNs combined with recurrent neural networks (RNNs) have significantly improved machine translation and text summarization systems8. Recent researches have also shown promise in modeling complex linguistic and cognitive processes9,10. In addition, in speech recognition, the use of DCNNs has led to more accurate acoustic models, contributing to better performance in automatic speech recognition systems11.

However, with the deep learning tasks becoming increasingly complex, DCNNs tend to grow larger in scale, demanding dramatic computational resources12,13. Model compression accelerates the inference process by reducing the model scale and computational complexity, thereby facilitating deployment and real-time inference in practical applications. Research focusing on model compression techniques for DCNNs has become indispensable. It plays a pivotal role in making deep learning feasible on resource-constrained platforms such as mobile devices, embedded systems, or edge computing devices. The model compression techniques primarily conclude network pruning, parameter quantization, low-rank decomposition, and knowledge distillation14. Network pruning and knowledge distillation are the two most widely used techniques among them.

Network pruning effectively reduces the model parameters and model computation by pruning the units while maintaining the original model performance as much as possible15. The network pruning process typically involves three steps: (1) Evaluate the importance of units and select the unimportant units for pruning. (2) According to the rule-based prune strategy, iteratively removes relatively unimportant units to reduce the model scale. (3) Fine-tuning the pruned model to improve model performance. The unit importance evaluation function, pruning strategy, and fine-tuning method are key to the effectiveness of network pruning. This paper focuses on iterative filter pruning of DCNNs, considering that when the pruning unit is the filter in convolutional layers, it can be better matched with hardware accelerators to achieve efficient computations, such as FPGAs and ASICs16,17. In the experiments, we define six explainable feature importance evaluation functions to identify the relatively unimportant filters. By comparing the outcomes of different evaluation functions, the most effective one can be selected for a given scenario or even combined. This approach aims to demonstrate that the proposed method can be more adaptable and generalizable by using multiple evaluation functions.

Traditional rule-based network pruning is limited by the rules set by experts, making it difficult to achieve a globally optimal solution. These methods often rely heavily on domain-specific knowledge and can be highly subjective, limiting their adaptability and generalization across different network architectures. Reinforcement Learning (RL)–based automatic pruning offers a promising alternative by mitigating these issues and reducing subjectivity. RL enables the discovery of optimal pruning policies through interaction with the environment, allowing for more dynamic and adaptive pruning strategies. However, current RL-based research primarily uses a single agent to explore the optimal pruning rate for all convolutional layers, neglecting the complex interactions and effects between multiple layers18,19,20.

To address these limitations, this study proposes a novel approach that integrates Multi-Agent Reinforcement Learning (MARL) with Deep Convolutional Neural Networks (DCNNs) for more effective pruning. Six filter importance evaluation functions are defined to capture different aspects of filter importance. These functions serve as the basis for our MARL framework, where each agent is responsible for optimizing the pruning rate of a specific layer. By using the MARL algorithm, the agents can collaborate effectively, taking into account the interdependencies between different layers.

The QMIX algorithm21 is a well-established benchmark in MARL, known for its ability to solve global optimization problems through Centralized Training with Decentralized Execution (CTDE). QMIX stands for “Q-value Mixing,” which refers to the process of combining the Q-values of individual agents to form a joint Q-value that represents the overall utility or reward of the team’s actions. Motivated by the principles of MARL and the effectiveness of the QMIX algorithm, we propose a DCNN pruning framework based on the QMIX algorithm, as illustrated in Fig. 1. Each convolutional layer is treated as an agent, and the entire DCNN serves as the environment, forming a multi-agent system. Thus, the network pruning task is modeled as a cooperative game. The common goal of all agents is to ensure that the pruned network meets the pruning rate constraint while minimizing the model’s prediction loss as much as possible.

Fig. 1
figure 1

The environment of DCNN network pruning based on QMIX. Each convolutional layer corresponds to an agent; the whole DCNN is an environment. QMIX consists of L agent networks and a mixing network. The optimal pruning strategy of DCNN is explored using the QMIX algorithm.

In addition, Knowledge Distillation(KD)22 draws upon the educational concept of transferring knowledge from teacher to student. It transfers the knowledge of a large-scale Teacher Network (T-Net) to a small-scale Student Network (S-Net), where the T-Net is pre-trained with good performance. Therefore, the S-Net can efficiently learn the generalization of the T-Net while remaining small in scale. Motivated by this, we fine-tune the pruned model utilizing KD by setting the original model as T-Net and the pruned model as S-Net.

In conclusion, our primary contribution can be summarized as follows:

  • We propose a filter pruning method based on the multi-agent reinforcement learning algorithm QMIX to explore the optimal pruning rate for each agent(layer). It is an automatic network pruning method.

  • We design an algorithm that integrates network pruning with knowledge distillation, i.e., fine-tuning the pruned model based on KD to accelerate performance improvement.

  • We validate the efficiency of the proposed method for DCNN pruning on CIFAR-10 and CIFAR-100 datasets, including VGG-16 and AlexNet. The experimental results show that the QMIX_FP achieves competitive parameter compression and the FLOPs reduction while maintaining accuracy similar to the original network.

The layout of this paper is organized as follows. We overview related works about existing network pruning methods in the Sect. “Related work.” Section “Methodology” describes our approach to automatic filter pruning based on multi-agent reinforcement learning. We present experimental results and provide a comprehensive analysis of the findings in the Sect. “Experiments,” followed by the Sect. “Conclusions.”

Related work

Network pruning has attracted significant attention in deep learning as an effective model compression technique. It reduces the model scale and computational complexity by removing redundant weights or entire structural components within the neural network while maintaining the performance of the original model to the greatest extent. Research on network pruning has experienced rapid development, including exploring pruning strategies, optimization objectives, automation methods, and so on23,24.

According to the granularity of the pruning unit, network pruning methods can be divided into unstructured pruning (fine-grained) and structured pruning (coarse-grained)14. Unstructured pruning methods can remove neuron connections (weights) independently. In 2015, Han et al.25 conducted the first real research on deep neural network pruning. It prunes the connections whose absolute values of weight parameters are lower than the threshold and proposes the three-stage iterative pruning process for the first time. This method, however, can lead to highly sparse models that are difficult to deploy on standard hardware. This team further proposed the “Deep Compression” method26. They implemented weight quantization to facilitate weight sharing and applied Huffman coding to achieve better compression. Despite its effectiveness, the method requires additional steps like quantization and coding, which can be complex and time-consuming. Yang et al.27 proposed the energy-aware weight pruning algorithm for CNNs that directly uses energy consumption estimation to guide the pruning process and then locally fine-tuned with a closed-form least-square solution to quickly restore the accuracy. Sanh et al.28 proposed movement pruning, a method that selects weights that tend to move away from zero and applies it to pre-trained language representations (such as BERT). However, these approaches may not generalize well to different hardware platforms.

Conversely, structured pruning focuses on removing entire channels, filters, or layers. Structured pruning methods maintain the regularity of the model structure, making them more hardware-friendly and easier to implement on existing platforms. However, they may not achieve the same level of compression as unstructured pruning. He et al.29 introduced an iterative two-step algorithm to channel prune by a LASSO regression-based channel selection and least square reconstruction. This method can effectively reduce the model size but may sacrifice some accuracy. Li et al.30 pruned filters together with their connecting feature maps from CNNs to reduce the computation costs significantly and can work with existing efficient BLAS libraries for dense matrix multiplications. Liu et al.31 made the scaling factors in batch normalization layers towards zero with \(L_1\) regularization to identify insignificant channels and prune. This approach can lead to more compact models but requires careful tuning of hyperparameters. Zhuang et al.32 improved this evaluation method by proposing a polarization regularizer, pushing scaling factors of unimportant neurons to 0 and others to a value \(a > 0\). This method enhances the robustness of the pruning process. In summary, structured pruning methods maintain the regularity of the model structure, making them more hardware-friendly and easier to implement on existing platforms, which is why this paper chooses the filter-based structured pruning method.

According to the pruning process, network pruning methods can be categorized into traditional rule-based and automatic RL-based. Rule-based methods typically execute the three-stage iterative pruning process25, where the pruning rate at each step and the total number of iterative are designed in advance33,34. Instead of being confined by subjective rules and hyper-parameter design, some researchers have formulated the network pruning problem as an RL problem to explore optimal solutions automatically. He et al.18 proposed AutoML for Model Compression (AMC), which leverages reinforcement learning to provide the pruning strategy layer-wise. Gupta et al.20 built upon AMC and contributed an improved framework based on Deep Q-Network(DQN) that provides rewards at every pruning step. Camci et al.35 introduced a deep Q-learning-based pruning method called QLP, which is weight-level unstructured pruning. Feng et al.19 proposed a filter-level pruning method based on the Deep Deterministic Policy Gradient(DDPG) algorithm. Unfortunately, in existing single-agent RL-based pruning methods, the agent prunes only one layer at each time step, which fails to capture the complex dependencies among multiple layers. This limitation can lead to suboptimal pruning decisions, as the performance of one layer can significantly influence that of subsequent layers. In contrast, multi-agent reinforcement learning (MARL) can better handle these challenges by allowing multiple agents to collaborate and make decisions, thereby improving the overall pruning efficiency and effectiveness.

Existing research has been deeply involved in network pruning for DCNNs, but there are some limitations and untapped potentials in this area. In this paper, we aim to fill the gap in the existing work and propose automatic filter pruning by MARL, which has been under-explored.

Methodology

Our proposed network pruning method, named QMIX_FP, is shown in Fig. 2. This figure provides a comprehensive overview of our approach. Fig. 2b demonstrates explicitly how we utilize the MARL algorithm QMIX for pruning. Here, each agent aligns with a convolutional layer in the DCNN. The network architecture at time step \(t\) is represented by the state \(s_t\), while \(a_t\) represents the pruning decision made at that time step. The agents progressively test different pruning techniques until the target pruning rate is met, ultimately shaping the pruned network. Figure 2a breaks down the pruning steps at each time step. Once the pruning target is reached, Fig. 2c outlines the distillation phase, where we refine the pruned network and enhance its performance through knowledge distillation. In this section, we initially frame the task of filter-level pruning for DCNNs. We then provide a thorough examination of QMIX_FP, detailing our feature evaluation metrics, the QMIX-driven pruning mechanism, and the KD-based fine-tuning process.

Fig. 2
figure 2

An overview of our approach QMIX_FP. (a) Filters pruning process of Markov Decision Process. (b) QMIX-based pruning architecture. (c) Knowledge distillation-based fine-tuning structure.

Assume that DCNN contains L convolutional layers. For the i-th convolutional layer \(L_i\), the weight of all filters is denoted as \(\varvec{w}_i \in \mathbb {R}^{n_{i} \times n_{i-1} \times d_i \times d_i}\), where \(d_i\) is the kernel size, \(n_{i-1}\) and \(n_{i}\) are input channels and output channels, respectively. \(\varvec{w}_{i,j} \in \mathbb {R}^{n_{i-1} \times d_i \times d_i}\) indicates the j-th filter of \(L_i\). Meanwhile, we denote the bias of \(L_i\) is \(\varvec{b}_i\), and the remaining parameter of DCNN other than the convolutional layer is \(\varvec{w}_0\). Thereby, the parameter of DCNN representes as \(W=\left\{ \varvec{w}_0, \left( \varvec{w}_1, \varvec{b}_1\right) ,\left( \varvec{w}_2, \varvec{b}_2\right) , \ldots , \left( \varvec{w}_L, \varvec{b}_L\right) \right\}\). Consider the classification dataset \(D = \left\{ X = \left\{ \varvec{x}_0, \varvec{x}_1, \ldots , \varvec{x}_{N-1}\right\} , Y=\left\{ y_0, y_1, \ldots , y_{N-1}\right\} \right\}\), where \(y_n\) is the ground truth of input \(\varvec{x}_n\) and the number of samples is N. For the image classification task, the optimization objective during model training is to minimize the prediction loss function

$$\begin{aligned} \mathcal {L}(D \mid W) =\frac{1}{N}\sum _{n=0}^{N-1} \mathbb {I}(\hat{y}_n \ne y_n), \end{aligned}$$
(1)

where \(\mathbb {I}(\cdot )\) indicator function and \(\hat{y}_n\) is the predict label of input \(\varvec{x}_n\).

The filter-level pruning involves removing relatively unimportant filters from the convolutional layers without changing the topology structure and depth of DCNN. After pruning, the parameters of the pruned network are updated to \(W^{\prime }=\left\{ \varvec{w}_0^{\prime }, \left( \varvec{w}_1^{\prime }, \varvec{b}_1^{\prime }\right) ,\left( \varvec{w}_2^{\prime }, \varvec{b}_2^{\prime }\right) , \ldots , \left( \varvec{w}_L^{\prime }, \varvec{b}_L^{\prime }\right) \right\}\). To apply mathematical notation for description conveniently, we constrain the \(W^{\prime }\) to be the same dimension as the W. Thus, the optimization objective of the filter pruning task is to minimize the prediction loss of the pruned network when satisfying the overall pruning rate constraint B, i.e.,

$$\begin{aligned} \min _{W^{\prime }}\mathcal {L}\left( D \mid W\right) -\mathcal {L}(D \mid W^{\prime }) \quad { s.t. } \frac{\left\| W^{\prime }\right\| _0}{\left\| W\right\| _0} \le 1 - B, \end{aligned}$$
(2)

where \(\left\| \cdot \right\| _0\) is \(l_0\) norm. Intuitively, the proportion of filters retained in \(W^{\prime }\) should be less than or equal to \(1-B.\) Here, \(B\) represents the overall pruning rate constraint for the entire DCNN.

Feature evaluation

The feature evaluation function for filter importance computing directly influences the effectiveness of DCNN network pruning with the given pruning rate. Filters play an important role in image understanding and feature extraction in DCNN. The input image data undergoes convolution with filters in the convolutional layer, resulting in feature maps, where each filter corresponds to a feature map. Typically, an activation function is applied to these feature maps to introduce nonlinearity and increase the expressiveness of the model, resulting in activation maps. In this work, we choose the classic ReLU (Rectified Linear Unit) as the activation function.

Therefore, the importance of filters can be evaluated based on their inherent properties or the information representations generated by the filters, such as feature maps and activation maps. We denote the feature map produced by filter \(\varvec{w}_{i}\) as \(F_{i}=\left\{ \varvec{F}_{i,1},\varvec{F}_{i,2},\ldots ,\varvec{F}_{i,n_i}\right\}\) and the activation map produced by \(F_{i}\) as \(A_{i}=\left\{ \varvec{A}_{i,1},\varvec{A}_{i,2},\ldots ,\varvec{A}_{i,n_i}\right\}\). Considering the explainability of network pruning, we define six feature importance evaluation functions based on feature attribution theory in explainable artificial intelligence36 to demonstrate the adaptability and generalizability of our proposed method. According to the objects used for calculating the importance of filter \(\varvec{w}_{i,j}\), we classify the explainable evaluation functions \(\textrm{Imp}(\cdot )\) into three categories: filter-based evaluation function \(\textrm{Imp}(\varvec{w}_{i,j})\), feature map-based evaluation function \(\textrm{Imp}(\varvec{F}_{i,j})\), and activation map-based evaluation function \(\textrm{Imp}(\varvec{A}_{i,j})\). The larger of \(\textrm{Imp}(\varvec{w}_{i,j})\) implies that the filter plays a more significant role in the reference process. Likewise, the larger \(\textrm{Imp}(\varvec{F}_{i,j})\) or \(\textrm{Imp}(\varvec{A}_{i,j})\) is, the more important the corresponding filter \(\varvec{w}_{i,j}\) is.

Filter-based

  • \(l_2\) Norm of Filter Weight

It’s most intuitive to directly evaluate the importance based on the filter’s weight parameters. We define that

$$\begin{aligned} \mathrm{Imp_{Weight}}(\varvec{w}_{i,j})=\frac{1}{|\varvec{w}_{i,j}|} \sum _{m=1}^{|\varvec{w}_{i,j}|} w_m^2, \end{aligned}$$
(3)

where \(|\varvec{w}_{i,j}|\) denotes the modulus of set obtained by \(\varvec{w}_{i,j}\) vectorization.

  • Taylor Expansion of Loss Function

The extent of a filter’s contribution is indirectly determined by calculating the change of loss function \(\mathcal {L}(D \mid W)\) before and after removing this filter from DCNN. A large change in the loss function signifies that this filter is of high importance. To reduce the computational complexity, we utilize the first-order Taylor expansion of the loss function to approximate it. Thus, we get the approximation importance of each filter only by computing the gradient of \(\mathcal {L}(D \mid W)\) once. The criterion for calculation is

$$\begin{aligned} \begin{aligned} \mathrm{Imp_{Taylor} }(\varvec{w}_{i,j})&= \left| \mathcal {L}\left( D| \varvec{w}_{i,j}\right) -\mathcal {L}\left( D | \varvec{w}_{i,j}=\varvec{0}\right) \right| \\&= \left| \mathcal {L}\left( D| \varvec{w}_{i,j} =\varvec{0}\right) + \frac{\partial \mathcal {L}\left( D| \varvec{w}_{i,j} =\varvec{0}\right) }{\partial \varvec{w}_{i,j}} \varvec{w}_{i,j} + R_1\left( \varvec{w}_{i,j}\right) -\mathcal {L}\left( D| \varvec{w}_{i,j}=\varvec{0} \right) \right| \\&\approx \left| \frac{\partial \mathcal {L}\left( D| \varvec{w}_{i,j} =\varvec{0}\right) }{\partial \varvec{w}_{i,j}} \varvec{w}_{i,j} \right| , \end{aligned} \end{aligned}$$
(4)

where \(\varvec{w}_{i,j}=\varvec{0}\) indicates that the filter \(\varvec{w}_{i,j}\) is removed, \(R_1\left( \varvec{w}_{i,j}\right)\) is the remainder term, negligible.

Feature map-based

  • Information Bottleneck Value

Motivated by the Information Bottleneck(IB) principle37, we calculate the information bottleneck value of the feature map \(\varvec{F}_{i,j}\), i.e.,

$$\begin{aligned} \mathrm{Imp_{IB}}(\varvec{F}_{i,j}) = \mathbb {I}(\varvec{F}_{i,j};Y) - \beta \mathbb {I}(\varvec{F}_{i,j};X), \end{aligned}$$
(5)

where \(\mathbb {I}(\cdot \hspace{5.0pt}; \hspace{5.0pt}\cdot )\) is mutual information, and \(\beta \in [0,1]\) is a balancing factor for the trade-off between information compression and information retention. For simplicity, we use Information Gain (IG), derived from information entropy \(\mathbb {H}(\cdot )\), to estimate mutual information, i.e.,

$$\begin{aligned} {\left\{ \begin{array}{ll} \mathbb {I}(\varvec{F}_{i,j};Y) = \textrm{IG}(Y , \varvec{F}_{i,j}) = \mathbb {H}(Y)-\mathbb {H}(Y \mid \varvec{F}_{i,j})\\ \mathbb {I}(\varvec{F}_{i,j};X) = \textrm{IG}(\varvec{F}_{i,j} , X) = \mathbb {H}(\varvec{F}_{i,j} )-\mathbb {H}(\varvec{F}_{i,j} \mid X ). \end{array}\right. } \end{aligned}$$
  • \(\gamma\) of Batch Normalization

Batch Normalization layers are typically inserted immediately after the convolution layer but before the activation layer. It ensures that the inputs to the activation functions have a more stable distribution. Assume \(z_\textrm{in}\) and \(z_\textrm{out}\) are the input and output of the Batch Normalization(BN) layer. The BN layer performs the following transformation:

$$\begin{aligned} z_{\text {out}} = \frac{z_{\text {in}} - \mathbb {E}[z_{\text {in}}]}{\sqrt{\operatorname {Var}[z_{\text {in}}] + \epsilon }} \cdot \gamma + \beta . \end{aligned}$$
(6)

where \(\mathbb {E}[z_\textrm{in}]\) and \(\operatorname {Var}[z_\textrm{in}]\) are the mean and the variance of the input, and \(\epsilon\) is a small constant to avoid division by zero. We refer to the channel importance evaluation criterion proposed by Liu et al.31 and define the importance of \(\varvec{F}_{i,j}\) as

$$\begin{aligned} \mathrm{Imp_{BN}} (\varvec{F}_{i,j}) = \gamma . \end{aligned}$$
(7)

If \(\gamma\) is close to 0, it means that \(\varvec{F}_{i,j}\) has significantly been reduced in scale. On the contrary, the large \(\gamma\) means \(\varvec{F}_{i,j}\) is enlarged in scale, which indicates that the influence on feature learning and information transfer is more significant.

  • Gradient

The gradient of the loss function with respect to the feature map reflects the extent to which small perturbations in \(\varvec{F}_{i,j}\) affect prediction loss. We define that

$$\begin{aligned} \mathrm{Imp_{Gradient}} (\varvec{F}_{i,j}) = \left| \frac{\partial \mathcal {L}(D \mid W)}{\partial \varvec{F}_{i,j}} \right| . \end{aligned}$$
(8)

Activation map-based

  • Average Percentage of Positive

This study refers to the Average Percentage of Zeros (APoZ) criterion proposed by Hu et al.38, and infers that the average percentage of positive value in the activation map \(\varvec{A}_{i,j}\) can evaluate the corresponding filter \(\varvec{w}_{i,j}\)’s importance, i.e.,

$$\begin{aligned} \mathrm{Imp_{APoP}}(\varvec{A}_{i,j}) = 1 - \frac{\sum _{n} \sum _{m} \mathbb {I}(a_{n, m}=0)}{N \times M} , \end{aligned}$$
(9)

where \(\mathbb {I}(\cdot )\) is indicator function. The larger \(\mathrm{Imp_{APoP}}(\varvec{A}_{i,j})\) means the more important filter \(\varvec{w}_{i,j}\).

Network pruning

As shown in Fig. 1, we explore the DCNN pruning strategy based on multi-agent reinforcement learning QMIX. The interaction process between agents and environment is described by a Markov Decision Process(MDP), i.e., a tuple \((\mathcal {S},\mathcal {A}, P, R,\gamma )\), which indicates the state space, the action space, the rewards, the state-transition probability matric, and the penalization factors, respectively. Here, it illustrates the first four elements within the tuple in detail.

  • State Representation

At time step t, the global state of the environment notes \(s_t=\{s_t^1,s_t^2,\ldots ,s_t^L \}\), where \(s_t^i\) is the state of agent i, which is also the state of the i-th convolutional layer. We characterize the state \(s_t^i\) with six features:

$$\begin{aligned} s_t^i = (n_i, n_{i-1}, \mathrm{Kernel\_{size}}_i, \mathrm{Mean\_{weight}}_i, \textrm{FLOPs}_i, \textrm{Params}_i), \end{aligned}$$
(10)

where \(n_i\) is the number of filters in the i-th layer, which is also the input channel of the \((i+1)\)-th layer. \(n_{i-1}\) is the input channel of the i-th layer, which is also the number of filters in the \((i-1)\)-th layer. \(\mathrm{Kernel\_{size}}_i\) is the kernel size and \(\mathrm{Mean\_{weight}}_i\) is the mean of weight. \(\textrm{FLOPs}_i\) and \(\textrm{Params}_i\) denote the Floating-point Operations(FLOPs) and the Parameters(Params) of the i-th convolutional layer, respectively.

  • Action Space

In our research, each agent corresponds to a single convolutional layer within the neural network, with its action defined as the pruning rate applied to that layer, represented by a simple numerical value. To facilitate precise control over the pruning process, we establish the action space as a discrete set of possible pruning rates, where \(\mathcal {A}=[0\%,5\%,10\%,15\%,20\%]\). At time step \(t\), the collective action across all layers is represented as \(a_t = \{a_t^1, a_t^2, \ldots , a_t^L\}\), where \(a_t^i\) indicates the pruning rate selected for the \(i\)-th convolutional layer from the predefined set.

These pruning rates are not static and can evolve throughout the reinforcement learning process. The ultimate goal is to iteratively fine-tune these rates to achieve a targeted overall pruning level \(B\) for the entire deep convolutional neural network (DCNN). This tuning is done with the dual goals of minimizing the network’s prediction error and optimizing its performance, ensuring that the pruned model remains both efficient and accurate.

  • State Transition

At time step t, agent i receives the current state \(s_t^i\) of the i-th convolutional layer. Firstly, the importance scores of all filters are evaluated and sorted. Then, agent i prunes the filters with lower importance according to the current action \(a_t^i\), which is output by agent network i. Notably, all agents perform this pruning process simultaneously. Finally, we fine-tune the pruned network and receive the global state of the environment \(s_{t+1}\). The above process describes a state transition in the environment, as shown in the upper part of Fig. 2a, which is repeated until the end of the episode.

  • Reward Function

Instant reward directly induces DCNN pruning toward optimality. Network pruning aims to minimize model performance drop while achieving the expected pruning rate. Therefore, we define instant reward as a linear combination of test accuracy change and network sparsity change, where the sparsity is the ratio of the total filters in the pruned network to the original network.

$$\begin{aligned} \begin{aligned} R_t = \lambda * (\textrm{TestAcc }_t-\textrm{TestAcc }_{t-1} ) + \mu * (\textrm{Sparsity}_t-\textrm{Sparsity}_{t-1}), \end{aligned} \end{aligned}$$
(11)

where \(\lambda\) and \(\mu\) are hyperparameters, \(\textrm{TestAcc}_t\) and \(\textrm{Sparsity}_t\) are prediction accuracy on the test dataset and network sparsity at time step t, respectively. In addition, the total return of an episode is the cumulative of instant rewards.

  • Loss Function

The MDP provides a mathematical framework for DCNN pruning. We solve the above MDP problem based on the QMIX algorithm, as shown in Fig. 2b. The QMIX architecture comprises \(L\) agent networks alongside a mixing network. Each agent \(i\) is equipped with its own agent network, which represents the state-action value function \(Q_i(\tau ^i, a^i)\), where \(\tau\) signifies the joint action-observation history. The mixing network, a feed-forward neural network, takes the outputs from the agent networks as inputs and combines them in a monotonic manner to generate the total value \(Q_{\text {tot}}\). Specifically, \(Q_{\text {tot}}\) is calculated as follows:

$$Q_{\text {tot}}(\tau , a, s) = \mathcal {M}(Q_1(\tau ^1, a^1), Q_2(\tau ^2, a^2), \ldots , Q_L(\tau ^L, a^L)),$$

where \(\mathcal {M}\) is the mixing function implemented by the mixing network. This function ensures that the combination of individual agent Q-values is monotonic and adheres to the problem’s constraints. s represents the global state, which is necessary for centralized training.

To train the QMIX model, we implement a Temporal-Difference loss function aimed at minimizing the discrepancy between the predicted Q-values and the target Q-values. The training is conducted in an end-to-end fashion to minimize the following loss function:

$$\begin{aligned} \mathcal {L}(\theta ) = \sum _{i=1}^b \left[ r+\gamma \max _{a^{\prime }} Q_{\textrm{tot}}\left( \tau ^{\prime }, a^{\prime }, s^{\prime } ; \theta ^{-}\right) -Q_{\textrm{tot}}(\tau , a, s ; \theta )\right] ^2, \end{aligned}$$
(12)

where b is the batch size of transitions sampled from buffer, \((\cdot )^{\prime }\) indicates the value of \((\cdot )\) at the next time step, \(\theta ^{-}\) are the parameter of target network.

Fine-tuning

The QMIX algorithm follows the CTDE training principle. After training it to convergence, each agent network makes decisions independently based on its own observations. Given the expected pruning rate, we prune the original pre-trained network according to the pruning strategy output from the trained QMIX’s agent networks. The pruned network undergoes a certain level of performance decline, reflecting degradation as a result of the pruning process.

To accelerate the performance improvement of the pruned network, we design a model compression framework that integrates network pruning and knowledge distillation, as shown in Fig. 2c. Specifically, we set the original pre-trained network as T-Net and the pruned network as S-Net. Then fine-tuning the pruned network based on knowledge distillation.

On the one hand, to guide S-Net to more fully learn the generalization ability of T-Net, we introduce label probability distributions with a temperature parameter T,

$$\begin{aligned} p_i\left( z_i, T\right) =\frac{\exp \left( z_i / T\right) }{\sum _{j=0}^{k-1} \exp \left( z_j / T\right) }, \end{aligned}$$
(13)

Compute the label probability distribution of S-Net and T-Net for \(T=t(t>1)\), respectively. We define the cross-entropy loss between two distributions as the distillation loss,

$$\begin{aligned} \mathcal {L}_D\left( p\left( z_t, T\right) , p\left( z_s, T\right) \right) = - \sum _{i=0}^{k-1} p_i\left( z_{t i}, T=t\right) \log \left( p_i\left( z_{s i}, T=t\right) \right) , \end{aligned}$$
(14)

where \(z_t\) and \(z_s\) are the fully connected layer outputs of T-Net and S-Net, respectively, and \(p_i\) denotes the probability of the i-th label. On the other hand, S-Net needs to learn not only the knowledge from T-Net but also the ground truth. We define the cross-entropy loss between label probability distributions for \(T=1\) and the ground truth as the student loss,

$$\begin{aligned} \mathcal {L}_S\left( y, p\left( z_s, T\right) \right) = - \sum _{i=0}^{k-1}y_i \log \left( p_i\left( z_{s i}, T=1\right) \right) . \end{aligned}$$
(15)

With the above discussion, the optimization objective of fine-tuning the pruned network based on KD is a linear combination of the distillation loss and the student loss,

$$\begin{aligned} \mathcal {L} =\alpha * \mathcal {L}_D\left( p\left( z_t, T\right) , p\left( z_s, T\right) \right) +(1-\alpha ) * \mathcal {L}_S\left( y, p\left( z_s, T\right) \right) , \end{aligned}$$
(16)

where \(\alpha\) is a hyperparameter that controls the weights of the two losses.

In summary, Algorithm 1 describes the process of automatic filter pruning based on QMIX.

Algorithm 1
figure a

QMIX_FP: Automatic DCNN filter pruning based on QMIX.

Experiments

Experimental settings

To demonstrate the efficiency of our proposed method QMIX_FP, we experiment on the classical deep convolutional neural networks VGG-1639 and AlexNet1. VGG-16 consists of 13 convolutional layers and 3 fully connected layers. The filters are distributed as 64, 64, 128, 128, 256, 256, 256, 512, 512, 512, 512, 512, 512. AlexNet consists of 5 convolutional layers and 3 fully connected layers. The filter distribution is 64, 192, 384, 256, 256. We test their performance over CIFAR-10 and CIFAR-100 datasets40 to further evaluate network pruning effectiveness.

In our experimental setting, we adopt widely used metrics to evaluate the network pruning effectiveness, including prediction accuracy, the number of filters, Parameters(Params), and Floating-point Operations(FLOPs). Notably, we set the prediction accuracy to be task-specific, i.e., top-1 accuracy on CIFAR-10 and top-5 accuracy on CIFAR-100.

For the original pre-trained DCNNs, we set the batch size as 128 and updated the network with the stochastic gradient descent algorithm, training for 200 epochs. The initial learning rate is 0.1, the decay rate is 1e-4, and the momentum is 0.9. For network pruning based on QMIX, we set that the discounted return is 0.99, and the \(\lambda\) and \(\mu\) in the reward function are 0.7 and 0.3, respectively. Moreover, the QMIX training optimizer is Adam, the learning rate is 0.01, the Buffer size is 512, the max episode is 500, and the max time step is 20 in an episode. The target network updates its parameters every 5 time steps. For network fine-tuning, we set the T in distillation loss as 2.5, the coefficient of distillation loss \(\alpha\) as 0.3, and training epochs as 100.

Results and analysis

We first conducted experimental analysis on the pruning result of VGG-16 on the benchmark dataset CIFAR-10. To further verify the robustness of our approach QMIX_FP, we successively replaced the DCNN with AlexNet and replaced the dataset with the more complex CIFAR-100, performing network pruning and analyzing results.

VGG-16 + CIFAR-10

To thoroughly investigate the performance and stability of the QMIX_FP under different target pruning rates, we set three expected pruning rates of 50%, 60%, and 70%, to observe the network compression results with various sparsity levels. Under the constraint of pruning rate, we prune VGG-16 with six filter importance evaluation functions, respectively. The results are shown in Table 1. The “Base” low is the metrics of the original pre-trained VGG-16. The “Acc. (pruned)” column shows the prediction accuracy of the pruned network. The “Acc. (fine-tuned)” column shows the fine-tuned prediction accuracy based on the knowledge distillation, and the data in square brackets shows the percentage of the prediction accuracy compared to the base. The “Filters” column indicates the number of filters. “M” denotes millions in the columns of FLOPs and Params columns.

Table 1 Pruning results of VGG-16 on CIFAR-10. The best pruning results are bolded, and the worst results are underlined.

Analyzing the prediction accuracy of the pruned network after fine-tuning, it’s shown that when the pruning rate is 50%, the pruned network maintains 99.60% of the original VGG-16 based on evaluation function \(\mathrm{Imp_{IB}}(\varvec{w}_{i,j})\), reaching the highest level. At a pruning rate of 60%, the pruned network maintains 98.73% of the original VGG-16, based on the average of the six importance evaluation methods. When the pruning rate climbs to 70%, the performance is 97.23% of the original. These results fully demonstrate the expected effect of network pruning, substantiating the efficacy of our proposed pruning method based on the QMIX algorithm and fine-tuning method based on knowledge distillation. Comparing six filter evaluation methods, we find that the pruned network obtained by the QMIX algorithm performs better only when evaluated based on the \(\gamma\) factor of the batch normalization layer. In contrast, it performs slightly less effectively when evaluated based on the APoP of the activation maps. What’s more, with the given pruning rate, the FLOPs and Params of the pruned network tend to be compressed by a larger amount. For example, when the pruning rate is 50%, the Params is compressed by about 77.50% based on evaluation function \(\mathrm{Imp_{Weight}}(\varvec{w}_{i,j})\); when the pruning rate is 60%, the FLOPs is compressed by about 76.65% based on evaluation function \(\mathrm{Imp_{Taylor}}(\varvec{w}_{i,j})\).

AlexNet + CIFAR-10

To verify that the network pruning applies to different DCNNs, we prune the pre-trained AlexNet and evaluate its performance on the benchmark dataset CIFAR-10. AlexNet’s architecture consists of five convolutional layers with a total of 1152 filters, resulting in fewer parameters and computations compared to VGG-16. We get a prediction accuracy of 0.8585 for the original pre-trained AlexNet on CIFAR-10, which is due to its lower complexity than VGG-16. The pruning result of AlexNet according to QMIX_FP is shown in Table 2, including indicators of changes in prediction accuracy and network structure parameters.

Table 2 Pruning results of AlexNet on CIFAR-10. The best pruning results are bolded, and the worst results are underlined.

We observe that when the pruning rate of AlexNet is at 50% or 60%, the pruned networks’ prediction accuracy can maintain approximately 99.00% of the original accuracy. At the pruning rate of 70%, the network’s performance is slightly degraded, but it still manages to maintain over 95.49% of its original accuracy. In comparison, when the feature evaluation function depends on the information bottleneck value of the filter’s feature map, the pruned network suffers a greater loss of accuracy. Further, when the network pruning rates are 50%, 60%, and 70%, the pruned network’s Params are approximately 31.19%, 23.22%, and 15.46% of the original network, respectively. The AlexNet achieves a higher level of compression in both space complexity and time complexity, which satisfies the requirements in practical application scenarios. Thus, it can be seen that the QMIX_FP algorithm is not only applicable to high-complexity networks but also effective to small-scale DCNNs, which demonstrates its broad applicability in neural network architectures of different sizes.

VGG-16 + CIFAR-100

One metric for evaluating the effectiveness of network pruning is the prediction accuracy achieved on image classification datasets; hence, the dataset is also an influencing factor. We still perform network pruning on VGG-16 but evaluate its performance on the more complex CIFAR-100 dataset, where the prediction accuracy is measured using the top-5 accuracy. As shown in Table 3, we find the pruned network obtained by \(\mathrm{Imp_{IB}}(\varvec{F}_{i,j})\) is generally worse compared to other evaluation functions. Fine-tuning the pruned network through knowledge distillation can improve the prediction accuracy to 95.51% - 99.05% of the original pre-trained model. However, the model exhibits significant performance degradation under higher pruning rate constraints. The experimental results show that our algorithm QMIX_FP is still effective on the complex dataset CIFAR-100.

Table 3 Pruning results of VGG-16 on CIFAR-100. The best pruning results are bolded, and the worst results are underlined.

To intuitively judge the rationality of exploring the optimal pruning strategy via the multi-agent reinforcement learning algorithm QMIX, we visualize the changes in the number of filters within each convolutional layer of VGG-16 and AlexNet both before and after pruning, as shown in Fig. 3. Figure a shows the filter number distribution of VGG-16 when the pruning rate is 50% where the evaluation function is \(\mathrm{Imp_{IB}}(\varvec{F}_{i,j})\). For the pruned network, the number of filters in each convolutional layer is 58, 64, 92, 128, 256, 256, 181, 109, 166, 247, 195, 116, 166. Figure b shows the filter number distribution of AlexNet when the pruning rate is 50% where the evaluation function is \(\mathrm{Imp_{Gradient}}(\varvec{F}_{i,j})\). After pruning, the number of filters in each convolutional layer is 56, 183, 133, 114, and 89. We find that the convolutional layers closer to the classifier have a higher proportion of redundant filters, which is also associated with the larger number of filters contained.

To demonstrate the advantages of our proposed method QMIX_FP, we compare our pruning results with the HRank41, HBFP42 and DDPG_FP19. Additionally, we compare the method employing a randomized evaluation criterion, i.e., the importance score of each filter is a random number. We adopt the same experimental setup as the baseline methods and focus on the pruning results of VGG-16 only.

  • HRank is an iterative filter pruning approach. The main idea is that low-rank feature maps contain less information, and the corresponding filters are considered relatively unimportant and can be pruned.

  • HBFP utilizes network training history for filter pruning iteratively. It prunes one redundant filter from each filter pair, ensuring minimal information loss during network training.

  • DDPG_FP is a filter-level pruning method based on Deep Deterministic Policy Gradient (DDPG), a single-agent reinforcement learning algorithm. The agent explores the pruning rate of one convolutional layer for each time step, and the pruned network is fine-tuned by retraining.

HRank and HBEP are traditional rule-based pruning methods, while DDPG_FP and our method QMIX_FP are automatic RL-based pruning methods. As shown in Table 4, the pruned network from RL-based methods exhibits lower model degradation than traditional methods. This is because the reinforcement learning algorithm can explore a larger pruning strategy space and continuously update the strategy according to reward feedback, effectively mitigating human subjectivity factors. Compared to DDPG_FP, QMIX_FP demonstrates more pronounced advantages on complex datasets, suggesting the robustness of the multi-agent reinforcement learning algorithm. In summary, the drop in prediction accuracy of QMIX_FP is lower than that of other methods on both the CIFAR-10 and CIFAR-100 datasets, which demonstrates the model’s stability.

Ablation study

We conduct ablation studies to analyze different components of the QMIX_FP further. This includes the QMIX algorithm for exploring optimal pruning strategy and the knowledge distillation technology for fine-tuning the pruned networks. Here, we report the results of VGG-16 on CIFAR-10 with prediction accuracy on the test dataset. Similar experimental conclusions can be identified in other deep convolutional neural networks and datasets.

Firstly, the contribution of the network pruning process based on QMIX is shown in Fig. 4. Our proposed pruning method is the “QMIX Prune,” while the “Iterative Prune” is a traditional rule-based pruning method. The horizontal axis in the figure represents the given pruning rate, and the vertical axis represents the test accuracy. The “Base” is the original prediction accuracy of the pre-trained network. Analyzing the model performance of the pruned network, we find that regardless of the pruning rate or the feature importance evaluate function, the test accuracy based on the QMIX prune method is higher than that of the traditional iterative prune method. Notably, the ordinary iterative method prunes a fixed number of filters, but the QMIX-based method tends to prune more filters than it. Here, the expected pruning rate is the criterion for determining whether an episode is terminated, whereas the real pruning rate is calculated based on a probability (action). It illustrates that the pruning strategy explored based on the QMIX algorithm is more effective as it prunes more filters while incurring smaller losses in test accuracy.

Fig. 3
figure 3

The changes in the number of filters within each convolutional layer of VGG-16 and AlexNet. In the bar chart, the gray bars represent the original number of filters in each layer, while the colored bars represent the number of filters in each layer of the pruned network.

Table 4 The comparative pruning results for VGG-16, where the Acc.\(\downarrow\) indicates the percentage decrease in prediction accuracy after pruning. The best pruning results are bolded.
Fig. 4
figure 4

The ablation study results on the QMIX algorithm.

Secondly, the contribution of the fine-tuning process based on knowledge distillation is shown in Fig. 5. When the temperature parameter T is frozen to 0 in Sect. “Fine-tuning,” it means that the S-Net cannot learn the generalization performance of the T-Net and instead improves model performance only through training itself. To verify the effectiveness of knowledge distillation, we compare our experimental results with the case of \(T=0.\) The “Only S_Net” indicates fine-tuning the pruned network by retraining, while “S_Net with KD” is ours in the figure. The horizontal axis represents the training epoch, and the vertical axis represents the test accuracy. It shows that the fine-tuning method based on knowledge distillation converges faster and achieves a test accuracy approximately 1% higher at convergence compared to the retraining-based method.

Fig. 5
figure 5

The ablation study results on knowledge distillation(KD).

Conclusions

We propose an automatic filter pruning approach grounded in reinforcement learning, distinct from conventional rule-based or single-agent reinforcement learning-based approaches. Our method, termed QMIX_FP, leverages the multi-agent reinforcement learning algorithm QMIX to discover the optimal pruning strategy autonomously. This innovation not only streamlines the pruning process but also enhances the efficiency and effectiveness of DCNNs across various applications.

The efficacy of the proposed method was validated by applying it to VGG-16 and AlexNet on the CIFAR-10 and CIFAR-100 datasets, achieving superior performance across multiple benchmarks. QMIX_FP stands out as a versatile framework for network pruning, adaptable to a wide range of DCNN architectures and datasets. Its flexibility lies in the ease of adjusting the number of agents and state transition functions to align with the specific structural characteristics of the network in question.

The societal impact of QMIX_FP is significant, particularly in resource-constrained environments such as mobile devices and edge computing platforms. By enabling more efficient and compact neural network models, our method contributes to the advancement of technologies that require real-time processing and minimal power consumption. This is particularly relevant for applications in healthcare, autonomous driving, and environmental monitoring, where the deployment of sophisticated AI models on portable or embedded systems can lead to advances.

Furthermore, our work builds on and complements existing research, such as MR-DCAE43, which uses manifold regularization in deep convolutional autoencoders for unauthorized broadcast identification and a real-time constellation image classification method for wireless communication signals based on the lightweight MobileViT network44. These studies highlight the growing trend of integrating advanced machine-learning techniques with specific domain knowledge to solve complex real-world problems. In line with this trend, future research directions for QMIX_FP include exploring the integration of reinforcement learning with other model compression techniques, such as parameter quantization and low-rank decomposition. In particular, we aim to investigate the optimal quantization precision for different convolutional features, using the QMIX algorithm to address mixed-precision parameter quantization challenges. Through these efforts, we aim to further improve the practical applicability and efficiency of neural network models in various fields.