Introduction

The global music streaming industry has experienced exponential growth over the past decade, driven by advancements in mobile connectivity, affordable subscription models, and the proliferation of personalised content recommendations. Recent industry analyses indicate that platforms like Spotify, Apple Music, and Amazon Music have achieved an estimated 25–30% compound annual growth rate (CAGR) in user adoption worldwide1. However, this expansion comes with a pressing challenge in customer churn. Industry reports indicate that up to 30–35% of new users abandon a streaming service within the first six months of subscription, primarily because of unsatisfactory personalization, limited engagement, or pricing concerns. This churn not only minimises the active user base but also significantly influences long-term revenue projections. Since acquiring a new customer is estimated to cost five to seven times more than retaining an existing one, churn prediction has emerged as a mission-critical function in the operational strategy of subscription-based platforms2.

To address this, several machine learning models have been examined. Gaddam and Kadali1 evaluated traditional classifiers such as Logistic Regression and Decision Trees for predicting churn based on static user profiles. Karwa et al.2 proposed an analytics pipeline linking churn detection to retention strategies using ensemble models. Joy et al.3 introduced a hybrid framework combining big data pipelines with explainable AI for enhanced interpretability. While these approaches have demonstrated moderate predictive success, they are restricted by their reliance on tabular data and failure to model relational dependencies among users, which may obscure complex behavioural dynamics.

This research aids in developing a graph-based deep learning model that influences both individual user features and community-based behavioural similarities to detect churn more efficiently. The scope comprises scalable processing of multi-source user data, graph construction based on shared user traits, and the incorporation of these representations using Graph Attention Networks (GAT) integrated with deep feature learning methods. The focus is not only to enhance predictive accuracy but also to maintain robustness in class-imbalanced scenarios commonly found in subscription platforms.

Streaming services, such as KKBox and Spotify, often encounter user behavior influenced by social or demographic context, which conventional models struggle to capture. By representing users as nodes in a graph and connecting them via shared characteristics such as city and registration method, it becomes possible to model latent peer influences. This graph-based formulation enables attention mechanisms to weight the importance of neighbouring behaviours, refining churn predictions beyond what isolated features can provide.

Our contribution

The main contribution of the Research is:

  • GAT-based relational learning with deep Multi-Layer Perceptrons to model both graph-structured and tabular data was developed for digital marketing trend analysis on an online subscription and view basis.

  • Constructed a synthetic user similarity graph that links users through shared traits to allow contextual learning through attention propagation for analyzing product and user interest in it.

  • Implemented a scalable big-data processing pipeline capable of aggregating behavioural features from 30GB + of user logs and integrating them into a graph-ready format.

  • We adopted a novel training strategy using weighted loss, which mitigates class imbalance (~ 9% churn rate) without data resampling, preserving distributional integrity.

The remaining sections in this article are arranged in the following sequence: Literature work is presented in Sect. 2. The proposed model is extensively discussed in Sect. 3 with the architecture, methodology description, and algorithms. Results are evaluated in Sect. 4 of the article, including comparisons and proofs. The research work is concluded in Sect. 5 of the article.

Literature survey

Adaptive Profit-Centric Churn Prediction Engine built on ensemble methods (XGBoost, Random Forest, Gradient Boosting)4, tightly integrated with economic measures such as profit and Customer Lifetime Value. It’s proposed to not only boost churn classification (with an accuracy of ~ 97%) but also directly optimize financial outcomes. The model captures customer interdependence by quantifying social linkages and co-usage patterns. It uses R² convergence curves to track stability and adaptability over time. However, it’s heavily limited by the need for high-quality economic data—uneven or noisy CLV estimates can undercut performance. Additionally, its complexity and dependence on customer network data may hinder use in privacy-sensitive or connectivity-limited environments. Systematic bibliometric study using Scopus data and VOSviewer tools5 to map AI’s application in customer retention. It proposes the identification of five key research themes—churn prediction, service experience, sentiment analysis, big-data analytics, and privacy—through term co-occurrence and bibliographic coupling. The proposed framework articulates 15 future research objectives, emphasizing novel directions such as ethics, privacy, and advanced AI techniques. It delivers insights on publication trends, authorship networks, influential journals, and geographic hotspots. However, the proposed model lacks evaluation of actual AI model performance or real-world deployment challenges. Focused solely on Scopus, it is limited by dataset bias and may not reflect the whole global landscape.

The SV-VAE model combines a Convolutional Neural Network with a modified Variational Autoencoder (VAE)6 to address high-dimensional and imbalanced churn datasets. It’s proposed that VAE compresses and denoises data before CNN classification, enhancing both accuracy and training speed. Evaluated on six benchmark datasets, it reports strong accuracy, precision, recall, F1-score, and reduced response time. The hybrid handles complex temporal patterns via latent feature capture and CNN feature extraction. However, the model requires significant computational resources, and tuning both the VAE and CNN adds additional overhead. Additionally, its complexity raises concerns about overfitting, particularly with limited training samples. Deep neural network7 tailored for telecom churn prediction, using multi-layer perceptron architectures with embedding layers for categorical variables. It’s proposed to detect complex nonlinear patterns in customer behavior with high predictive accuracy. The network optimizes class imbalance via weighted loss functions, but lacks explicit mention of SMOTE or other resampling methods. Evaluation on real-world telecom data shows improved recall but limited precision trade-offs. However, the proposed model is constrained by its opaque inner workings, which limit its explainability for business stakeholders. Additionally, it faces issues of overfitting in small datasets without regularisation.

Random Forest with deep learning (MLP/CNN)8, integrated into a single pipeline for telecom churn prediction. It’s proposed to extract engineered features first, then apply deep nets for final classification. The hybrid approach reportedly enhances accuracy and offers modular adaptability across telecom datasets. However, the need to manage two distinct pipelines increases the complexity of implementation and training time. It also relies heavily on feature engineering quality, limiting generalization to new contexts. Finally, the dual-stage process may suffer from error propagation between the feature and deep-learning stages. BiLSTM-CNN hybrid9 to capture sequential and spatial churn signals from telecom usage logs. It’s proposed that LSTM layers extract temporal trends while CNN layers learn hierarchical patterns. Tested on benchmark telco datasets, the composite architecture reaches ~ 81% accuracy. It balances context capture and feature abstraction more effectively than standalone architectures. However, the proposed solution requires extensive hyperparameter tuning, and training can be computationally intensive. It may also underperform when the sequence length is irregular or the dataset lacks timestamped behaviour.

Clustering (K-means, hierarchical) with classification (Gradient Boosting, SVM, RF)10 is proposed to first group similar customers, then apply tailored classifiers, thus improving adaptability and prediction accuracy. The approach better handles heterogeneity and high dimensionality in telecom datasets. However, selecting the optimal cluster count is tricky and can bias classifiers. It also increases pipeline complexity and training time. Furthermore, clustering errors may propagate, affecting the accuracy of the classification stage. Archimedes Optimisation Algorithm (AOA)11 for feature selection, followed by CNN-Autoencoder deep models for churn prediction. It’s proposed to optimally select informative features from high-dimensional telecom data and then predict churn via a CNN-AE hybrid. The optimisation pre-processing reduces dimensionality and improves accuracy, while AE ensures the robustness of latent representation. However, AOA feature selection is computationally costly and may trap in local optima. CNN-AE depth increases training time and complexity. The combined framework may be overkill for simpler datasets and lacks clear benchmarks against lighter models.

Decision Tree, Random Forest, and XGBoost12 were integrated via hyperparameter tuning (grid search) for feature selection in telecom churn prediction. It’s proposed to enhance interpretability while maintaining ensemble strength in handling imbalances and high dimensionality. The optimized models enable proactive retention strategies with scalable training. Nonetheless, the proposal still depends on clean, feature-engineered inputs—it’s not an end-to-end solution. The interpretability gains from DT may be lost in tuned ensembles. Additionally, focusing on tree models may miss potential improvements from deep learning architectures. Random Forest classification13 proposed to incorporate textual sentiment scores as churn predictors, capturing emotional drivers of attrition. This enables churn detection beyond transactional data, using qualitative customer signals. However, reliance on sentiment accuracy makes it vulnerable to sarcasm, multilingual context, or misclassifications. It may perform poorly when text data is limited or unstructured (Survey of existing work provided in Table 1).

Table 1 Survey of existing work.

In the context of mobile shopping, social crowding can induce anxiety and cause distractions, and consequently, diminishes the ability of shoppers to easily process the information on their mobile devices24,25. The challenges posed by high cognitive load during information processing are mitigated by learning compact, task-relevant prompts that retain all essential components of a task. The model generalizes better across varied situations by removing irrelevant information and noise specific to a domain26,27. In the case of panel data analysis, the estimation of latent group structures is done to identify and understand hidden subgroup variances which allows better and more relaxed modeling of complex and heterogeneous data28,29. In the same fashion, the prediction-aware network in sentiment analysis around which a model is built you constantly modifies itself based on in-between outcomes thereby increasing the precision of the interlinks of aspects, opinions and sentiments within the text30.

Furthermore, Random Forest may not optimally handle nuanced sentiment patterns compared to deep NLP models.

Proposed methodology

The overall workflow begins with data collection from the WSDM-KKBox churn dataset, which provides user logs, transactions, and demographics. Next, feature engineering is performed to derive static, temporal, and behavioural features from multi-source inputs31. The dataset then undergoes cleaning and standardisation to handle missing values and scale features for deep learning32. A Hybrid GAT + MLP model is trained to learn from both user similarities (through a graph) and individual tabular patterns38,39. Class imbalance is addressed using a weighted binary cross-entropy loss, which highlights the minority churn class. Finally, the model performs churn classification and evaluation using metrics like AUC and accuracy. Figure 1 shows the overall architecture of the proposed framework.

Fig. 1
figure 1

Proposed hybrid churn detection workflow.

Dataset description

The dataset used in this study is sourced from the WSDM-KKBox Churn Prediction Challenge and provides a large-scale, multi-relational view of user behavior in a music streaming platform. The \(\:transactions\_v2.csv\) which captures detailed subscription and billing history including transaction dates, payment plans, auto-renewal, and cancellation status as shown distribution in Fig. 2 and \(\:user\_logs\_v2.csv\) a time-series file exceeding 30GB that logs daily listening activity per user for three years, including total seconds played as shown distribution in Fig. 2 and number of unique tracks. The \(\:members\_v3.csv\) which contains user demographics such as city, gender, registration method, and registration timestamp as shown gender distribution in Fig. 3 among city. It consists of four core CSV tables such as \(\:train\_v2.csv\) which includes the user identifier (msno) and the binary target label (is_churn) as shown in Fig. 4 pie chart. Together, these files provide both static and temporal insights into user engagement. The final merged dataset comprises 970,960 users, resulting in a rich feature set through joins across all tables.

Fig. 2
figure 2

Log CSV statistical data distribution.

Fig. 3
figure 3

Member csv statistical data distribution.

Fig. 4
figure 4

Transaction csv statistical data distribution and pie chart –train CSV data analysis.

Data preparation

To prepare the data for churn prediction, an extensive preprocessing pipeline was implemented to aggregate and merge multiple relational tables into a unified user-level dataset. First, transaction data from \(\:transactions\_v2.csv\) was aggregated per user (\(\:msno\)) to extract key features such as the total number of transactions (\(\:total\_transactions\)), cumulative payment amount (\(\:total\_payment\)), number of cancellations (\(\:is\_cancel\_sum\)), and the most recent transaction date (\(\:last\_transaction\_date\)). Next, due to the large size of \(\:user\_logs\_v2.csv\) (over 30GB), chunk-wise processing was implemented, reading the file in segments of one million rows to manage memory usage efficiently. Behavioral features were computed for each user, comprising the total number of active days (\(\:log\_days\)), total listening time in seconds (\(\:total\_secs\_sum\)), and the number of unique songs played (\(\:total\_songs\_played\)). These features provided a condensed yet informative summary of user engagement across time. Finally, the aggregated transaction and log data were integrated with demographic attributes from \(\:members\_v3.csv\) and the churn labels from \(\:train\_v2.csv\), resulting in a comprehensive feature matrix representing 970,960 users. This merged dataset integrated static, transactional, and dynamic behavioral data, forming the foundation for model training.

Feature engineering

To improve model performance and extract meaningful patterns from raw data, a series of feature engineering techniques were performed to derive temporal and categorical attributes. Time-based features were constructed by leveraging both registration and transaction timestamps. Specifically, the membership duration (\(\:membership\_days\)) was estimated as the difference between the user’s most recent transaction date and their initial registration date as mentioned in Eq. (1).

$$\:{membership\_days}_{i}={last\_transaction\_date}_{i}-{registration\_init\_time}_{i}$$
(1)

where \(\:i\) refers to the user index, and dates are transformed into standard datetime formats. Additionally, the year and month of registration were extracted from \(\:registration\_init\_time\) to generate two discrete features such as \(\:registration\_year\) and \(\:registration\_month\). This help capture seasonal or yearly patterns in user onboarding behaviour. For handling categorical variables with low cardinality like city, gender, and \(\:registered\_via\) label encoding was implemented. Every unique category was mapped to a corresponding integer, preserving ordinal relationships where appropriate. Let \(\:C\) be a categorical variable with \(\:n\) unique categories. Label encoding is defined using Eq. (2).

$$\:Encoded\left({C}_{i}\right)=j\:\:where\:{C}_{i}={category}_{j},\:\:\:\:j\in\:\left\{\text{0,1},\dots\:,n-1\right\}$$
(2)

This transformation allows categorical inputs to be used effectively within neural network models without inflating dimensionality, as would occur with one-hot encoding. Together, these engineered features enriched the dataset with temporal and structural insights necessary for churn prediction. The Fig. 2 illustrates the distributions of four critical engineered features extracted from the user activity, transaction, and churn datasets used in the model. Each subplot reveals insights into user behaviour patterns relevant to churn prediction. This Fig. 5 (a) shows the distribution of the \(\:membership\_days\) feature, which represents the total duration (in days) from user registration to the last transaction. The distribution is right-skewed, showing many users with relatively short platform tenure, while only a minority have long-standing membership histories. Several spikes suggest user acquisition in bursts, possibly driven by promotional campaigns or product updates. Figure 5 (b) indicates the log-transformed distribution of \(\:total\_secs\_sum\), which captures the total listening time per user. The log transformation was performed to handle extreme outliers and long-tail behavior. The distribution resembles a normal curve post-transformation, which is ideal for input into deep learning models, as it stabilizes training and prevents high-variance features from dominating. The Fig. 5 (c) depicts the frequency of \(\:is\_cancel\_sum\), reflecting the number of times a user has canceled their subscription. Most users have never canceled (value 0), while a small portion have canceled once. Values beyond 4 were capped to minimize noise and sparsity in higher bins. This categorical-type feature supports to identify users prone to churn recurrence. The Fig. 5(d) histogram illustrates the distribution of \(\:total\_transactions\), showing the total number of subscription-related transactions made by each user. The extreme left-skew describes that most users have made very few transactions, suggesting short-term trial usage. The presence of outliers on the far right (users with > 100 transactions) may show long-term or recurring subscribers, valuable for churn resistance modelling.

Fig. 5
figure 5

Distribution plots of key engineered features used for churn prediction: (a) membership duration in days, (b) log-transformed total listening time, (c) capped subscription cancellation count, and (d) total subscription transactions.

Data standardization

To assure the integrity and compatibility of the dataset for deep learning models, systematic data cleaning and feature scaling procedures were performed for extracted features. Missing values (\(\:NaN\)) in the merged dataset, primarily raising from users with incomplete transaction or log histories, were handled by imputation. All such missing entries were replaced with zero, under the assumption that the absence of activity shows no engagement or payment behavior during the observation period. This approach-maintained dataset completeness without introducing bias through arbitrary statistical estimates. Subsequently, all numerical features were standardized by Z-score normalization through the \(\:StandardScaler\) utility from the Scikit-learn library. This method transforms each feature \(\:x\) according to the formula given by Eq. (3).

$$\:{x}_{scaled}=\frac{x-\mu\:}{\sigma\:}$$
(3)

where \(\:\mu\:\) refers to the mean and \(\:\sigma\:\) is the standard deviation of the feature over all training instances. This transformation ensures that all features have zero mean and unit variance, which is crucial for stabilizing the learning dynamics in deep models particularly those involving gradient-based optimization such as Graph Attention Networks (GAT). Proper scaling prevents certain features from dominating due to scale disparity and improves the convergence behavior of the model.

As visualized in Fig. 6, the application of \(\:StandardScaler\) significantly alters the distribution of the \(\:membership\_days\) feature. The original values, shown in blue, span a wide range up to approximately 5000 days, directing to severe skewness and variance. After scaling (shown in red), the feature is tightly concentrated around zero with a normalized spread, indicating successful mean-centering and variance reduction. This transformation eliminates the impact of magnitude differences and prepares the feature for balanced learning across all input dimensions.

Fig. 6
figure 6

Impact of Z-score standardisation on the \(\:membership\_days\) feature.

Proposed GAT + MLP hybrid model

To efficiently model both the individual behavioural attributes of users and their latent group-based interactions, we present a hybrid GAT + MLP model. This architecture integrates the structural learning capabilities of Graph Attention Networks (GAT) with the expressive power of a Multi-Layer Perceptron (MLP), allowing the system to learn from both node-level features and their graph-based relationships. The model is designed to mitigate class imbalance, sparse graph connectivity, and nonlinear feature interactions commonly encountered in real-world churn datasets. We hypothesise that users who share similar demographics or onboarding behaviour exhibit correlated churn tendencies. To capture these relationships, a synthetic user similarity graph \(\:G=(V,E)\) is constructed: Nodes (\(\:{v}_{i}\in\:V\)) represent individual users and Edges (\(\:\left({v}_{i},{v}_{j}\right)\in\:E\)) as mentioned in Eq. (4).

$$\:\left({v}_{i},{v}_{j}\right)\in\:E\:\leftrightarrow\:\:{city}_{i}={city}_{j}\wedge\:\:{registere{d}_{via}}_{i}={registered\_via}_{j}$$
(4)

This results in localised clusters of similar users, allowing message passing across user neighbourhoods and enabling the model to exploit contextual patterns in churn behaviour. The proposed GAT + MLP model follows a dual-path architecture that learns graph-based embeddings in parallel with deep tabular feature representations and then fuses them for final classification. This design enables the model to extract both topological and semantic signals that influence churn behaviour. Each user node \(\:i\) with feature vector \(\:{h}_{i}\in\:{R}^{d}\) aggregates information from its neighbours \(\:j\in\:N\left(i\right)\) using attention-weighted message passing. The attention coefficient between node \(\:i\) and neighbour \(\:j\) is defined using Eqs. (5) and (6).

$$\:{e}_{ij}=LeakyRelu\left({a}^{\text{T}\:}\left[W{h}_{i} \left\| \right. W{h}_{j}\right]\right)$$
(5)
$$\:{\alpha\:}_{ij}=\frac{exp\left({e}_{ij}\right)}{{\sum\:}_{k\in\:N\left(i\right)}exp\left({e}_{ik}\right)}$$
(6)

The updated embedding of node \(\:i\) is then computed using the weighted aggregation of its neighbours given by Eq. (7).

$$\:{z}_{i}=\sigma\:\left(\sum\:_{j \in N\left(i\right)}{\alpha\:}_{ij}W{h}_{j}\right)$$
(7)

To enhance expressiveness, we employ multi-head attention with K = 4 heads as described in Eq. (8).

$$\:{z}_{i}={\sum\:}_{k=1}^{K}\sigma\:\left(\sum\:_{j\in\:N\left(i\right)}{\alpha\:}_{ij}^{\left(k\right)}{W}^{\left(k\right)}{h}_{j}\right)$$
(8)

Dropout with a rate of 0.6 is applied between GAT layers to prevent overfitting on sparse graph neighbourhoods. In parallel, each user’s 15-dimensional tabular feature vector \(\:{x}_{i}\) is processed through a deep MLP to capture nonlinear interactions among features that may not be evident in the graph topology. The MLP operations are as follows by Eq. (9).

$$\:{t}_{1}=ReLU\left({BN}_{1}\left({W}_{1}{x}_{i}+{b}_{1}\right)\right)$$
(9)
$$\:{t}_{2}=Dropout\left(ReLU\left({BN}_{2}\left({W}_{2}{t}_{1}+{b}_{2}\right)\right)\right)$$
(10)

This captures nonlinear interactions among original user-level features that are not captured through the graph. The final user representation is formed by concatenating the GAT embedding \(\:{z}_{i}\in\:{R}^{32}\) with the MLP-transformed tabular vector \(\:{t}_{2}\in\:{R}^{15}\) as given in Eq. (11).

$$\:{u}_{i}=\left[{z}_{i}\left\| \right. {t}_{2}\right]\in\:{R}^{47}$$
(11)

This hybrid representation is passed through a final classification head as described in Eq. (12).

$$\:{\widehat{y}}_{i}=Sigmoid\left({W}_{3} \cdot \:Dropout\left(ReLU\left({u}_{i}\right)\right)+{b}_{3}\right)$$
(12)

where \(\:{\widehat{y}}_{i}\) is the predicted probability of user \(\:i\:\)churning. Figure 7 completes architectural design of the proposed Hybrid GAT + MLP model, combining structural, behavioural, and demographic learning pathways for robust churn classification.

Fig. 7
figure 7

Detailed architecture of the proposed hybrid model.

The churn prediction dataset exhibits a significant class imbalance, with approximately 9% of users labelled as churners and 91% as non-churners. To address this imbalance without altering the real-world data distribution through oversampling or undersampling techniques, we implement a weighted binary cross-entropy loss function given by Eq. (13).

$$\:L=-{w}_{1}ylog\left(\widehat{y}\right)-{w}_{0}\left(1-y\right)\text{l}\text{o}\text{g}(1-\widehat{y})$$
(13)

Where, \(\:y\in\:\left\{\text{0,1}\right\}\) is the ground truth label, \(\:\widehat{y}\in\:\left(\text{0,1}\right)\) is the predicted probability from the model, \(\:{w}_{1}=10\) is the weight assigned to the positive (churn) class, \(\:{w}_{0}=1\) is the weight for the negative (non-churn) class. This weighting forces the model to pay greater attention to correctly predicting the minority churn class while maintaining the integrity of the original data distribution. Algorithm 1 shows proposed churn detection process in detail. Table 2 gives details on hyperparameter used for better performance.

Algorithm 1
figure a

The proposed churn detection algorithm.

Table 2 Hyperparameter table for proposed model.

Results and discussion

Experimental setup

To ensure reproducibility and take advantage of GPU acceleration, all experiments were conducted using Google Colaboratory with a high-RAM, GPU-enabled runtime. The environment included an \(\:NVIDIA\:Tesla\:T4\:GPU\:\left(16GB\:VRAM\right)\), Intel Xeon \(\:CPU\:(\sim2.2GHz)\), and \(\:25.5\:GB\:RAM\), running \(\:Python\:\text{3.10.12}\) on \(\:Ubuntu\:20.04\). The proposed Hybrid GAT + MLP model was implemented using \(\:PyTorch\:2.0\), PyTorch Geometric 2.4.0, and supporting libraries like Scikit-learn, Pandas, and NumPy. PyTorch Geometric dependencies were installed using Colab-compatible wheels. The model was trained using Adam optimizer with a learning rate \(\:\eta\:=0.005\), weight decay \(\:{10}^{-4}\), batch size of 512, and early stopping based on validation AUC. Class imbalance was addressed using weighted cross-entropy loss (\(\:{w}_{churn}=10,\:{w}_{not\:churn}=1\)), with performance evaluated using AUC, accuracy, and F1-score. Table 3 shows experimental setup in detail.

Table 3 Experimental setup details.

Evaluation metrics

The various Quantitative metrics like precision, RoC, recall, f1 score and accuracy are used to evaluate the metrics in detail. Precision is described as correct positive prediction across the predicted positives as,

$$\:Precision=\:\frac{TP}{TP+FP}$$
(14)

Recall is termed as correct positives predictions over the actual positives in the data as,

$$\:recall=\:\frac{TP}{TP+FN}$$
(15)

Finding the harmonic average for precision and recall is termed as

$$\:F1\:SCORE=2\:X\:\frac{Precision\:x\:recall}{Precision+\:recall}$$
(16)

Quantitative performance evaluation

To evaluate the performance and effectiveness of the proposed Hybrid GAT + MLP model, we computed several standard classification metrics on the KKBox test set using the optimal decision threshold of 0.89, selected based on validation accuracy as shown in Table 4. The model achieved a final test accuracy of 95.87% and an Area Under the ROC Curve (AUC) of 0.9609, indicating strong discriminative power. Class-wise analysis revealed that for the majority class (Not Churn), the model attained precision = 0.98, recall = 0.98, and F1-score = 0.98, with a support count of 176,726 users. For the minority class (Churn), despite class imbalance, it achieved precision = 0.76, recall = 0.79, and F1-score = 0.78, across 17,466 users. The macro-averaged F1-score was 0.88, and the weighted F1-score was 0.96, reflecting the model’s balanced generalisation across both classes.

Table 4 Classification performance on KKBox test Set.

Qualitative performance results

Confusion matrix analysis

To further assess the predictive quality of the proposed Hybrid GAT + MLP model, a confusion matrix was constructed using the final test set results. As shown in Fig. 8, the model correctly identified 173,279 non-churners and 12,761 churners, while misclassifying 3,447 non-churners as churners and 4,705 churners as non-churners. These results emphasise the model’s effectiveness in recognising the minority class (churners), despite the underlying class imbalance. The false positive rate remains acceptably low, showing that the model does not excessively flag loyal users as churners. Meanwhile, the true positive count for churners suggests that the model captures complex churn-inducing patterns efficiently, supporting its utility in proactive customer retention strategies.

Fig. 8
figure 8

Final confusion matrix for the Hybrid GAT + MLP model on the KKBox test set.

Training dynamics

To monitor the learning behaviour of the Hybrid GAT + MLP model, we plotted the training loss and accuracy over 250 epochs. As shown in Fig. 9, the training loss steadily reduces from approximately 0.9 to under 0.27, showing successful convergence of the optimization process. The rapid loss reduction in the first 30 epochs reflects effective early learning, followed by a plateau suggesting convergence stability. Simultaneously, the training and validation accuracy curves describes that the model accomplishes over 95% accuracy within the first 40 epochs and maintains high consistency throughout subsequent epochs. The absence of significant divergence between training and validation accuracy shows minimal overfitting and validates the effectiveness of dropout regularization and early stopping. These curves assure that the model generalizes well while maintaining learning stability.

Fig. 9
figure 9

Training loss and accuracy curves demonstrating stable convergence and high generalization performance of the Hybrid GAT + MLP model.

Receiver operating characteristic analysis

As illustrates in Fig. 10, the Hybrid GAT + MLP model describes a strong trade-off between true positive rate (sensitivity) and false positive rate, with the curve closely hugging the top-left corner of the plot. The Area Under the Curve (AUC) is 0.9583, showing high discriminative power. This score is consistent with other performance metrics and ensures the model’s robustness in identifying churn behaviour across various classification thresholds. The smooth and well-separated ROC profile validates the model’s effectiveness in balancing precision and recall, especially under class imbalance.

Fig. 10
figure 10

ROC curve showing strong classification performance of the Hybrid GAT + MLP model with an AUC of 0.9583 on the KKBox test set.

Precision-recall curve analysis

To estimate the model’s ability to identify churners under class imbalance, the Precision-Recall (PR) curve was generated. As depicted in Fig. 11, the proposed Hybrid GAT + MLP model achieved an average precision (AP) score of 0.8235, which is significantly higher than the no-skill baseline of approximately 0.09 (reflecting the churn class prevalence). The model maintains high precision across a wide range of recall values, depicting its capability to retrieve relevant churn instances without producing excessive false positives. The curve’s smooth descent and large area under the curve show strong positive predictive power, especially important for high-risk churn intervention use cases. The high AP reinforces the model’s strength in classifying minority class instances effectively despite the severe class imbalance in the dataset.

Fig. 11
figure 11

Precision-recall curve of the hybrid GAT + MLP model.

Feature importance analysis

To interpret which attributes most influenced churn predictions, we examined feature importance scores derived from the Gini impurity metric. As shown in Fig. 12, the most influential feature is \(\:is\_cancel\_sum\), indicating the number of times a user has canceled their subscription. This is followed closely by \(\:avg\_payment\) and \(\:total\_payment\), suggesting that payment behavior—both in terms of consistency and volume—plays a crucial role in churn detection. Other significant features include \(\:total\_transactions,\:log\_days,\) and \(\:registration\_year\), capturing transactional frequency, engagement duration, and user tenure respectively. In contrast, features such as \(\:total\_songs\_played,\:total\_secs\_sum,\) and gender contributed minimally to the model’s decision process. These insights validate the feature engineering steps and emphasize that both behavioral and payment-related attributes are key drivers of churn outcomes (Table 5).

Fig. 12
figure 12

Top 15 most important features for churn prediction ranked by Gini impurity.

Ablation study

Table 5 Ablation study results across models and Datasets. Initially we tried to imnplemnt on Telco dataset, but this dataset does not perform well on graph based model, so we tried with KKBOX which works extensively for kkbox.

To assess the impact of model architecture, feature engineering, and dataset scale, an ablation study was conducted across five configurations, as presented in Table 5. The Baseline GNN (GraphSAGE) achieved a test accuracy of 74.0%, AUC of 0.82, and F1-score of 0.61, indicating overfitting and poor generalization. Adding feature engineering in the Hybrid GNN (SAGE) slightly improved the performance to 75.5% accuracy, 0.83 AUC, and 0.63 F1-score. When GAT layers were introduced in the HybridGAT (Telco) configuration, performance further improved to 78.4% accuracy and 0.84 AUC, validating the benefit of attention mechanisms, though the F1-score remained at 0.61, limited by the dataset’s simplicity. The Ensemble (GAT + XGBoost) approach slightly underperformed with a 77.9% accuracy and a sharp drop in F1-score to 0.43, suggesting overfitting. Finally, the proposed HybridGAT (KKBox) model, when applied to a richer dataset with large-scale features and weighted loss, achieved 95.8% accuracy, 0.963 AUC, and 0.77 F1-score, establishing it as the best-performing configuration. This clearly illustrates that both model expressiveness and data richness are essential for robust churn prediction.

State-of-the-art comparison

Table 6 Comparison of proposed model with state-of-the-art methods.

Table 6 shows the state of art comparison on KKBOX. This dataset is nor preferred widely like telecom, bank etc. TBformer33 and Entropy34 shows standard result which is nearer to our proposed model. However, our model achieves outperforming results than existing works which is standard benchmark works.

Discussion

This research proposes the new Hybrid GAT + MLP model for predicting churn which uses hybrid GAT structure combined with MLP in a churn predicting manner for imbalanced datasets. It works by the integration of GAT’s structural graph reasoning for learning complex user interconnections with the MLP’s polygenic feature interrelationship modeling capability. This model works towards minimizing or completely overcoming class imbalance, low density graphs, and churn behavior complexity.

The attention aggregation neural networks make the model to focus on the most critical vertices for churn prediction takes place, thus removing the attention aggregation neck. MLP works in parallel with user-level attributes by learning missing network patterns. This results are achieved by using both graph theory and the underlying real-world phenomena, enabling models to overcome user churn gaps with greater precision.

Through quantitative evaluation techniques, it was proved that the performance of the Hybrid GAT + MLP model was remarkable as it achieved an accuracy of 95.87% and an AUC of 0.9609 on the KKBox dataset. It was also noted that the model surpassed many other churn predicting methods. In addition, the model was able to retain precision as well as recall on both classes, achieving a weighted F1 score of 0.96, despite the dataset’s imbalanced characteristics (9% churners and 91% non-churners). Furthermore, qualitative assessments like confusion matrices and ROC curve analyses confirmed the fact that the model was able to accurately separate churn and non-churn cases while only capturing a few false positives.

The ablation study and the results highlight the improvement of the model’s performance of the graph attention mechanism, engineered features, and expanded dataset. In addition, it was noted that the Hybrid GAT + MLP model surpassed the current state of the art techniques such as TBformer and Entropy-RFM, which describes the robustness and generalization capacities of the model. Furthermore, the user behavior features such as is_cancel_sum and avg_payment became particularly important in predicting churn, which is essential in building proactive retention mechanisms.

Limitation

The limitation of the research is, where evaluation was focused on the KKBox dataset only, which restricts the model’s applicability to other churn detection services with varying consumer patterns or subscription behaviours. Next, the problem of class imbalance35 was addressed by the weighted loss method36, but other techniques such as resampling, or even generating synthetic data, which are likely to produce different distribution, were not considered. Explainable interpretability is not obtained for this model37.

Conclusion

Customer churn prediction remains a crucial challenge in subscription-based digital services, specifically in music streaming platforms where user behaviour is dynamic and engagement patterns are complex. High churn rates can significantly minimise profitability, making it essential to identify at-risk users early for targeted retention strategies. Conventional machine learning models often fall short in capturing relational dependencies and non-linear feature interactions, especially in the presence of imbalanced class distributions. To overcome these limitations, we present a Hybrid Graph Attention Network (Hybrid GAT + MLP) that integrates graph-based user similarity learning with deep tabular feature modelling. A synthetic similarity graph was constructed using demographic attributes to capture contextual user relationships, while a GAT encoder learned structural embeddings that were concatenated with handcrafted behavioural and transactional features. The model was trained end-to-end using a weighted cross-entropy loss to handle class imbalance without data distortion. Experimental results depicted the effectiveness of the model, achieving 95.8% accuracy and an AUC of 0.9626 on the KKBox test set. These metrics emphasise the system’s strong predictive performance and its ability to generalise across diverse user profiles. Additionally, the architecture’s ability to process large-scale data in a scalable and interpretable manner makes it well-suited for real-world deployment in streaming services, where retention interventions must be both accurate and timely. In future work, we plan to extend the model by integrating temporal graph structures that capture evolving user behaviour over time, as well as experimenting with self-supervised learning techniques to leverage unlabeled data. The model can be tested with diverse churn detection datasets not limiting for improving generalization. Hyper parameter tuning is preferred to improve the accuracy. Explainable AI (XAI) model like SHAP and LIME should be added for knowing the model interpretability.