Abstract
Machine learning models often rely on simple spurious features – patterns in training data that correlate with targets but are not causally related to them, like image backgrounds in foreground classification. This reliance typically leads to imbalanced test performance across minority and majority groups. In this work, we take a closer look at the fundamental cause of such imbalanced performance through the lens of memorization, which refers to the ability to predict accurately on atypical examples (minority groups) in the training set but failing in achieving the same accuracy in the testing set. This paper systematically shows the ubiquitous existence of spurious features in a small set of neurons within the network, providing the first-ever evidence that memorization may contribute to imbalanced group performance. Through three experimental sources of converging empirical evidence, we find the property of a small subset of neurons or channels in memorizing minority group information. Inspired by these findings, we hypothesize that spurious memorization, concentrated within a small subset of neurons, plays a key role in driving imbalanced group performance. To further substantiate this hypothesis, we show that eliminating these unnecessary spurious memorization patterns via a novel framework during training can significantly affect the model performance on minority groups. Our experimental results across various architectures and benchmarks offer new insights on how neural networks encode core and spurious knowledge, laying the groundwork for future research in demystifying robustness to spurious correlation.
Similar content being viewed by others
Introduction
Machine learning models often achieve high overall performance, yet struggle in minority groups due to spurious correlations – patterns that align with the class label in training data but have no causal relationship with the target1,2. For example, considering the task of distinguishing cows from camels in natural images, it is common to find 95% cow images with grass backgrounds and 95% of camel images on sand. Models trained using standard Empirical Risk Minimization (ERM) often focus on minimizing the average training error by depending on spurious background attributes ("grass” or “sand”) instead of the core characteristics ("cow” or “camel”). In such settings, models may yield good average accuracy but lead to high error rates in minority groups ("cows on sand” or “camel on grass”)3,4. This illustrates a fundamental issue: even well-trained models can develop systematic biases from these spurious attributes in their data, thus leading to alarmingly consistent performance drop for minority groups where the spurious correlation does not hold. Indeed, in Fig. 1, we present both the training and test accuracy on the majority and minority groups of the Waterbirds benchmark for two popular models: ResNet-505 and ViT-small6. It is clear from Fig. 1 that the test performance is poor on minority groups (1 and 2). Moreover, we observe that majority groups have a smaller gap between the training and testing accuracy, as compared to minority groups that have a more significant gap. Thus, understanding the underlying causes of this unbalanced performance between the majority and minority groups is crucial to their reliable and safe deployment in various real-world scenarios7,8,9.
The minority groups are atypical examples to neural networks (NNs), as these small subsets of examples bear a similarity to majority groups due to the same spurious attribute, but have distinct labels. Recent efforts have shown that NNs often ‘memorize’ atypical examples, primarily in the final few layers of the model10,11, and possibly even in specific locations of the model12. Memorization, in this context, is defined as the neural network’s ability to accurately predict outcomes for atypical examples (e.g., mislabeled examples) in the training set through ERM training. This is in striking analogy to the spurious correlation issue, because 1) the minority examples are atypical examples by definition, and 2) the minority examples are often more accurately predicted during training but poorly predicted during testing, as demonstrated in Fig. 1. Therefore, a natural open question arises: Does memorization play a role in spurious correlations?
In this work, we present the first study to systematically understand the interplay of memorization and spurious correlations in deep overparametrized networks. We undertake our exploration through the following avenues: 1) What makes the comprehensive condition for the existence or non-existence of spurious correlations within NNs? 2) How do NNs handle atypical examples, often seen in minority groups, as opposed to typical examples from majority groups? and 3) Can NNs differentiate between these atypical and typical examples in their learning dynamics?
To achieve these goals, we show the existence of a phenomenon named spurious memorization. We define ‘spurious memorization’ as the ability of NNs to accurately predict outcomes for atypical (i.e., minority) examples during training by deliberately memorizing them in certain part of the model. Indeed, we first identify that a small set of neurons is critical for memorizing minority examples. These critical neurons significantly affect the model performance on minority examples during training, but only have minimal influence on majority examples. Furthermore, we show that these critical neurons only account for a very small portion of the model parameters. Such a memorization by a small portion of neurons causes the model performance on minority examples to be non-robust, which leads to the poor testing accuracy on minority examples despite the high training accuracy. Overall, our study offers a potential explanation for the differing performance patterns of NNs when handling majority and minority examples.
Our systematic study is performed in two stages. In Stage I, to verify the existence of critical neurons, we identify two experimental sources to trace spurious memorization at the neuron and layer level. These two sources are unstructured tracing (assessing the role of neurons within the entire model for spurious memorization using heuristics including weight magnitude and gradient) and structured tracing (assessing the role of neurons within each individual layer with similar heuristics). Specifically, by evaluating the impact of spurious memorization via unstructured and structured tracing at the magnitude and gradient level, we observe a substantial decrease in minority group accuracy, contrasting with a minimal effect on the majority group accuracy. This suggests that at unstructured and structured level, the learning of minority group opposes the learning of majority group, and indicates that 1) critical neurons for spurious memorization indeed exist within NNs; 2) both gradient and magnitude criteria are effective tools for identifying these critical neurons; and 3) NNs tend to memorize typical examples from majority groups on a global scale, whereas a miniature set of nodes (critical neurons) is involved in the memorization of minority examples to a greater extent than other neurons. Overall, we provide converging empirical evidence to confirm the existence of critical neurons for spurious memorization.
In Stage II, inspired by the observations above, we develop a framework to investigate and understand the essential role of critical neurons in spurious memorization that would incur the imbalanced group performance of NNs. Specifically, we construct an auxiliary model which is an adaptively pruned version of the target model, and then contrast the features of this auxiliary model with those of the target model. Our motivation comes from recent empirical finding13 that pruning can improve a network’s robustness to accurately predict rare and atypical examples (minority groups in our case). This allows the target model to identify and adapt to various spurious memorization at different stages of training, thereby progressively learning more balanced representations across different groups. Through extensive experiments with our training algorithm across a diverse range of architecture, model sizes, and benchmarks, we confirm that the critical neurons have emergent spurious memorization properties, thereby more friendly to pruning. More importantly, we show that majority examples, being memorized by the entire network, often yield robust test performance, whereas minority examples, memorized by a limited set of critical neurons, show poor test performance due to the miniature subset of neurons. This provides a convincing explanation for the imbalanced group performance observed in the presence of spurious correlations.
Concretely, we summarize our contributions as follows: (1) To the best of our knowledge, we present the first systematic study on the role of different neurons in memorizing different group information, and confirm the existence of critical neurons where memorization of spurious correlations occurs. (2) We show that modifications to specific critical neurons can significantly affect model performance on the minority groups, while having almost negligible impact on the majority groups. (3) We propose spurious memorization as a new perspective on explaining the behavior of critical neurons in causing imbalanced group performance between majority and minority groups.
Results
Identifying the existence of critical neurons
We validate the existence of critical neurons in the presence of spurious correlations. We comprehensively examine the underlying behavior of ‘critical neurons’ on the Waterbirds dataset with the ResNet-50 backbone. Within this section, the term ‘neurons’ specifically refers to channels in a convolutional kernel. It is worth noting that the Waterbirds dataset comprises two majority groups and two minority groups. For clarity in our discussions and figures, we use the following notations, aligned with the dataset’s default setting: The majority groups are \({{{{\mathcal{G}}}}}_{0}\) (Landbird on Land) and \({{{{\mathcal{G}}}}}_{3}\) (Waterbird on Water), while the minority groups are \({{{{\mathcal{G}}}}}_{1}\) (Landbird on Water), \({{{{\mathcal{G}}}}}_{2}\) (Waterbird on Land).
In the following discussion, we consider the model as f(θ, ⋅ ), with Θ representing the collection of all neurons. Individual neurons are denoted as zi, for i ∈ [M]: = {1, ⋯ , M}, and θ can be expressed as θ = {z1, z2, ⋯ , zM}. For the training data, we use \({{{{\mathcal{D}}}}}_{0}\), \({{{{\mathcal{D}}}}}_{1}\), \({{{{\mathcal{D}}}}}_{2}\), \({{{{\mathcal{D}}}}}_{3}\) to represent the datasets, where \({{{{\mathcal{D}}}}}_{j}\) comprises examples from group \({{{{\mathcal{G}}}}}_{j}\), for each j ∈ {0, 1, 2, 3}, respectively. Finally, let \({{{{\mathcal{L}}}}}_{{{{\rm{CE}}}}}\) signify the cross-entropy loss. We emphasize that all the group accuracy evaluated before and after pruning in this section is evaluated on the training set, which strictly complies with the definition of memorization from Section Introduction.
To begin with, we adopt unstructured tracing to assess the effect of neurons on spurious memorization across the entire model, using weight magnitude and gradient as criteria. For the gradient-based criterion, we begin with a model trained by ERM. We then select the neurons with the largest gradient, measured in the ℓ2 norm, across the entire model. Zeroing out these neurons, we can then observe the resultant impact on group accuracy. To be specific, we compute the loss gradient for each of the 4 Waterbirds groups. The loss gradient v( ⋅ ) on group j w.r.t. neuron i is defined as
For each group j, we select those neurons \({i}^{{\prime} }s\) of which the ∥v(i, j)∥2 are the top-k largest among all M neurons. In our experiments, we evaluate cases with k = 1, 2, 3. We demonstrate that even just pruning the top-1 largest gradient neuron can significantly affect the minority group training accuracy. We denote the indices of these neurons as \({{{{\mathcal{I}}}}}_{j}\), where \({{{{\mathcal{I}}}}}_{j}\) is a subset of {1, ⋯ , M}. To assess the importance of these selected neurons in memorizing examples, we zero them out and calculate the change in group accuracy on the training set. The pruned model is identified as f(mj ⊙ θ, ⋅ ), where mj is a mask with neurons in \({{{{\mathcal{I}}}}}_{j}\) being masked. The change in accuracy Δacc for each group j is given by \({\Delta }_{{{{\rm{acc}}}}}(j)=| {{{\rm{acc}}}}({{{{\mathcal{D}}}}}_{j},f({{{\boldsymbol{\theta }}}},\cdot ))-{{{\rm{acc}}}}({{{{\mathcal{D}}}}}_{j},f({{{{\bf{m}}}}}_{j}\odot {{{\boldsymbol{\theta }}}},\cdot ))|\), where acc represents the accuracy. In our experiments below, all the group accuracy change is based on the following training accuracy: 97.34% (\({{{{\mathcal{G}}}}}_{0}\)), 47.83% (\({{{{\mathcal{G}}}}}_{1}\)), 69.64% (\({{{{\mathcal{G}}}}}_{2}\)), 97.63% (\({{{{\mathcal{G}}}}}_{3}\)). For completeness, we also report the baseline test accuracy: 96.98% (\({{{{\mathcal{G}}}}}_{0}\)), 35.68% (\({{{{\mathcal{G}}}}}_{1}\)), 56.98% (\({{{{\mathcal{G}}}}}_{2}\)), and 96.26% (\({{{{\mathcal{G}}}}}_{3}\)). The baseline test accuracy follows the same pattern as the training accuracy.
Similarly, when using magnitude as the selection criterion, the tracing procedure remains the same except that we zero-out the neurons with the largest magnitude measured in ℓ2 norm. That is, instead of ∥v(i, j)∥2, we select neurons with largest ∥zi∥2. It is worth noting that the magnitude-based selection approach here is group invariant – the magnitude used for selection does not vary with the model’s input.
We demonstrate that zeroing out the top-1 to top-3 critical neurons can significantly impact the training accuracy of minority groups. A natural inquiry arises: are three neurons sufficient? In essence, we investigate whether pruning additional neurons can amplify the performance drop. Thus, we conduct an ablation study varying the number of pruned neurons. The findings are summarized in Table 3 (in Supplementary materials). We assert that similar trends persist as observed in Fig. 2, despite altering the number of pruned neurons. Notably, the decline in performance among minority groups (\({{{{\mathcal{G}}}}}_{1}\) and \({{{{\mathcal{G}}}}}_{2}\)) exceeds that of majority groups (\({{{{\mathcal{G}}}}}_{0}\) and \({{{{\mathcal{G}}}}}_{3}\)), even with an increase to 10 neurons pruned.
Within each group \({{{\mathcal{G}}}}\), three bars with gradated hues indicate the accuracy shift after zeroing out the neurons with the top-1, top-2, and top-3 largest gradients or magnitudes, respectively. Note that the minority groups \({{{{\mathcal{G}}}}}_{1}\) and \({{{{\mathcal{G}}}}}_{2}\) are emphasized with the star superscript (*).
In our study, we plot the change in accuracy, Δacc(j), for each group j as shown in Fig. 2. For every group, we consider three scenarios: pruning the top-1, top-2, and top-3 neurons, which corresponds to the 3 bars for each group in Fig. 2 Note that we limit our reporting to the results involving up to 3 critical neurons based on experimental findings indicating that pruning the top-3 neurons is adequate. This decision is supported by the number of pruned neurons, detailed in Supplementary Materials Table 3 (in Supplementary materials). It can be clearly observed that the accuracy of minority groups exhibits significant shifts, while the accuracy of majority groups shows only minimal impact. Specifically, for the majority groups \({{{{\mathcal{G}}}}}_{0}\) and \({{{{\mathcal{G}}}}}_{3}\), the maximum of the group accuracy shifts stands at 2.15% when we zero out the top 3 neurons with the largest gradient. While for minority group \({{{{\mathcal{G}}}}}_{1}\) and \({{{{\mathcal{G}}}}}_{2}\), the maximum of the group accuracy shifts stands at 11.96% when we zero out the top 2 neurons with the largest gradient. This is a sharp contrast between the groups, where accuracy shifts significantly, underscoring the critical role of selected neurons in memorizing minority examples at both gradient and magnitude levels. Meanwhile, the substantial contrast in accuracy shifts between majority and minority groups provides initial evidence that the model’s performance on minority groups can be solely dependent on a few neurons, occasionally even as few as three or fewer.
Interestingly, we observe that both the gradient-based and magnitude-based criteria can yield similar effects. We show in the following that it is attributed to an overlap in the distribution of critical neurons identified by each criterion. To delve deeper, in Fig. 3, we analyze the relative magnitude ranking among all neurons for the neurons with the largest gradient, and the relative gradient ranking for neurons with the largest magnitude. In the left of Fig. 3, we show the magnitude ranking for the neurons with top 0.01% largest gradient, and Fig. 3 right subfigure demonstrates the gradient ranking for the top 0.01% largest magnitude neurons. In both histograms, there is a noticeable clustering in the rightmost two bins (ranging from 95% to 100%). This suggests that the neurons with the highest magnitudes tend to exhibit large gradients, and the neuron with the largest gradient often coincides with a high weight magnitude. This finding provides tantalizing evidence of the similar distribution of critical neurons under both criteria and explains the matching phenomenon observed between the two criteria.
Our experiments thus far offer preliminary evidence for the existence of critical neurons. To gain a more comprehensive understanding, we explore alternatives to pruning, especially studying the effects of random initialization and random noise. These two experiments are motivated by our desire to investigate the effects of perturbation from two perspectives: perturbation on the original neuron weights and perturbation on the pruned neurons. By examining these perturbations, we draw more credible supporting evidence on the existence of critical neurons by evaluating the sensitivity of group accuracy to specific neurons more comprehensively.
To implement random initialization, instead of performing pruning on the selected neurons, we opt to initialize them randomly using a zero-mean Gaussian random variable. That is, we replace the neuron weight zi with \({\tilde{{{{\bf{z}}}}}}_{i}={\epsilon }_{i}\) where \({\epsilon }_{i} \sim {{{\mathcal{N}}}}({{{\bf{0}}}},{{{{\boldsymbol{\sigma }}}}}^{2})\). The accuracy change is formulated as:
where \(\tilde{{{{\boldsymbol{\theta }}}}}={\{{{{{\bf{z}}}}}_{i}\}}_{i\notin {{{{\mathcal{I}}}}}_{j}}\cup {\{{\tilde{{{{\bf{z}}}}}}_{i}\}}_{i\in {{{{\mathcal{I}}}}}_{j}}\). The result is in Fig. 4.
To implement random noise, we add an extra noise term, which is a zero-mean Gaussian random variable, to the selected neurons, i.e., \({\tilde{{{{\bf{z}}}}}}_{i}={{{{\bf{z}}}}}_{i}+{\epsilon }_{i}\), where \({\epsilon }_{i} \sim {{{\mathcal{N}}}}({{{\bf{0}}}},{{{{\mathbf{\sigma }}}}}^{2})\). The result is shown in Fig. 5.
In Figs. 4 and 5, we found that 1) the results from random initialization (σ = 0.005) closely resemble those from the pruning method. Notably, the minority groups show a 2–12% shift in group accuracy compared to the majority groups’ 0–2.6% shift, since random initialization converges to pruning as the standard deviation of the Gaussian random variables \({\{{\epsilon }_{i}\}}_{i\in {{{{\mathcal{I}}}}}_{j}}\) decreases; 2) with random noise added (σ = 0.005), the accuracy changes in minority groups still surpass those in majority groups. We also observe that the extent of accuracy change with random noise is much smaller than that observed with random initialization and pruning. This occurs because, in the presence of random noise, neuron values are not reset to a mean-zero state, allowing their initial values to impact the model’s performance across different groups. Moreover, we conduct the additional unstructured tracing experiment on CelebA, as shown in Table 11 (in Supplementary materials). We can see that the results for the minority group of CelebA align with those obtained on the Waterbirds dataset. As shown in Table 13 (in Supplementary materials), our results in the Waterbirds dataset reveal that when modifying critical neurons, the training accuracy for minority groups (\({{{{\mathcal{G}}}}}_{1}\) and \({{{{\mathcal{G}}}}}_{2}\)) consistently drops across nearly all experimental setups, whereas the corresponding test accuracy remains relatively stable. This discrepancy reinforces the interpretation that these neurons play a memorization role. Furthermore, the raw values reported in Table 14 (in Supplementary materials) further substantiate these findings, showing that for minority groups, training accuracy decreases consistently in all experimental settings and for every choice of k in the top-k analysis. Together, these results provide compelling evidence that identified critical neurons are primarily responsible for memorization in minority groups, rather than affecting overall generalization. Additional experimental results on random initialization and random noise with various standard deviations can be found below.
For all random initialization (Fig. 11 in Supplementary materials) and noise-adding (Fig. 12 in Supplementary materials) experiments, we choose the random variable with multiple standard deviations to validate the existence of critical neuron. For each subfigure in Figs. 11 and 12 (in Supplementary materials), the results is averaged over 10 random seeds.
Overall, regardless of the scale variation in the accuracy shifts, our experiments using pruning, random initialization, and random noise consistently demonstrate that the accuracy of minority groups is significantly sensitive to the alteration of a handful of selected neurons. This finding suggests that a small subset of critical neurons contributes more significantly to the memorization of minority examples during training than other neurons. Moreover, it validates that both gradient-based and magnitude-based criteria are effective in identifying these critical neurons.
In unstructured tracing, we select neurons from the entire model without considering any sub-structures (i.e., layers, blocks) of networks. To gain a deeper understanding of how these sub-structures influence memorization, we use structured tracing for probing and comprehending the role of sub-structures in the networks. Specifically, we begin by fixing a particular layer, and then selecting neurons within the layer to assess the importance of these neurons in memorizing examples from groups. We still employ either gradient-based or magnitude-based criterion for neuron selection, but the scope of this specific experiment is confined to the individual layer. This process is identically repeated for each layer in the entire model.
In Fig. 7 (in Supplementary materials), we employ a heatmap to visualize how accuracy changes across different groups when we selectively zero-out a subset of neurons within a specific layer. What becomes evident is that deactivating the same number of neurons with the highest gradients or magnitudes within a layer consistently leads to a more significant shift in accuracy for the minority groups compared to the majority groups. This difference is clearly discernible in the brighter color associated with the minority groups \({{{{\mathcal{G}}}}}_{1}\) and \({{{{\mathcal{G}}}}}_{2}\) in the middle two rows. Furthermore, we notice that these within-layer critical neurons appear to be distributed across multiple layers in the early stages of the model, rather than being confined to the final few layers. This finding aligns with the literature which indicates that the memorization of atypical examples can be distributed and localized throughout the neural networks12.
Spurious memorization by critical neurons
In the aforementioned findings, our experiments have empirically demonstrated the presence of a small set of critical neurons involved in the memorization of minority examples during training. This underscores the role of spurious memorization as a significant factor in imbalanced group performance. In this section, we take a further step in demystifying the cause of imbalanced group performance under spurious correlation, particularly focusing on the discrepancy in the test accuracy between majority and minority groups.
To further validate the hypothesis that spurious memorization is a key factor in the imbalanced group performance, we investigate whether countering spurious memorization during training could lead to improved test accuracy on minority groups. Our findings affirmatively answer this question. By specifically targeting and removing spurious memorization via a specialized fine-tuning framework, we observe a consistent improvement in the test accuracy for minority groups. We report extensive experimental results across different model architectures, including ResNet-50 and ViT-Small, and on benchmark datasets including Waterbirds1,14 and CelebA15, providing comprehensive analysis on the effects of spurious memorization on imbalanced group performance.
Figure 9 (in Supplementary materials) summarizes our fine-tuning framework for analyzing spurious memorization. By default, our framework is built upon simCLR16, adhering to its key components such as data augmentations and the non-linear projection head. The primary distinction between ours and simCLR is centered around two models: a target model and an auxiliary model. The auxiliary model is essentially a pruned version of the target model, where certain critical neurons are masked while the remaining neurons retain the same weights as the target model. This allows the framework to feed two augmented images into separate models, yielding two distinct feature representations for contrasting with each other.
More specifically, we begin with the target model, represented as f(θ, ⋅ ), where θ denotes the model weights. These weights are initialized by pretraining the model using ERM. The next stage involves fine-tuning the target model. To this end, we construct a pruned model, f(m ⊙ θ, ⋅ ), with m being a masking vector. The mask is created based on criteria derived from either gradient or magnitude, as inspired in the previous findings. In our experiments, we zero-out the top 0.01% of neurons based on their ℓ2-norm of their gradient or magnitude, where 0.01% serves as a hyperparameter.
To calculate the gradient, it is worth noting that for gradient calculation, we do not rely on group labels as in the previous findings, but instead use the model’s predictions as pseudo labels for sample selection. During each epoch, we calculate the cross-entropy loss for each sample, select the top 256 samples with the highest loss, and randomly sample 128 out of them to form the batch for gradient computation.
During training, contrasting output features of the two models – target and auxiliary – enables adaptive online identification of critical neurons for the current target model. This approach implicitly gives greater emphasis to these samples in the loss function, effectively tailoring the training process to bolster the model in recalling these challenging forgotten examples. In particular, the key innovation in our training framework is to restrict the target model’s tendency to memorize atypical examples using only a small set of neurons. Inspired by the observation13 that pruning reduces a model’s prediction accuracy on rare and atypical instances, we enforce feature alignment by adopting the NT-Xent loss17 as we find that these samples typically exhibit the greatest prediction disparities between the pruned and non-pruned models. Utilizing pruning as an experimental tool will amplify the prediction disparity between the pruned and non-pruned models, resulting in an implicit rebalancing of the loss.
Consider an arbitrary input image x, and denote by \({{{{\bf{x}}}}}^{{\prime} }\) and x″ two augmentations of x. The loss term on the input x (together with its positive and negative pairs) can be formulated as:
where r is the output feature of input x by the target model \(f({{{\boldsymbol{\theta }}}},{{{{\bf{x}}}}}^{{\prime} })\), rp is the output feature of the auxiliary model f(m ⊙ θ, x″), and rk is the output feature of negative pairs. And τ is the loss temperature, and sim( ⋅ , ⋅ ) is the cosine similarity: sim(u, v) = u ⋅ v/(∥u∥ ⋅ ∥b∥).
Additionally, we incorporate a supervised loss for the target model. Interestingly, we found that the Mean Squared Error (MSE) loss had a more pronounced effect than the Cross Entropy (CE) loss in our experiments18. Thus, we adopt the MSE loss, as defined as:
where y is the one-hot vector of the ground-truth class of input x, and \(\hat{{{{\bf{y}}}}}\) is the model prediction vector. The final loss function is formulated as:
where λ > 0 is a hyperparameter for balancing loss terms.
In summary, we initiate the process by pretraining the target model using ERM and subsequently fine-tune it further using the above-mentioned framework for a few additional epochs. At the start of each fine-tuning epoch, we create an auxiliary model by pruning a small portion of neurons from the target model based on either gradient or magnitude. The target and auxiliary models are then trained to align their output features with each other, fostering robust learning.
In this study, our primary objective is to investigate whether mitigating spurious memorization can lead to an enhancement in the test accuracy of minority groups. The findings are illustrated in Fig. 6, where we compare the Worst Group Accuracy (WGA) between standard ERM training and our proposed framework. The WGA for ERM training is evaluated by the testing set with the best-performing checkpoint (on the validation set) among the first 100 epochs. Notably, we observe a significant increase in WGA under all scenarios. Specifically, with the ResNet-50 backbone, we observe a significant 16.87% and 10.77% improvement in WGA for the Waterbirds and CelebA datasets, respectively. Similarly, on ViT-Small model, we observe 23.83% and 9.45% improvements in WGA. Furthermore, the statistical summary is included in Table 12 (in Supplementary materials).
It is important to highlight that our auxiliary model is essentially a pruned version of the target model, with only 0.01% of the neurons being masked. Despite this seemingly small modification, the consistent performance boost in WGA across different architectures and datasets is strikingly remarkable. This improvement suggests that by strategically disrupting the spurious memorization mechanism through contrasting two model branches, we can guide the target model to learn atypical or minority examples more robustly. These findings lend further support to our hypothesis that spurious memorization contributes to imbalanced group performance. Additionally, we have conducted comprehensive ablation studies to explore various aspects of our framework, including the choice of kick-in and fine-tuning epoch, loss balancing term λ, pruning ratio, different gradient sources, and loss function.
In the following, we perform extensive ablation studies to offer a more comprehensive perspective of the framework. All the ablation experiments are conducted using the Waterbirds dataset with the ResNet-50 model. (1) Loss Functions: we use MSE as one of the loss terms in the model fine-tuning (see Eqs. (4) and (5)). Here we compare MSE with Cross Entropy (CE) loss. Using CE, the final loss becomes \({{{{\mathcal{L}}}}}_{{{{\rm{total}}}}}({{{\boldsymbol{\theta }}}},{{{\bf{x}}}},{{{\bf{y}}}})={{{{\mathcal{L}}}}}_{{{{\rm{NT}}}}}({{{\boldsymbol{\theta }}}},{{{\bf{x}}}})+\lambda {{{{\mathcal{L}}}}}_{{{{\rm{CE}}}}}({{{\boldsymbol{\theta }}}},{{{\bf{x}}}},{{{\bf{y}}}})\), as compared to Eq. (5). The result is shown in Supplementary Table 4. We observe that MSE is more effective in terms of WGA gain than CE under the same pruning percentage. Still, both choices manifest significant WGA gain against ERM, corroborating our hypothesis that spurious memorization in the critical neurons might play a critical role in imbalanced group performance. (2) Kick-in Epoch: in our training framework, we first pretrain the target using ERM for 40 epochs and then switch to a fine-tuning stage with loss function Eq. (5). In other words, our framework kicks in at epoch 40. Here we test different choices of the kick-in epoch. The result is shown in Table 5 (in Supplementary materials). Overall, we see that epoch 30 is not effective, while the choice of 40, 50, and 60 all return meaningful returns. This reason is that the ERM training of the target model has not converged yet at epoch 30. (3) Number of Fine-tuning Epochs: we then compare the number of epochs for the fine-tuning stage. The result is shown in Table 6 (in Supplementary materials). Observe that, there is no difference between the result for using 20 or 30 fine-tuning epochs. This is because the best model check-point appears within 20 epochs. The result indicates that fine-tuning for more epochs is unnecessary. (4) Data Source for Gradient Calculation: we compare different source of gradient in the calculation of gradient-based criterion. As introduced before, the neuron gradient is computed on a selected subset of training data. By default, this subset is chosen as the worst predicted examples by the target model in terms of the CE loss. From Table 7 (in Supplementary materials), we observe that, calculating the gradient on the subset of worst predicted examples from the minority groups does not show any benefit. Considering the fact that using minority groups as the gradient source requires access to the group label which is sometimes unavailable, we suggest using the full training set as the gradient source. (5) Loss Term Ratios: we compare different choice of the loss term ratio λ in Eq. (5). The result is shown in Table 8 (in Supplementary materials). (6) Pruning Percentage: we compare different choice of pruning percentage for both the gradient-based criterion and magnitude-based criterion. The result is shown in Table 9 (in Supplementary materials). (7) Combined Pruning Criteria: we test a mixed pruning criterion which combines the gradient-based criterion with the magnitude-based one. The result is shown in Table 10 (in Supplementary materials).
To interpret the outcome of the trained neural networks by ERM and our fine-tuning strategy, we visualize the GradCAM on ResNet-50 trained by solely ERM and our fine-tuning strategy. The target layer is set to layer4.2.conv3.weight, and the target dimension in output feature is set to dimension 0. Figure 10 (in Supplementary) clearly shows that by our fine-tuning strategy, the neural network shifts its focus from the spurious element (i.e., background) to the main object (i.e., bird).
Discussion
In this paper, we conduct the systematic investigation aimed at uncovering the root structural cause of imbalanced group performance in the presence of spurious correlations. This phenomenon is characterized by both majority and minority groups achieving high training accuracy, yet minority groups experiencing reduced testing accuracy. Our comprehensive study verifies the presence of spurious memorization, a mechanism involving critical neurons significantly influencing the accuracy of minority examples while having minimal impact on majority examples. Building upon these key findings, we demonstrate that by intervening with these critical neurons, we can effectively mitigate the influence of spurious memorization and enhance the performance on the worst group. Our findings shed light on the reasons behind NNs demonstrating robust performance with majority groups but struggling with minority groups, and establish spurious memorization as a pivotal factor contributing to imbalanced group performance. We hope that our discoveries offer valuable insights for practitioners and serve as a foundation for further exploration into the intricacies of memorization in the presence of spurious correlations.
Mitigating spurious correlations in machine learning and statistical models is a key step towards crafting more reliable and trustworthy medical AI. Our research uncovers that by eliminating spurious memorization, we can pinpoint critical neurons, whose modification significantly influences the model’s performance, particularly in recognizing minority groups. Concerning privacy risks, these are relatively low in our approach, as the analysis requires existing access to the dataset and the capability to train models. Looking forward, our future research will aim to address challenges within the broader scope of spurious correlations, extending beyond vision applications to include language datasets, among others. This expansion will help in developing AI solutions that are more versatile and universally applicable.
Methods
Datasets and models
In our study, we conduct experiments on two popular benchmark datasets for spurious correlation: Waterbirds1,14, and CelebA15. We comprehensively evaluate the extent to which spurious memorization exists in the large pre-trained models (ResNet-505 and ViT-Small6) on ImageNet19. Note that we report the average performance of 10 independent runs with different random seeds for experiments including unstructured tracing and structured tracing. In the experiments detailed in Section Results, we strictly adopt the standard dataset splits for both Waterbirds and CelebA, following the setting in ref. 20. Our adoption of ResNet or ViT models pre-trained on ImageNet is consistent with the main literature21,22,23. Furthermore, the high baseline accuracy achieved by pre-trained models is critical for studying memorization, which is a focal point of our study.
Identification of critical neurons
For identifying critical neurons, we utilize two key metrics: gradient-based and magnitude-based criteria. Here, gradient refers to the gradients calculated during backpropagation with respect to a specific data batch. Magnitude, on the other hand, is determined by the norm of neuron weights. Details of data batch selection are given in Section Results.
Neurons and layers
For our study on convolutional neural networks (e.g., using ResNet-50 as the backbone), we consider channels as the basic units in order to preserve the channel structure, as suggested in prior work12. On the other hand, for our study involving Vision Transformer (e.g., using ViT-Small as the backbone), we consider individual neurons as the basic units. Therefore, for ease of reference in our study, we use the term ‘neuron’ to collectively refer to both channels in ResNet-50 and neurons in ViT-Small.
Experimental setup
In all our experiments, we keep the experimental setup consistent. We use a single NVIDIA Titan RTX GPU. We conduct our experiments using PyTorch 1.13.1+cu117 and Python 3.10.4, to ensure reproducibility.
Data preprocessing
Our dataset preprocessing remains consistent across all datasets and experiments. For details, please refer to Table 1 (in Supplementary materials). These steps ensure that the resulting image size is 224 × 224 pixels, suitable for both ResNet-50 and ViT-S/16@224px. Following these augmentation steps, we normalize the image by subtracting the average pixel values (mean = [0.485, 0.456, 0.406]) and dividing by the standard deviation (std = [0.229, 0.224, 0.225]). This normalization procedure aligns with the approach used in CLIP24. No further data augmentation is applied after these steps.
Hyperparameters
A comprehensive collections of hyperparameters and their values is presented in Table 2 (in Supplementary materials).
Implementation details
In Section Results, we implemented the ERM with specific configurations. We utilized an Adam optimizer with a weight decay of 0 and a momentum of 0.9. The learning rate was set at a fixed value of 1 × 10−4, and the models were trained for a total of 100 epochs. To dynamically adjust the learning rate, we employed a ReduceLROnPlateau scheduler, which reduces the learning rate by a factor of 0.5 after a patience of 3 epochs. For the ERM method, the batch size used in the ERM process was set to 256, and the model’s parameters were initialized with torchvision.models.ResNet50_Weights.IMAGENET1K_V2. In our experiments involving gradient-based neuron modification, we first curated a set of 256 worst-performing samples. From this set, we randomly sampled 128 samples for calculating both the loss and the gradient. Our model selection process remained consistent across all methods. After each epoch, we evaluated the model’s performance on the validation set and selected the model with the highest worst-group accuracy as the final model for testing. It’s important to note that all accuracy metrics reported in this paper are derived from the test set.
Figure 3 presents an analysis of the distribution and relative rankings of neurons in two aspects: their gradient magnitude for neurons with the highest magnitudes, and their magnitude for neurons with the largest gradients. Specifically, in a model’s convolutional layer, each neuron possesses a magnitude (the norm of the weight) and a gradient magnitude (the norm of the gradient). In the upper part of Fig. 3, we select neurons from the convolutional layer that are in the top 0.01% in terms of gradient magnitude. We then calculate the percentage rank of these selected neurons based on their magnitude compared to all neurons in the convolutional layer. Similarly, in the lower part of Fig. 3, we select neurons that are in the top 0.01% in terms of weight magnitude and calculate the percentage rank of these based on their gradient magnitude relative to all neurons in the convolutional layer. Both histograms in Fig. 3 show a significant concentration in the two rightmost bins (which represent the range from 95% to 100%). This indicates that neurons with the highest weight magnitudes tend to have large gradients, and neurons with the highest gradients often have substantial weight magnitudes. This observation provides intriguing evidence of a similar distribution pattern for critical neurons under both criteria, explaining the observed correlation between these two metrics.
All the methods we assess utilize an Adam optimizer with a weight decay of 0 and a momentum of 0.9. The learning rate is held constant at 2 × 10−4 throughout the training process, which lasts for 20 epochs. Additionally, we implement a ReduceLROnPlateau scheduler, which dynamically adjusts the learning rate. This scheduler reduces the learning rate by a factor of 0.5 and waits for 1 epoch before making adjustments. In our experiments involving gradient-based pruning, we selected 256 of the poorest-performing samples and then randomly sampled 128 from this group to calculate both the loss and the gradient. Our model selection standard remains uniform across all methods. After each epoch, we evaluate the model’s performance on the validation set and choose the one that achieves the highest worst-group accuracy as the final model for testing. It’s important to note that all accuracy metrics presented in this paper are derived from the test set. For selecting positive and negative pairs, within our framework, we define positive pairs in contrastive learning as the output features that originate from the same input image. Conversely, when dealing with output features from different images within the batch, we regard them as negative samples relative to the current feature.
Data availability
The Waterbirds dataset is available at https://github.com/kohpangwei/group_DRO, formed from https://www.vision.caltech.edu/datasets/cub_200_2011/ and http://places2.csail.mit.edu/. And the CelebA dataset is available at http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html/. All requests from institution-affiliated researchers for access to processed data for purposes of study validation will be considered and should be directed to C.Y. (chenyu.you@yale.edu), and will be handled within 1 month. Source data are provided with this paper.
Code availability
The code that supports the findings of this study is available at https://github.com/aarentai/Silent-Majority.
References
Sagawa, S., Koh, P. W., Hashimoto, T. B. & Liang, P. Distributionally robust neural networks. In International Conference on Learning Representations (ICLR, 2020).
Geirhos, R. et al. Shortcut learning in deep neural networks. Nat. Mach. Intell. 2, 665–673 (2020).
Ribeiro, M. T., Singh, S. & Guestrin, C. "why should i trust you?” explaining the predictions of any classifier. In Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, 1135–1144 (ACM, 2016).
Beery, S., Van Horn, G. & Perona, P. Recognition in terra incognita. In Proceedings of the European conference on computer vision (ECCV), 456–473 (ECCV, 2018).
He, K., Zhang, X., Ren, S. & Sun, J. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, 770–778 (IEEE, 2016).
Dosovitskiy, A. et al. An image is worth 16x16 words: Transformers for image recognition at scale. In International Conference on Learning Representations (ICLR, 2021).
Blodgett, S. L., Green, L. & O’Connor, B. Demographic dialectal variation in social media: a case study of african-american english. In Proceedings of the 2016 Conference on Empirical Methods in Natural Language Processing, 1119–1130 (Association for Computational Linguistics, 2016).
Buolamwini, J. & Gebru, T. Gender shades: Intersectional accuracy disparities in commercial gender classification. In Conference on fairness, accountability and transparency, 77–91 (PMLR, 2018).
Hashimoto, T., Srivastava, M., Namkoong, H. & Liang, P. Fairness without demographics in repeated loss minimization. In International Conference on Machine Learning, 1929–1938 (PMLR, 2018).
Baldock, R., Maennel, H. & Neyshabur, B. Deep learning through the lens of example difficulty. Adv. Neural Inf. Proces. Syst. 34, 10876–10889 (2021).
Stephenson, C. et al. On the geometry of generalization and memorization in deep neural networks. In International Conference on Learning Representations (ICLR, 2021).
Maini, P. et al. Can neural network memorization be localized? In Proceedings of the 40th International Conference on Machine Learning, 202, 23536–23557 (PMLR, 2023).
Hooker, S., Courville, A., Clark, G., Dauphin, Y. & Frome, A. What do compressed deep neural networks forget? In ICML Workshop on Human Interpretability in Machine Learning (WHI, 2019).
Wah, C., Branson, S., Welinder, P., Perona, P. & Belongie, S. The caltech-ucsd birds-200-2011 dataset. Technical Report CNS-TR-2011-001 (California Institute of Technology, Computation & Neural Systems Technical Report, 2011).
Liu, Z., Luo, P., Wang, X. & Tang, X. Deep learning face attributes in the wild. In Proceedings of the IEEE international conference on computer vision, 3730–3738 (IEEE, 2015).
Chen, T., Kornblith, S., Norouzi, M. & Hinton, G. A simple framework for contrastive learning of visual representations. In International conference on machine learning, 1597–1607 (PMLR, 2020).
Sohn, K. Improved deep metric learning with multi-class n-pair loss objective. Advances in neural information processing systems 29 (2016).
Hui, L. & Belkin, M. Evaluation of neural architectures trained with square loss vs cross-entropy in classification tasks. In International Conference on Learning Representations (ICLR, 2021).
Deng, J. et al. Imagenet: A large-scale hierarchical image database. In IEEE conference on computer vision and pattern recognition, 248–255 (IEEE, 2009).
Idrissi, B. Y., Arjovsky, M., Pezeshki, M. & Lopez-Paz, D. Simple data balancing achieves competitive worst-group-accuracy. In Conference on Causal Learning and Reasoning, 336–351 (PMLR, 2022).
Kirichenko, P., Izmailov, P. & Wilson, A. G. Last layer re-training is sufficient for robustness to spurious correlations. In The Eleventh International Conference on Learning Representations (ICLR, 2023).
Qiu, S., Potapczynski, A., Izmailov, P. & Wilson, A. G. Simple and fast group robustness by automatic feature reweighting. In International Conference on Machine Learning, 28448–28467 (PMLR, 2023).
Yang, Y., Nushi, B., Palangi, H. & Mirzasoleiman, B. Mitigating spurious correlations in multi-modal models during fine-tuning. In International Conference on Machine Learning, 39365–39379 (PMLR, 2023).
Radford, A. et al. Learning transferable visual models from natural language supervision. In International conference on machine learning, 8748–8763 (PMLR, 2021).
Acknowledgements
C.Y. and J.S.D. were supported by NIH grants R01CA206180 and R01HL121226. H.D. and S.J. were supported by NSF grant DMS-1912030. The authors want to thank all the anonymous reviewers and editors for their constructive comments and suggestions that substantially improved this paper. C.Y. is the corresponding author of this paper.
Author information
Authors and Affiliations
Contributions
C.Y., H.D., and Y.M. developed and implemented new machine learning methods, benchmarked machine learning models and analyzed model behavior, all under the guidance of J.S.S., S.J., J.S.D., and performed a validation study to evaluate its effects. All authors discussed the results and contributed to the final manuscript. C.Y., H.D., and Y.M. designed the study.
Corresponding author
Ethics declarations
Competing interests
The authors declare no competing interests.
Peer review
Peer review information
Nature Communications thanks the anonymous reviewer(s) for their contribution to the peer review of this work. A peer review file is available.
Additional information
Publisher’s note Springer Nature remains neutral with regard to jurisdictional claims in published maps and institutional affiliations.
Supplementary information
Source data
Rights and permissions
Open Access This article is licensed under a Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International License, which permits any non-commercial use, sharing, distribution and reproduction in any medium or format, as long as you give appropriate credit to the original author(s) and the source, provide a link to the Creative Commons licence, and indicate if you modified the licensed material. You do not have permission under this licence to share adapted material derived from this article or parts of it. The images or other third party material in this article are included in the article’s Creative Commons licence, unless indicated otherwise in a credit line to the material. If material is not included in the article’s Creative Commons licence and your intended use is not permitted by statutory regulation or exceeds the permitted use, you will need to obtain permission directly from the copyright holder. To view a copy of this licence, visit http://creativecommons.org/licenses/by-nc-nd/4.0/.
About this article
Cite this article
You, C., Dai, H., Min, Y. et al. Uncovering memorization effect in the presence of spurious correlations. Nat Commun 16, 5424 (2025). https://doi.org/10.1038/s41467-025-61531-5
Received:
Accepted:
Published:
DOI: https://doi.org/10.1038/s41467-025-61531-5