Introduction

Stroke is a major public health problem and accounts for a significant proportion of morbidity and mortality worldwide1. According to the World Health Organization (WHO), more than 15 million people are affected by stroke each year, of which around 5 million die and another 5 million are left permanently disabled. This growing global burden underscores the urgent need for early, accurate and automated detection systems to help clinicians diagnose stroke and assess risk. Individual patients are affected by stroke, and they represent a high burden on healthcare systems worldwide2. Accurate, early prediction of stroke is of paramount importance for early intervention and prevention to avoid prolonged disability, improve quality of life and relieve the burden on healthcare resources. The development of effective stroke prediction models that utilize advanced machine learning and deep learning techniques has gained significant attention in recent years3,4,5. The next generation of these models utilizes large amounts of clinical imaging data to improve diagnostic accuracy and enable early detection of strokes. By detecting signs of a stroke earlier, healthcare providers can implement appropriate preventative measures and targeted treatments, leading to better patient outcomes and lower healthcare costs6. However, stroke prediction remains challenging due to the complexity of the underlying psychophysiology, the heterogeneity of stroke subtypes and the limited availability of well-labelled and comprehensive datasets7.

Despite the success of deep learning in medical imaging, MRI analysis of stroke still faces challenges, particularly in handling class imbalance, ensuring interpretability, and validating clinical applicability. Most existing models overlook class imbalances or are not explainable, limiting their clinical relevance. This study introduces CBDA-ResNet50, a ResNet50-based model developed for stroke MRI classification. It incorporates class-balanced loss, data augmentation, and Grad-CAM visualizations to improve accuracy and interpretability.

Stroke datasets are often very unbalanced, with significantly more negative (non-stroke) than positive (stroke) cases, making it difficult for models to learn the minority class. To address this problem, CBDA-ResNet50 applies class-balanced loss and targeted augmentation techniques, such as random flips, rotation, resizing, and normalization, improving generalization and robustness. Based on the proven ResNet50 architecture, our model effectively learns complex clinical patterns while maintaining real-time and interpretable performance suitable for clinical use.

In this study, we compare CBDA-ResNet50 with different stroke prediction methods, including traditional machine learning approaches, logistic regression, support vector machines, decision trees, random forests, deep learning models, convolutional neural networks, long short-term memory networks, and DenseNet-121. Previous studies achieved accuracies of 66% decision trees to 96.45% GA + BiLSTM. However, many of these methods struggled with class imbalance, resulting in suboptimal predictions for the minority class.The proposed CBDA-ResNet50 model addresses class imbalance using WeightedRandomSampler and Weighted Cross-Entropy, and employs data augmentation strategies such as intensity normalization to enhance performance and overcome the limitations of existing methods. In this study, our key contributions are as follows:

  • We developed CBDA-ResNet50 for stroke risk prediction, which mitigates the class imbalance problem and enhances model generalization.

  • We applied Grad-CAM to interpret the decisions made by CBDA-ResNet50, thereby increasing the model’s transparency and reliability for clinical use.

  • We conducted a comparative analysis of several existing approaches to identify the most effective model for stroke prediction.

Study distribution

The rest of the article is organized as follows: Section 2 gives an overview of related studies, Sect. 3 describes the methodology, Sect. 4 presents the Evaluation matrics, Sect. 5 the results and experiments, Sect. 6 the discussion, and Sect. 7 concludes with future work.

Related work

Machine learning and deep learning have been extensively studied for stroke prediction and offer significant potential to improve diagnostic accuracy and prediction results. Gudadhe et al.8 proposed a machine learning-based approach to classifying different stroke subtypes using algorithms such as Naive Bayes, Random Forest, and Support Vector Machines. This study showed that these models can compare stroke phenotypes using computed tomography (CT) scan data to customize treatment strategies and improve patient outcomes. Similarly, Akter et al. developed a deep learning model with random forest classifiers that achieved 95.30% accuracy in stroke detection9. In addition, previous studies have investigated different machine learning approaches such as K-nearest neighbors, Naïve Bayes, and decision trees to improve stroke prediction accuracy10.

The development of machine learning for stroke enables earlier detection, risk assessment and prognosis, as highlighted by Tursynova et al.11. Their findings support deep learning methods such as CNNs in the classification of stroke types and other neuroimaging applications. A study by Saleem et al.12 developed a stroke detection model using CT images of the brain that combines a genetic algorithm with bidirectional long-term memory (BiLSTM) and achieves 96.5% accuracy, further emphasizing the effectiveness of combining feature selection with advanced deep learning methods. Another study on machine learning for stroke prediction combined algorithms such as support vector machines and decision trees to classify patients based on stroke risk factors, providing valuable decision support13.

While previous studies have achieved promising accuracy, they often overlook critical challenges such as class imbalance, limited generalizability of data sets, and lack of interpretability. In medical imaging, advanced deep architectures such as EfficientNet, Vision Transformers (ViTs) and CNN-RNN hybrids have emerged. ViTs, for example, use self-attention to model dependencies over large distances, while EfficientNet balances performance and computational efficiency through scaling. However, these models have only been used to classify stroke MRIs to a limited extent. In addition, there are a few studies that integrate interpretation tools such as Grad-CAM. Our proposed CBDA-ResNet50 fills these gaps by addressing class imbalance through weighted sampling, improving generalization through MRI-specific augmentations, and increasing explainability through Grad-CAM-based feature visualization.

These studies show the potential of machine learning (ML) and deep learning (DL) for stroke prediction but often do not adequately address challenges such as class imbalance and limited data14. Advanced architectures such as Vision Transformers (ViTs), EfficientNet and Hybrid CNN-RNN models have shown promise in medical imaging. ViTs leverage self-attention for long-range dependencies, EfficientNet strikes a balance between accuracy and efficiency with composite scaling, and CNN-RNN hybrids improve sequential image analysis. Although these models outperform traditional CNNs, BiLSTMs and Random Forest classifiers in some cases, their application to stroke prediction is still limited. However, these problems can greatly affect the performance or reliability of prediction models and make them much more difficult to access in real-world applications. For example, the datasets used to train the models often contain a disproportionately high number of negative cases normal compared to positive cases of stroke, leading to a bias in favour of the majority class and reducing the predictive ability of the model for minority cases8.

To address the limitations of conventional deep learning approaches in stroke MRI analysis, particularly class imbalance, overfitting, and limited dataset size, we propose CBDA-ResNet50, an enhanced variant of the standard ResNet50 architecture. This model is specifically designed to improve stroke classification performance by integrating techniques tailored for imbalanced medical datasets. Unlike traditional models that rely on standard cross-entropy loss and are prone to biased learning toward the majority (non-stroke) class, CBDA-ResNet50 employs a weighted cross-entropy loss, which enables better sensitivity to the minority (stroke) class and promotes balanced classification outcomes14. To mitigate overfitting and enhance generalization, the model employs a comprehensive data augmentation strategy, which includes random flipping, rotation, color jittering, and resized cropping. These augmentations simulate diverse imaging scenarios and generate synthetic variations that strengthen the model’s ability to learn discriminative features from limited data15. Furthermore, ReduceLROnPlateau is implemented as a dynamic learning rate scheduler to adjust the learning rate adaptively during training, thereby improving convergence stability and efficiency. Through these combined strategies, CBDA-ResNet50 achieves robust and accurate stroke prediction, outperforming conventional machine learning (ML) and deep learning (DL) models, and offers a clinically viable solution for real-world diagnostic use.

The following section explains the development and application of the proposed CBDA-ResNet50 stroke prediction. It outlines the dataset, data preprocessing, model architecture, balancing class-balanced loss mechanism and application of data augmentation techniques, training approach, and evaluation dimension used to assess the performance of the learning model.

Methodology

The study follows a systematic approach to create a model to classify MRI brain images as either stroke or normal. As shown in Figure 1, we proceed in a sequential flow that includes data input, splitting of the data, augmentation, training of the model, and evaluation.

Fig. 1
Fig. 1
Full size image

Proposed model block diagram.

The dataset is then preprocessed by applying operations such as cropping and normalization to the MRI images of the brain and augmented by flipping and rotating. These pre-processed images are then converted into a format that ResNet50 can load so that they are compatible and have many valuable features. The dataset is split into a training set with 70% 1751 of the images, and a test set with 30% 750 of the images. In this method, the data is split to train the model with part of the training dataset and test it with data that the model does not yet know at the time of testing to prevent the observed overfitting and improve generalization. To improve the model and reduce overfitting, the data was augmented with random rotation on a horizontal line, random rotation on a vertical line, random rotation, and resizing of the section. These changes bring randomness to the training set so that the model can better handle data it has never seen before.

A pre-trained ResNet50 is at the core of the model architecture. A fully connected layer dense layer, which performs the binary classification stroke and normal, is replaced for the last fully connected layer of ResNet50. A dropout is attached to the fully connected layer, followed by a sigmoid activation that creates a probability distribution for the output classes and sets the condition that the sum of the probabilities is 1.

Algorithm for proposed CBDA-ResNet50 model

The procedure of the proposed CBDA-ResNet50 model presented in algorithm 1 begins with data preprocessing and augmentation, which includes resizing, normalization, random flip, random rotation, and splitting the dataset into training and valid subsets. The CBDA-ResNet50 model, pre-trained on ImageNet, is fine-tuned with a modified full layer for binary classification. For model optimization, the cross-entropy loss weighted by class frequency and the Adam optimizer is used. The model is trained over 100 epochs, and the learning rate is dynamically adjusted when the validation loss reaches a plateau. We use the weighted cross-entropy loss to avoid imbalance between classes by ensuring that both the majority class and the minority class contribute equally, thus preventing a bias in favor of the dominant class. While the focal loss is effective for extremely rare conditions (<1% prevalence), weighted cross-entropy ensures a balanced learning process without overly down-weighting the lightly classified cases, making it more suitable for stroke prediction. For optimization, we use ReduceLROnPlateau, which dynamically adjusts the learning rate when the validation loss stagnates, increasing stability in medical imaging tasks. Unlike SGD, which requires extensive tuning, or AdamW, which may over-regularize small datasets, Adam with ReduceLROnPlateau ensures adaptive learning and robust generalization. The best-performing model is saved based on the accuracy of the validation.

Algorithm 1
Algorithm 1
Full size image

Working of CBDA-ResNet50

Experimental setup

The experimental setup defines the steps and configurations used in the study: Hardware: The model was trained on an NVIDIA CUDA-enabled GPU, which accelerates the matrix operations and backpropagation computations. This hardware allowed us to train efficiently with the MRI dataset and reduce the required computation time. Software: Model development was performed in PyTorch 1.x while Torchvision worked with the dataset and transformations. We used Scikit-learn to calculate the metrics for the performance evaluation. Also, Seaborn and Matplotlib were used to visualize the key training and evaluation metrics, including accuracy, loss, and ROC-AUC curves.

Dataset and data collection

The dataset used in this study is an MRI dataset of the brain from Kaggle. It consists of images divided into two classes: Stroke patients and normal patients. The dataset contains MRI scans of stroke patients and normal individuals, with images pre-organized into class-specific directories. The dataset is divided into a training set and a test set, with 70% for training and 30% for testing16.

Since MRI images are usually grayscale images (single-channel), we convert them to the three-channel format (RGB) by duplicating the grayscale values across all three channels. This ensures compatibility with the pre-trained ResNet50 model, which expects three-channel RGB input. The MRI images are preprocessed to improve the performance of the model by removing noise, artifacts, and contrast variations. We applied normalization (mean = 0.485, std = 0.229) for intensity consistency, Gaussian smoothing to reduce noise while preserving detail, and adaptive histogram equalization for contrast enhancement. Augmentation techniques such as random rotation, flipping, and resizing have improved the range and robustness of the model. These steps also ensure that CBDA-ResNet50 retains robust features and mitigates bias. The dataset of 2501 MRI images, consisting of 950 images of strokes and 1551 normal images, leads to a class imbalance that can affect the model. To solve this problem, the images were resized to 224\(\times\)224 pixels and normalized according to ResNet50. We used a weighted cross-entropy loss to focus on stroke cases and implemented class-balanced sampling to obtain a more balanced distribution in the training phase, as shown in Fig. 2. These techniques increase the model’s ability to generalize and improve stroke detection.

Fig. 2
Fig. 2
Full size image

Class distribution before and after augmentation: A comparison of the original number of samples, the added augmented samples and the final number of samples for the stroke, normal and total categories.

To balance the dataset, we augmented 1000 stroke images and 500 normal images by flipping, rotating, cropping and color jittering. This controlled augmentation improved the class balance of 1950 strokes and 2051 normals and reduced model bias while preserving data diversity.

The mathematical expression for the normalization of an image \(\textbf{I}\) is as follows:

$$\begin{aligned} \textbf{I}_{\text {norm}} = \frac{\textbf{I} - \varvec{\eta }}{\varvec{\delta }} \end{aligned}$$
(1)

where

$$\begin{aligned} \varvec{\eta } = [0.485, 0.452, 0.403] \quad \text {and} \quad \varvec{\delta } = [0.228, 0.222, 0.223] \end{aligned}$$

are the mean and standard deviation of the pixel values derived from the ImageNet dataset used to pre-train ResNet50.

Class-balanced loss function

The data set shows a class imbalance, with more samples from normal individuals than from stroke patients. To mitigate this problem, we apply a weighted cross-entropy loss function, where each class is assigned a weight inversely proportional to its frequency in the training set. The weight for the class \(i\) is computed as follows:

$$\begin{aligned} w_i = \frac{1}{f_i} \end{aligned}$$
(2)

where \(f_i\) represents the frequency of class \(i\) in the dataset. This ensures that minority classes contribute equally to the model’s learning process.

The class-balanced cross-entropy loss function is then formulated as follows:

$$\begin{aligned} L = -\sum _{i=1}^{C} w_i \cdot y_i \cdot \log (\hat{y}_i) \end{aligned}$$
(3)

where:

  • \(C\) is the total number of classes,

  • \(y_i\) is the true label for class \(i\) (1 for the correct class, otherwise 0),

  • \(\hat{y}_i\) denotes the predicted probability for the class \(i\),

  • \(w_i\) ensures that the contributions of the minority class are strengthened while the contributions of the majority class are weighted down.

In the backpropagation method, the updates for the weights of the class-balanced cross-entropy are evaluated using the gradient:

$$\begin{aligned} \frac{\partial L}{\partial \theta } = -\sum _{i=1}^{C} w_i \cdot y_i \left( \frac{1}{\hat{y}_i} \cdot \frac{\partial \hat{y}_i}{\partial \theta } \right) \end{aligned}$$
(4)

where \(\theta\) represents the learnable parameters of the model. This gradient formulation ensures that the loss function assigns the loss for the minority class samples and minimizes the probability that the prediction is biased towards the majority class samples. To utilize this loss function, we use PyTorch’s CrossEntropyLoss with user-defined class weights. The weights are chosen to ensure a balanced distribution in the training cases with and without stroke. This approach improves the robustness of the model and also improves generalization to different patient populations.

Data augmentation

To increase the size and diversity of the training data and improve model generalization, data augmentation was applied through a series of artificial, random image transformations. These augmentations simulate the natural variability found in MRI scans, helping to reduce overfitting. Specifically, the training images were augmented using random horizontal and vertical flipping (each with a 50% probability), random rotation within a range of ±30 degrees, random resized cropping to 224\(\times\)224 pixels, and color jittering with brightness, contrast, saturation, and hue factors set to 0.3. These transformations collectively enhance the robustness of the model to spatial and intensity variations in real-world stroke MRI data.

Random horizontal and vertical flipping The images are randomly flipped horizontally or vertically with a certain probability. This can be represented as follows:

$$\begin{aligned} I_{\text {flip}}(x, y)= & I(W - x, y) \end{aligned}$$
(5)
$$\begin{aligned} I_{\text {flip}}(x, y)= & I(x, H - y) \end{aligned}$$
(6)

where \(I(x, y)\) is the original image and \(H\) and \(W\) represent the height and width of the image. This technique introduces orientation invariance to account for the fact that stroke lesions can occur in both hemispheres of the brain.

Random rotation Random rotation is applied by rotating the image by an angle \(\theta\) drawn from a uniform distribution within a given range. The mathematical expression for this transformation is

$$\begin{aligned} I_{\text {rot}}(x', y') = I(x \cos \theta - y \sin \theta , x \sin \theta + y \cos \theta ) \end{aligned}$$
(7)

where \((x, y)\) are the original pixel coordinates and \((x', y')\) are the coordinates after the rotation. This augmentation leads to a variance in the orientation of the head in MRI scans.

Random Resized Cropping This transformation extracts a random image section and resizes it to a fixed size (e.g. 224x224 pixels). The cropping is described mathematically as follows:

$$\begin{aligned} I_{\text {crop}}(x, y) = I(x_c + x, y_c + y) \quad \text {for} \quad -\frac{w}{2} \le x \le \frac{w}{2}, -\frac{h}{2} \le y \le \frac{h}{2} \end{aligned}$$
(8)

where \((x_c, y_c)\) denotes the center of the crop, and \(w\) and \(h\) are the width and height of the crop. The cropped area is then resized to the desired output dimensions.

Color Jittering Although MRI images are typically displayed in grayscale and have no innate brightness variance, color jittering is used to simulate the variability in intensity and contrast that different MRI acquisition parameters, different MRI scanner types, and different signal processing workflows can cause. This method increases the robustness of the model by ensuring good network performance across different data sets and image acquisition conditions. The transformations for the brightness and contrast adjustments are given by

$$\begin{aligned} I_{\text {aug}}(x, y)= & \alpha _b I(x, y) + \beta _b \quad \text {(brightness adjustment)} \end{aligned}$$
(9)
$$\begin{aligned} I_{\text {aug}}(x, y)= & \alpha _c \cdot \frac{I(x, y) - \eta }{\delta } + \eta \quad \text {(contrast adjustment)} \end{aligned}$$
(10)

where \(\alpha _b\) and \(\alpha _c\) are scaling factors for brightness and contrast respectively. Similar adjustments are made for saturation and hue to make the model robust to different lighting conditions. Combining these augmentation techniques improves the model’s ability to generalize across different images and avoid overfitting, especially given the relatively limited size of medical imaging datasets.

Several traditional machine learning models, such as logistic regression, decision trees, and support vector machines, have the problem of imbalance of medical data that necessarily teach the dominant class the patterns but capture features that fail the minority class. This leads to biased predictions and lower sensitivity to stroke. With class balancing strategies, such as weighted cross-entropy loss, CBDA-ResNet50 ensures that both stroke and non-stroke samples are equally included in the training process, which is beneficial in cases where the model can accurately categorize minority class patterns. Similarly, advanced data augmentation methods improve the generalization of ResNet50 by introducing random transformations: flips, rotations, transformations in color space, and cropping as simulations of variations in the shape of the MRI scans. In this way, overfitting is avoided, and the model learns shift-invariant features that are more common under different image scan conditions. Therefore, CBDA-ResNet50 outperforms traditional methods by effectively dealing with the class imbalance and improving generalization and robust performance on different strokes.

Model architecture

The model used in this analysis is an adaptation of the ResNet50 architecture, which consists of several building blocks: Convolutional layers, batch normalization, residual module, pooling layers, and a fully connected layer with sigmoid activation for classification. Each component is critical to the performance of the model. The architecture is explained in detail below, with a focus on the mathematical formulations contained in each layer. The architecture is shown in Fig. 3.

Input layer

The input data includes MRI scans of the brain that have been preprocessed by resizing to \(224 \times 224\) pixels and converting to 3-channel RGB format. These images are usually resized to \(224 \times 224\) pixels with 3 channels (RGB). Mathematically, the input image \(I(x, y)\) can be represented as follows:

$$\begin{aligned} I(x, y) \in \mathbb {R}^{224 \times 224 \times 3} \end{aligned}$$
(11)

where \(x \in [1, 224]\) and \(y \in [1, 224]\) denote the pixel coordinates and the third dimension represents the color channels (red, green, blue).

Stem layer

The ResNet50 model stem consists of an initial convolutional layer followed by batch normalization and ReLU activation. The convolution operation is given by:

$$\begin{aligned} I_{\text {conv}}(x, y) = \sum _{i=-k}^{k} \sum _{j=-k}^{k} W(i, j) \cdot I(x+i, y+j) + b \end{aligned}$$
(12)

where \(W(i, j)\) are the learnable convolution filters and \(b\) is the bias term. A ReLU activation function controls the output of the convolution:

$$\begin{aligned} I_{\text {ReLU}}(x, y) = \max (0, I_{\text {conv}}(x, y)) \end{aligned}$$
(13)

This is followed by a batch normalization that normalizes the output so that it has a mean of zero and a unit variance:

$$\begin{aligned} I_{\text {BN}}(x, y) = \frac{I_{\text {ReLU}}(x, y) - \eta _{\text {batch}}}{\delta _{\text {batch}}} \end{aligned}$$
(14)

where \(\eta _{\text {batch}}\) and \(\delta _{\text {batch}}\) are the mean value and the standard deviation of the current data batch.

Residual modules

ResNet50 is characterized by residual modules that enable efficient training of deep networks by introducing jump connections. Each residual module has multiple convolutional layers, a batch normalization, a ReLU activation and a jump connection. The residual connection can be formulated as follows:

$$\begin{aligned} I_{\text {out}}(x, y) = I_{\text {in}}(x, y) + F(I_{\text {in}}(x, y), W) \end{aligned}$$
(15)

where \(F(I_{\text {in}}(x, y), W)\) represents the transformation applied by the convolutional layers and \(I_{\text {in}}(x, y)\) is the input for the residual block. The residual connection relieves the problem of the vanishing gradient, which can learn gradients back to the previous layers during training.

Pooling and flatten layers

The residual modules pass values that then apply Global Average Pooling to each feature map to reduce dimensionality by calculating the average of all values in that map. Mathematically, Global Average Pooling can be expressed as follows:

$$\begin{aligned} I_{\text {pool}}(x, y) = \frac{1}{H \times W} \sum _{i=1}^{H} \sum _{j=1}^{W} I_{\text {out}}(i, j) \end{aligned}$$
(16)

where \(W\) and \(H\) represent the width and height of the feature map.

The output of the pooling layer for the global average is a single vector whose length corresponds to the number of channels. This vector is then flattened and forwarded to the dense (fully connected) layer.

Dense layer and sigmoid activation

The fully connected (dense) layer combines all extracted features from the previous layers to make a final classification decision. Let \(f_i\) be the output of the flattened layer and \(W_{\text {dense}}\) be the weights of the dense layer, then the output of the dense layer is given by :

$$\begin{aligned} z = W_{\text {dense}} \cdot f + b_{\text {dense}} \end{aligned}$$
(17)

Finally, sigmoid activation is applied to the output of the dense layer to convert the raw output into class probabilities. The sigmoid function is given by:

$$\begin{aligned} P(y=k|z) = \frac{\exp (z_k)}{\sum _{i=1}^{C} \exp (z_i)} \end{aligned}$$
(18)

where \(z_k\) is the score for class \(k\), and \(C\) is the total number of classes (in this case, \(C = 2\), representing stroke and normal).

Classification output

The final output of the model is the probability distribution over the two classes (stroke and normal). The class with the highest probability is selected as the predicted class:

$$\begin{aligned} \hat{y} = \arg \max _{k} P(y = k | z) \end{aligned}$$
(19)
Fig. 3
Fig. 3
Full size image

CBDA-ResNet50-Based Stroke Classification Model.

Training process

The training of the CBDA-ResNet50 model follows a structured approach using a pre-trained ResNet50 with user-defined modifications for binary classification of MRI brain images. While newer architectures such as DenseNet, EfficientNet, and MobileNetV3 offer design improvements, ResNet50 was chosen for its deep residual learning, which prevents vanishing gradients and enables efficient training on large medical imaging datasets. Although EfficientNet and MobileNetV3 reduce computational costs, they can sacrifice fine-grained stroke MRI features, while DenseNet-121 requires more GPU memory. Overfitting was mitigated by ReduceLROnPlateau for adaptive learning rate adjustment, dropout regularization (0.5) to improve generalization, data augmentation flipping, rotating, cropping for variability, and class-balanced sampling to avoid class bias. Together, these strategies improved the robustness of the model. Considering the size of the dataset and computational limitations, ResNet50 provides the best balance between complexity and performance.

Table 1 Comparison of CNN Backbones for Stroke MRI Classification.

Table 1 summarizes the most important differences between these architectures. In the next sections, the training process is described in detail, including the hyperparameters, the optimization strategies and the training methodology.

Hyperparameters

Tuning the hyperparameters is crucial for optimizing deep-learning models for medical image classification. We analyzed the learning rate, batch size, and class weights to evaluate their impact on CBDA-ResNet50. A learning rate of 0.0001 provided a balance between stability and convergence, while smaller batch sizes improved generalization but increased training time. Inverted frequency class weights improved stroke sensitivity, while equal weights degraded performance. These hyperparameters were selected based on empirical tuning and previous research results, which are summarized in Table 2.

Table 2 Hyperparameters Used in Training and Their Impact on Model Performance.

Algorithm for the training of CBDA-ResNet50

The algorithm 2 describes the training process for the CBDA-ResNet50 model, starting with the initialization of the ResNet50 model with pre-trained ImageNet weights. The training and validation phases alternate for 100 epochs. During each epoch, the model computes predictions, calculates the weighted cross-entropy loss, and updates the weights using the Adam optimizer. If the validation loss does not decrease for 5 epochs after each validation phase, the learning rate is reduced by 0.1. Finally, the selected model is given the best accuracy from the validation.

Algorithm 2
Algorithm 2
Full size image

Training Algorithm for CBDA-ResNet50

Evaluation metrics

The performance of the CBDA-ResNet50 model was evaluated using the following metrics:

  • Accuracy The accuracy, calculated as shown below, is the ratio of the correct predictions to the total predictions:

    $$\begin{aligned} \text {Accuracy} = \frac{TP + TN}{TP + TN + FP + FN} \end{aligned}$$
    (20)

    An overall performance measure, accuracy, is generally accurate but sometimes less informative when there are imbalanced data sets.

  • Sensitivity This is the percentage of actual positive results that the model correctly recognizes.

    $$\begin{aligned} \text {Sensitivity} = \frac{TP}{TP + FN} \end{aligned}$$
    (21)

    High sensitivity is required to ensure that strokes are detected to limit the likelihood of false negatives.

  • Specificity Specificity measures the proportion of true negatives.

    $$\begin{aligned} \text {Specificity} = \frac{TN}{TN + FP} \end{aligned}$$
    (22)

    A high specificity means that there are no false-positive results in normal patients, and a stroke is avoided.

  • Balanced accuracy Balanced accuracy is a crucial metric for evaluating classification performance, especially for imbalanced datasets. It is the average of the sensitivity for each class to ensure the contributions of the majority and minority classes to the result. The formula for calculating balanced accuracy is:

    $$\begin{aligned} \text {Balanced Accuracy} = \frac{\text {Sensitivity} + \text {Specificity}}{2} \end{aligned}$$
    (23)

    In this study, balanced accuracy was used to measure the model’s ability to discriminate between stroke cases and normal cases without preference for the majority class.

  • Receiver Operating Characteristic (ROC): AUC is the performance of the model at different thresholds and is able to differentiate between a positive and a negative case.

These metrics indicate the overall efficiency of the model: Accuracy, Sensitivity, and Specificity to reliably differentiate strokes and other types of cerebral haemorrhage.

Experimental results and discussion

In this section, the experimental results of the stroke prediction model from the CBDA-ResNet50 deep learning model and the traditional machine learning method are presented. The accuracy, specificity, sensitivity and ROC-AUC are used as a measure for each model to evaluate its performance. Below is a more detailed analysis of the training process, performance metrics, and comparison of the models.

Deep learning model: CBDA-ResNet50

The architecture of the neural network based on ResNet50 was first loaded with PyTorch to obtain CBDA-ResNet50 through data augmentation. Various data augmentation techniques were used, such as random horizontal and vertical flips, random rotations, resizing and color jittering, to overcome the problem of overfitting and provide the model with better generalization capability. CBDA-ResNet50 delivered strong results with a validation accuracy of 98%. Figure 4 shows instances of accuracy and loss of the CBDA-ResNet50 model during training and validation in the 100 epochs for training and validation. Both the training and validation accuracy stabilized early so that the model converged quickly, and an efficient learning process was observed here. The training loss also decreased rapidly, and the validation loss remained at 0.1 without overfitting.

Fig. 4
Fig. 4
Full size image

(a) Trends in training and validation loss across epochs, (b) Training and validation accuracy for the CBDA-ResNet50 model.

Figure 5 shows the ROC curve analysis of the performance of the proposed model, which achieves an AUC value of 0.98. The high AUC value indicates a strong performance of the model, with the receiver operating characteristic curve strongly approaching the upper left corner.

Fig. 5
Fig. 5
Full size image

ROC curve of CBDA-ResNet50 model.

Figure 6 illustrates the strong positive correlations between Accuracy, Specificity, Sensitivity and Balanced Accuracy for the CBDA-ResNet50 model, confirming its stability and consistent performance. The regression lines show that improvements in one metric correspond to proportional increases in the others, emphasizing the model’s reliability in stroke classification. The tightly clustered scatter points and smooth density distributions confirm the robustness of CBDA-ResNet50 and the minimal variance across evaluation metrics.

Fig. 6
Fig. 6
Full size image

Scatterplot matrix of the graph showing the relationship between accuracy, specificity, and sensitivity of the CBDA-ResNet50 model.

The evaluation of the performance of the CBDA-ResNet50 model is based on a confusion matrix, as shown in Fig. 7. The confusion matrix provides a comprehensive overview of the classification results by showing the number of correctly and incorrectly classified instances for both stroke and non-stroke cases. Each value in the confusion matrix is the classification result of the model; the diagonal elements are correctly classified instances, and the non-diagonal ones are misclassifications. The highest diagonal value in the confusion matrix indicates the robustness and reliability of the model in terms of stroke detection and a good invariant specificity and sensitivity of the two classes. The low rate of false-positive and false-negative cases shows that the model efficiently eliminates diagnostic errors, which is particularly important in medical applications. The CBDA-ResNet50 model achieves the highest accuracy in classification with an overall accuracy of 98%. This high accuracy is primarily due to several key advantages, such as the use of a well-labelled MRI dataset that significantly reduces misclassification, the use of balanced sampling with weighted cross-entropy to prevent reduced class imbalance in the classes, the deep residual learning framework of ResNet50, and the augmentation techniques that contribute to the generalizability of the model.

Important classification statistics such as accuracy, specificity, and sensitivity were calculated to investigate the performance of the model further. The evaluation results confirm that the model consistently achieves high scores in all metrics and thus proves to be well-suited for stroke diagnosis. The ability of CBDA-ResNet50 to accurately predict strokes with a very low number of false positives results in a reliable diagnostic model that helps physicians make accurate and timely diagnoses.

Fig. 7
Fig. 7
Full size image

Confusion matrix illustrating the classification performance of the CBDA-ResNet50 model.

Statistical significance analysis

To analyze the statistical significance of the observed accuracy of 97.87%, several statistical tests were performed to compare CBDA-ResNet50 with the baseline version of CNN, BiLSTM, and RF. Paired t-tests showed that the accuracy increase of CBDA-ResNet50 was significant in all scenarios, with p-values below 0.001. Also, the confidence intervals at the 95% level showed reliability ranging from no less than 3.07% against CNN to an extreme 8.27% over RF. The results of the Wilcoxon signed-rank test (p = 0.0625) may be attributed to the limited sample size, which can affect the test’s statistical power. These results suggest that the gains indicated by CBDA-ResNet50 are not due to random variation, but rather represent statistically significant improvements.

In addition, the McNemar test was applied to investigate the inconsistencies in classification between CBDA-ResNet50 and the base models. This test evaluates whether the observed improvements in accuracy are due to a genuine reduction in misclassification errors. The p-values obtained from the McNemar test show that CBDA-ResNet50 significantly reduces misclassification errors compared to all previous models, with the strongest improvement observed over BiLSTM (p = 7.85e-05). As all p-values are below the significance threshold of 0.05, these results confirm that the improvements of CBDA-ResNet50 are not due to chance but represent a significant advance in stroke classification, as shown in Table 3.

Table 3 Statistical Significance Analysis of CBDA-ResNet50 Compared to Baseline Models.

Model calibration

Although the AUC score of 0.98 shows a strong discriminatory power, it does not evaluate the calibration of the probability predictions. To evaluate this aspect, we calculated the Brier score and created a reliability diagram calibration curve. The model achieved a Brier score of 0.01177, indicating highly calibrated probability estimates. In addition, the calibration curve follows the diagonal line exactly, confirming that the probabilities predicted by CBDA-ResNet50 match the observed frequencies well, with only minor deviations. Figure 8 CBDA-ResNet50 not only excels at accurately classifying MRI scans of strokes but also provides well-calibrated probability estimates, highlighting its suitability for clinical decision-making.

Fig. 8
Fig. 8
Full size image

Calibration curve for CBDA-ResNet50: The model’s predicted probabilities closely align with the observed frequencies, indicating well-calibrated results.

Comparison with the state of the art (SOTA) for stroke prediction

The performance of the proposed CBDA-ResNet50 method for stroke prediction is reviewed along with several machine learning-based methods from recent research. We compared the classification accuracy of the proposed model with other methods that have used the Kaggle dataset of brain CT images for stroke prediction. Table 4 provides a performance comparison between the CBDA-ResNet50 model and other state-of-the-art techniques. Notably, the proposed method shows higher accuracy compared to previous studies using machine learning and deep learning techniques for stroke detection.

Table 4 Performance comparison of the proposed method with state-of-the-art (SOTA) techniques.

To ensure consistent evaluation, we also tuned each deep learning model using the validation set. The optimal configuration for BiLSTM included a learning rate of 0.001, a batch size 32, an Adam optimizer, and 50 epochs. The CNN baseline used a learning rate of 0.0005, batch size 32, ReLU activation, and cross-entropy loss. For DenseNet-121, we fine-tuned the pre-trained ImageNet model with a learning rate of 0.0001, batch size 16, an Adam optimizer, and early stopping after 30 epochs. To compare the classification performance of the proposed model with other deep learning architectures, we generated ROC curves with the same validation set. As shown in Fig. 9, CBDA-ResNet50 outperforms all other models regarding AUC.

Fig. 9
Fig. 9
Full size image

ROC Curves for traditional machine learning models.

Performance of traditional ML models

To evaluate the performance of CBDA-ResNet50, we compared it to a number of classical machine learning models, including Logistic Regression (LR), Support Vector Machine (SVM), K-Nearest Neighbors (KNN), Decision Tree (DT), Random Forest (RF), and Gradient Boosting (GB), trained on the same dataset. Hyperparameter tuning for the traditional machine learning models was performed using the hold-out validation scheme 70/30 split. For Random Forest, the optimal configuration included 100 estimators and a maximum depth of 10. The Support Vector Machine (SVM) performed best with an RBF kernel and the regularization parameter C=1.0. For logistic regression, the best results were achieved with an L2 regularization and the solver ‘lbfgs’. These values were selected based on the accuracy of the validation. Table 5 summarizes the performance of each model based on accuracy, specificity and sensitivity.

Table 5 Performance comparison with traditional machine learning models.

The analysis of the ROC curve, shown in Fig. 10, compares the performance of the traditional models. Random Forest achieved the highest AUC of 0.83, indicating strong performance at different classification thresholds. Logistic regression and SVM also performed well, with AUC values of 0.82 and 0.80, respectively. The decision tree had the lowest AUC of 0.58, indicating weaker performance.

Fig. 10
Fig. 10
Full size image

ROC Curves for traditional machine learning models.

Model interpretability and feature visualization

To ensure the interpretability of the model, we used gradient-weighted class activation mapping (Grad-CAM) to visualize stroke-relevant regions in MRI scans. Figure 11 shows Grad-CAM heatmaps, with warmer colors indicating regions of high model activation. The left image represents the original CT scan/MRI, while the right image highlights the regions that are most influential for stroke prediction. In misclassified cases, Grad-CAM matches well with the stroke lesions, emphasizing the reliability of the model. In false positive cases, Grad-CAM highlights regions with non-stroke artifacts, indicating the need for further refinement of feature extraction. In false negative cases, weaker activations in the actual stroke regions indicate potential problems in detecting low-contrast lesions. These results improve the clinical interpretability of CBDA-ResNet50 and show that AI-assisted stroke detection is not a black-box approach.

Fig. 11
Fig. 11
Full size image

Left: Original CT scan, Right: Grad-CAM heatmap highlighting the key regions that influence the model’s stroke prediction.

Computational complexity and performance comparison

To evaluate the feasibility of CBDA-ResNet50, we compare its computational cost with other models based on FLOPs, number of parameters, and GPU training time. The model was trained on an NVIDIA CUDA-enabled GPU using PyTorch 1.x, Torchvision and Scikit-learn and processed the MRI dataset efficiently, using Seaborn and Matplotlib for metric visualization.

The higher complexity of CBDA-ResNet50 (25.5M parameters, 4.1G FLOPs, 61.5s/epoch) results in a higher accuracy of 97.87% and a balanced accuracy 98.27%, making it suitable for specificity-critical clinical applications, as shown in Table 6. Future optimizations such as knowledge distillation, pruning or quantization could increase efficiency while maintaining performance. The Fig. 12 compares CBDA-ResNet50 with traditional machine learning and deep learning baselines.

Fig. 12
Fig. 12
Full size image

Computational complexity comparison between CBDA-ResNet50 and baseline models.

Table 6 Computational complexity comparison between CBDA-ResNet50 and baseline models.

Discussion

The results of this study demonstrate that the proposed CBDA-ResNet50 model achieves a strong and clinically meaningful performance in the classification of strokes from MRI data. With an accuracy of 97.8%, an AUC of 0.98, a sensitivity of 97.9%, and a specificity of 98.5%, the model effectively minimizes both false-negatives and false positives. This is particularly important in the clinical setting, where misdiagnosis can delay treatment or lead to unnecessary interventions.

In the comparative analysis, CBDA-ResNet50 outperformed traditional machine learning models such as Random Forest, Logistic Regression and Decision Trees, which had lower balanced accuracy and weaker MCC and NPV values. These models could not extract hierarchical spatial features inherent in MRI scans. Classical deep learning models such as CNN, BiLSTM and DenseNet provided better performance. Nevertheless, they were outperformed by CBDA-ResNet50 due to the inclusion of class balancing strategies, advanced augmentation and optimized ResNet50 fine-tuning. The integration of WeightedRandomSampler and class-weighted losses mitigated the imbalance between classes. At the same time, augmentation techniques such as horizontal and vertical flips, rotations and color jittering improved generalization and reduced overfitting.

The model’s high NPV (0.9788) and MCC (0.96) confirm its reliability in discriminating non-stroke cases, ensuring efficient triage and minimizing unnecessary follow-up scans. The Grad-CAM visualization increased interpretability by identifying stroke-relevant brain regions, an essential feature for clinical confidence and integration. These results make CBDA-ResNet50 potentially useful for hospital PACS systems.

Despite its strengths, this study has several limitations. First, although balanced through sampling and augmentation, the dataset is relatively small. It comes from a publicly available platform, which may limit the model’s generalizability to different populations or imaging environments. Second, the model was trained and evaluated with a fixed, stratified split of 70/30, which may introduce sample-specific biases despite the balance. Future studies could improve statistical robustness and generalization by using k-fold cross-validation. Third, the model was developed based on MRI scans only, and its applicability to other modalities, such as CT, which is commonly used in emergency stroke assessment, has not been investigated. Finally, despite the application of class balancing and augmentation methods, residual biases may remain due to the inherent imbalance of the original dataset.

Future directions include extending the model to support multimodal data, integrating federated learning for cross-institutional training, and optimizing the architecture with hybrid CNN-RNN layers or transformer-based attention modules. Integration with EHR systems and real-time use in clinical workflows remains a priority, especially for early stroke detection.

Conclusion and future work

In this study, we proposed a stroke prediction model, CBDA-ResNet50, to diagnose strokes by processing MRI images of the brain. The model integrates the ResNet50 framework with class balancing strategies and image enhancements to improve spatial feature extraction and eliminate data imbalances. The result: CBDA-ResNet50 outperformed conventional machine learning and deep learning models. It achieved high accuracy, specificity and sensitivity and showed robust performance on additional metrics such as AUC, NPV and MCC. Grad-CAM also supports interpretability by highlighting critical brain regions that influence predictions, making it suitable for clinical decision support. Advantages of the model include high classification performance, strong generalization through augmentation, class-balanced learning, and explainability. However, there are still limitations: It was trained on a relatively small dataset from a single center and has not yet been evaluated with other imaging modalities, such as computed tomography, which are commonly used in emergency stroke assessment. Although balancing and augmentation strategies improve fairness, bias can still occur. Future work will explore more advanced architectures such as Vision Transformers and EfficientNet, incorporate SHAP into multimodal data such as clinical records, and validate the model in multicenter datasets to ensure generalizability and real-world clinical use.