Abstract
Accurate vulnerability prediction is crucial for identifying potential security risks in software, especially in the context of imbalanced and complex real-world datasets. Traditional methods, such as single-task learning and ensemble approaches, often struggle with these challenges, particularly in detecting rare but critical vulnerabilities. To address this, we propose the MTLPT: Multi-Task Learning with Position Encoding and Lightweight Transformer for Vulnerability Prediction, a novel multi-task learning framework that leverages custom lightweight Transformer blocks and position encoding layers to effectively capture long-range dependencies and complex patterns in source code. The MTLPT model improves sensitivity to rare vulnerabilities and incorporates a dynamic weight loss function to adjust for imbalanced data. Our experiments on real-world vulnerability datasets demonstrate that MTLPT outperforms traditional methods in key performance metrics such as recall, F1-score, AUC, and MCC. Ablation studies further validate the contributions of the lightweight Transformer blocks, position encoding layers, and dynamic weight loss function, confirming their role in enhancing the model’s predictive accuracy and efficiency.
Similar content being viewed by others
Introduction
With the rapid development of the Internet era, the Internet has greatly facilitated people’s lives, and computer software has permeated all aspects of life, but it also brings more potential security risks. According to the “Vulnerability and Threat Trends Report 2023”1 published by Skybox Security in 2023, the National Vulnerability Database (NVD) added 25,096 vulnerabilities in 2022, a surge of 25% compared to the 20,196 new vulnerabilities reported in 2021. Vulnerabilities are not only increasing, but the growth rate is also accelerating, with nearly 200,000 cumulative vulnerabilities, and most of the new vulnerabilities are medium and high-risk vulnerabilities. In addition to the growth in numbers, the forms of security vulnerabilities also present complexity and diversity, posing a great threat to the normal and secure operation of computer systems2.
Software vulnerabilities are an important factor leading to network security threats, and timely discovery and repair of vulnerabilities are key tasks to ensure software quality and security3,4,5. However, as the target customers of computer software become more and more extensive, the requirements are also becoming more and more complicated, which makes the scale and complexity of software increase day by day, and the difficulty of vulnerability mining and analysis is also increasing. In recent years, research on vulnerability prediction technology has made significant progress. Researchers use various machine learning algorithms to try to predict potential vulnerabilities from a large number of software attributes, such as using Decision Tree6, SVM7, and traditional machine learning techniques such as neural networks8. However, existing methods often face difficulties in identifying minority category vulnerabilities when dealing with highly unbalanced real-world vulnerability datasets, and traditional deep learning and machine learning methods cannot fully learn vulnerability features9. These minority category vulnerabilities are not many in number, but they are often the most dangerous.
Multi-Task Learning (MTL) is a method of machine learning that can simultaneously solve multiple related learning tasks. By utilizing the commonality and differences between tasks, it improves the learning efficiency and prediction accuracy of specific tasks10,11,12. Multi-task learning has a wide range of applications13 and research in the fields of computer vision, natural language processing, recommendation systems and so on. In the field of vulnerability mining, multi-task learning also has great potential, because vulnerability mining involves multiple sub-tasks, such as vulnerability identification, vulnerability classification, vulnerability exploitability determination, etc., there is a certain correlation between these sub-tasks, through multi-task learning, knowledge sharing and transfer between tasks can be realized, improving the effect of vulnerability mining14,15. Vulnerability prediction refers to the use of machine learning or data mining methods, based on the characteristics or historical data of the software, to build a prediction model, evaluate the current status or future trend of software component vulnerabilities, which can help software developers and maintainers to discover and repair vulnerabilities in a timely manner, improving the quality and safety of the software16. Vulnerability prediction is an important part of vulnerability mining, so multi-task learning can also be used to improve the ability of vulnerability prediction models17.
However, vulnerability prediction based on multi-task learning also faces some challenges, such as how to design a reasonable model structure, how to balance the weights of different tasks, and how to reduce the false positive and false negative rates of the model18, etc. In addition, multi-task learning models often require more parameters and computing resources, making the model lightweight and deployment difficult. Therefore, how to reasonably design the model structure, better adjust the task weight, and implement a lightweight vulnerability prediction method based on multi-task learning is a problem that needs to be solved at present.
To solve these problems, we propose MTLPT: a Multi-Task Learning framework with Position Encoding and Lightweight Transformer for Vulnerability Prediction, based on a dynamic weight allocation strategy. The MTLPT method can achieve lightweight deployment while maintaining high performance, and has excellent performance for unbalanced real-world vulnerability datasets.
This paper strictly evaluates our MTLPT method on a subset of the Draper19 dataset composed of 1.27 million pieces of data collected and annotated from the real world code after undersampling, and we solve the following three research questions:
RQ1: How is the performance of our MTLPT method based on multi-task learning for predicting the five most common categories of unbalanced real-world code vulnerabilities?
Result: MTLPT shows excellent learning ability by capturing the complex relationships and dependency patterns between the five most common vulnerability categories with unbalanced distribution in real-world data. Compared with traditional single-task learning models and ensemble learning strategies, MTLPT has achieved a 10% to 50% improvement in all performance indicators. In addition, while maintaining high predictive performance, MTLPT reduces the number of model parameters, achieving the goal of model lightweight. This not only improves the efficiency of model deployment, but also provides effective technical support for early vulnerability discovery and repair, demonstrating its application potential and practical value in the field of vulnerability prediction.
RQ2: Is the effect of the loss function based on dynamic weights in the multi-task framework of the MTLPT method good for alleviating the problem of unbalanced real-world vulnerability data?
Result: The design of the loss function based on dynamic weights takes into account the importance of different prediction tasks, allowing the MTLPT model to make adjustments on prediction tasks that are easily overlooked due to their small data volume, improving the predictive performance of these tasks. The introduction of this loss function not only alleviates the problem of data imbalance, but also ensures that the MTLPT model maintains excellent performance when performing other prediction tasks.
RQ3: What is the contribution of the components of our proposed MTLPT method to dealing with unbalanced data?
Result: Each component in the MTLPT method not only independently contributes to the performance of the model, but their synergistic effect significantly improves the performance on almost all evaluation indicators, especially when predicting unbalanced real-world vulnerability categories. This paper fully demonstrates the significant effects of the key components of the MTLPT model (multi-task learning framework, loss function \(L_t\) based on dynamic weights, and custom lightweight Transformer block and position encoding layer) in dealing with unbalanced data through ablation experiments.
To summarize, our key contributions are as follows:
-
(1)
MTLPT Framework for Vulnerability Prediction: We propose the MTLPT framework, which integrates multi-task learning with custom lightweight Transformer blocks and position encoding layers, achieving superior performance on imbalanced vulnerability prediction tasks, with up to 50% improvement over traditional methods.
-
(2)
Dynamic Weight Loss Function: We introduce a dynamic weight-based loss function to address the challenge of imbalanced data, enhancing the model’s ability to learn from underrepresented vulnerability categories without sacrificing overall performance.
-
(3)
Synergistic Effect of Model Components: Through extensive ablation experiments, we demonstrate that the combination of multitask learning, dynamic weighting, and custom Transformer blocks significantly improves predictive accuracy and efficiency in real-world vulnerability datasets.
The structure of the subsequent chapters of this paper is as follows: “Related works” Section mainly introduces the related research of deep learning in the field of vulnerability prediction research. “Method” Section introduces the multi-task learning (MTLPT) algorithm based on dynamic weight allocation strategy, custom lightweight Transformer block and position encoding layer. “Experiment” Section analyzes and experiments on a subset of unbalanced real-world vulnerability datasets. Part 5 discusses potential problems and challenges. Finally, “Conclusion” Section is the conclusion.
Related works
In the lifecycle of software development, vulnerability prediction plays a crucial role. The purpose of vulnerability prediction is to identify potential security threats early in the software development process20, reduce repair costs, and avoid potential risks. Over the past few decades, traditional methods have focused on using static code analysis and dynamic code analysis to identify potential vulnerabilities. Static analyzers, such as Clang21, can complete the analysis of the project without running the program. In contrast, dynamic analyzers identify potential weaknesses in various test inputs by repeatedly running the program on a real or simulated processor. Whether it is a static or dynamic analyzer, they are rule-based tools, which means that their analytical capabilities are limited by pre-set rules. Therefore, these analyzers cannot guarantee complete test coverage of the codebase. As the complexity of software increases, these methods are limited by high false positive rates and low true performance22, often generating a large number of false positives and not being able to understand the semantics of the code well. In addition, the identification of unbalanced real-world vulnerability data cannot fully exert its advantages due to the small amount of data. In recent years, with the vigorous development of deep learning and machine learning methods in various fields, researchers have begun to try to use these technologies to predict potential vulnerabilities in software23, which can predict future possible vulnerabilities by learning the vulnerability features in historical data.These deep learning and machine learning methods used in the field of vulnerability prediction can be categorized into two main types: (1) Single-Task Learning methods, (2) a combined category of Ensemble Learning and Multi-Task Learning methods.
Single task learning methods
Single Task Learning methods (STL) play an important role in the field of software vulnerability prediction. These methods usually focus on using a specific deep learning or machine learning technique to solve a specific prediction task. The basic idea of single task learning is: by focusing on optimizing a specific problem, higher performance can be achieved. Single task learning methods mainly cover a variety of different algorithms, such as: Long Short-Term Memory networks24 (LSTM), Convolutional Neural Networks25 (CNN), and Recurrent Neural Networks26 (RNN), each method improves the accuracy of a specific task by focusing on one type of data processing.
For example, LSTM24 uses its internal state (cell state) and gating mechanism to process and remember long-term dependency information, making it particularly suitable for processing time series data or any form of sequence data, such as software development logs or code execution paths. These capabilities make LSTM excel at predicting those vulnerabilities that depend on historical data. CNN25 processes and analyzes images or class image data structures through their special layer structure—convolutional layer, pooling layer, and fully connected layer. In the context of software vulnerability prediction, CNN can be applied to the structured representation of source code, such as viewing code blocks as image data to extract local features. The convolutional layer of CNN performs convolution operations on the input data through a sliding window (convolution kernel), effectively capturing local dependencies and pattern features, such as keywords, operators, and control structures in the code. These capabilities make CNN excel at processing data with clear local structural features, such as identifying specific vulnerability patterns in code snippets. RNN26 is a class of neural networks designed to process sequence data. Unlike traditional neural networks, RNN introduces loops in its architecture, allowing the network to use the output from the previous time point as part of the current input, making RNN particularly suitable for processing data with significant continuous dependencies. In software vulnerability prediction, RNN can be applied to analyze the execution path of software or development logs, in order to understand and predict the behavior of the code. The ability of RNN lies in its memory ability for each part of the input sequence, which makes it excel at predicting those vulnerabilities that depend on long-term historical information. For example, by analyzing the changes in a function across multiple versions, RNN can help identify vulnerabilities introduced by historical code modifications.
Recent research has also introduced new advancements in STL methods. Liu et al.27 propose an enhanced neural network framework designed to improve vulnerability prediction by focusing on fine-tuning for specific vulnerability types, especially rare ones that are often overlooked by traditional STL methods. While this approach achieves better performance for certain rare vulnerabilities, it remains limited in addressing broader class imbalance and often requires significant model-specific tuning. In another notable development, Zhang et al.28 introduce a semantic-aware STL model that integrates source code structure with higher-level semantic understanding, significantly improving vulnerability prediction accuracy, particularly for complex and interdependent vulnerabilities. However, despite these advancements, these methods do not effectively leverage correlations between different vulnerability types, which limits their capability in comprehensively learning from real-world datasets.
As a result, single task learning methods often struggle with highly imbalanced data distributions. These methods cannot effectively learn the features of minority classes, leading to insufficient identification of rare but dangerous vulnerability types. Although single task models may show high accuracy on majority classes, their performance on critical vulnerability types that pose significant risks remains suboptimal. Furthermore, the inability to utilize correlations between different vulnerabilities—such as the relationship between buffer overflow and integer overflow vulnerabilities—further limits their effectiveness. These limitations underscore the need for a more comprehensive approach, such as Multi-task Learning (MTL), which is better equipped to address these challenges through shared learning across tasks.
Ensemble learning methods and multi-task learning methods
Ensemble learning methods are a strategy in machine learning that improves the accuracy and stability of prediction, classification, or regression tasks by combining multiple learning algorithms. The basic idea of ensemble learning is: a group of “experts” (i.e., different models) collectively make decisions that are usually better than any single expert’s decision. Ensemble learning methods are mainly divided into two categories: Bagging29 (Bootstrap Aggregating) such as Random Forest30 (RF), Boosting31 such as: AdaBoost32, they all improve the overall accuracy by combining the prediction results of multiple simple models. For example, using multiple decision trees to vote to determine whether a code snippet may contain vulnerabilities31. Ensemble learning can improve the stability and accuracy of prediction in some cases, but when facing severe unbalanced data distribution, these methods are still easily affected by the majority class samples and ignore the minority class, i.e., those vulnerability types that may cause serious security risks33.
Multi-task learning (MTL) is a machine learning approach that enhances learning efficiency and model generalization by sharing knowledge across multiple related tasks. Unlike traditional single-task learning methods, MTL jointly optimizes the loss functions of multiple tasks, handling them simultaneously and effectively capturing inter-task correlations. This approach significantly improves model performance, especially when tasks share common information. In recent years, MTL methods have made significant progress in various domains.For example, Y. Xin et al. proposed MmAP34 method for cross-domain multi-task learning, which aligns text and visual modalities to enhance the model’s generalization ability across different tasks. Z. Chen et al. introduced MoE35, which flexibly optimizes the matching between tasks and experts, maximizing the positive relationships between tasks. This model achieved optimal results in visual multi-task learning. Additionally, J. Bonato et al. introduced the MIND36 algorithm, a multi-task network distillation method for class-incremental learning. By using parameter isolation and distillation techniques, MIND improved learning performance on new tasks while preserving memory of old tasks.Compared to these existing methods, our work employs a more flexible dynamic weight allocation strategy to address the class imbalance issue in vulnerability data. Furthermore, by customizing lightweight Transformer blocks and position encoding layers, we enhance the model’s ability to learn source code features, achieving lightweight deployment of the model.
Recent research in malware detection and software vulnerability prediction has explored a variety of techniques to address the challenges posed by unbalanced and unknown classes. While our work shares some concepts with several notable studies in the field, such as DPNSA37, and LTW38, and DMMal39, there are key differences that set our approach apart.
Chai et al.37 introduce the DPNSA method, which focuses on few-shot malware detection using dynamic convolution and sample adaptation techniques. This method aims to improve the model’s ability to detect unknown malware by dynamically adjusting its parameters based on the input samples. While this method introduces innovations in handling small sample sizes, it is specifically tailored to malware detection rather than software vulnerability prediction. In contrast, our proposed MTLPT method focuses on multi-task learning to address the issue of unbalanced vulnerability data in software, where multiple vulnerability types are learned simultaneously. Moreover, our approach incorporates a dynamic weight allocation strategy that adjusts for imbalanced data, improving the model’s ability to predict minority vulnerabilities—a challenge that DPNSA does not directly address.
Liu et al.38 introduce the LTW approach, a prominent technique aimed at mitigating negative transfer in multi-task learning. This method dynamically adjusts the weights of different tasks during training to control the influence of individual tasks, thus improving the overall performance of the multi-task model. The LTW framework has been shown to effectively reduce the impact of tasks that may negatively affect others, particularly in computational chemistry with 128 tasks. In contrast to LTW, which adjusts task weights based on task performance, our approach incorporates a dynamic weight allocation strategy tailored to handle imbalanced data in software vulnerability prediction. While LTW focuses on mitigating negative transfer by adjusting task weights, our MTLPT model specifically addresses the issue of class imbalance in real-world vulnerability data, allowing the model to focus more on rare but critical vulnerabilities while reducing the model’s complexity through lightweight design. Furthermore, our method integrates custom Transformer blocks and position encoding layers, which are not part of LTW’s framework, to enhance the model’s ability to learn the structural and semantic features of source code.
Chai et al.39 propose the DMMal method, which introduces improvements at both the data level (through multi-channel malware image generation) and the model level (through sharpness-aware minimization for better generalization). Although these techniques provide valuable insights for few-shot learning, DMMal’s focus is on malware classification using image data derived from malware samples. Our work, however, addresses a different challenge: software vulnerability prediction from source code data. We focus on multi-task learning, leveraging custom lightweight Transformer blocks and position encoding layers to capture the structural and semantic features of source code. Unlike DMMal, which synthesizes new images for training, our approach works directly with raw source code data and adjusts the learning process to handle multiple vulnerability types, balancing the model’s focus on rare vulnerabilities and reducing parameter complexity for better deployment.
In summary, existing methods for software vulnerability prediction, including STL and Ensemble Learning40, have made significant contributions. However, they face limitations in handling complex source code structures, imbalanced data, and the inability to capture relationships between different vulnerability types. Recent MTL approaches, such as DPNSA37, LTW38, and DMMal39, introduce improvements to tackle these issues, but they focus primarily on specific challenges like few-shot learning or mitigating negative transfer. Our proposed MTLPT method differentiates itself by employing a Multi-Task Learning framework to jointly predict multiple vulnerability types, addressing the imbalance in real-world data through a dynamic weight allocation strategy. Moreover, MTLPT incorporates custom lightweight Transformer blocks and position encoding layers, which allow the model to effectively capture the structural and semantic features of source code while reducing model complexity for lightweight deployment. This approach not only enhances the model’s generalization but also enables efficient handling of multiple vulnerabilities simultaneously, offering a more robust solution for real-world vulnerability prediction tasks.
Method
We propose a Multi-Task Learning Vulnerability Prediction Technique (MTLPT) based on position encoding and lightweight Transformer blocks, as shown in Fig. 1, to predict vulnerabilities in the given source code in conjunction with the five most common vulnerabilities in the real-world dataset. This technique combines a custom position encoding layer that can enhance the understanding of the structure and semantics of program code, and a lightweight Transformer block that captures long-distance dependencies and reduces model parameters to improve efficiency and performance. It can enhance the extraction of real-world source code vulnerability features, improve model generalization ability, address the problem of very large sample size differences between different categories of real-world source code vulnerabilities, alleviate the impact of unbalanced data, improve the accuracy of vulnerability prediction and reduce model parameters to achieve lightweight deployment, which is of positive significance for early discovery and repair of vulnerabilities in the real world.
Next, we will explain in detail our proposed new method, focusing on the model architecture, the backbone networks, and the MTLPT algorithm, respectively.
Model architecture
Traditional vulnerability prediction models are mostly STL models23, which typically focus on a specific task, leading to excellent performance on that task but an inability to generalize the learned features to other tasks. When multiple related tasks are present, STL models require training independent models for each task, resulting in redundant use of computational resources, increased storage needs, and low training and data utilization efficiency. Vulnerabilities are often related to multiple factors in the source code and share certain similar features across different vulnerabilities, making it difficult for STL models to fully learn vulnerability features, leading to low accuracy and high false positive rates. In contrast, MTL models can share parameters between multiple tasks to learn a general feature representation. They can utilize data from related tasks to assist tasks with less data, improving data utilization efficiency and offering high flexibility and scalability. Nevertheless, determining the weight of each task and the total loss function is challenging for MTL models. Inappropriate weight distribution can lead to neglect of certain tasks, thereby reducing performance4. The design and tuning of multi-task learning models are also crucial; they must consider the interactions between tasks. Applying a deeper or wider network architecture may better learn effective representations but also increases model complexity, parameter count, and reduces model efficiency. Vulnerabilities are data with complex features involving multiple dimensions of source code, including syntactic structure, semantic information, control flow, data flow, and developers’ coding habits. They are characterized by semantic complexity, high dimensionality, data sparsity, and imbalance32. Vulnerabilities in code also have complex dependencies; they may not only depend on local code snippets but also involve complex dependencies across modules and libraries41. A model needs to understand these complex dependencies, semantic features, contextual information, and mitigate issues of data sparsity and imbalance to accurately identify and locate vulnerabilities. As shown in Fig. 1 MTLPT Model Architecture Diagram, our proposed MTLPT model is a multi-task learning model that receives a sequence of source code tokens through the input layer. It then transforms discrete source code tokens into continuous embedding vectors through the embedding layer and custom position encoding layer while adding position information. This ensures that the model understands not only the meaning of code tokens but also their position in the source code; it then uses custom lightweight Transformer blocks applying multi-head attention mechanisms allowing the model to process multiple positions simultaneously. This captures relationships between different parts of the code which is important for understanding complex interactions such as variable usage and function calls in the source code. It accelerates training speed and improves model performance through layer normalization and residual connections. Additionally, it uses feedforward networks to provide extra learning capability to capture complex code patterns; it then adds a one-dimensional convolutional layer to further extract local features capturing local vulnerability sequence features such as specific vulnerability code patterns through global average pooling layers for output dimensionality reduction reducing parameter count; finally, through each task’s specific output layer predicting the top five vulnerability categories in real-world vulnerability datasets.
Our proposed MTLPT model is carefully designed through a multi-task learning framework to handle network structures with semantic complexity, high dimensionality, data sparsity, complex dependencies and mitigates imbalance issues in vulnerability data through shared knowledge between tasks offering certain flexibility and scalability. Below we will explain in detail the embedding Layer and positional encoding layer, custom lightweight Transformer blocks and CNN layer.
Embedding layer and positional encoding layer
The source code is highly structured, and its syntactic and structural characteristics determine the execution flow42. As shown in Fig. 2, a simple C language code segment, including the declaration and initialization of variable x (line 2), conditional judgment (line 3), and function call based on the conditional judgment (line 4), forms a logical block. The security of this block may be affected by the return value of the getUserInput function. Without considering position encoding, the model would struggle to understand the direct dependency between lines 2 and 3 due to the self-attention mechanism considering all tokens equally across the entire sequence. By incorporating a custom position encoding layer, the model can recognize that line 2 directly precedes line 3, thereby understanding the role of variable x in the conditional judgment and capturing syntactic and structural dependencies. Furthermore, in Fig. 2’s C language code segment, the semantics of line 4 are closely related to its position within the code; it resides within a processing logic for \(x<10\), which depends on the return value of getUserInput function from line 2. With the custom position encoding layer, the model can better comprehend the logical relationship between line 4 and preceding code without confusing it with line 7, which is logically distant. Our proposed custom position encoding provides a unique position code PE for each position \(\theta\) in the source code token sequence, where each position code PE is generated by combining sine values \(H_z\) and cosine values \(H_y\) for each position \(\theta\):
where \(d_i\) is the dimension index at position \(\theta\) in the embedding vector, even indices use formula (1) to calculate sine values, and odd indices use formula (2) to calculate cosine values, with \(h_m\) being the total dimension of embedding vectors. After obtaining sine values \(H_z\) and cosine values \(H_y\) for each position \(\theta\), we calculate an interleaved vector composed of sine and cosine values with a dimension of 2 \(\times h_m\). Assuming that sine values calculated for each position \(\theta\) are ( \(H_{z_0}\), \(H_{z_1}\),…,\(H_{z_{2d_i-1}}\) ), and cosine values are ( \(H_{y_{2d_i}}\),…, \(H_{y_{2h_m-1}}\) ), we can obtain a complete position encoding vector \(PE_\theta\):
For each position index \(\theta\) in MTLPT model, this position’s encoding vector \(PE_\theta\) will be added to its embedding vector to obtain an embedding vector with position information E. This enables MTLPT model to utilize position information for richer input representations, aiding in learning patterns and dependencies within source code token sequences, significantly enhancing performance for sequential tasks.
Custom lightweight transformer block
Since its introduction by Vaswani43 et al. in 2017, the Transformer model has become the mainstream model for processing sequential data, especially in the field of Natural Language Processing (NLP). It is characterized by high computational efficiency, the ability to capture relationships between any two elements, learning complex representations, and scalability. However, traditional Transformer models often have a large number of parameters and consume significant computational resources, which can become a bottleneck when dealing with large-scale or complex datasets such as real-world vulnerability data. To address these challenges and fully leverage the advantages of MTL models, we have defined a lightweight Transformer block specifically designed for source code vulnerability prediction. Our custom lightweight Transformer block aims to reduce the number of parameters and improve computational efficiency while retaining the core advantages and performance of Transformers. It includes a multi-head self-attention (MHA) layer with four heads to reduce computational complexity; a Dropout layer that randomly discards part of the features during training to prevent the MTLPT model from over-relying on specific vulnerability features in the training data, reducing overfitting; a Layer Normalization (LN) layer to maintain internal covariate stability and accelerate MTLPT model convergence; a feedforward network (FFN) consisting of two fully connected layers, a Dropout layer, and ReLU activation function to ensure that the same fully connected neural network is applied at each position \(\theta\), further enhancing the representational capabilities of the MTLPT model; and a residual connection (Add & Norm) to help alleviate gradient vanishing or explosion problems in deep networks, allowing information to flow through the network and ensuring effective learning within the MTLPT model. The output \(O_t\) of our designed lightweight Transformer block can be calculated using the following formula:
where \(O_{ffn}\) is the output of the FFN part after LN processing, which can be calculated using:
where FN is the FFN computation function, and \(O_{attn}\) is the output after attention mechanism and first residual connection and layer normalization:
where Q, K, and V are different representations of the same input obtained by applying different weight matrices to input data, MultiHead is the output of multi-head attention mechanism. MultiHead can be calculated as follows:
where W are linear projection matrices that map inputs to different representation spaces, and \(W^O\) is a matrix used for mapping after output concatenation.
By introducing our custom lightweight Transformer block, the MTLPT model can efficiently learn complex features related to source code vulnerabilities while maintaining high efficiency.
CNN
CNN25 is a significant model in the field of deep learning. By emulating the working principles of the human visual system, it can automatically and effectively learn complex feature representations from vast amounts of data and is widely used in various visual tasks. Although CNNs were initially designed for image data processing, they have also been successfully applied to text and language processing tasks, such as sentiment analysis, text classification, and language translation. CNNs are capable of capturing local dependencies and structural features within sentences or documents. Our research indicates their suitability for highly structured sequential data tasks like vulnerability prediction, and some studies have already achieved certain results8. We have defined a one-dimensional CNN layer with 256 filters and a kernel size of 3x3 after the lightweight Transformer block to capture local patterns from the Transformer layer’s output, enhancing the model’s ability to recognize structural patterns in source code, such as indentation patterns and bracket matching. This approach, which considers local neighborhoods, strengthens MTLPT’s perception of context and balances model performance and efficiency. Furthermore, by adding a global average pooling layer after the CNN layer, we reduce data dimensions and computational load. Such design enables the MTLPT model not only to learn and capture local and global patterns related to vulnerabilities in source code but also to improve efficiency and effectiveness through shared underlying feature representations. After processing through the one-dimensional CNN layer, the MTLPT model obtains a feature map rich in local and global information. Finally, decisions for tasks ( \(T_1,...,T_v\))(v=5) are made through five dense layers and a softmax layer, resulting in vulnerability prediction outcomes.
Assuming the sequence length is n and input dimension is d, since the multi-head attention mechanism within our custom lightweight Transformer block is the primary contributor to time complexity in our entire MTLPT model, we only need to consider the complexity of the Transformer block, which is \(O(n^2d)\), representing the time complexity of our entire model.
MTLPT algorithm
The MTLPT algorithm we propose is based on a multi-task learning framework combined with a custom position encoding layer, lightweight Transformer block, and CNN technology to learn the structure and semantics of program code from real-world vulnerability data, improve model generalization ability and reduce model parameters to improve model performance and efficiency, and alleviate the phenomenon of extreme imbalance of vulnerability categories in real-world vulnerability datasets.
Multi-task learning is a strategy in machine learning that can improve model generalization ability by learning multiple related tasks simultaneously with a given model, and force the model to learn cross-task general features by sharing parameters between different tasks, especially for some tasks with less data, it can help them gain a deeper understanding from tasks with larger data volumes, helping the model to better learn features from minority class samples4. The schematic diagram in Fig. 1 explains the principle diagram of the multi-task learning algorithm of MTLPT. We divided the five most common vulnerability types in the real-world unbalanced dataset Draper19: CWE-119, CWE-120, CWE-469, CWE-476, CWE-other into tasks (\(T_1\),...,\(T_v\))(v=5), for each task \(T_v\) there is a dataset \(D_{vul}\) = \(D_{train}\), \(D_{validate}\), \(D_{test}\), where the former is the training dataset, and the latter are the validation dataset and test dataset. These tasks obtain the parameters \(\theta _v\) of \(T_v\) through the ensemble learning preprocessed \(D_{vul}\), and the \(\theta _v\) parameters collected for task \(T_v\) will be affected by other tasks.
But there may be negative interactions between different tasks, especially when the optimal representations of these tasks differ greatly, it may lead to a significant reduction in model performance. In addition, the choice of weights for different tasks can also have a big impact on the MTL model, improper weight selection may cause the model to bias towards some tasks and ignore other tasks, causing some tasks to not be able to learn feature data well, reducing accuracy and generalization ability. In response to these challenges, we defined a loss function based on dynamic weight adjustment and a lightweight Transformer block based on attention mechanism , dynamically adjusting their weights in the total loss function according to the loss change rate of each task and using multi-head attention mechanism to let the model automatically learn the correlation between tasks, reducing the difference interference between different tasks. Assuming the number of tasks is N, the definition of the loss function \(L_t\) based on dynamic weights is as follows:
where \(L_i\) is the loss function for each task, and the corresponding weight is \(\delta _i\). Assuming the performance indicator of each task at the end of each iteration is \(M_i\), the definition of the corresponding weight \(\delta _i\) of each task’s loss function is as follows:
where S is a parameter used to control the smoothness of the weight distribution (S \(\in (0,5]\)), when S is larger, the weight distribution is smoother, when S is close to 0, the weight distribution tends to one task occupying the vast majority of weights.
The MTLPT algorithm has been continuously tuned and optimized, aiming to improve the accuracy of deep learning networks in real-world vulnerability prediction tasks. In the data preparation stage, the algorithm first loads the training set, validation set, and test set from the specified data file, then, these data go through multiple preprocessing steps, including: data cleaning, data cropping, tokenization, creating sequence matrices, sorting, and converting category labels to one-hot encoding format and reconstructing data, to fit the input requirements of the MTL model. In the model architecture stage, a carefully designed position encoding layer and lightweight Transformer block were added, along with convolutional and pooling layers, to improve the ability to capture dependencies, reduce model parameters, and prevent overfitting. And in the compilation process, a loss function \(L_t\) based on dynamic weights and Adam optimizer were used, further improving the performance of the model.
In the training phase of the MTLPT algorithm, we use a combination of training data, labels, and real-time validation data to ensure a thorough and comprehensive training process, and by using real-time validation data and a custom early stopping strategy, we ensure that the quality and reliability of the model are validated, improving the efficiency of model training. In the evaluation phase, the model is tested in all aspects, not only showing the good generalization ability of our proposed MTLPT model on real-world datasets, but also accurately measuring other key performance indicators such as F1-Score, Recall. Overall, the MTLPT algorithm shows excellent performance and capabilities in all aspects of data processing, model building, and performance evaluation.
Our implementation is summarized in the pseudocode given in Algorithm .
Experiment
In the experimental section, we conducted hyperparameter tuning for the proposed MTLPT model, performed comparative experiments with commonly used STL models and state-of-the-art approaches in vulnerability prediction, and carried out ablation studies in the end.
Experimental environment
The proposed framework was implemented using the Python 3.8 programming language and conducted related experiments on the Microsoft Windows 10 operating system. All experimental algorithms were computed using CPU with GPU acceleration. The specific parameters are displayed in Table 1.
Dataset
The publicly available Draper Dataset, compiled by Russell et al.19, considers the complexity and diversity of programs by capturing the granularity of subprograms’ overall flow at the function level through functional-level analysis of software packages. It is composed of a vast collection of useful function-level real code samples from millions of function-level C and C++ code examples compiled from SATE IV Juliet Test Suite, Debian Linux distributions, and public Git repositories on GitHub. The dataset contains 1.27 million entries, with a significant majority being non-vulnerability data, each CWE type accounting for less than 10% of the total. Therefore, addressing data imbalance is a priority for the real-world vulnerability dataset Draper. In this experiment, undersampling was employed to sample the original dataset. The distribution of data after sampling is shown in Table 2. We performed undersampling on data without CWE vulnerabilities from the original dataset, randomly selecting 10% of non-CWE vulnerability data to recombine with the original dataset containing CWE data, generating a subset after Draper undersampling. After sampling, the proportion of each vulnerability category in terms of total data volume increased: CWE-120 by approximately 20%, CWE-other by approximately 15%, CWE-119 by approximately 10%, CWE-476 by approximately 5%, and CWE-469 by approximately 1%.
As demonstrated in the scatter plot distribution of CWE category numbers in the original Draper dataset (a) and the subset after undersampling (b) shown in Fig. 3, it is evident that before sampling, the data of CWE categories were extremely imbalanced with very low proportions across all categories. After our undersampling treatment, the situation of data imbalance was greatly alleviated, but the issue still persists. For instance, the proportion of CWE-469 vulnerability category data remains very low, necessitating further measures to mitigate this imbalance. In this paper, we use the subset after undersampling divided into training dataset (80%), validation dataset (10%), and test dataset (10%) to train and evaluate the model.
Data preprocessing
In the data preprocessing phase, we adopted a series of refined data preprocessing steps to prepare and optimize the dataset, ensuring that the MTLPT model can effectively learn and predict source code vulnerabilities. This process is divided into three main stages: data cleaning, tokenization and serialization of text data, and preparation of target variables. The following will detail these steps and their application in data preprocessing.
Firstly, the dataset was cleaned by removing redundant data such as duplicate functions and garbled codes, and the non-numerical logical data in the training dataset, validation dataset, and test dataset were transformed into numerical data that the model can handle, thereby improving the consistency and processability of the data. Then, the source code data was tokenized: the text data was tokenized by word, that is, the text data was converted into a series of tokens for constructing numerical inputs that the model can understand, and a vocabulary table of maximum size \(L_{max}\) sorted by word frequency was created based on all source code texts in the training dataset. Afterwards, each dataset (training dataset, validation dataset, and test dataset) was transformed into an equal-length input matrix \((M, I_{max})\) according to the tokens of the text data, where M is the number of data, ensuring the consistency of model input. Finally, the feature columns in the dataset were converted into one-hot encoding format using one-hot encoding, and the number of categories was \(N_c\). Each category had a clear target vector representation \(V_1,...,V_c\) (c=5), further enhancing the expressive power of the data and laying the foundation for subsequent model training and evaluation.
Parameter settings
For our MTLPT method, we chose 10000 as the size of the vocabulary and fixed the length of the input samples at 500. We used the Adam optimizer with a learning rate of 0.001 for parameter updates. In terms of the MTLPT model architecture, we used a custom lightweight Transformer block with Dropout of 0.5, 4-head multi-head attention mechanism, an internal feed-forward network of dimension 512, an embedding layer and position encoding layer of dimension 128, and a one-dimensional convolutional layer with 256 filters and 3x3 size convolution kernels. For the hyperparameters of the baseline methods, we followed the best settings specified by the original authors.
To evaluate the effectiveness of the proposed MTLPT model, we compared it against both traditional machine learning methods and state-of-the-art approaches for vulnerability prediction. Random Forest (RF), a widely used ensemble learning method, was implemented with 100 estimators to ensure robust performance, serving as a benchmark for non-deep learning methods. For deep learning baselines, we employed CNN, RNN, and LSTM, which were implemented with consistent global hyperparameters, including a vocabulary size of 10,000, sequence length of 500, and a fixed random seed for reproducibility. The CNN utilized an embedding layer, followed by a convolutional layer with 512 filters, max pooling, and dense layers for output. The LSTM model featured a single LSTM layer with 128 units, followed by a dense layer for classification, while the RNN employed a SimpleRNN layer with the same structure. All models were trained using the Adam optimizer with a categorical cross-entropy loss function for 40 epochs. In addition, state-of-the-art methods, including AST + ML16, Code2Vec44 and Code2Vec MLP16, were included using performance data reported in the literature. These approaches provide a strong benchmark for comparing MTLPT’s ability to effectively predict vulnerabilities in real-world datasets.
Basic model comparison experiment
We trained the MTLPT model using our proposed loss function \(L_t\) based on dynamic weights and the undersampled Draper subset. As shown in Fig. 4, Task1 to Task5 are the training curves of the prediction tasks of the vulnerability categories CWE-119, CWE-120, CWE-469, CWE-476, CWE-other. In the first 15 iterations, the total loss of the proposed method dropped quickly and then became relatively stable, and the loss of each task also tended to stabilize within 15 times. This result demonstrates the adaptability of our learning model in the real-world imbalanced vulnerability dataset (RIV) environment.
-
(1)
For RQ1: How is the performance of our MTLPT method based on multi-task learning for predicting the five most common types of vulnerabilities in unbalanced real-world code?
In this paper, since RIV is an unbalanced dataset, we comprehensively evaluated the performance of our proposed MTLPT method for predicting real-world vulnerability data. To achieve this, we compared it against traditional, widely-used machine learning models (RF30) and deep learning models (LSTM24, RNN26, CNN25), as well as state-of-the-art approaches (AST + ML16, Code2Vec44, and Code2Vec + MLP16) in vulnerability prediction. The comparison was conducted on the vulnerability categories CWE-119, CWE-120, CWE-469, CWE-476, and CWE-other, using key indicators such as Recall, F1-score, Area Under the Curve (AUC), and Matthews Correlation Coefficient (MCC)45. Recall can measure the ability of the model to recognize all relevant instances, F1-Score provides a balance between precision and Recall, which is valuable in unbalanced datasets, AUC provides a general performance measurement standard for all classification thresholds, and MCC has a particularly large amount of information, because it will only give a high score when the prediction gets good results in all four confusion matrix categories (true positives, false negatives, true negatives, and false positives), which is beneficial for evaluating the performance of unbalanced datasets. Based on these four evaluation indicators, we can effectively evaluate the performance of our proposed method for predicting the five most common types of vulnerabilities in unbalanced real-world code. The specific evaluation results are listed and compared in Table 3.
As can be seen from Table 3, MTLPT outperforms other models on all indicators of different CWE category prediction tasks, including traditional, widely-used machine learning models (RF30), deep learning models (LSTM24, RNN26, CNN25), and state-of-the-art approaches (AST+ML16, Code2Vec44, Code2Vec+MLP16). For these five major categories of vulnerability prediction tasks, the recall rate of MTLPT is significantly higher than that of other methods. For example, in the CWE-119 vulnerability category prediction task, the recall rate of MTLPT reaches approximately 79%, while CNN achieves only 36.59%, a difference of more than 40%. Among state-of-the-art methods, Code2Vec+MLP achieves a recall rate of 87.3%, but MTLPT provides a more balanced performance across metrics, such as F1-score and MCC, showcasing its robustness and adaptability.
In addition to recall, MTLPT surpasses other models in terms of F1-score and AUC. For instance, in the CWE-120 vulnerability category prediction task, MTLPT achieves an F1-score approximately 43% higher than CNN and an AUC of 90%, highlighting its ability to maintain high classification performance under varying thresholds. In terms of the MCC indicator, MTLPT also demonstrates superior performance. For example, in the CWE-469 vulnerability category, which accounts for only 1.3% of the dataset, MTLPT achieves an MCC of about 30%, outperforming all other baseline methods, including state-of-the-art approaches.
RF represents traditional machine learning models, which are effective in handling moderately imbalanced datasets due to their ensemble learning capabilities. However, RF struggles with high-dimensional and unstructured data, such as source code representations, resulting in low recall, F1-score, and MCC values in all vulnerability categories.
Deep learning models, including LSTM, RNN, and CNN, improve upon traditional methods by leveraging their ability to learn sequential and spatial patterns. LSTM and RNN are capable of capturing temporal dependencies, while CNN is effective in extracting local spatial features through convolutional layers. Despite these advantages, these models cannot share information across tasks, limiting their effectiveness on highly imbalanced datasets. For instance, in the CWE-469 category, LSTM, RNN, and CNN fail to detect any vulnerabilities, indicating their inability to leverage correlations from other tasks or effectively learn features of rare classes.
State-of-the-art approaches, such as AST+ML and Code2Vec, focus on structural and semantic code analysis. AST+ML uses abstract syntax trees to extract structural features, while Code2Vec encodes syntactic and semantic information into vector representations for classification tasks. Code2Vec+MLP enhances Code2Vec by introducing additional layers to improve feature learning. However, these methods still struggle with extreme data imbalance, as shown in the CWE-469 category, where they fail to achieve significant recall or MCC. MTLPT, by contrast, leverages a multi-task learning framework, which enables it to capture correlations across vulnerability categories and perform well even on highly imbalanced datasets.
In the CWE-469 vulnerability category, despite its extreme data imbalance (only 1.3% of the dataset), MTLPT effectively learns its features and achieves high recall and MCC, whereas single-task deep learning models and state-of-the-art approaches fail to perform well. This is because MTLPT integrates a custom lightweight Transformer block and position encoding layer, allowing it to better capture contextual and structural information from the data. Additionally, by sharing information across tasks, MTLPT alleviates the class imbalance issue, improving its ability to predict rare vulnerabilities.
In summary, MTLPT demonstrates strong robustness and adaptability to varying data distributions and outperforms all baseline models. Traditional machine learning models, such as RF30, lack the capability to process high-dimensional and unstructured data effectively. Deep learning models (LSTM24, RNN26, CNN25) show improvements in learning sequential and spatial patterns but struggle with extreme class imbalance due to their single-task nature. State-of-the-art approaches, including AST+ML16 and Code2Vec+MLP16, advance vulnerability prediction by leveraging structural and semantic information, yet they remain limited in handling underrepresented classes in imbalanced datasets. MTLPT addresses these shortcomings by integrating a custom lightweight Transformer block and position encoding layer into a multi-task learning framework, enabling it to capture both local and global dependencies, share information across tasks, and improve prediction accuracy for rare vulnerabilities. These advantages make MTLPT a powerful and generalizable solution for real-world vulnerability prediction tasks.
Figure 5 shows the confusion matrix of the five major vulnerability category prediction tasks in the Draper subset dataset. From Fig. 5, it can be seen that the MTLPT model achieves a balanced Precision and Recall rate while maintaining a high accuracy rate. MTLPT uses a multi-task learning framework and a loss function based on dynamic weight adjustment to effectively use the correlation between different tasks through the shared representation learning layer, enhancing the model’s prediction ability for different security vulnerability categories (CWE-119, CWE-120, CWE-469, CWE-476, CWE-other). Secondly, as shown in Fig. 5c, the confusion matrix of CWE-469, the MTLPT model has achieved significant results in reducing false positives (FP) and enhancing true negatives (TN), reducing the false alarm rate, and in Fig. 5d and e, the model shows a high TN value, indicating that MTLPT can effectively identify and exclude instances that do not contain specific vulnerabilities.
From Table 4, it can be seen that in the comparison of the total parameter amount of the basic models and our model that can effectively predict most CWE vulnerability category data, the parameter amount of our proposed MTLPT model is about 19% less than the CNN model and about 73% less than the RF model. This is because the MTLPT model uses a custom lightweight Transformer block, which reduces the parameter amount by using a lower dimension of the feed-forward network and the number of attention heads while maintaining performance, making the model more lightweight. It can be seen that the MTLPT model maintains lightness while achieving high performance, reducing the amount of parameters required during training, and improving the efficiency of the model.
In terms of inference time, the MTLPT model demonstrates nearly a 50% improvement compared to the RF model, showcasing its efficiency in handling complex data with reduced computational complexity. While the inference time is slightly higher than that of the CNN model, the MTLPT achieves a balanced trade-off between computational cost and performance through its custom lightweight Transformer blocks, enabling efficient processing of real-world vulnerability datasets.
The memory usage of the MTLPT model exhibits a remarkable advantage, consuming less than 1% of the memory required by the RF model and reducing memory usage by over 90% compared to the CNN model. This improvement is primarily attributed to the optimized design of the MTLPT model, which leverages reduced dimensions in multi-head attention mechanisms and efficient weight-sharing strategies, significantly lowering memory consumption without compromising predictive accuracy.
Result: In summary, MTLPT can capture complex patterns and dependencies between the five most common types of vulnerabilities on unbalanced real-world vulnerability datasets, fully learn vulnerability features, and achieve a performance that is 10% to 50% higher than traditional single-task learning and ensemble learning methods. State-of-the-art approaches, including AST+ML16 and Code2Vec+MLP16, advance vulnerability prediction by leveraging structural and semantic information, yet they remain limited in handling underrepresented classes in imbalanced datasets. MTLPT addresses these shortcomings by integrating a custom lightweight Transformer block and position encoding layer into a multi-task learning framework, enabling it to capture both local and global dependencies, share information across tasks, and improve prediction accuracy for rare vulnerabilities. Moreover, the MTLPT model not only maintains high predictive performance but also demonstrates substantial lightweight advantages, reducing the total parameters by approximately 19% compared to CNN and over 73% compared to RF. It achieves exceptional efficiency in inference time, reducing computational costs by nearly 50% compared to RF, and exhibits remarkable memory savings, using less than 1% of the memory required by RF and over 90% less than CNN. These features make the MTLPT model highly suitable for deployment in resource-constrained environments, such as embedded systems or large-scale vulnerability detection tasks. By achieving lightweight model deployment and improving efficiency while maintaining robust performance, the MTLPT model effectively addresses real-world vulnerability prediction challenges.
-
(2)
For RQ2: How effective is the dynamic weight-based loss function in the multi-task framework of the MTLPT method in alleviating the imbalance problem of real-world vulnerability data?
We trained and predicted the undersampled Draper subset after the same preprocessing with the MTLPT without the dynamic weight-based loss function \(L_t\) and our proposed MTLPT. As shown in Table 3, MTLPT for the CWE-469 vulnerability category prediction task is about 10% higher in the MCC indicator than MTLPT without the dynamic weight-based loss function \(L_t\), and MTLPT has a recall rate of about 58% and an F1-Score of 26%, which is far higher than MTLPT(without \(L_t\)) in comprehensive evaluation. The comparison results between MTLPT and MTLPT(without \(L_t\)) in Table 3 show that \(L_t\) plays an important role in enhancing model performance. The presence of \(L_t\) in MTLPT will always produce better scores in terms of F1-Score and MCC, which is crucial when dealing with highly unbalanced data as shown in Table 2. This is because our proposed dynamic weight-based loss function \(L_t\) can effectively guide the MTLPT model to pay more attention to data classes with insufficient representation, thereby alleviating the impact of data imbalance and maintaining the performance of other category prediction tasks.
Result: In general, the dynamic weight-based loss function \(L_t\) in MTLPT can effectively guide the MTLPT model to pay more attention to prediction tasks with less data and insufficient representation, effectively alleviate the imbalance problem of real-world vulnerability data while maintaining the performance of other prediction tasks.
Ablation experiment
-
(3)
For RQ3: What is the contribution of the components of our proposed MTLPT method to dealing with unbalanced data?
Ablation experiments are a research method that observes the impact on the final performance by removing some parts of the model, thereby explaining the importance and role of these components. In this section, we explore and evaluate the multi-task learning framework of the MTLPT model, the key component PT (custom lightweight Transformer block and position encoding layer), the dynamic weight-based loss function \(L_t\) in MTLPT, and the contribution to model performance and future research through ablation experiments.
We set the following four model configurations to compare on the same preprocessed dataset: (1) PT without MTL: a single-task learning model that only contains a custom lightweight Transformer block and position encoding layer. (2) MTLPT without PT: a basic MTL model that does not contain additional custom lightweight Transformer blocks and position encoding layers. (3) MTLPT without Dynamic weight: MTLPT model removes the dynamic weight strategy, that is, the dynamic weight-based loss function \(L_t\). (4) MTLPT model: our proposed MTL model based on a custom lightweight Transformer block and position encoding layer.
In Fig. 6, we compare the performance of the MTLPT model with PT without MTL. We observe significant improvements in all metrics for the model with a multi-task learning framework, except for the AUC. This indicates that our proposed multi-task learning framework enables the sharing of information across various vulnerability prediction tasks, enhancing the predictive ability for minority categories. Consequently, it improves the model’s generalization capability and adaptability to imbalanced data.
The comparison between the MTLPT model and MTLPT without PT reveals substantial improvements across all metrics. Particularly, the MTLPT model outperforms MTLPT without PT by 30.1% in Recall and 6.7% in F1-score. This highlights the effectiveness of our custom lightweight Transformer blocks and position encoding layer in enhancing the model’s understanding of sequential data. It also improves the recognition of vulnerability categories with limited data, which is crucial for handling imbalanced datasets. Specifically, the introduction of custom lightweight Transformer blocks and position encoding layers enhances the model’s ability to capture complex patterns in vulnerability data. Furthermore, the sequence information introduced by the position encoding layer further improves the model’s contextual understanding, effectively enhancing vulnerability prediction accuracy.
Additionally, comparing the MTLPT model with MTLPT without Dynamic Weight demonstrates that the inclusion of \(L_t\) results in better performance in MCC and F1-score, with improvements of 30.0% and 26.4%, respectively. However, there is a slight decrease in AUC and Recall. This suggests that \(L_t\) contributes to better balancing of the model’s performance across different vulnerability category prediction tasks, especially in highly imbalanced data environments. This observation is consistent with the results from the previous comparative experiments.
Result: The ablation experiments in this section confirm the importance of the MTLPT model’s multi-task learning framework, key components PT (custom lightweight Transformer blocks and position encoding layer), and \(L_t\) for handling imbalanced data. The synergistic effect of each component leads to near-optimal performance across all evaluation metrics, making this improvement highly effective and relevant for real-world vulnerability category prediction in imbalanced scenarios.
Discussion
In this paper, we propose the MTLPT method, and through comparative experiments and ablation study results, we demonstrate its outstanding performance in predicting the five most common real-world vulnerability types. Additionally, we highlight its contribution to addressing the issue of imbalanced data. Despite achieving significant performance improvements, we recognize several challenges in the field of vulnerability prediction:
-
(1)
Handling Imbalanced Data: While MTLPT has made progress in dealing with imbalanced data, the inherent class imbalance in vulnerability datasets remains a challenging problem. To further enhance model prediction accuracy, future work will explore various strategies, including resampling techniques and modified class weights.
-
(2)
Generalization to Diverse Vulnerabilities: Current vulnerability prediction methods often focus on specific types of vulnerabilities or employ a single generic model to handle multiple vulnerabilities. However, this approach may limit predictive performance and generalization when dealing with diverse and complex vulnerabilities32. Our MTLPT method focuses on the five most frequent vulnerability types in the real world, enhancing the model’s understanding of both differences and commonalities among vulnerabilities through multi-task learning. Nevertheless, the model’s ability to predict unlabeled vulnerability samples and unknown vulnerabilities remains limited. As a future research direction, we plan to introduce unsupervised learning and semi-supervised learning methods to better capture vulnerability patterns in unlabeled data. Additionally, combining more data sources and advanced feature extraction techniques will enrich the model’s learning capacity. This will enable MTLPT to adapt to all types of vulnerability prediction, including those not explicitly identified yet.
In summary, the MTLPT method not only improves predictive performance for known vulnerability types but also holds promise for broader applications in vulnerability defense. By continuously researching and optimizing our approach, we anticipate that MTLPT will become a significant force driving the development of vulnerability prediction technologies.
Conclusion
In this paper, our proposed MTLPT method demonstrates outstanding performance in predicting the five most common real-world vulnerability types within imbalanced datasets. Our approach combines the advantages of multi-task learning and defines a loss function called \(L_t\) based on dynamic weights. \(L_t\) allows parameter sharing across multiple tasks, learns task correlations, dynamically allocates task weights, and mitigates data imbalance issues. By introducing custom lightweight Transformer blocks and position encoding layers, MTLPT further extracts rich long-range dependencies and contextual information, deepening the learning of vulnerability features.
Comparative experiments between MTLPT and single-task learning models (such as LSTM24, RNN26, CNN25) and ensemble learning methods (such as RF30) reveal that MTLPT captures complex patterns and dependencies among the five most common vulnerability categories in real-world imbalanced datasets. It significantly outperforms traditional single-task learning and ensemble learning methods by 10% to 50%. State-of-the-art approaches, including AST+ML16 and Code2Vec+MLP16, advance vulnerability prediction by leveraging structural and semantic information, yet they remain limited in handling underrepresented classes in imbalanced datasets. MTLPT addresses these shortcomings by integrating a custom lightweight Transformer block and position encoding layer into a multi-task learning framework, enabling it to capture both local and global dependencies, share information across tasks, and improve prediction accuracy for rare vulnerabilities. Additionally, the MTLPT model maintains high performance while reducing model parameters, enabling lightweight deployment for effective vulnerability prediction across various real-world datasets.
Furthermore, the comparison between MTLPT and MTLPT without \(L_t\) demonstrates that our proposed dynamic weight-based loss function effectively guides the MTLPT model to focus more on prediction tasks with limited representative data. This effectively mitigates the data imbalance issue while maintaining performance on other prediction tasks. Our ablation study results emphasize the contributions of custom components: the multi-task learning framework, dynamic weight-based loss function (\(L_t\)), and custom lightweight Transformer blocks with position encoding layers. Overall, MTLPT’s joint learning across multiple tasks not only enhances performance on specific tasks but also deepens the model’s understanding of the overall data distribution. The design philosophy of this model and its performance in practical applications showcase innovation and optimization in multi-task learning within the field of machine learning, providing new perspectives and solutions for future research and applications.
Data availability
All data generated or analysed during this study are included in the supplementary information files.
References
Rosen, M.: Vulnerability and threat trends report 2023 (2023). https://www.skyboxsecurity.com/resources/report/vulnerability-threat-trends-report-2023/
Hanif, H., Nasir, M. H. N. M., Ab Razak, M. F., Firdaus, A. & Anuar, N. B. The rise of software vulnerability: Taxonomy of software vulnerabilities detection and machine learning approaches. J. Netw. Comput. Appl. 179, 103009 (2021).
Mim, R.S., Ahammed, T., Sakib, K.: Identifying vulnerable functions from source code using vulnerability reports (2023)
Zhang, Y. & Yang, Q. A survey on multi-task learning. IEEE Trans. Knowl. Data Eng. 34(12), 5586–5609 (2021).
Senanayake, J., Kalutarage, H., Al-Kadri, M. O., Petrovski, A. & Piras, L. Android source code vulnerability detection: A systematic literature review. ACM Comput. Surv. 55(9), 1–37 (2023).
Pochu, S. & Kathram, S. R. Applying machine learning techniques for early detection and prevention of software vulnerabilities. Multidiscip. Sci. J. 1(01), 1–7 (2021).
Yosifova, V., Tasheva, A., Trifonov, R.: Predicting vulnerability type in common vulnerabilities and exposures (cve) database with machine learning classifiers. In 2021 12th National Conference with International Participation (ELECTRONICA), 1–6 (IEEE, 2021).
Liu, K., Zhou, Y., Wang, Q., Zhu, X.: Vulnerability severity prediction with deep neural network. In 2019 5th International Conference on Big Data and Information Analytics (BigDIA), 114–119 (IEEE, 2019).
Aslan, Ö., Aktuğ, S. S., Ozkan-Okay, M., Yilmaz, A. A. & Akin, E. A comprehensive review of cyber security vulnerabilities, threats, attacks, and solutions. Electronics 12(6), 1333 (2023).
Caruana, R. Multitask learning. Mach. Learn. 28, 41–75 (1997).
Zhang, Y. & Yang, Q. An overview of multi-task learning. Natl. Sci. Rev. 5(1), 30–43 (2018).
Standley, T., Zamir, A., Chen, D., Guibas, L., Malik, J., Savarese, S.: Which tasks should be learned together in multi-task learning? In International Conference on Machine Learning, 9120–9132 (PMLR, 2020).
Thung, K.-H. & Wee, C.-Y. A brief review on multi-task learning. Multimed. Tools Appl. 77(22), 29705–29725 (2018).
Le, T.H.M., Hin, D., Croft, R., Babar, M.A.: Deepcva: Automated commit-level vulnerability assessment with deep multi-task learning. In 2021 36th IEEE/ACM International Conference on Automated Software Engineering (ASE), 717–729 (IEEE, 2021).
Huang, J., Zhou, K., Xiong, A. & Li, D. Smart contract vulnerability detection model based on multi-task learning. Sensors 22(5), 1829 (2022).
Bilgin, Z. et al. Vulnerability prediction from source code using machine learning. IEEE Access 8, 150672–150684 (2020).
Gong, X., Xing, Z., Li, X., Feng, Z., Han, Z.: Joint prediction of multiple vulnerability characteristics through multi-task learning. In 2019 24th International Conference on Engineering of Complex Computer Systems (ICECCS), 31–40 (IEEE, 2019).
Crawshaw, M.: Multi-task learning with deep neural networks: A survey. arXiv preprint arXiv:2009.09796 (2020)
Russell, R., Kim, L., Hamilton, L., Lazovich, T., Harer, J., Ozdemir, O., Ellingwood, P., McConley, M.: Automated vulnerability detection in source code using deep representation learning. In 2018 17th IEEE International Conference on Machine Learning and Applications (ICMLA), 757–762 (IEEE, 2018).
Liu, C., Chen, X., Li, X. & Xue, Y. Making vulnerability prediction more practical: Prediction, categorization, and localization. Inf. Softw. Technol. 171, 107458 (2024).
Fülöp, E., Pataki, N.: Temporal logic-driven symbolic execution with the clang static analyzer. In 2024 7th International Conference on Software and System Engineering (ICoSSE), 78–82 (IEEE, 2024).
Zhou, X. et al. Large language model for vulnerability detection and repair: Literature review and the road ahead. ACM Trans. Softw. Eng. Methodol. 34(5), 1–31 (2025).
Chakraborty, S., Krishna, R., Ding, Y. & Ray, B. Deep learning based vulnerability detection: Are we there yet?. IEEE Trans. Softw. Eng. 48(9), 3280–3296 (2021).
Bai, Y., Liu, L., Huang, Q. & Deng, J. A rapid vulnerability identification of open source software based on a two-way long-short-term memory network. Int. J. Comput. Sci. Math. 20(3), 243–258 (2024).
Zhang, X. & Wu, D. On the vulnerability of CNN classifiers in EEG-based BCIS. IEEE Trans. Neural Syst. Rehabil. Eng. 27(5), 814–825 (2019).
Zheng, J., Pang, J., Zhang, X., Zhou, X., Li, M., Wang, J.: Recurrent neural network based binary code vulnerability detection. In Proceedings of the 2019 2nd International Conference on Algorithms, Computing and Artificial Intelligence, 160–165 (2019).
Liu, Z., Tang, Z., Zhang, J., Xia, X., Yang, X.: Pre-training by predicting program dependencies for vulnerability analysis tasks. In Proceedings of the IEEE/ACM 46th International Conference on Software Engineering, 1–13 (2024).
Zhang, Y., Ma, J. & Jia, Y. MCAN: Multimodal cross-aware network for fake news detection by extracting semantic-physical feature consistency. J. Supercomput. 81(1), 1–36 (2025).
Jabeen, G. et al. Machine learning techniques for software vulnerability prediction: A comparative study. Appl. Intell. 52(15), 17614–17635 (2022).
Chernis, B., Verma, R.: Machine learning methods for software vulnerability detection. In Proceedings of the Fourth ACM International Workshop on Security and Privacy Analytics, 31–39 (2018).
Nong, Y., Fang, R., Yi, G., Zhao, K., Luo, X., Chen, F., Cai, H.: Vgx: Large-scale sample generation for boosting learning-based software vulnerability analyses. In Proceedings of the IEEE/ACM 46th International Conference on Software Engineering, 1–13 (2024).
Lomio, F., Iannone, E., De Lucia, A., Palomba, F. & Lenarduzzi, V. Just-in-time software vulnerability detection: Are we there yet?. J. Syst. Softw. 188, 111283 (2022).
Zhang, L. et al. A novel smart contract vulnerability detection method based on information graph and ensemble learning. Sensors 22(9), 3581 (2022).
Xin, Y., Du, J., Wang, Q., Yan, K., Ding, S.: Mmap: Multi-modal alignment prompt for cross-domain multi-task learning. In Proceedings of the AAAI Conference on Artificial Intelligence, vol. 38, 16076–16084 (2024).
Chen, Z., Shen, Y., Ding, M., Chen, Z., Zhao, H., Learned-Miller, E., Gan, C.: Mod-squad: Designing mixtures of experts as modular multi-task learners. In 2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 11828–11837. https://doi.org/10.1109/CVPR52729.2023.01138 (2023).
Bonato, J., Pelosin, F., Sabetta, L., Nicolosi, A.: Mind: Multi-task incremental network distillation. In Proceedings of the AAAI Conference on Artificial Intelligence, vol. 38, 11105–11113 (2024).
Chai, Y., Du, L., Qiu, J., Yin, L. & Tian, Z. Dynamic prototype network based on sample adaptation for few-shot malware detection. IEEE Trans. Knowl. Data Eng. 35(5), 4754–4766. https://doi.org/10.1109/TKDE.2022.3142820 (2023).
Liu, S., Liang, Y., Gitter, A.: Loss-balanced task weighting to reduce negative transfer in multi-task learning. In Proceedings of the AAAI Conference on Artificial Intelligence, vol. 33, 9977–9978 (2019).
Chai, Y. et al. From data and model levels: Improve the performance of few-shot malware classification. IEEE Trans. Netw. Serv. Manag. 19(4), 4248–4261. https://doi.org/10.1109/TNSM.2022.3200866 (2022).
Zeng, P., Lin, G., Pan, L., Tai, Y. & Zhang, J. Software vulnerability analysis and discovery using deep learning techniques: A survey. IEEE Access 8, 197158–197172 (2020).
Yang, Y. et al. Dlap: A deep learning augmented large language model prompting framework for software vulnerability detection. J. Syst. Softw. 219, 112234 (2025).
Liu, R. et al. Vul-LMGNNs: Fusing language models and online-distilled graph neural networks for code vulnerability detection. Inf. Fus. 115, 102748 (2025).
Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N., Kaiser, Ł., Polosukhin, I.: Attention is all you need. Advances in neural information processing systems 30 (2017)
Alon, U., Zilberstein, M., Levy, O., Yahav, E.: code2vec: Learning distributed representations of code. In Proceedings of the ACM on Programming Languages 3(POPL), 1–29 (2019).
Itaya, Y., Tamura, J., Hayashi, K. & Yamamoto, K. Asymptotic properties of Matthews correlation coefficient. Stat. Med. 44(1–2), 10303 (2025).
Acknowledgements
This work was supported in part by Guangdong Science and Technology Innovation Strategy Special Funds (pdjh2024b231) and Scientific Research Projects of Guangdong Provincial Education Department (No.2022ZDZX1015).
Author information
Authors and Affiliations
Contributions
L.L. was responsible for guiding the overall direction and methodology of the paper. Z.F.H. was in charge of experimental design and analysis of experimental data. G.M.C. handled the editing of figures and tables. T.F.C. were responsible for the overall layout of the paper, collecting references, and polishing work. L.L. provided overall guidance and methodology for the paper. Z.F.H. designed experiments and analyzed experimental data. G.M.C. handled figure and table editing. Z.F.H. and T.F.C. were responsible for the paper’s layout, reference collection, and polishing. C.Y.Z. contributed to the entire content. All authors reviewed the manuscript.
Corresponding author
Ethics declarations
Competing interests
We declare that we have no known competing financial interests or personal relationships that could have appeared to influence the work reported in this paper.
Additional information
Publisher’s note
Springer Nature remains neutral with regard to jurisdictional claims in published maps and institutional affiliations.
Rights and permissions
Open Access This article is licensed under a Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International License, which permits any non-commercial use, sharing, distribution and reproduction in any medium or format, as long as you give appropriate credit to the original author(s) and the source, provide a link to the Creative Commons licence, and indicate if you modified the licensed material. You do not have permission under this licence to share adapted material derived from this article or parts of it. The images or other third party material in this article are included in the article’s Creative Commons licence, unless indicated otherwise in a credit line to the material. If material is not included in the article’s Creative Commons licence and your intended use is not permitted by statutory regulation or exceeds the permitted use, you will need to obtain permission directly from the copyright holder. To view a copy of this licence, visit http://creativecommons.org/licenses/by-nc-nd/4.0/.
About this article
Cite this article
Liu, L., Hui, Z., Chen, G. et al. A lightweight transformer based multi task learning model with dynamic weight allocation for improved vulnerability prediction. Sci Rep 15, 28176 (2025). https://doi.org/10.1038/s41598-025-10650-6
Received:
Accepted:
Published:
Version of record:
DOI: https://doi.org/10.1038/s41598-025-10650-6









