Abstract
Alzheimer’s disease (AD) is a progressive neurodegenerative disorder that begins with subtle cognitive changes and advances to severe impairment. Early diagnosis is crucial for effective intervention and management. In this study, we propose an integrated framework that leverages ensemble transfer learning, generative modeling, and automatic ROI extraction techniques to predict the progression of Alzheimer’s disease from cognitively normal (CN) subjects. Using the Alzheimer’s Disease Neuroimaging Initiative (ADNI) dataset, we employ a three-stage process: (1) estimating the probability of transitioning from CN to mild cognitive impairment (MCI) using ensemble transfer learning, (2) generating future MRI images using Transformer-based Generative Adversarial Network (ViT-GANs) to simulate disease progression after two years, and (3) predicting AD using a 3D convolutional neural network (CNN) with calibrated probabilities using isotonic regression and interpreting critical regions of interest (ROIs) with Gradient-weighted Class Activation Mapping (Grad-CAM). However, the proposed method has generality and may work when sufficient data for simulating brain changes after three years or more is available; in the training phase, regarding available data, brain changes after 2 years have been considered. Our approach addresses the challenge of limited longitudinal data by creating high-quality synthetic images and improving model transparency by identifying key brain regions involved in disease progression. The proposed method demonstrates high accuracy and F1-score, 0.85 and 0.86, respectively, in CN to AD prediction up to 10 years, offering a potential tool for early diagnosis and personalized intervention strategies in Alzheimer’s disease.
Similar content being viewed by others
Introduction
Alzheimer’s disease is an acquired, generalized, and usually progressive disorder that based on some references appears in three distinct stages1. The first stage, known as the preclinical phase, involves subtle changes in the brain, blood, and cerebrospinal fluid without noticeable symptoms in the patient2. This stage can begin 20 years before symptoms become evident3. The second stage called mild cognitive impairment (MCI), is characterized by noticeable symptoms affecting cognitive abilities, although these do not significantly impact daily life. Not all individuals with MCI progress to Alzheimer’s disease, but it is estimated that 10-15 percent of them develop annually4,5. MCI patients are further categorized into progressive MCI (pMCI), who develop Alzheimer’s during a follow-up period (usually 1 to 3 years), and stable MCI (sMCI), who do not6. The final stage, Alzheimer’s dementia, involves clear symptoms of memory, cognitive, and behavioral impairments that significantly interfere with daily functioning7.
Diagnosing Alzheimer’s disease typically involves a comprehensive medical evaluation, including a medical history review, mental state test, physical tests, and neuroimaging techniques such as structural MRI, functional MRI, and PET techniques. Among the methods of diagnosing Alzheimer’s disease, MRI imaging is a neuroimaging technique that is the most common technique for identifying Alzheimer ’s-related brain atrophy among Alzheimer’s diagnosis and prediction biomarkers8.
In recent years, with the advent of deep learning, methods such as convolutional neural networks (CNNs), Generative Models, and Transformers have been increasingly utilized for medical image processing, including Alzheimer’s diagnosis and prediction. A variety of methods, including three-dimensional MRI images of the whole brain,9,10,11, converting 3D MRI into two-dimensional MRI slices of the brain12,13,14, converting the image into several three-dimensional patches15,16,17, as well as methods Based on regions of interest (ROI) focusing on known regions18,19,20 are used to diagnose or predict Alzheimer’s Disease. On one hand, using the 3D volume of the brain increases the computational complexity, on the other hand, converting this volume into 2D slices or 3D patches will cause the loss of some important image information21.
ROI-based methods focus on specific brain regions such as the hippocampus, but might miss varying disease characteristics across different stages. Therefore, in some papers, many brain regions are selected as ROIs. For example, in22 134 regions are selected for diagnosing Alzheimer’s disease, and in the end, some of them are identified as the most informative ROIs. Therefore, one of the challenges in these methods is extracting the best ROIs. Since most of the Machine learning methods, specifically Deep Learning methods, are not interpretable, the reliability of these methods decreases. To this end, post hoc Explainable methods have recently been widely used to make machine learning models interpretable23. For example, in some papers, SHAPLY is used as a post-hoc interpretability method for Alzheimer’s prediction24,25, and some other papers used LIME to make their model explainable14,26. Another interpretable method is Grad-CAM, which is used for image-based deep learning models for AD diagnosis27,28. Grad-CAM is also used to automatically ROI extraction in the study29. In this study, a model is trained on one task (AD vs CN) to automatically extract ROIs using Grad-CAM. After that, the model is trained on the pMCI vs sMCI task using the given images and extracted ROIs. In this paper, customized transfer learning is also used to improve the results.
The results in recent articles show that, since the number of medical images is usually small, the use of transfer learning, including the use of pre-trained data models such as ResNet, AlexNet, VGGNet, MobileNet, etc., for diagnosis and prediction of Alzheimer’s disease have achieved good results30,31,32. Since these methods have been trained on natural images and considering that these models have been trained on two-dimensional data, in some studies, instead of using pre-existing methods, pre-trained customized methods have been used, and using these models researchers have been able to extract good features and achieve good results33,34.
Diagnosing Alzheimer’s disease is a complex process, but its complexity also depends on the stage at which the disease is diagnosed. It is more difficult to diagnose people with Alzheimer’s in the early stages because most of the symptoms are not clear35. Another problem in diagnosing Alzheimer’s disease before symptoms appear (when a person is in the CN category) is the lack of data to train a model for predicting Alzheimer’s disease. One of the ways to solve this problem is to generate data using generative methods. Techniques such as Generative Adversarial Networks (GANs), Deep Convolutional GANs (DCGANs), and Diffusion models have demonstrated considerable success in the medical field36,37,38. These methods have been effectively employed for various purposes, including data augmentation39,40,41, addressing the issue of missing data42,43, and converting data across different modalities in multimodal approaches44. There are limitations in predicting Alzheimer’s Disease from Cognitively Normal subjects. One of the limitations is data leakage. Training a model that can predict Alzheimer’s Disease from CN subjects needs lots of CN subjects who are converted to AD. Another limitation is generalizability. using one model that trains on a dataset and also tests on the same dataset decreases the generalizability of the model. The other limitation is the lack of confidence of experts in the results of two-class classification. In other words, when we use a threshold to classify data, it may give a wrong probability that just because of binary classification data is placed in a class randomly.
In this study, we propose an integrated framework that leverages ensemble transfer learning and generative modeling to predict the progression of Alzheimer’s disease from cognitively normal subjects. To enhance generalizability, we utilize customized transfer learning methods, trained on various datasets and tasks, to improve predictive accuracy. we employ a combination of two pre-trained models to estimate the probability of a healthy individual converting to mild cognitive impairment (MCI). Given the substantial influence of brain age on Alzheimer’s disease, as indicated in the literature, one of the pre-trained models is the brain age estimation model proposed in45. The other is the pre-trained sMCI and pMCI classifier, which has extracted good features. We have used the method introduced in our previous article46 which uses interpretable methods to extract ROIs automatically to improve our work. To overcome the challenge of lack of data for training CN to AD prediction, we have used generative methods to generate data for predicting Alzheimer’s disease from CN people. We have used generative methods to generate an image of a healthy person’s brain after two years (healthy or MCI) and use this image to predict AD.
The reason for choosing two years to generate the image is the lack of data for training the model for more than two years. Since the model has well recognized the brain changes in the two-year follow-up and even in the test data which have been converted to MCI in more than two years, it has accurately predicted Alzheimer’s from the image two years later, this shows that the model has good generalizability and if there is enough data to train the generative model for more than two years, it will be more accurate. Since the distance between a healthy person and Alzheimer’s is high, and the reliability of the model which predicts whether a person converts to AD or not decreases, we refrain from announcing the results definitively and emphasize the obtained probability.
Since there is not enough data to train a model to predict Alzheimer’s disease from healthy people, we proposed a method that combines the prediction model of the CN to MCI conversion and the prediction model of the conversion of MCI to AD to estimate the probability of predicting the CN to AD conversion. To do this, we have used the multiplication of the probability of CN to MCI conversion and the probability of converting synthetic images to AD. In our proposed method for predicting the probability of a healthy person to Alzheimer’s disease, first, an MRI image of the person two years later from the baseline image is produced. Then it estimates the probability of Alzheimer’s disease from the synthetic image. In other words, the goal is to see if the person will get MCI after two years or not, and then the synthetic image is fed into the MCI to AD prediction model to obtain the probability of AD progression. Ultimately, the probability of MCI from the CN subject is multiplied by the probability of synthetic image to AD. Corrected probabilities are crucial for reliable predictions in medical diagnosis tasks. Calibration methods adjust the predicted probabilities to better reflect the true probabilities observed in the data47,48. In this study, we use isotonic regression, a non-parametric method that ensures a monotonic relationship between predicted and true probabilities, to improve the calibration of our model’s predictions and correct the biased probabilities. Also, to make the predicted model more robust, we add some demographic features that the literature focuses on to predict AD. By combining advanced machine learning techniques with a focus on interpretability, our study aims to improve the accuracy of AD progression prediction and provide a valuable tool for early diagnosis and personalized intervention strategies. Our approach consists of four main contributions:
\(\bullet\) We propose an integrated approach to predict the development of AD in CN individuals up to 10 years prior to diagnosis.
\(\bullet\) Our proposed method predicts the AD conversion probability using the multiplication of CN to MCI probability and MCI to AD probability.
\(\bullet\) To increase the generalizability of our proposed method, we proposed an accurate Ensemble Transfer deep learning method to predict MCI conversion from CN subjects which is a combination of a fine-tuned model for Brain Age estimation and a fine-tuned model for MCI to AD prediction.
\(\bullet\) We use a generative model (ViT-GAN) to generate brain MRI images of a subject after two years which shows the brain changes in two years to predict AD from CN subject accurately.
\(\bullet\) To ensure interpretability, we employ Gradient-weighted Class Activation Mapping (Grad-CAM) to identify the most critical regions of interest (ROIs) that influence the model’s decisions. These ROIs are further analyzed using a 3D CNN to compute the probability of developing Alzheimer’s disease.
Our integrated approach not only aims to improve the accuracy of AD progression prediction but also emphasizes interpretability, ensuring that the model’s decisions are transparent and clinically meaningful. Our study’s results demonstrate significant results in predicting Alzheimer’s progression from healthy subjects up to 10 years. It achieves accuracy and f1-score equal to 0.85 and 0.86 respectively. Also, we compare our proposed framework with a baseline model that classifies CN, MCI, and AD. The results show that the baseline model did not achieve good results compared to our method, especially in AD detection.
In the remainder of this paper, we detail our dataset and the proposed integrated framework for predicting Alzheimer’s disease progression in the Methods section. The Results section presents the outcomes of our experiments, showcasing the performance of our models through various evaluation metrics. We also provide qualitative and quantitative analyses of our generative models. In the Discussion section, we discuss the proposed method and obtained results, and also interpret our findings. Finally, we conclude our work and suggest future works in the conclusion and future works section.
Method
Overview
Based on Fig. 1, In the first step, the images are pre-processed. So using B1 Correction, N3 [36], and Grad Warp [35], the artifacts created by the imaging machine are removed, and then the brain is extracted from the skull and scalp using FreeSurfer. Next, images as well as metadata (including Age, Gender, Marital Status, and Education) are fed into Ensemble Transfer learning. The Ensemble transfer learning model is the combination of the results of two fine-tuned models which are explained in “Ensemble Transfer Learning” sub-section. In the next phase, the pre-processed images are fed into the generative model to synthesize the brain image in the next two years. The details of image generation have explained in “Image Generation” sub-section. In the last phase, Generated Images are fed into a 3D CNN model trained on real sMCI and pMCI to obtain the probability of MCI to AD conversion. In this phase, conversion from MCI to AD is obtained using ROIs. At the end of this phase, the probability of CN to MCI and the probability of MCI to AD are multiplied. In the last sub-section, the details of this phase are illustrated.
Data description
In this paper, T1-weighted (T1w) MRI data exclusively from the Alzheimer’s Disease Neuroimaging Initiative (ADNI) dataset49 are used to train and test the proposed method. ADNI is a multicenter longitudinal study aimed at predicting and tracking Alzheimer’s disease, collecting subjects’ clinical, imaging, genetic, and biochemical biomarkers since 2004 across four phases: ADNI-1, ADNI-Go, ADNI-2, and ADNI-3. The parameters of the T1w MRI data include a field strength of 1.5 Tesla, a pulse sequence of T1-weighted MPRAGE (Magnetization Prepared RApid Gradient Echo), and a matrix size of \(256 \times 256\). The subjects in our analysis are derived from the post-processed ADNI data, which underwent B1 correction, N3 bias field correction, gradient warp correction, and brain extraction using FreeSurfer. The data used in this article consists of three categories: cognitively normal (CN), mild cognitive impairment (MCI), and Alzheimer’s disease (AD). Within the CN and MCI groups, we identify subcategories, including healthy individuals who develop Alzheimer’s (pCN), those who do not (sCN), MCI individuals who convert to Alzheimer’s (pMCI), and those who remain stable (sMCI). We had 15 CN subjects who are converted to AD after years. We extract these subjects and 15 subjects who are not converted to MCI and AD from our whole dataset. We have used the remained data, including CN who are converted to MCI but not AD, CN who are not converted to MCI and AD, MCI who are converted to AD, and MCI who are not converted to AD, and AD subjects in each part.
We also utilized a comprehensive dataset comprising MRI images and associated meta-data to predict the progression from cognitively normal (CN) to Alzheimer’s (AD). The meta-data includes key demographic features such as age, gender, marital status, and education level of the participants. This demographic information provides essential context for interpreting the MRI images and enhances the validity of our predictive models by accounting for potential confounding factors that may influence disease progression. Information about the utilized data is presented in Table 1.
To ensure the robustness of our deep learning models, we implemented a clear data preparation strategy. While our primary test data consisted of pCN and sCN images that never used in training phase, each deep learning component within the framework required its own validation data for performance evaluation. For this, we applied an 80-20 split to the training data, using 80% for training the models and 20% for validation.
Ensemble transfer learning
In this section, we explain how we estimate the probability of MCI conversion using Ensemble Transfer learning illustrated in Fig. 2. According to Fig.2, the results of two Fine-tuned models are combined to estimate the probability more accurately. The fine-tuned models are the Age Estimation model, introduced in45, and the MCI conversion to AD model introduced in46. The reason for using Brain Age Estimation is that according to that article, Brain Age Gap(BAG) is a good biomarker for Alzheimer’s detection and obtained good results on AD detection as the only biomarker. Since brain age is estimated using MRI volume, to avoid redundancy, instead of using the brain age gap as a biomarker, the pre-trained model is used as transfer learning for feature extraction. Another reason is the Generalizability of the introduced model. The results of the article show that the model has good performance on the datasets which are not in the train set. The other fine-tuned model has good performance in the classification of sMCI and pMCI and the model learns the features of MCI patients. As illustrated in Fig.2, the output of the last feature extraction block (3D ResNet block in one model and 3D CNN block in another), fed into a Flatten block which consists of one flatten layer, Leaky ReLu Activation function \(\alpha = 0.1\), BatchNormalization Layer, Dropout Laeyr with rate decay of 0.4, one fully connected dense layer with 32 features. After that, one 1D array for each demographic data (Age, Gender, Education, and Marriage Status) is concatenated to a feature vector. Therefore, 36 features were selected for classification for each data. At the end of each network, the softmax activation function is used for classification. Each model is fine-tuned using a train set separately. Figure 2 effectively illustrates the model fusion process. In this process, two fine-tuned models (the brain age estimation model and the classifier for sMCI and pMCI) are utilized. As shown in Fig. 2, the combination of the two models is executed using the Max(P) approach. Specifically, each model provides a probability of belonging to a class for the i-th test sample. The maximum of the probabilities from both models is selected and returned as the final output.
Image generation
In this study, we focus on predicting the progression from CN to AD using generated images from the baseline image. This section details the process of image generation. As mentioned in the Introduction, we simulate changes in the brain by generating images of healthy subjects after two years. Our approach employs the ViT-GAN model, which generates future MRI images based solely on baseline images. This approach allows us to effectively train the model to generate predictive images while maintaining a focus on using baseline data for our final predictions. While our predictive model does not require longitudinal data, the ViT-GAN is trained using subjects who possess both baseline and two-year follow-up images. This selection ensures that our training set is robust and addresses potential issues related to data completeness. The choice of a two-year interval is primarily due to data availability limitations; however, the method is generalizable. If more extensive longitudinal data were available, such as three years or more, the results would likely be more accurate.
We applied normalization techniques that scaled the MRI intensity values to a standardized range of 0 to 1. This step is essential to enhance the model’s performance and ensure that it learns from data that is uniformly represented. To facilitate the training of our ViT-GAN model, all images were resized to dimensions of \(128 \times 128 \times 128\), which ensures that the input data maintains consistency in shape. For our generative model, we specifically utilized longitudinal data, where baseline images served as the input, and images taken two years later were generated as output. In the ADNI dataset, longitudinal processing creates a within-subject template and initializes each time point with the template to reduce individual variability. This design allows the network to effectively learn the changes that occur over time, thereby improving its predictive capability. By employing these pre-processing steps and standardization techniques, we aimed to mitigate inter-scanner variability, thereby enhancing the reliability and validity of our results.
The generative model consists of a 3D Vision Transformer Encoder, a 3DCNN Decoder for the generator, and a 3DCNN for the Discriminator part. The details of the model are shown in Fig. 3. According to Fig. 3, \(128 \times 128 \times 128\) input image is transferred into some patches, and the patch embedded is a 3D CNN. The output is then flattened, transposed, and fed into the Transformer Encoder. The Transformer Encoder includes LayerNorm, Multi-Head Attention, and Multi-Layer Perceptron. The output of Vit is fed into the 3DCNN Decoder, which includes 3D transpose convolution, Instance Normalization, and ReLU activation function. The Discriminator is a 3DCNN-based classifier consisting of downsampling layers, a Leaky ReLU activation function for each hidden layer, and a softmax activation function for the classification layer. Based on the original GAN network, two loss functions, Adversarial loss and reconstruction loss, are used in this network, which are introduced in Eq. 2 and Eq. 3, respectively. According to these equations, x is the original 3D MRI Volume from baseline, whereas z is the original output 3D MRI Volume for the next two years. In this equation, D is the Discriminator function and G is the generator function.
Alzheimer’s disease prediction
Since the data showing the conversion from CN to AD is not enough to train a classification model, we proposed a method that combines CN to MCI and MCI to AD for Alzheimer’s disease prediction from cognitively normal patients. In the proposed method, based on Fig. 4, we estimate the probability of CN to MCI conversion using the baseline MRI Image. After that, we generate the MRI volume of the subject from the baseline MRI volume. At this phase, we estimate the probability of MCI to Alzheimer’s disease conversion from the generated images. The process of the AD probability estimation is in Algorithm 2. For this purpose, the generated MRI images are fed into the sMCI and pMCI classification model introduced in the ’Ensemble Transfer Learning’ part. We also use an ROI-based model for AD probability prediction. Since the ROIs for each stage in Alzheimer’s phases are different, automatic ROI extraction is proposed in46. Based on this method, ROIs are extracted from MRI images using the Explainable-AI-based method (Grad-CAM) for each patient. ROIs of each image are extracted from the feature importance obtained from the Grad-CAM algorithm. According to this method, we first fed MRI images into the sMCI/pMCI classification model which is a 3DCNN model, after that, feature weights for each feature map, \(w_f\), are obtained using the gradient of the last layer of the network (before softmax layer) respect to each feature map of the last CNN layer based on Eq.4. After that, according to Eq.5, the feature importance, Heatmap of the image, is obtained using applying Global Average pooling on weights of feature maps. The details of used 3DCNN model is demonstrated in Table A1 in Appendix-1.
Based on46, we calculate a mask for each image in which the most important parts of the image get a value of 1 and the pixel value of the other parts is 0. Finally, to obtain the ROI, the mask is multiplied by the MRI image.
To model the probability of transitioning from a cognitively normal (CN) state to Alzheimer’s Disease (AD) through the mild cognitive impairment (MCI) stage, we use conditional probabilities. We start with the law of total probability for the event of developing AD from a CN state (Eq. 6):
where, \(P(\text {MCI} | \text {CN})\) is the probability of developing MCI given that a person is currently cognitively normal, \(P(\text {AD} | \text {MCI})\) is the probability of developing AD given that a person is currently in the MCI stage, \(P(\text {AD} | \text {CN})\) is the probability of developing AD given that a person is currently cognitively normal.
Given that the direct transition from CN to AD without passing through MCI is low50, we assume:
Thus, the equation simplifies to:
Using the definition of conditional probability:
Since MCI and CN are mutually exclusive states and knowing that MCI is a prerequisite stage before AD, we have:
Thus, the equation becomes:
Combining the above, the final probability of transitioning from CN to AD is given by:
As we mentioned before, a generative model is used for generating the MRI image of the brain two years into the future based on the current MRI and other relevant features; so let \(\hat{I}^{t+2}\) denotes the generated MRI image of the brain two years from now. The generated MRI image \(\hat{I}^{t+2}\) is used to predict the probability of AD; Let \(P(\text {AD} | \hat{I}^{t+2})\) be the probability of developing AD given the generated image. This probability can also be obtained from the softmax output of the second neural network model. To find the overall probability of progressing from CN to AD, we need to consider the conditional probability \(P(\text {AD} | \text {CN})\). Using the law of total probability and the conditional probability, we can express \(P(\text {AD} | \text {CN})\) as:
Since the generated image represents the brain state of an MCI subject after two years.
Probability calibration using isotonic regression
To improve the calibration of the predicted probabilities, we applied isotonic regression. Isotonic regression is a non-parametric method that fits a piecewise constant, monotonic function to the predicted probabilities, adjusting them to better match the true probabilities observed in the data. This method is particularly useful for handling complex relationships not well-captured by parametric models like logistic regression. The isotonic regression model is trained on a calibration set separate from the training and test sets used for the main prediction tasks. Isotonic regression ensures that the fitted probabilities are monotonic, i.e., if \(p_i\) and \(p_j\) are predicted probabilities and \(p_i \le p_j\), then the calibrated probabilities \(q_i\) and \(q_j\) will satisfy \(q_i \le q_j\). The isotonic regression problem can be formulated as Eq. 14:
where \(y_i\) are the true labels, and \(q_i\) are the fitted probabilities.
Results
CN to MCI progression
As we mentioned in the proposed method, the first phase of the proposed method is MCI progression prediction. In this phase, MRI images of the baseline as well as Demographic features (including Age, Gender, Marital Status, and Education) as metadata, are fed into the model to predict MCI. To train the model, we use data from cognitively normal subjects (CN) who do not convert to MCI as class 1 and CN subjects who convert to MCI (not to AD) in up to three years, in addition to sMCI patients as class 2. The number and information of subjects are displayed in Table 1. This study is longitudinal; two years of MRI images of subjects are used for training the model. 644 and 640 MRI volumes are used as training sets for class 1 and class 2, respectively. Also, Demographic features such as Age, Gender, Education, and Marital status from baseline are added as extra features.
Implementation details
The proposed CN to MCI prediction model comprises two fine-tuned models that are ensembled to obtain the final result. The first model consists of three Residual Blocks, each Block includes two 3D CNN networks with a kernel size of 3\(\times\)3\(\times\)3, an Elu activation function, Batch Normalization, a concatenation layer, and a 2\(\times\)2\(\times\)2 Max-pooling. Also, there are two Attention Blocks that have attention layers in addition to the ResNet Blocks. The number of features in the Residual Blocks is eight, 32, and 128, and the number of features for the Attention Blocks is 16 and 64, respectively. Moreover, another fine-tuned model, the 3D CNN model, comprises three 3D CNN blocks, each block consists of 3D convolutional layers with a kernel size of 3\(\times\)3\(\times\)3, the LeakyReLU activation function, Max-Pooling with a size of 2\(\times\)2\(\times\)2, and dropout with a rate of 0.3. The number of features in the 3D CNN Blocks is 32, 64, and 64respectively. In the end, in both models, the output of the last 3D CNN Block is fed into the flatten Block to train and fine-tune, which consists of a flatten layer, Batch Normalization, Leaky Relu, a Dense layer with dimension 32, and a softmax activation function. The Adam optimization algorithm is employed with a learning rate of 0.001, in conjunction with the Cross Entropy loss function. The convergence is achieved after 100 epochs. The usage network for this part of the proposed method is the Keras library in Python 3.11 using Tesla T4 GTX and Intel(R) Xeon(R) CPU @ 2.20GHz.
CN to MCI classification
In our classification task, we applied a threshold of 0.5 to determine the final class labels as Eq. 15, and obtained evaluation metrics such as Accuracy, Precision, Recall, and F1-score based on Eq. 16 to Eq. 19. The decision rule is as follows:
where \(\hat{y}\) is the predicted class label, and \(P(y=1|X)\) is the predicted probability of the positive class given the input features \(X\). This threshold is chosen because it balances the trade-off between precision and recall, and is commonly used in binary classification tasks.
To obtain the best performance, we investigated some different situations and compared the results obtained from them in CN to MCI progression as an ablation study. For example, in Table 2, in the first row, Demographic features are used as the only biomarker for the classification of stable CN and CN to MCI conversion. Some models, such as Logistic Regression, SVM, Decision Tree, Random Forest, and XGBoost, are trained on these data, and XGBoost obtained the best results as shown in Table 2. In the second row of Table 2, the results obtained from the proposed model, a combination of two introduced fine-tuned models, using MRI volume as a biomarker, are displayed. The results of the proposed method on the whole data, a combination of MRI volume and Demographic features, are in the last row of Table 2. The results show that the best performance is for multimodal classification. Also, the comparison between the second and third rows shows that using demographic features as well as MRI images increases the performance of the classification.
Table 3 shows a comparison of the results obtained from each fine-tuned model, 3D ResNet Attention, and 3D CNN. The results show that 3DCNN has obtained better performance than 3D ResNet Attention and the combination of these two methods has received the best result. As seen in Table 3, Ensemble Deep Learning, a combination of two introduced fine-tuned models, has increased by about two percent.
Interpretablity of features
This section uses the XGBoost feature importance algorithm and LIME (Local Interpretable Model-agnostic Explanations) model to rank demographic features. The results of the XGBoost feature importance algorithm are shown in Fig. 5, and the results of the LIME model are displayed in Fig. 6. According to the LIME model, unlike XGBoost feature importance, which ranks features based on the whole test data, feature importance is obtained for each data. In Fig. 6, the results of two random samples, one for class 1 and one for class 2, are displayed. As the results show, Age and Gender are the most important features based on both algorithms. The difference between the two usage models is in two other features. The XGBoost model gives rank three to Marital Status, whereas in the LIME, Education is more important than Marital Status.
Image generation
For generating MRI volumes two years later, to avoid biasing the model on healthy or diseased data, the number of data we used from healthy or diseased subjects for training was considered equal. Since our dataset is longitudinal, there are some sets of input-output images in which the output is an image for two years after the input MRI image. There are 295 input-output sets for healthy subjects and 445 input-output sets for MCI patients. We used the MRI volumes of whole Cognitively Normal subjects and selected 295 sets of MCI patients. Therefore, the overall number of train sets is 590. It means that we have a set including 590 MRI images as input and 590 MRI images as output of the model, which are MRI images from two years later.
Implementation details
In this study, we implemented a Generative Adversarial Network (GAN) architecture for the generation of 3D medical images. The model comprises two main components: the Generator and the Discriminator. The Generator begins by converting the input 3D image into \(8 \times 8 \times 8\) patches using a convolutional layer. These patches are then processed through a series of Transformer encoder blocks with an output channel equal to 512 and a stride and kernel 1. It includes multi-head self-attention and MLP modules. Positional embeddings are added to retain spatial information. Following the Transformer encoder, the patches are upsampled through several transposed convolutions, ultimately reconstructing the image to its original size. The final output layer applies a tanh activation function to ensure the pixel values are within the desired range. The Discriminator network evaluates the authenticity of the generated images, consisting of an initial convolutional layer followed by multiple down-sampling blocks consisting of 3D convolutions, instance normalization, and LeakyReLU activation functions. a 3DCNN-based classifier consisting of four downsampling layers with 32, 64, 128, and 256 output features, respectively. The final layer is a softmax activation function indicating whether the input image is real or generated. In this part, Data augmentation techniques, including random flipping, rotation, and noise addition, are applied to enhance the diversity of the training set. The training loop alternates between updating the Generator and Discriminator using Adam optimizers with a learning rate of 0.00001. The loss functions used include Mean Squared Error (MSE) for the Discriminator, and a combination of MSE and L1 loss for the Generator. The model is implemented in PyTorch and trained using an NVIDIA RTX 3090 GPU.
Qualitative evaluation
The results of some samples in the training cycle are displayed in Fig.7 . We train the model on 300 iterations. The training results in iterations two, 20, 50, and 300 are displayed respectively.
Quantitative evaluation
For quantitative analysis, we report PSNR and SSIM of generated images and real output (ground truth) of test images based on Eq. 20 and Eq. 21 in Table 4.
where \(\text {MAX}\) is the maximum possible pixel value of the image (e.g., 255 for 8-bit images) and \(\text {MSE}\) is the Mean Squared Error between the generated image and the ground truth image.
where:
\(\mu _x = \text {mean of } x, \mu _y = \text {mean of } y,\)
\(\sigma _x^2 = \text {variance of } x,\sigma _y^2 = \text {variance of } y,\) \(\sigma _{xy} = \text {covariance of } x \text { and } y,\)
\(c\_1 = (k\_1 L)^2, c\_2 = (k\_2 L)^2,\)
\(L = \text {dynamic range of pixel values},\)
\(k_1 = 0.01, k_2 = 0.03,\)
Our test set includes 112 images, which is 20 percent of the whole images. The results show that the proposed method performed well in generating the MRI Volume.
To validate the architectural choice and demonstrate the effectiveness of our proposed 3D-ViTGAN, we conducted a comparative analysis with several representative generative models: 3D U-Net, 3D StyleGAN, and 3D CycleGAN. These models span key architectural families including convolutional encoder-decoder frameworks, style-based generators, and adversarial domain translation methods. Evaluation was carried out using the ADNI dataset across 112 test subjects. As shown in Table 5, the proposed 3D-ViTGAN outperformed all baseline models in both PSNR and SSIM, achieving a PSNR of \(36.69 \pm 1.07\) and an SSIM of \(98.80 \pm 0.03\). The model maintained a competitive training time and real-time inference speed, demonstrating efficiency alongside accuracy. The superior performance can be attributed to the hybrid design of 3D-ViTGAN, which integrates CNN-based skip connections for local texture preservation with Performer attention modules in the bottleneck to capture long-range dependencies.
Alzheimer’s disease prediction
In this section, the results of the prediction of Alzheimer’s disease from cognitively normal subjects are explained. To this end, we create an ADNI data set that consists of subjects who have converted to AD until now, mentioned in Table 1 as pCN. There are 15 CN subjects who are converted to AD after years in our dataset. We extract these subjects together with 15 subjects who are not converted to MCI and AD from our whole data as test set which are not used in any training step. There are 69 MRI images as pCN. In Fig. 8, it is shown that after how many years subjects have developed MCI and AD. Based on Fig. 8 on the left side, 47 percent of subjects after one or two years, 11 percent after three or four years, 26 percent after five or six years, and 16 percent after seven or eight years are converted to MCI. In the same way, each part of the right side of Fig. 8 indicates the percentage of converting to AD after one or two, three or four, five or six, seven or eight, nine, or ten years. The figure shows that in the test data, we have considered different cases of MCI and Alzheimer’s disease in healthy people from one year to ten years.
Implementation details
We have trained several models, such as 3D-ResNet, 3D Transformer, and 3DCNN model (introduced in46) on our MCI to AD prediction dataset, and the accuracy results are summarized in Table 6. According to the results, the 3D CNN consistently outperformed the other architectures. This observation is attributed to the limited training data available, which can significantly affect the performance of more complex models. Therefore, we test the data using the 3DCNN model to classify sMCI and pMCI subjects. The details of the model are introduced in the ’CN to MCI progression’ sub-section. In our method, first, the ROIs of Images are extracted, and then ROIs are tested using the sMCI/ pMCI classification model. We test the model on Google Colab with Tesla T4 GTX, Intel(R) Xeon(R) CPU @ 2.20GHz, and Python 3.11.
Left: Years from CN to MCI, Right: Years from CN to AD. We categorize every two years in other words for example, value ’4’, shows the patients who are converted to MCI/AD between two to four years. The values in each section of the chart represent the percentage of people who will develop MCI/AD in those years.
MCI and AD probability analysis
In addition to sCN and pCN subjects, we used some patients who converted to MCI but did not develop Alzheimer’s as test data. The number of these patients is 64 patients and the total number of images is 140 MRI images. The probability results from our proposed method on two categories of test data are depicted in Fig. 9. The first category illustrates the probability of developing MCI in healthy individuals. The test data is divided into three groups: those who progress to MCI and eventually Alzheimer’s, those who develop MCI but not Alzheimer’s, and those who remain healthy. The results indicate that the model accurately predicts MCI in healthy individuals, with two-thirds of those who convert to MCI showing probabilities close to 1, and the rest close to zero. The second category involves the probability of developing Alzheimer’s. This is calculated by multiplying the probability of developing MCI by the probability of the generated image indicating Alzheimer’s. The results show a range from 0 to 1, reflecting the varying stages of individuals who progress to MCI but not Alzheimer’s (sMCI). This demonstrates the model’s robust performance in predicting Alzheimer’s progression.
To demonstrate the robustness of our proposed method, we present a detailed analysis in Fig. 10. The figure categorizes probabilities into three distinct groups: CN to MCI (blue), MCI to AD (green), and CN to AD (red). Each category is further divided into subcategories. For CN to MCI, the subcategories are CN subjects who do not convert to MCI and CN subjects who convert to MCI. The charts illustrate that most subjects are correctly classified within these subcategories. For MCI to AD, the subcategories include the probability of conversion for generated CN subjects, generated pMCI subjects (MCI converts to AD), and generated sMCI (MCI does not convert to AD) subjects. The results indicate that the probability of AD progression from generated MRI images, particularly for the pCN subcategory, is close to one, showcasing the high performance of our method in predicting AD progression. The first subcategory of MCI to AD and CN to AD shows that although the probability of MCI to AD conversion for generated sCN subjects is not zero, the overall probability of AD from CN is near zero for most sCN subjects. The results for the last subcategory demonstrate that the probabilities of MCI to AD and CN to AD vary between zero and 0.5 but do not exceed 0.5, indicating the proposed method’s robustness even in these more challenging cases.
We evaluated the calibration performance of our models using Expected Calibration Error (ECE) and Maximum Calibration Error (MCE). The results before and after applying Platt Scaling and Isotonic Regression are presented in Table 7.
The results show that Isotonic Regression significantly reduces the calibration errors of our models. Figure 11 shows the reliability diagrams before and after calibration. The reliability diagrams before and after calibration, shown in the Figure, demonstrate the effectiveness of isotonic regression in improving the calibration of the predicted probabilities. Also, Fig. 12 compares the distribution of the obtained probability on test data before and after applying isotonic regression. The results before calibration show that the probabilities are spread out, but there is a noticeable cluster around lower probabilities, indicating that the model is somewhat uncertain in its predictions. The results after calibration show that the distribution of probabilities has shifted, with a tighter clustering around higher probability values. This illustrates that the calibration has improved the model’s confidence and made the predictions more reliable.
CN to AD classification
To show our whole proposed framework’s performance, we compared its performance against a baseline standalone 3D CNN model for classifying AD, MCI, and CN subjects. The baseline 3D CNN was trained without the generative modeling and ensemble transfer learning techniques introduced in our framework. In this comparison, we used CN to MCI to AD, and CN to MCI classification tasks, as they reflect the early and critical stages of Alzheimer’s progression. To compare the proposed method with the baseline model, we consider stable CN subjects as CN, CN to MCI not to AD subjects as MCI, and CN to AD subjects as AD, and report the results of detected MCI and AD in two separate lines in Table 8.
Table 8 demonstrates the superior performance of the proposed method compared to the baseline 3D CNN for both CN to MCI and CN to AD classifications. The baseline model’s performance in detecting CN to MCI transitions yielded a precision of 0.75, a recall of 0.62, and an accuracy of 73%. The proposed method significantly improved these metrics, with precision and recall both at 0.83 and accuracy at 82%. For CN to AD detection, the baseline model performed poorly, with a recall of only 0.13 and an accuracy of 50%, whereas the proposed model achieved a much better precision of 0.84, recall of 0.83, and accuracy of 83%. After applying calibration, the CN to AD detection further improved, with precision increasing to 0.88 and specificity to 0.89. This shift is particularly important as it indicates a better balance between sensitivity (still at 0.83) and specificity, showing that the model is now more capable of correctly identifying true negatives, which is essential for reducing false positives in clinical settings. These results underscore the efficacy of the proposed method, especially after calibration, in providing more accurate and reliable predictions for Alzheimer’s disease progression. Also, the model’s confusion matrix on data from healthy individuals who develop Alzheimer’s and healthy individuals who develop MCI has been shown in Fig. 13.
To demonstrate the model’s general discriminative ability, we include the ROC curve, which shows an AUC of 0.85, indicating strong separability between classes across a range of possible thresholds in Fig. 14. We believe this supports the model’s value in real-world risk stratification, while allowing institutions to define their own clinical decision flow if needed.
Stratified analysis confirmed that the model retains strong discrimination across age, sex, education, and marital status. Table 9 summarises model performance across sub-groups. Discrimination remained consistently high, with AUCs between 0.80 (70–80 y) and 0.93 (> 80 y). Accuracy, sensitivity, and specificity were well balanced in both age bands. Performance was higher in males than females (AUC 0.93 vs 0.77; DeLong \(p = 0.02\)). Educational attainment did not materially affect discrimination (\(\le 12\) y, AUC 0.87; \(> 12\) y, AUC 0.90). The apparent sensitivity/specificity trade-off reflects a decision-threshold effect rather than model bias. Married participants showed AUC 0.82, mirroring overall cohort performance.
To show the interpretability of the proposed method on MRI images, the GRAD-Cam algorithm51 is used. According to this algorithm, the most important features of the image are extracted as a Heatmap. The obtained Heatmap of some of the test data (CN subjects) is displayed in . According to Fig. 15, the most important parts are shown in red and then yellow. The focus on the hippocampus, medial temporal lobes, posterior cingulate cortex, and parietal cortex indicates a well-targeted approach to identifying key regions associated with Alzheimer’s pathology. The cerebellum’s inclusion, although less typical, might still provide useful information in the context of advanced disease stages.
To assess the generalizability of our proposed framework beyond the ADNI dataset, we conducted an external evaluation using a separate cohort. Specifically, we employed a subset of test data drawn from a publicly available brain age estimation datasets which we have used in our brain age estimation model. This dataset, which was not used during training, represents a different population and imaging context from ADNI. Since some demographic variables included in our model were not available in this external dataset, we used a version of our model that was trained solely on T1-weighted MRI data—allowing consistent input across datasets. This model was trained on ADNI data and tested twice: once on the ADNI held-out test set and once on a randomly selected subset of 137 healthy subjects from the brain age dataset (to match the ADNI test set size). Importantly, the brain age dataset includes only healthy individuals. Therefore, we evaluated model performance by comparing the recall (true positive rate for detecting healthy cases) between the two datasets.
As shown in Table 10, this comparable performance despite differences in source population and scanner protocols provides preliminary evidence that our framework maintains reasonable generalization across independent cohorts, even without retraining. We acknowledge that broader external validation (e.g., across institutions or using multi-modal data) remains a future direction. However, this analysis demonstrates that our model is not overfitted to ADNI and can offer robust performance on unseen data.
Atlas-Based Saliency Patterns
We have used the Harvard-Oxford atlas 52 to show the contribution of brain areas for Pre-clinical \(\rightarrow\) MCI \(\rightarrow\) AD, and the results are shown in Table 11. Grad-CAM saliency maps computed with the Harvard-Oxford atlas show that, when the network is applied to pre-clinical scans projected two years ahead, the right thalamus receives the highest weight (mean \(\approx\) 1.0042), followed by the brain-stem (locus coeruleus), left thalamus, left hippocampus, and left amygdala.
This ordering aligns with contemporary neuropathological evidence that tau pathology and atrophy appear earliest in the locus coeruleus and thalamic nuclei and only subsequently spread to medial-temporal structures.
By contrast, as shown in Table 12, the same network applied to an MCI \(\rightarrow\) AD cohort was dominated by the left hippocampus, matching meta-analytic reports of pronounced left-lateral hippocampal atrophy at the time of clinical conversion. The divergence between thalamo–brainstem salience in the pre-clinical stage and hippocampal salience in MCI supports the stage-specific validity of the model. Anatomical boundaries were defined with the probabilistic Harvard-Oxford atlas, facilitating reproducible ROI assignment and cross-study comparison.
Discussion
The proposed integrated framework for predicting AD progression from CN subjects offers significant advancements in early diagnosis and personalized intervention strategies for this neurodegenerative condition. By combining ensemble transfer learning, generative modeling, and advanced interpretability techniques, our approach addresses key challenges in the field, including limited longitudinal data, the need for high-fidelity synthetic data, and model transparency. The ensemble transfer learning strategy effectively merges the strengths of two fine-tuned models: the Age Estimation model and the MCI to AD conversion model. The Age Estimation model, known for its robustness and generalizability across diverse datasets, enhances our ability to detect early cognitive decline by providing essential features. Meanwhile, the MCI to AD conversion model, fine-tuned on sMCI and pMCI classification, adds specificity by focusing on the crucial transition from MCI to AD. This combination results in improved accuracy and robustness in predicting the transition from CN to MCI and subsequently to AD, as demonstrated by the high Accuracy and F1 scores in our evaluations.
GANs play a pivotal role in our framework by generating realistic MRI images that represent future stages of cognitive decline. This addresses the challenge of limited longitudinal data by creating high-quality synthetic images that provide a visual representation of brain changes over time (after two years). These generated images are crucial for subsequent analysis using 3D CNNs, enhancing the robustness of our predictions. In this study, we limited the generative modeling to a 2-year prediction interval due to the availability and consistency of follow-up MRI data in the ADNI dataset. Ideally, future work should extend the model to predict MRIs at additional time points (e.g., 3-, 4-, or 5-year follow-ups) to enable a more comprehensive analysis. We consider this an important direction for future research, especially as more extended follow-up data becomes available. Nevertheless, although due to the lack of data to train the generative model, we simulated brain changes after two years, the proposed method has shown its effectiveness with these data, even for predicting Alzheimer’s up to ten years. If more longitudinal data become available beyond the two-year mark, not only can the method still be applied, but its performance is expected to improve. This is because brain changes become more distinguishable as one approaches the onset of Alzheimer’s disease, potentially leading to better predictive performance. Furthermore, this could enable reliable predictions beyond the ten-year horizon.
If we have other datasets to use, it is better to have a registration phase in the pre-processing step to register input and output MRI images to train the changes better. Our innovative methodology demonstrates that it is feasible to predict the progression from CN to AD without needing longitudinal data for final predictions. By utilizing baseline images and generating follow-up images through the ViT-GAN, we establish a novel pathway for assessing Alzheimer’s progression. Quantitative evaluations using metrics such as PSNR and SSIM confirm the high fidelity of the generated images, validating the effectiveness of our generative approach. In the prediction of MCI to AD progression using generated images, we trained several 3D-based models, including 3D CNN, 3D ResNet, and 3D Transformer, on our sMCI and pMCI datasets. The results demonstrate that the simpler 3D CNN model outperformed the more complex architectures like 3D ResNet and 3D Transformer. This performance difference may be due to the limited size of the training dataset, which tends to favor simpler models over those with higher complexity. Additionally, we conducted a subjective evaluation of the extracted ROIs using both the 3D CNN and 3D ResNet models.
The findings confirmed that 3D CNN is a reliable choice for classification tasks. However, since the primary contribution of our research lies in the comprehensive framework for predicting the progression from CN to AD, utilizing a 3D ResNet model within this framework can yield promising results. Calibration of predicted probabilities is crucial in medical diagnosis applications to ensure that the predicted risks accurately reflect the true probabilities of outcomes.
In this study, we applied isotonic regression for probability calibration, which significantly improved the calibration of our model’s predictions. This improvement is evidenced by the reduction in the Expected Calibration Error (ECE) and Maximum Calibration Error (MCE), as well as the more aligned reliability diagram after calibration. The isotonic regression method proved effective in handling the complex relationship between predicted and true probabilities, which was not well-captured by parametric models like logistic regression. The flexibility of isotonic regression to fit a monotonic function without assuming a specific form allowed it to better adjust the predicted probabilities, leading to more reliable and interpretable risk predictions. To demonstrate the performance of our proposed framework, we compare the results of the proposed method with a three-class classification model as a baseline model. The baseline model struggled with recall, particularly in CN to AD transitions, where its recall was only 0.13, resulting in poor detection of Alzheimer’s progression. However, the proposed model significantly enhanced both precision and recall, reaching 0.83, with calibration further improving the precision to 0.88 and specificity to 0.89. This balance between sensitivity and specificity demonstrates the proposed method’s robustness in minimizing false positives, which is crucial in clinical settings for early-stage Alzheimer’s diagnosis. The results show that the enhanced calibration achieved through isotonic regression improves the clinical utility of our predictive model, providing more accurate risk assessments for Alzheimer’s disease progression from cognitively normal subjects. This can facilitate earlier and more targeted interventions, ultimately improving patient outcomes.
To ensure that our framework learns from both 1.5 T and 3 T resolutions, we first trained the age–prediction network on a mixed set that included 3 T images, and then transferred its weights to the \(\textrm{CN}\!\rightarrow \!\textrm{MCI}\) classification task. This strategy allowed the feature extractor to internalise anatomical variation from both field strengths, thereby increasing the robustness of all downstream models.
In a further experiment prompted by the reviewer, we added 395 paired 3 T MRIs to the 3DViT–GAN training data. The resulting model showed no meaningful change in performance, an average improvement of only \(0.05\%\) in PSNR (as shown in Table 13). We attribute this to (i) the redundancy of 3 T scans in ADNI, where most participants already have 1.5 T counterparts, and (ii) the strong intensity and geometric augmentations that already simulate higher-resolution variability. These observations indicate that the current framework performs reliably without full 3 T inclusion. Nevertheless, future work that incorporates larger, independent 3 T cohorts or datasets beyond ADNI could further enhance generative realism and improve diagnostic prediction, particularly for prodromal neurodegeneration.
Alzheimer’s disease prediction
The core of our framework lies in the integration of generated MRI images with advanced prediction models to estimate the probability of developing Alzheimer’s disease from a cognitively normal state. By leveraging both the generated MRI images and real MRI data, our model effectively captures the progression of the disease. The probabilistic approach adopted in our prediction model ensures a realistic and clinically relevant estimation of disease progression. The high accuracy, precision, recall, and F1-scores obtained in predicting both CN to MCI and CN to AD transitions underscore the efficacy of our method. Furthermore, the detailed analysis of MCI and AD probabilities illustrates the model’s ability to distinguish between different stages of cognitive decline, providing valuable insights into individual patient trajectories. The accurate classification of subjects transitioning from CN to AD highlights the model’s potential for early intervention and personalized treatment planning. Interpretability is a cornerstone of our framework, ensuring that the model’s decisions are transparent and clinically meaningful. Gradient-weighted Class Activation Mapping (Grad-CAM) identifies and visualizes the most critical regions of interest (ROIs) influencing the model’s predictions. This enhances the trustworthiness of our model and provides valuable insights into the underlying pathological changes associated with AD. The automatic extraction of ROIs using Grad-CAM aligns with current clinical understanding, highlighting areas such as the hippocampus and medial temporal lobe, which are known to be affected in early AD stages. In the end, using Isotonic Regression, we tried to correct the biased probabilities predicted by our model and make the probabilities more reliable.
Clinical implications
The promising results from our integrated framework underscore its potential clinical utility in predicting Alzheimer’s disease progression from cognitively normal individuals. We show the reliability of our proposed method using Grad-CAM and the results of the Grad-CAM algorithm demonstrate that our proposed model makes decisions based on the parts of the brain that are affected in the very early stage of AD. Early identification of at-risk individuals can facilitate timely interventions, potentially slowing disease progression and improving patient outcomes. Our probabilistic approach to prediction acknowledges the inherent uncertainty in disease progression, providing a nuanced assessment that can guide personalized treatment plans.
In this study, we adopted a simplified assumption that the transition from CN to AD occurs exclusively through an intermediate MCI stage, effectively treating the probability \(P(\text {AD} \mid \lnot \text {MCI}) \approx 0\). This assumption aligns with the dominant progression pattern observed in large longitudinal datasets such as ADNI and was necessary to design a structured generative-classification framework. However, we acknowledge that in real-world clinical settings, some cases may bypass detectable MCI due to diagnostic limitations or rapid symptom progression. These edge cases are not fully captured by our model and represent an important limitation. Future work should consider incorporating probabilistic or soft transition models to better reflect the full clinical spectrum and capture atypical progression paths.
Conclusion and future works
Future research should focus on using other datasets to enhance the generalizability of our model. Also, to improve the prediction performance, we can generate brain images of subjects after three years or more if there is enough data. Integrating other biomarkers, such as PET (e.g., amyloid PET, FDG-PET), CSF biomarkers (e.g., A\(\beta\)42/40, p-tau), genetic factors (e.g., APOE-\(\varepsilon\)4), or blood-based markers could provide a more comprehensive view of disease progression. Also, we used Grad-CAM to qualitatively assess whether the model attends to clinically relevant brain regions, and found alignment with known Alzheimer’s-related structures (e.g., hippocampus, thalamus). However, we acknowledge that Grad-CAM has limitations, including potential estimation errors and sensitivity to model architecture. While it provides useful visual cues, deeper quantitative analyses and alternative explainability methods, such as SHAP or LIME, could offer more robust insights; these are planned for future work.
In conclusion, our study presents a novel and effective approach to predicting Alzheimer’s disease progression using an integrated framework of ensemble transfer learning, generative modeling, and advanced interpretability techniques. The combination of these advanced machine learning techniques with a focus on interpretability offers a valuable contribution to neurodegenerative disease research, paving the way for early diagnosis and personalized intervention strategies in Alzheimer’s disease.
Data availability
The datasets analysed during the current study are available in the ADNI repository, adni.loni.usc.edu
References
Rosén, C., Hansson, O., Blennow, K. & Zetterberg, H. Fluid biomarkers in alzheimer’s disease-current concepts. Mol. neurodegeneration 8, 1–11 (2013).
Association, A. 2019 alzheimer’s disease facts and figures. Alzheimer’s & dementia 15, 321–387 (2019).
Gaser, C. et al. Brainage in mild cognitive impaired patients: predicting the conversion to alzheimer’s disease. PloS one 8, e67346 (2013).
Gong, N.-J., Wong, C.-S., Chan, C.-C., Leung, L.-M. & Chu, Y.-C. Correlations between microstructural alterations and severity of cognitive deficiency in alzheimer’s disease and mild cognitive impairment: a diffusional kurtosis imaging study. Magn. Reson. Imaging 31, 688–694 (2013).
O’Dwyer, L. et al. Using support vector machines with multiple indices of diffusion for automated classification of mild cognitive impairment. PloS one 7, e32441 (2012).
Moscoso, A. et al. Prediction of alzheimer’s disease dementia with mri beyond the short-term: Implications for the design of predictive models. NeuroImage: Clin. 23, 101837 (2019).
Chong, M. S. & Sahadevan, S. Preclinical alzheimer’s disease: diagnosis and prediction of progression. The Lancet Neurol. 4, 576–579 (2005).
Sarazin, M. et al. The amnestic syndrome of hippocampal type in alzheimer’s disease: an mri study. J. Alzheimer’s disease 22, 285–294 (2010).
Candemir, S. et al. Predicting rate of cognitive decline at baseline using a deep neural network with multidata analysis. J. Med. Imaging 7, 044501–044501 (2020).
Hu, Z., Wang, Z., Jin, Y. & Hou, W. Vgg-tswinformer: Transformer-based deep learning model for early alzheimer’s disease prediction. Comput. Methods Programs Biomed. 229, 107291 (2023).
Pan, D. et al. Deep learning for brain mri confirms patterned pathological progression in alzheimer’s disease. Adv. Sci. 10, 2204717 (2023).
Inan, M. S. K. et al. A slice selection guided deep integrated pipeline for alzheimer’s prediction from structural brain mri. Biomed. Signal Process. Control. 89, 105773 (2024).
Singh, A. & Kumar, R. Brain mri image analysis for alzheimer’s disease (ad) prediction using deep learning approaches. SN Comput. Sci. 5, 160 (2024).
Aghaei, A., Ebrahimi Moghaddam, M. & Malek, H. Interpretable ensemble deep learning model for early detection of alzheimer’s disease using local interpretable model-agnostic explanations. Int. J. Imaging Syst. Technol. 32, 1889–1902 (2022).
Zhang, X. et al. smri-patchnet: A novel efficient explainable patch-based deep learning network for alzheimer’s disease diagnosis with structural mri. IEEE Access 11, 108603–108616. https://doi.org/10.1109/ACCESS.2023.3321220 (2023).
Liu, F., Yuan, S., Li, W., Xu, Q. & Sheng, B. Patch-based deep multi-modal learning framework for alzheimer’s disease diagnosis using multi-view neuroimaging. Biomed. Signal Process. Control. 80, 104400 (2023).
Zhu, W., Sun, L., Huang, J., Han, L. & Zhang, D. Dual attention multi-instance deep learning for alzheimer’s disease diagnosis with structural mri. IEEE Transactions on Med. Imaging 40, 2354–2366 (2021).
Liu, S. et al. Generalizable deep learning model for early alzheimer’s disease detection from structural mris. Sci. reports 12, 17106 (2022).
Mueller, K., Meyer-Baese, A. & Erlebacher, G. Combining graph neural networks and roi-based convolutional neural networks to infer individualized graphs for alzheimer’s prediction. In Medical Imaging 2023: Biomedical Applications in Molecular, Structural, and Functional Imaging, vol. 12468, 25–30 (SPIE, 2023).
Zhao, Y. et al. Multi-view prediction of alzheimer’s disease progression with end-to-end integrated framework. J. Biomed. Informatics 125, 103978 (2022).
Wen, J. et al. Convolutional neural networks for classification of alzheimer’s disease: Overview and reproducible evaluation. Med. image analysis 63, 101694 (2020).
Kautzky, A. et al. Prediction of autopsy verified neuropathological change of alzheimer’s disease using machine learning and mri. Front. aging neuroscience 10, 406 (2018).
Martin, S. A., Townend, F. J., Barkhof, F. & Cole, J. H. Interpretable machine learning for dementia: a systematic review. Alzheimer’s & Dementia 19, 2135–2149 (2023).
Jahan, S. et al. Explainable ai-based alzheimer’s prediction and management using multimodal data. Plos one 18, e0294253 (2023).
Jia, M., Wu, Y., Xiang, C. & Fang, Y. Predicting alzheimer’s disease with interpretable machine learning. Dementia Geriatr. Cogn. Disord. 52, 249–257 (2023).
Parvin, S., Nimmy, S. F. & Kamal, M. S. Convolutional neural network based data interpretable framework for alzheimer’s treatment planning. Vis. Comput. for Ind. Biomed. Art 7, 1–12 (2024).
Tang, Z. et al. Interpretable classification of alzheimer’s disease pathologies with a convolutional neural network pipeline. Nat. communications 10, 2173 (2019).
Yousefzadeh, N. et al. Neuron-level explainable ai for alzheimer’s disease assessment from fundus images. Sci. Reports 14, 7710 (2024).
Luo, M. et al. Class activation attention transfer neural networks for mci conversion prediction. Comput. Biol. Medicine 156, 106700 (2023).
Shmulev, Y., Belyaev, M. & Initiative, A. D. N. Predicting conversion of mild cognitive impairments to alzheimer’s disease and exploring impact of neuroimaging. In Graphs in Biomedical Image Analysis and Integrating Medical Imaging and Non-Imaging Modalities: Second International Workshop, GRAIL 2018 and First International Workshop, Beyond MIC 2018, Held in Conjunction with MICCAI 2018, Granada, Spain, September 20, 2018, Proceedings 2, 83–91 (Springer, 2018).
Abrol, A. et al. Deep residual learning for neuroimaging: an application to predict progression to alzheimer’s disease. J. neuroscience methods 339, 108701 (2020).
Chaurasia, B. K., Raj, H., Rathour, S. S. & Singh, P. B. Transfer learning-driven ensemble model for detection of diabetic retinopathy disease. Med. & Biol. Eng. & Comput. 61, 2033–2049 (2023).
Rye, I., Vik, A., Kocinski, M., Lundervold, A. S. & Lundervold, A. J. Predicting conversion to alzheimer’s disease in individuals with mild cognitive impairment using clinically transferable features. Sci. Reports 12, 15566 (2022).
Mahmud, T. et al. Exploring deep transfer learning ensemble for improved diagnosis and classification of alzheimer’s disease. In International Conference on Brain Informatics, 109–120 (Springer, 2023).
Nobili, F. & Morbelli, S. [18f] fdg-pet as a biomarker for early alzheimers disease. The Open Nucl. Medicine J. 2 (2010).
SinhaRoy, R. & Sen, A. A hybrid deep learning framework to predict alzheimer’s disease progression using generative adversarial networks and deep convolutional neural networks. Arab. J. for Sci. Eng. 49, 3267–3284 (2024).
Wang, J. et al. Fedmed-gan: Federated domain translation on unsupervised cross-modality brain image synthesis. Neurocomputing 546, 126282 (2023).
Li, Y. et al. Zero-shot medical image translation via frequency-guided diffusion models. IEEE transactions on medical imaging (2023).
Alauthman, M. et al. Enhancing small medical dataset classification performance using gan. In Informatics, vol. 10, 28 (MDPI, 2023).
Wali, A., Ahmad, M., Naseer, A., Tamoor, M. & Gilani, S. Stynmedgan: Medical images augmentation using a new gan model for improved diagnosis of diseases. J. Intell. & Fuzzy Syst. 1–18 (2023).
Golhar, M. V., Bobrow, T. L., Ngamruengphong, S. & Durr, N. J. Gan inversion for data augmentation to improve colonoscopy lesion classification. IEEE J. Biomed. Heal. Informatics (2024).
Festag, S. & Spreckelsen, C. Medical multivariate time series imputation and forecasting based on a recurrent conditional wasserstein gan and attention. J. Biomed. Informatics 139, 104320 (2023).
Shahbazian, R. & Greco, S. Generative adversarial networks assist missing data imputation: A comprehensive survey & evaluation. IEEE Access (2023).
Liu, Y. et al. Assessing clinical progression from subjective cognitive decline to mild cognitive impairment with incomplete multi-modal neuroimages. Med. image analysis 75, 102266 (2022).
Aghaei, A., Ebrahimi Moghaddam, M. & Initiative, A. D. N. Brain age gap estimation using attention-based resnet method for alzheimer’s disease detection. Brain Informatics 11, 16 (2024).
Aghaei, A. & Moghaddam, M. E. Smart roi detection for alzheimer’s disease prediction using explainable ai. arXiv preprint arXiv:2303.10401 (2023).
Niculescu-Mizil, A. & Caruana, R. Predicting good probabilities with supervised learning. In Proceedings of the 22nd international conference on Machine learning, 625–632 (2005).
Zadrozny, B. & Elkan, C. Transforming classifier scores into accurate multiclass probability estimates. In Proceedings of the eighth ACM SIGKDD international conference on Knowledge discovery and data mining, 694–699 (2002).
Mueller, S. G. et al. The alzheimer’s disease neuroimaging initiative. Neuroimaging Clin. North Am. 15, 869 (2005).
Jack, C. R. et al. Tracking pathophysiological processes in alzheimer’s disease: an updated hypothetical model of dynamic biomarkers. The lancet neurology 12, 207–216 (2013).
Selvaraju, R. R. et al. Grad-cam: Visual explanations from deep networks via gradient-based localization. In Proceedings of the IEEE international conference on computer vision, 618–626 (2017).
Makris, N. et al. Decreased volume of left and total anterior insular lobule in schizophrenia. Schizophr. research 83, 155–171 (2006).
Acknowledgements
Data used in the preparation of this article were obtained from the Alzheimer’s Disease Neuroimaging Initiative (ADNI) database (adni.loni.usc.edu). As such, the investigators within the ADNI contributed to the design and implementation of ADNI and/or provided data but did not participate in the analysis or writing of this report. A complete listing of ADNI investigators can be found at: http://adni.loni.usc.edu/wp-content/uploads/how_to_apply/ADNI_Acknowledgement_List.pdf.
Author information
Authors and Affiliations
Contributions
A.A. wrote the manuscript text, A.A. and M.E.M. conceived the experiment(s), A.A. experimented, and A.A. and M.E.M. analyzed the results. All authors reviewed the manuscript.
Corresponding author
Ethics declarations
Competing interests
The authors declare no competing interests.
Additional information
Publisher’s note
Springer Nature remains neutral with regard to jurisdictional claims in published maps and institutional affiliations.
Supplementary Information
Rights and permissions
Open Access This article is licensed under a Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International License, which permits any non-commercial use, sharing, distribution and reproduction in any medium or format, as long as you give appropriate credit to the original author(s) and the source, provide a link to the Creative Commons licence, and indicate if you modified the licensed material. You do not have permission under this licence to share adapted material derived from this article or parts of it. The images or other third party material in this article are included in the article’s Creative Commons licence, unless indicated otherwise in a credit line to the material. If material is not included in the article’s Creative Commons licence and your intended use is not permitted by statutory regulation or exceeds the permitted use, you will need to obtain permission directly from the copyright holder. To view a copy of this licence, visit http://creativecommons.org/licenses/by-nc-nd/4.0/.
About this article
Cite this article
Aghaei, A., Moghaddam, M.E. An integrated predictive model for Alzheimer’s disease progression from cognitively normal subjects using generated MRI and interpretable AI. Sci Rep 15, 28340 (2025). https://doi.org/10.1038/s41598-025-13478-2
Received:
Accepted:
Published:
DOI: https://doi.org/10.1038/s41598-025-13478-2