Introduction

Benefiting from the advancements in artificial intelligence, a large number of computer vision techniques have been widely applied to various medical image segmentation tasks. Automatic and accurate medical image segmentation holds tremendous significance in medical diagnosis and treatment. Physicians can utilize segmentation results to design tailored treatment strategies, thereby facilitating patients’ recovery. Consequently, much attention has been devoted to developing promising techniques for medical image segmentation.

Convolutional Neural Networks (CNNs)1 are feedforward neural networks inspired by the visual cortex of biological organisms. Compared to traditional neural networks, CNNs fully leverage the convolution module, making them particularly suitable for computer vision tasks like image segmentation. In the field of medical image segmentation, a variety of network models utilize the convolution operation as a fundamental unit. Two main popular approaches, including Fully Convolutional Network (FCN)2 and U-Net3, have been presented for medical image segmentation. These CNN based methods have exhibit their good performance since they focus on capturing local correlations of input data due to their powerful ability to handle small receptive fields effectively. However, they are restricted in effectively capturing global contextual correlations of input data.

Recently, Transformer based methods4 have gained popularity due to their powerful ability of capturing long-range dependencies in the field of Natural Language Processing (NLP). So far, Transformer-based methods have been successfully utilized on various tasks5 such as object detection and recognition6,7,8, medical image segmentation9,10,11, depression and emotion recognition12,13,14, etc. For medical image segmentation, various works have explored the use of Transformer modules and its variants, such as Vision Transformer (ViT)15, Swin-Transformer16, Pyramid Vision Transformer (PVT)17, Swin-Unet18, and so on. These Transformer based models are capable of capturing global contextual correlations of input data, leading to their promising performance in medical image segmentation tasks. Nevertheless, they are limited in capturing the local correlations of input data.

In terms of the above observations, there exist a complementarity to some extent between CNN and Transformer based methods. Inspired by this, it is potential to integrate their complementary strengths in capturing the correlations of input data to further improve medical image segmentation tasks. In particular, CNN based methods are well-suited for capturing local correlations of input data due to its ability to handle small receptive fields effectively. In contrast, Transformer based methods excel at extracting global contextual correlations simultaneously. In this sense, the complementarity between CNN and Transformer methods allows for constructing a more robust network that can capture both local and global features simultaneously, leading to improving performance in medical image segmentation tasks. To achieve this, we propose a novel deep learning model called MedFuseNet which aims to fuse local and global deep feature representations with hybrid attention mechanisms for medical image segmentation. The proposed MedFuseNet takes full advantage of the corresponding strengths of CNN and Transformer methods to separately capture local and global correlations of input data for promoting performance on medical image segmentation tasks. For feature fusion and enhancement, the designed hybrid attention mechanisms incorporates four different attention modules: (1) an atrous spatial pyramid pooling (ASPP)19 module for the CNN branch, (2) a cross attention module in the encoder for fusing local and global features, (3) an adaptive cross attention (ACA) module in the skip connections for further performing fusion, and (4) a squeeze-and-excitation attention (SE-attention)20 module in the decoder for highlighting informative features. Two typical datasets such as ACDC21 and Synapse22, are employed for experiments, and experiment results show the validity of the proposed approach.

Related work

CNN based methods

CNN based methods have been widely applied to various segmentation applications. In the context of medical image segmentation, one of the prevalent CNN based methods is Fully Convolutional Network (FCN)2 with an encoder-decoder architecture. The encoder of FCN aims at extracting high-level features from input data, and comprises of several convolutional layers followed by pooling or strided convolutions. The decoder of FCN aims at recovering the spatial resolution of the original input by using transposed convolutions or upsampling operations. Nevertheless, FCN has some limitations in handling fine details and edge information, since FCN adopts the single receptive field size. To address this issue, researchers have designed various improved versions of FCN for medical image segmentation. One of the widely-used strategies for improving FCN is dilated convolution based methods allowing for an increased receptive field, such as Multi-scale Dilated Convolution Network (Md-Net)23, Pyramid Dilated Network (PyDiNet)24, and so on. Besides, another popular strategy is the typical U-Net3. U-Net incorporates an encoder-decoder architecture with skip connections. The used encoder in U-Net adopts convolution operations, while the used decoder employs deconvolution, which is the transposed operation of convolution. Inspired by the success of the used u-shape network in U-Net, at present various methods have devoted to improving the U-Net network, such as U-Net++25, U-Net 3+26, Atrous UNet+27, DenseRes-Unet28, EG-Unet29, FS-UNet30, CoLe-CNN+31, DeepLabv3+32, UCR-Net33, Dense-PSP-UNet34, Eres-UNet++35 and so on. UNet++25 aims to enhance performance by leveraging deeper and wider network structures, as well as interconnecting the encoder and decoder sub-networks via nested skip pathways. The UNet++ method achieved a higher average IoU than U-Net. The recently-developed Eres-UNet++35 aims to integrate a high-efficiency channel attention module with U-Net++ for liver CT image segmentation. Although, these CNN based methods have exhibited their promising performance on medical image segmentation tasks, they still face a challenge. That is, due to the inherent locality constraints in convolution operations, these network models focus on learning the local correlations of input data,and thus fail to capture the global contextual correlations effectively. As a result, they are limited in modeling long-range correlations of input data.

Transformer based methods

Attention based methods have garnered considerable interest in the field of computer vision due to their remarkable performance in NLP tasks. One of the most popular attention based methods is Transformer4, which allows for concentrating on the relevant parts of input data. The original Transformer utilizes a self-attention mechanism to capture global dependencies and contextual relationships between words in a sequence. Building upon the success of Transformers in NLP4, a number of Transformer variants have been developed, such as Vision Transformer (ViT)15, Swin-Transformer (ST)16, video vision Transformer (ViviT)36, Detection-Transformer37, Pyramid Vision Transformer (PVT)17, Swin-Unet18, ST-UNet38, TFormer39, SwinGAN40, and so on. Tang et al.41 developed a Swin-Transformer-based deep learning model equipped with a hierarchical encoder for self-supervised pre-training and applied to medical image segmentation. Cao et al.18 presented an Unet-like pure Transformer called Swin-Unet consisting of the Transformer-based U-shaped Encoder-Decoder architecture with skip-connections for medical image segmentation. Wang et al.17 proposed a Pyramid Vision Transformer (PVT), which utilized a pyramid structure within the Transformer architecture. The PVT comprises of four stages, each of which consists of an embedding layer and Transformer layers. These Transformer-based models aim at modeling the global contextual correlations of input data, thereby leading to impressive performance on various computer vision tasks. Nevertheless, these Transformer-based models are restricted in capturing the local correlations of input data.

Hybrid CNN-Transformer methods

In recent years, a few works present hybrid CNN-Transformer architectures for medical image segmentation tasks. The representative method is TransUNet42, which integrates the Transformer module into a U-Net architecture to enhance contextual understanding. TransClaw U-Net43 is a variant of TransUNet which integrates the convolution operation with the Transformer operation in the encoding part. Although these methods attempt to leverage the strengths of both CNNs for local feature extraction and Transformers for capturing long-range dependencies, they are still restricted in fully exploiting the potential strengths of combing CNN and Transformer methods, leading to their obtained limited performance in medical image segmentation tasks. This is because these existing hybrid CNN-Transformer architectures may not fail to effectively fuse the extracted CNN and Transformer feature maps.

To address this issue, this work proposes a new deep learning framework called MedFuseNet combining the strengths of CNN and Transformer methods with hybrid attention mechanisms for medical image segmentation. Specially, the proposed MedFuseNet employ CNNs and Swin-Transformer to separately capture local and global correlations of input data. Meanwhile, four different attention modules are incorporated in our model for further feature fusion and enhancement on medical image segmentation tasks, as displayed in Fig. 1.

Our method

Let X = (\(X_1\), \(X_2\), ..., \(X_N\)) denotes a set of medical images, and each image \(X_i\) consists of \(M_i\) pixels. Let \(p_j\) = (\(x_j\), \(y_j\), \(z_j\)) be the coordinate of the j-th pixel in \(X_i\). We define \(S_i\) = (\(s_1\), \(s_2\), ..., \(s_i\)) as the corresponding binary segmentation mask for image \(X_i\), where \(s_j\) \(\in\) 0, 1 indicates whether pixel \(p_j\) belongs to the background (\(s_j\) = 0) or the target object (\(s_j\) = 1).

For medical image segmentation, our goal is to learn a function F that maps an input image \(X_i\) into its corresponding segmentation mask \(S_i\). Specifically, we train a deep learning model to minimize the discrepancy between the predicted segmentation mask \(F(X_i)\) and the ground truth mask \(T_i\), i.e., \(F(X_i)\) \(\approx\) \(S_i\) \(\approx\) \(T_i\). Given an image x \(\in\) \(R^{H \times W \times C}\) with the spatial size of \(H\times W\) and C channels, we aim to predict the corresponding obtained label map with size \(H\times W\) for each pixel in an input image.

Overview

The pipline of the proposed method, named MedFuseNet, is illustrated in Fig. 1. As in Fig. 1, the proposed MedFuseNet consists of three key components: (1)an encoder integrating a CNN branch equipped with the atrous spatial pyramid pooling (ASPP)19 for local feature extraction, a Swin-Transformer branch for global feature extraction, and a cross attention module for local and global feature fusion, (2)an decoder with squeeze-and-excitation attention (SE-attention), and (3) skip connections with three CNN branches equipped with adaptive cross attention (ACA). The ASPP module is used to capture multi-scale contextual information from the extracted feature maps with the second CNN module (denoted by \(A_2\)), since it adopts multiple parallel atrous convolutional layers with different dilation rates. MedFusenet is an encoder-decoder structure, and can be trained with an end-to-end training scheme. To fully leverage the strengths of CNNs and Transformers, MedFusenet designs a hybrid attention mechanism strategy to effectively fuse local and global deep feature representations learned by CNN and Swin-Transformer, respectively. The designed hybrid attention mechanisms comprises of a cross attention module in the encoder, an ACA module in the skip connections and a SE-attention module in the decoder. In the followings, we will provide the details of these three components in our MedFuseNet.

Fig. 1
figure 1

The pipline of the proposed MedFuseNet, which aims to fuse local and global deep feature representations with hybrid attention mechanisms for medical image segmentation. MedFuseNet consists of three key components: (1) an encoder integrating a CNN branch equipped with the atrous spatial pyramid pooling (ASPP)19 for local feature learning, a Swin-Transformer (ST) branch for global feature learning, and a cross-attention module for fusing local and global features, (2) an decoder incorporated with a squeeze-and-excitation attention (SE-attention) module20, and (3) skip connections with three CNN branches equipped with an adaptive cross attention (ACA) module.

Encoder

As shown in Fig. 1, the designed encoder in our MedFuseNet consists of a dual downsampling module with CNNs and Swin-Transformer16, and a cross attention module. The dual downsampling module contains a CNN branch for capturing local features, and a Swin-Transformer (ST) branch for capturing global features. To combine the extracted local and global features, a cross attention module is designed. These relevant modules are described below in detail.

CNN Branch for local feature learning

The CNN branch in the encoder initially utilizes a pre-trained CNN model as the backbone to capture local feature s of input data. Given an input image X \(\in\) \(R^{H\times W \times C}\) with spatial dimensions H \(\times\) W and C channels, it is fed into a four-level CNN network for local feature learning. In the CNN branch, the input feature map 224\(\times\)224\(\times\)3 is first processed through a 7\(\times\)7 convolutional layer, followed by a 3\(\times\)3 max-pooling layer. Then, three small convolutional blocks, each of which contains two 1\(\times\)1 convolution, a 3\(\times\)3 convolution, are used to convert 224\(\times\)224\(\times\)3 into 56\(\times\)56\(\times\)256. Similar operations are performed in different spatial levels. Additionally, we add a atrous spatial pyramid pooling (ASPP) module19 in the second CNN module (denoted by \(A_2\)) to enrich the obtained features such as the boundary information. Specifically, the designed ASPP module is capable of capturing multi-scale contextual information from the extracted feature maps, since it adopts multiple parallel atrous convolutional layers with different dilation rates.

In the last three CNN modules (\(A_2\), \(A_3\), \(A_4\)), the outputs of each CNN module are connected to the corresponding outputs of the used Swin-Transformer modules, followed by an adaptive cross attention (ACA) module for fusing the learned features by CNN and Swin-Transformer. The obtained feature maps are then connected to the decoder section through the skip connections.

The operation is adopted to recover any potential information loss during the processing. The Swin-Transformer obtains additional valuable information by applying a 1x1 convolution and the adaptive cross attention module to the output of the CNN module. The resulting feature map, which includes the added features, is then connected to the decoder section through a skip connection. The strategy helps to compensate for any missing low-level information. By incorporating these mechanisms, our module aims to capture both global and local features effectively while preserving important boundary details.

Swin-Transformer branch for global feature learning

The Swin-Transformer branch in the encoder employs a pre-trained Swin-Transformer model as the backbone to capture global features of input data. The Swin-Transformer16 comprises of several essential steps when processing an input image. First, patch partitioning is conducted. The input image X is divided into non-overlapping patches which are reshaped into a sequential representation. Each patch, denoted as x \(\in\) \(R^{P^{2} \times C}\), corresponds to a specific spatial location. Here, P represents the number of patches along one spatial dimension, and C represents the number of channels in each patch. Each patch is taken as a token and processed individually. The total number of patches is decided by N=\(\frac{H \times W}{P^{2}}\), where H and W are the height and width of the input image, respectively. Second, patch embedding is performed. Each patch is embedded into a low-dimensional embedding space by using a trainable linear projection. Meanwhile, position embeddings are also incorporated to the patch embedding for capturing positional information. Third, shifted windows are leveraged to processes the divided patches hierarchically.

It is pointing out that the Swin-Transformer module consists of two successive improved Transformer blocks: window-based multi-head self-attention (W-MSA) and shifted window-based multi-head self-attention (SW-MSA). The W-MSA block applies self-attention to the local windows of size M \(\times\) M, resulting in its linear complexity. The SW-MSA block employs a shifted windowing configuration to facilitate cross-window connections. This process is represented mathematically by Equations (1-4).

$$\begin{aligned} \begin{aligned} \hat{Z}^{l} = \text {W-MSA}(\text {LN}(\hat{Z}^{l-1})+\hat{Z}^{l-1}) \end{aligned} \end{aligned}$$
(1)
$$\begin{aligned} \begin{aligned} \hat{Z}^{l} = \text {MLP}(\text {LN}(\hat{Z}^{l})+\hat{Z}^{l}) \end{aligned} \end{aligned}$$
(2)
$$\begin{aligned} \begin{aligned} \hat{Z}^{l+1} = \text {SW-MSA}(\text {LN}(\hat{Z}^{l})+\hat{Z}^{l}) \end{aligned} \end{aligned}$$
(3)
$$\begin{aligned} \begin{aligned} \hat{Z}^{l+1} = \text {MLP}(\text {LN}(\hat{Z}^{l+1})+\hat{Z}^{l+1}) \end{aligned} \end{aligned}$$
(4)

Cross attention module for fusing local and global features

To effectively fuse the learned local and global features from the corresponding CNN and Swin-Transformer branches, a cross attention module is designed, as illustrated in Fig. 2.

Fig. 2
figure 2

The flowchart of the designed cross attention module for fusing local (\(X_1\)) and global (\(X_2\)) features.

Given two input feature maps \(X_i\) (\(i=1,2\)), such as local features \(X_1\), global features \(X_2\), the class token (denoted as CLT) related to \(X_1\) and \(X_2\) is defined as:

$$\begin{aligned} CLT=\text {GAP}(Norm(X_i)) \end{aligned}$$
(5)

where ’GAP’ stands for global average pooling operation and ’Norm’ represents the normalization operation. We take their class token as the query vector Q, and all tokens as the key vector K, and the value vector V. Then, a matrix multiplication between Q and K is performed, followed by a Softmax operation to calculate the attention feature maps Attention(QKV) for \(X_1\) and \(X_2\). Finally, the obtained attention feature maps for \(X_1\) and the original input feature map \(X_2\) are concatenated as the final output. This process is defined as:

$$\begin{aligned} Attention(Q,K,V)=Softmax(\frac{QK^{T}}{\sqrt{d}})V \end{aligned}$$
(6)
$$\begin{aligned} Output(X_1,X_2)=Concat(Attention(Q,K,V),X_2) \end{aligned}$$
(7)

When the original input feature map is \(X_1\), the above process of concatenation is repeated. In this sense, we can obtain two results of \(Output(X_1,X_2)\), followed by a summation as an input of the decoder. This class token serves as the query vector in our attention mechanism, thereby reducing the computational burden to some extent in comparison with the common attention models that process all tokens such as ViT and Transformer. The designed cross attention module integrates local and global features learned by the related CNN and Swin-Transformer branches, and produces comprehensive feature vectors for downstream tasks.

Skip connections

Skip connections serve as foundational components in encoder-decoder networks, effectively enhancing the feature propagation within deep neural networks. Similar to U-Net3, the used skip connections between the encoder and decoder exclusively connect features sharing the identical spatial resolutions. Additionally, the obtained features at various levels within the encoder contain distinct clues. Fusing these distinct features at various levels is thus beneficial for ultimate segmentation tasks. To this end, we consider connecting three kinds of features from different spatial levels (\(A_2\), \(A_3\), \(A_4\)) in the encoder to retain a richer set of feature representations for downstream tasks, as shown in Fig. 1. A simple CNN operation such as \(1 \times 1 \times d\) is added to the ST branch within the skip connections. This makes the obtained feature maps of the ST branch has the same dimension as the CNN branch. To effectively integrate local feature maps from the CNN branch and global features maps from the ST branch at a specific spatial level, we design an adaptive cross attention (ACA) module to dynamically adjust the feature importance from these two branches for further fusion.

Fig. 3
figure 3

The flowchart of the designed adaptive cross attention module.

As depicted in Fig 3, the designed adaptive cross attention (ACA) module contains three steps, as described below.

Initially, similar to the used cross attention module in Fig. 2, in terms of the obtained feature maps \(M_1\) from the CNN branch, we take their class token as the query vector Q, and all tokens as the key vector K, and the value vector V. Here, all tokens refer to the whole feature maps that are extracted from the CNN branch. In this case, we can calculate the attention feature maps Attention(QKV), and perform a concatenation operation between the output of the ST branch and the attention feature maps to achieve the output \(M_2\). This process is defined as:

$$\begin{aligned} \begin{aligned} M_2 = Concat(Attention(Q,K,V), Output(ST)) \end{aligned} \end{aligned}$$
(8)

Secondly, two non-linear transformations, denoted by \(F_{\text {gp}}\) and \(F_{\text {fc}}\), are applied to the sum of \(M_1\) and \(M_2\). Here, \(F_{\text {gp}}\) represents the global average pooling (GAP) operation, and \(F_{\text {fc}}\) denotes the fully-connected (FC) layer. This allows the ACA module to further refine the extracted features and capture relationships between them. This process is expressed as:

$$\begin{aligned} \begin{aligned} M_3 = F_{\text {fc}}(F_{\text {gp}}(M_1 + M_2)) \end{aligned} \end{aligned}$$
(9)

Finally, the output of the ACA module is decided by a combination of \(M_1\) and \(M_2\) based on the adaptive weights, which are calculated by applying the Softmax function on \(M_3\). In particular, this step leverages the Softmax function to adaptively assign the weight values to \(M_1\) and \(M_2\) with an element-wise product. Then, an element-wise summation is used to linearly merge \(M_1\) and \(M_2\) into the final output of the ACA module. The process is expressed as:

$$\begin{aligned} \begin{aligned} Output = Softmax(M_3) \otimes M_1 + Softmax(M_3) \otimes M_2 \end{aligned} \end{aligned}$$
(10)

Decoder

A decoder aims to take the low-resolution feature maps from the encoder and upsample them to the original input image size or a higher resolution. The decoder typically comprises multiple upsampling layers, commonly referred to as deconvolutions, that enhance the spatial resolution of the feature maps. To highlight informative features while suppressing less relevant ones, we incorporate a squeeze-and-excitation attention (SE-attention)20 module between two deconvolution layers B1 and B2. The SE-attention module aims to adaptively recalibrate the feature responses across channels, allowing the network to focus on more informative channels and suppress less relevant ones, thus enhancing the representational ability of feature maps.

Datasets

To testify the performance of the proposed method, two typical datasets such as ACDC21 and Synapse22, are employed for experiments.

The ACDC dataset is specifically designed for cardiac diagnosis and aims to segment the left ventricle (LV), right ventricle (RV), and myocardium (Myo) in MRI images. This dataset consists of MRI images from 100 different patients, with 1304 slices in the training set, 182 slices in the validation set, and 40 slices in the testing set. We preprocess the ACDC images by applying intensity normalization, spatial resampling, Gaussian smoothing, and data augmentation with random rotations and flips. As done in42, we evaluate the performance of our method by reporting the average Dice Similarity Coefficient (DSC) on these three classes.

The Synapse dataset is a multi-organ segmentation dataset. Following in22, we use 30 abdominal CT scans from the MICCAI 2015 Multi-Atlas Abdomen Labeling Challenge for experiments. In this case, this dataset contains a total of 3,779 axial contrast-enhanced abdominal clinical CT images. Each CT volume contains a certain number of slices ranging from 85 and 198, each of which has a spatial resolution of 512 \(\times\) 512 pixels. The voxel spatial resolution for each slice is ranged from [0.54 \(\sim\) 0.54] mm² to [0.98 \(\sim\) 0.98] mm² in the transverse plane and [2.5 mm to 5.0 mm] in the axial direction. We preprocess the Synapse images by using intensity normalization, artifact removal and data augmentation with random rotations and flips. As done in42, we evaluate the performance of our method by reporting the average Dice Similarity Coefficient (DSC) on these eight classes.

Implementation details

All used methods are configured based on the Python 3.8 and PyTorch 1.8.1. Deep learning algorithms are performed on a single Nvidia RTX 3090 GPU with 24 GB of memory. During the training of deep learning, data augmentation techniques such as random rotation and flips, are leveraged. The input image is fixed at 224\(\times\)224 pixels. We initialize the network parameters of our MedFuseNet method by using the weights of pre-trained CNN and Swin-Transformer models. For initialization we utilize the pre-trained ResNet5044 and Swin-S16 for the CNN and ST branches, separately. The batch size is set to 4 and the learning rate is 0.01. The SGD optimizer is adopted with a momentum of 0.9 and a weight decay of 1e-4 for the backpropagation optimization process.

Evaluation metrics

To evaluate the performance of all used methods, two typical evaluation metrics, including Dice Similarity Coefficient (DSC) and Hausdorff Distance (HD), are employed. This is because these two metrics are widely used for a comparison in existing literatures when performing experiments on the used ACDC and Synapse datasets.

The DSC metric is utilized to assess the degree of overlap between the predicted segmentation and the corresponding ground truth. It quantifies the similarity between two sets by measuring the overlap of their binary masks. The larger the DSC metric is, the better segmentation results are. The HD metric is employed to evaluate the quality of segmentation boundaries. It measures the maximum distance between points on the predicted segmentation boundary and the nearest point on the ground truth boundary. The smaller the HD metric is, the better the quality of segmentation boundaries is. the These two metrics are defined as:

$$\begin{aligned} \text {DSC} = \frac{2|P \cap G|}{|P| + |G|} \end{aligned}$$
(11)
$$\begin{aligned} \text {HD}(P, G) = max[\text {D}(P, G),\text {D}(G, P)] \end{aligned}$$
(12)
$$\begin{aligned} \text {D}(P, G)=\max _{p \in P} \min _{g \in G}\Vert p-g\Vert \end{aligned}$$
(13)

where p and g denote coordinate vectors of two pixels, P and G represent the coordinate sets of predicted segmentation and the ground truth, respectively. The symbol \(\cap\) is an intersection operator for two sets. \(|\quad |\) is the pixel number, and \(||\quad ||\) is the L2 norm of a vector.

Results and analysis

To demonstrate the advantages of our method, we compare our method with the several state-of-the-art methods, as listed below.

U-Net3: an U-shaped neural network architecture for semantic segmentation that combines an encoder for feature extraction and a decoder for feature reconstruction.

R50 U-Net3: the ResNet-50 backbone is used in the encoder of U-Net.

AttnUNet45: attention gated networks that enable U-Net to learn salient regions in medical images.

R50 AttnUNet42: the ResNet-50 backbone is integrated with attention gated networks.

R50 ViT42: the ResNet-50 backbone is integrated with ViT15 for the hybrid encoder design.

ViT-CUP42: a Cascaded Upsampler (CUP) decoder is integrated with ViT15.

R50 ViT-CUP15: the ResNet50 backbone is integrated in the decoder with ViT-CUP.

Swin-Unet18: an U-Net-like pure Transformer-based U-shaped Encoder-Decoder architecture.

TransUnet42:an integration of both Transformers and U-Net for medical image segmentation.

DARR22: an unsupervised domain adaptation model called Domain Adaptive Relational Reasoning (DARR) for enhancing the generalization capabilities of 3D multi-organ segmentation models.

DeepLabv3+32: an extension of DeepLabv319 equipped with atrous separable convolution.

DuAT46: a Dual-Aggregation Transformer Network called DuAT equipped with the Global-to-Local Spatial Aggregation (GLSA) and Selective Boundary Aggregation (SBA) modules for medical image segmentation.

DAE-Former47: a Dual Attention-guided Efficient Transformer called DAE-Former that reformulates the self-attention mechanism to efficiently capture spatial and channel relations.

TransAttUnet48: a Transformer-based Attention Guided Network called TransAttUnet with multi-level guided attention and multi-scale skip connection for medical image segmentation.

TransClaw U-Net43: a claw U-Net with Transformers in which a bottom upsampling part is added to retain deep features for detail segmentation.

Table 1 Performance comparisons of different methods on the ACDC dataset.
Table 2 Performance comparisons of different methods on the Synapse dataset.

Table 1 and Table 2 individually present the segmentation performance comparisons of different methods on the ACDC and Synapse datasets. Concurrently, the visualization outcomes are depicted in Fig. 4 and Fig. 5. It is noted that we employ the DSC metric on the ACDC dataset for evaluation as done in previous comparing works. By contrast, we utilize the DSC and HD metrics on the Synapse dataset for evaluation.

From the results in Table 1 and Table 2, we can make the following observations:

  1. (1)

    Our method obtains the best performance on both the ACDC and Synapse datasets. In particular, our method achieves the highest average DSC of 89.73% on the ACDC dataset, and 78.40% on the Synapse dataset. Moreover, on the Synapse dataset our method yields the lowest HD of 18.44. This demonstrates the advantages of our method on medical image segmentation tasks. This is attributed to be that our method is capable of combining the strengths of CNN and Swin-Transformer models in a complementary manner, thus fusing local and global deep feature representations with hybrid attention mechanisms for improving the performance of medical image segmentation.

  2. (2)

    These comparing methods, such as U-Net or Transformer-based methods, focus on the single local or global feature learning tasks for medical image segmentation. In this case, they produce lower performance than our method on both the ACDC and Synapse datasets for medical image segmentation. This indicates the validity of our method.

  3. (3)

    For specific organ segmentation, on the ACDC dataset our method almost achieves the highest performance. Our method gives the highest DSC of 86.29%, and 94.54% for Myo and LV, respectively, and presents the second DSC of 88.36% for RV. On the Synapse dataset our method yields the highest DSC of 85.10%, 93.98%, 78.74% for kidney(L), liver, and stomach, respectively. Although our method does not perform best for each organ segmentation, it produces the highest average DSC for segmenting 8 organs, as listed in Table 2.

Fig. 4
figure 4

The visual assessment of diverse general medical image segmentation methodologies is conducted on the ACDC dataset to discern their comparative performance.

Fig. 5
figure 5

The visual assessment of diverse general medical image segmentation methodologies is conducted on the Synapse dataset to discern their comparative performance.

Ablation Study

In order to investigate the impact of various components in our model, we conduct an ablation study on the ACDC and Synapse datasets. Specifically, we explore the effectiveness of different modules such as the CNN branch, the Swin-Transformer (ST) branch, cross attention, atrous spatial pyramid pooling (ASPP), SE-Attention, adaptive cross attention (ACA), skip connections, and so on. The detailed results and analysis are presented below.

Table 3 The effect of different modules in our method on the ACDC dataset.
Table 4 The effect of different modules in our method on the Synapse dataset.

Effect of different modules

Table 3 and Table 4 separately display the segmentation performance on these two datasets when considering the single absent module (denoted by w/o) in our method. As shown in Table 3 and Table 4, the following observations can be obtained.

  1. (1)

    For the backbone networks in the encoder, the CNN branch (w/o Swin-Transformer) clearly outperforms the Swin-Transformer branch (w/o CNN). This shows the advantages of the CNN branch with local feature learning over the Swin-Transformer branch with global feature learning. On the ACDC dataset the CNN branch gives an average DSC of 85.75%, whereas the Swin-Transformer branch achieves an average DSC of 47.16%. Likewise, on the Synapse dataset the CNN branch obtains an average DSC of 72.73%, the Swin-Transformer branch yields an average DSC of 28.60%. The results highlight the necessity of incorporating both CNN and Swin-Transformer backbones.

  2. (2)

    For the used hybrid attention mechanisms, including cross attention, adaptive cross attention, ASPP, SE-Attention, the results in Table 3 and Table 4 indicate the effectiveness of these used attention modules. In particular, on the Synapse dataset whether the relevant attention modules are utilized or not has relatively a big impact on the final target segmentation tasks.

  3. (3)

    For the impact of skip connection, it can be seen that skip connection is an important component in the encoder-decoder architecture, and aims to connect features between the encoder and decoder for feature propagation. On the Synapse dataset leveraging the skip connection module yields an improvement of 12.74% on the segmentation performance.

Effect of different number of skip connections

As skip connections are employed to facilitate the recovery of image information during the upsampling process, the number of skip connections may have an important impact on final image segmentation tasks. To evaluate the importance of skip connections, we compare the performance of different number of skip connections. Table 5 and Table 6 present the obtained results on the ACDC and Synapse datasets, respectively, when using different number of skip connections. The results in Table 5 and Table 6 reveal that three skip connections (denoted by 3-skip) yield the best average DSC on these two datasets. In particular, three skip connections (3-skip) obtain the highest average DSC of 89.73%, and 78.40% on the ACDC and Synapse datasets, respectively. The findings indicate that incorporating three skip connections yields the best performance for our method.

Table 5 The effect of different number of skip connections on the ACDC dataset.
Table 6 The effect of different number of skip connections on the Synapse dataset.
Table 7 The effect of different positions of ASPP and SE-Attention modules on the ACDC dataset.
Table 8 The effect of different positions of ASPP and SE-Attention modules on the Synapse dataset.

Effect of different positions of attention modules

In the encoder and decoder, two attention modules including atrous spatial pyramid pooling (ASPP) and SE-Attention are adopted for feature enhancement. The ASPP module aims to capture multi-scale contextual information from the extracted feature maps. The SE-Attention module aims to emphasize informative features while suppressing less relevant ones. However, what is the optimal position embedded in the encoder and decoder for these two attention modules?

To address this issue, we explore the effect of different positions of attention modules on these two datasets, as listed in Table 7 and Table 8. The results in Table 7 and Table 8 show that the first line represents the optimal positions for these two attention modules, which are also marked up in the flowchart of our proposed method, as illustrated in Fig. 1.

Effect of different pre-trained networks

To investigate the effect of different pre-trained networks on the CNN branch, we present a comparative analysis of different pre-trained networks employed as the backbone in our MedFuseNet. Specially, we compare the performance of ResNet50, ResNet101, and WiderResNet50 which are utilized as the backbone on the CNN branch when conducting experiments on the ACDC and Synapse datasets.

Tables 9 and 10 provide the results of different pre-trained networks on these two datasets. The results in Tables 9 and 10 indicate that the pre-trained ResNet50 obtains the highest overall performance in terms of the reported average DSC, although WiderResNet50 and ResNet101 occasionally outperform ResNet50 in one or two organs. That is why we adopt the pre-trained ResNet50 as the backbone on the CNN branch.

Table 9 The effect of different pre-trained networks on the ACDC dataset.
Table 10 The effect of different pre-trained networks on the Synapse dataset.

Discussion and conclusion

In this work, we present a novel model called MedFuseNet for medical image segmentation, in which CNN-based local and ST-based global deep representations are effectively combined by using hybrid attention mechanisms. For feature fusion and enhancement, four different attention modules, such as an ASPP module, a cross attention module, an ACA module, and a SE-attention module, are designed and combined in a suitable manner. Comprehensive experiments are performed on the public ACDC and Synapse datasets, and the results demonstrate the advantages of our proposed approach, achieving superior performance to state-of-the-art approaches. This shows that our method can effectively fuse local and global deep feature representations with the designed hybrid attention mechanisms.

It is pointed out that our method is just verified on two public datasets, which are relatively small. However, deep learning models usually rely on a large number of samples for training so as to obtain a good generalization ability. In this sense, it is desirous to collect much more samples for training our method in future. In addition, it is interesting to testify our method on real-world clinical scenarios for assisting doctor’s diagnosis and treatment.