Introduction

Reinforcement Learning (RL), a core branch of machine learning, has opened up new avenues for solving complex problems through a mechanism where an agent continuously interacts with its environment through trial and error to maximise long-term cumulative rewards1. This mechanism has demonstrated tremendous potential in various cutting-edge fields. In robotics2, RL helps robots acquire complex manipulative skills. In autonomous driving technology3, it enables vehicles to navigate safely and efficiently in complex traffic environments. In game intelligence4, RL allows computers to master superior gaming skills. Additionally, in medical diagnosis, RL has played a crucial role in further enhancing the diagnostic accuracy of models5.

To establish the theoretical framework of RL, researchers have integrated knowledge from various disciplines, such as probability theory, dynamic programming, and control theory, driving continuous innovation in algorithms. Introducing classic algorithms such as Q-learning6 and SARSA7 has laid a solid foundation for modeling intelligent behavior. However, as application fields expand, traditional algorithms face high computational complexity and low sample efficiency challenges when dealing with high-dimensional and continuous state and action spaces8. To overcome these limitations, researchers have begun exploring a new direction by combining Transformers with RL. Leveraging their ability to capture long-range dependencies9 to handle high-dimensional and continuous states and action spaces. Among them, the Decision Transformer10 proposed by Chen et al. has emerged as a representative achievement in this field, marking a significant breakthrough in algorithm design. Combining Transformers with RL11,12 not only overcomes the limitations of traditional algorithms but also significantly improves the performance and efficiency of RL, opening up vast prospects for future research and applications.

Existing reinforcement learning methods based on Transformer can be primarily categorized into online and offline reinforcement learning, as shown in Table 1. Online reinforcement learning refers to the process where an agent learns through real-time interaction with the environment, such as the Online Decision Transformer (ODT)13 and other methods that rely on real-time online interaction with the environment for reinforcement learning. However, these methods often come with high costs and potential risks in practical application scenarios14. Offline reinforcement learning methods, on the other hand, delve deeply into and analyze existing datasets to extract valuable decision-making information and patterns without the need for real-time environmental interaction. Wang et al. proposed the Bootstrapping Transformer (BooT)15, which introduces the bootstrapping idea into the Transformer to generate richer offline data for training. Konan et al. proposed the Contrastive Decision Transformer (ConDT)16, which leverages contrastive loss to cluster embedding features of different goal returns, deepening the excavation and utilization of offline data features. Clinton et al. proposed the Planning Transformer17, which includes Planning Tokens that encapsulate long-term information, improving the autoregressive compounding errors of Transformers in long-term tasks. The research on these methods has contributed to the refinement of Transformer algorithms in offline reinforcement learning, providing core technical support for the efficient deployment and development of this field in more practical scenarios.

Table 1 Summary of transformer-based reinforcement learning methods.

However, Transformer-based offline reinforcement learning faces limitations in trajectory stitching, a crucial capability that allows algorithms to explore better strategies by combining suboptimal trajectory segments. To tackle this issue, Zhang et al. introduced ContextFormer18, which effectively stitches suboptimal trajectory segments by integrating context-based imitation learning. Yamagata et al. attempted to reconnect trajectory segments by relabeling the dataset with the Q-learning Decision Transformer (QDT)19. Nevertheless, these methods overlook the influence of environmental stochasticity during trajectory stitching. Although the Advantage-based Clipping Transformer (ACT)20 algorithm proposed by Gao et al. enhances trajectory stitching capabilities and improves robustness to environmental stochasticity by assessing action quality through the advantage function, existing Transformer-based reinforcement learning methods, when processing multimodal information in trajectories, a common strategy is to simply stack all modal information along the length and input it into a single Transformer. This approach not only leads to the dispersion of the model’s attention resources when processing mixed information, making it difficult to deeply mine the inherent deep-level correlation information within each single modality, but also results in inefficient and insufficient cross-modal interaction. That is, these methods have failed to achieve deep interaction and fusion within and between modalities, limiting the algorithm’s ability to deeply understand and adapt to complex and dynamically changing environments.

This study proposes an offline reinforcement learning method, namely CGM, that combines generalized advantage estimation and modality decomposition interaction (MDI) to address the questions posed. This method enhances the model’s ability to stitch trajectories from offline datasets. Moreover, it fully captures the intra- and inter-modal interactions of multimodal information within trajectories, strengthening the agent’s comprehensive understanding of the data and improving its decision-making capabilities. The core contributions of this study are summarized as following:

  1. (1)

    In this study, CGM employs generalized advantage estimation to improve the model’s trajectory stitching in offline datasets, enhancing the agent’s ability to learn optimal strategies.

  2. (2)

    CGM incorporates MDI, leveraging an encoder-decoder framework to maximise the exploitation of intra- and inter-modal correlations embedded in multimodal trajectory data. This strategy enhances the agent’s ability to understand and optimally use the trajectory data.

  3. (3)

    The encoder of the MDI employs ConvFormer-based intra-modal interaction to process uni-modal temporal information effectively, assisting the model in capturing the inherent local correlations within state and action sequences. Additionally, it utilizes dual-transformer-based inter-modal interaction to conduct multimodal information exchange separately for states and actions, thereby enhancing the model’s representation ability for state and action sequences.

  4. (4)

    This study compares the proposed approach with eight other state-of-the-art offline reinforcement learning methods on the D4RL dataset. The results demonstrate that our method effectively improves the agent’s decision-making capabilities and achieves significant improvements in offline tasks.


The organization of this paper is as follows. In Section “Related work”, we review related work pertinent to the method proposed in this paper. Section “Method” expounds on the detailed architecture and implementation specifics of our proposed multimodal medical image fusion method, which leverages dilated convolution and attention-mechanism-based graph convolutional networks. Section “Experiments” meticulously outlines the experimental setup we designed, accompanied by an in-depth analysis and exhibition of the experimental results. Lastly, Section “Conclusion and future work” provides a concise conclusion summarizing the entire paper.

Related work

Decision transformer (DT)

Decision Transformer (DT)10 is a deep learning model that handles sequential decision-making tasks. It integrates traditional reinforcement learning methods with the Transformer architecture to address sequence-based decision problems. Specifically, DT selects trajectories \(\tau\) of length T from an offline dataset D and calculates the cumulative discounted sum of future rewards r from time step t to T, obtaining the future discounted return RTG: \({ \mathop {R}\limits^{\frown}} _{\text{t}} = \sum\nolimits_{t^{\prime } = t}^{T} {r_{t^{\prime } } }\) at each time step t. Return-to-Go (RTG) represents the cumulative discounted reward from the current time step t to a future time step T, used to guide the agent in making optimal decisions in sequential decision tasks. DT then constructs a sequence from RTG, past state, and action information, using stacked attention and residual connections in the Transformer architecture to produce accurate outputs, as shown in Eq. (1).

$$\tau _{T} = \left\{ {{\mathop R\limits^{\frown }}_{1} ,s_{1} ,a_{1} ,{\mathop R\limits^{\frown }}_{2} ,s_{2} ,a_{2} , \ldots ,{\mathop R\limits^{\frown }}_{T} ,s_{T} ,a_{T} } \right\}$$
(1)

where and a represent the state and the action, respectively, when handling long input sequences and generating accurate outputs, DT performs exceptionally well. However, it faces limitations in capturing the inherent local dependency patterns within trajectories. To address this issue, Kim et al. and Shang et al. proposed Decision Convformer21 and StARTransformer22, respectively, which aim to enhance Decision Transformer’s ability to model local relationships in sequences. However, these improved methods also have their shortcomings. Firstly, they overlook the importance of multimodal information interactions within sequences23. Multimodal information interactions refer to the effective integration of different types of information to better understand complex environments. Secondly, these methods have limitations in their trajectory stitching capabilities24, which could lead to agents focusing on suboptimal trajectories that provide immediate rewards while neglecting those that might yield higher cumulative rewards. These issues pose significant challenges in many offline reinforcement learning applications, and addressing them would significantly improve the effectiveness of Decision Transformer in real-world problem-solving.

Generalized advantage estimation

Generalized Advantage Estimation25 is a method used in reinforcement learning to estimate advantage values. It can be viewed as a hybrid of Temporal Difference (TD) and Monte Carlo estimation methods, aiming to improve the policy update process by more accurately estimating the advantage of actions. Generalized advantage estimation combines the strengths of TD methods and Monte Carlo methods by introducing a truncation parameter λ to balance the weight of long- and short-term rewards, as defined by Eqs. (2) and (3).

$$\delta_{t} = R(t) + \gamma *V\left( {s_{t + 1} } \right) - V\left( {s_{t} } \right)$$
(2)
$${\mathop{A}\limits^{\frown}}_t^{GAE(\gamma ,\lambda )} = \sum\limits_{l = 0}^\infty {{(\gamma \lambda )}^l}\delta _{t + l}$$
(3)

In the Eqs. (2) and (3), \(\delta_{t}\) represents temporal difference at time step t, \(R(t)\) represents the immediate reward at time step t, γ represents the discount factor, which weighs the importance of current and future rewards, and \(V\left( {s_{t} } \right)\) represents the estimated state value function at time t, λ represents a parameter that adjusts the weight between TD and Monte Carlo, and \({\mathop{A}\limits^{\frown}}_{t}^{GAE(\gamma ,\lambda )}\) represents the advantage estimate along a given trajectory for the continuous l time steps starting from the current time step.

Generalized advantage estimation significantly improves the efficiency and performance of policy learning in reinforcement learning. Specifically, generalized advantage estimation stabilizes and enhances the learning process in complex robotic manipulation tasks26 by reducing the variance of policy estimates. In autonomous vehicle navigation tasks, generalized advantage estimation effectively estimates the advantage of each action, thereby updating both the actor and critic networks to improve navigation performance. Additionally, Gao et al.20 utilized generalized advantage estimation to estimate state values, successfully addressing the stitching capability issues of decision Transformer models when handling continuous decision problems. This further demonstrates generalized advantage estimation’s broad application potentiality and practical support for complex tasks in reinforcement learning. These studies highlight generalized advantage estimation’s adaptability and advantages across various tasks and environments.

Method

The offline reinforcement learning method proposed in this paper, named Generalized Advantage Estimation with Modal Decomposition Interaction (GCM), consists of two components: the Advantage Estimation Module and the Modal Decomposition Interaction Module (MDI), as illustrated in Fig. 1. The Advantage Estimation Module employs Generalized Advantage Estimation to estimate the advantages of states based on the dataset, aiming to improve the continuity of the original trajectories and enhance the model’s trajectory stitching capability. The MDI comprises an encoder and a decoder. The encoder utilizes intra-modal interaction based on ConvFormer and inter-modal interaction based on dual Transformers to fully capture intra-modal and inter-modal information of states and actions, thereby enhancing the high-level representations of states and actions. The decoder further integrates the advantage values with the high-level state representations to improve the accuracy of action prediction.

Fig. 1
Fig. 1
Full size image

The framework for CGM.

Specifically, first, we feed the dataset \(D = \{ \tau \} = \left\{ {r_{t} ,s_{t} ,a_{t} } \right\}_{t}^{N}\) into the advantage estimation module and conduct advantage estimation on the state to obtain the advantage value AE. We construct a new dataset \(D^{AE}\) using AE to replace the RTG value in the reinforcement learning sequence model \(D^{AE} = \left\{ {\tau^{AE} } \right\} = \left\{ {AE_{t} ,s_{t} ,a_{t} } \right\}_{t = 1}^{N}.\) Next, we input \(D^{AE}\) into the MDI module for feature learning and action prediction. First, we feed the state s and action a into the encoder. Through intra-modal interaction based on ConvFormer and inter-modal interaction based on dual Transformer, we explore the profound correlation between actions and states, generating the state’s high-level feature representation \(\overline{H}^{S}\). Finally, in the decoder, we fuse the advantage value AE with \(\overline{H}^{S}\) to complete action \(\mathop A\limits^{\frown }\) prediction, achieving policy optimization for offline reinforcement learning.

Advantage estimation module

The advantage estimation module utilizes Generalized Advantage Estimation5 to provide precise estimates of state values. This module deeply annotates and processes the offline dataset D, resulting in the creation of a new dataset, also denoted as \(D^{AE} = \left\{ {\tau^{AE} } \right\}\). where \(\tau^{AE} = \left\{ {AE_{t} ,s_{t} ,a_{t} } \right\}_{t = 1}^{N}\)\(AE_{t}\), \(s_{t}\), \(a_{t}\) represent the advantage value, state, and action at the t time step; the new dataset significantly improves the continuity and logical coherence of the trajectory data. Generalized Advantage Estimation comprises three components: value function approximation, Generalized Advantage Estimation itself, and a predictor.

Value function approximation

This study discusses value function approximation in the context of reinforcement learning. The goal of parameterizing the value network V is to approximate the state value function to evaluate the long-term accumulated rewards or expected returns in a specific state. A more accurate estimation of state values can be obtained by updating and optimizing the value network V. To address the issue of overestimation27, we employ three-layer MLPs to construct a separate value network \(V_{\psi }\), with its target network \(V_{{\overline{\psi }}}\). During the training of the value networks \(V_{\psi }\), small batches of data sets \(B = \left\{ {\left( {s_{i} ,a_{i} ,r_{i} ,{s_{i}^{\prime} }} \right)} \right\}_{i = 1}^{M}\) are sampled uniformly from the offline data set D, and update the value network \(V_{\psi }\) with the expected regression loss \(L_{expectile\_regression}\). In contrast, the parameters of the target networks \(V_{{\overline{\psi }}}\) are updated through the moving average. Equations govern this process (4)–(5).

$$L_{expectile\_regression} = L_{{\sigma_{1} }} \left( {r + \gamma V_{{\overline{\psi }}} \left( {s^{\prime} } \right) - V_{\psi } (s)} \right)$$
(4)
$$\overline{\psi }_{i}^{k + 1} = (1 - \alpha )\overline{\psi }_{i}^{k} + \alpha \psi_{i}^{k + 1}$$
(5)

In the provided passage, γ denotes the discount factor, \({s^{\prime}}\) represents the next state after transitioning from state s, \(\psi\) denotes the parameters of the value network V, \(\overline{\psi }\) denotes the parameters of the target value network V, M represents the batch size, \(\lambda\) represents the learning rate. Additionally, \(L_{{\sigma_{1} }} u = |\sigma_{1} - {\mathbb{I}}(u < 0)|u^{2}\) function represents the expected regression loss, where \({\mathbb{I}}()\) denotes the indicator function. \(\sigma_{1}\) represents hyperparameters. By adjusting the value of \(\sigma_{1}\), a balance point can be found between the approximate policy value function and the optimal value function within the sample, allowing to retain the information of the optimal value function within the sample while compensating through the approximate policy value function, thereby achieving better performance and faster convergence during the training process.

Generalized advantage estimation

This paper employs generalized advantage estimation to estimate the state-value function, enhance the accuracy of predicting future rewards, and reduce model bias from relying on a single estimation method. The advantage of generalized advantage estimation lies in its clever combination of Temporal Difference (TD) methods and Monte Carlo estimation, which addresses the issues of Monte Carlo estimation being unbiased but having high variance and requiring a large number of samples while also overcoming the bias present in TD methods despite their lower variance. Specifically, generalized advantage estimation constructs a multi-step framework that weights and sums TD residuals across different time steps. This strategy effectively balances bias and variance, resulting in more accurate and stable estimates of state advantage values. In practice, GAE relies on a pre-trained state-value function V to estimate advantage values AE, as shown in Eq. (6).

$$\begin{gathered} AE_{{\text{t}}}^{\left( \lambda \right)} = \left( {1 - \lambda } \right)\sum\limits_{{{\text{l}} = 1}}^{\infty } {\lambda^{{l - {1}}} \delta_{t}^{(l)} } , \hfill \\ \delta_{t}^{(l)} = - V_{\varphi } \left( {s_{t} } \right) + \sum\nolimits_{i = 0}^{l - 1} {\gamma^{{\text{i}}} r_{t + i} + \gamma^{l} V_{\varphi } } \left( {s_{t + i + 1} } \right) \hfill \\ \end{gathered}$$
(6)

where \(AE_{t}^{(\lambda )}\) represents the advantage estimation at the current time step t, l denotes the number of steps along a given trajectory, γ represents the discount factor, \(\delta_{t}^{(l)}\) denotes the l-step advantage along the given trajectory. \(V_{\varphi }\) denotes the pre-trained state value function, \(r_{t}\) represents the reward function at time step t, \(\lambda \in [0,1]\) represents the hyperparameter used to control the balance between Monte Carlo estimation and Temporal Difference (TD) estimation. When λ is close to 0, Generalized Advantage Estimation tends more towards TD estimation, exhibiting lower variance but potentially some bias; when λ is close to 1, Generalized Advantage Estimation more closely resembles Monte Carlo estimation, being unbiased but with higher variance. By adjusting the value of λ, Generalized Advantage Estimation can strike an optimal balance between bias and variance, thereby significantly reducing the variance of the advantage estimation and making the action value assessment more stable.

Compared to the Return-to-Go (RTG) used in reinforcement learning sequence modeling, Generalized Advantage Estimation improves trajectory splicing ability more effectively. For a given trajectory \(\tau = \left\{ {s_{1} ,a_{1} ,r_{1}, \ldots s_{T} ,a_{T} ,r_{T} } \right\}\), RTG calculates the sum of all rewards \(r_{t}\) from time step t to T by introducing the discount factor \(\gamma\), i.e., \({\mathop {R} \limits^{\frown }}_{{\text{t}}} = \sum\nolimits_{{t^{\prime}} = t}^{T} {\gamma {r_{t^{\prime}}} }\). RTG primarily estimates action values based on cumulative rewards from a single trajectory. Though unbiased, it is noise-sensitive and has high variance. In sparse or stochastic reward environments, accidental factors cause estimation fluctuations, and it struggles to capture long-term values, limiting trajectory splicing ability. The formula for Generalized Advantage Estimation, as shown in Eq. (6), integrates multi-step TD errors and introduces a hyperparameter λ to strike a balance between Monte Carlo estimation and TD estimation, effectively controlling bias and variance, thereby stabilizing the evaluation of state values. This stability is crucial for trajectory stitching, as it ensures accurate assessment of each state’s value during the stitching process, effectively bridging the state-action values of different trajectory segments, thus enhancing the coherence and consistency of the trajectories and ultimately improving overall performance.

By applying the generalized advantage estimation to label the offline dataset D, a new dataset \({D^{AE}} = \left\{ {\tau^{AE} } \right\}\) is obtained, denoted as \({\tau^{AE}} = \left\{ {AE_{t} ,s_{t} ,a_{t} } \right\}_{t = 1}^{N}\). In this dataset, \({AE_{t}}\) represents the advantage estimate at the t-th time step obtained using generalized advantage estimation. Subsequently, the labeled offline dataset \({D^{AE}} = \left\{ {\tau^{AE} } \right\}\) trains the modality decomposition interaction module for action prediction.

Predictor

We will employ a predictor to evaluate the target advantage values for each state during the testing phase. Predictor c primarily comprises a 3-layer MLP and is trained using expected regression loss. The loss function is depicted in Eq. (7):

$$L = E_{{(AE_{t} ,s_{t} ) \sim D^{AE} }} \left[ {L_{{\sigma_{2} }} \left( {AE_{t} - c(s_{t} )} \right)} \right]$$
(7)

where \(\left( {AE_{t} ,s_{t} } \right) \sim D^{AE}\) represents \(\left( {AE_{t} ,s_{t} } \right)\) uniformly sampled from \(D^{AE}\), \(\sigma_{2}\) is a hyperparameter, and when \(\sigma_{2} \to 1\), Predictor c will learn the maximum advantage within the given state s with the largest internal sample.

MDI module

The MDI module comprises three components: a modal decomposition module, an encoder, and a decoder, as shown in Fig. 2. The modal decomposition module decomposes the information within the sequence, segmenting it into distinct modal components. The encoder processes state and action information through intra-modal interaction based on ConvFormer and inter-modal interaction based on a dual Transformer. The intra-modal interaction, leveraging ConvFormer, effectively captures the inherent correlations within a single mode, enhancing the representation capability of single-modal information. The inter-modal interaction, utilizing a dual causal Transformer structure, captures the associations between state and action information to learn advanced representations of both thoroughly. The decoder further integrates multimodal information based on convolution and MLP to accurately predict actions.

Fig. 2
Fig. 2
Full size image

The framework for modality decomposition interaction module.

Modal decomposition module

This study aims to perform sequence modeling on the trajectory information \(\tau = \left\{ {\left( {AE_{1} ,s_{1} ,a_{1} } \right),\left( {AE_{2} ,s_{2} ,a_{2} } \right), \ldots ,\left( {AE_{N} ,s_{N} ,a_{N} } \right)} \right\}_{t = 1}^{N}\), extracted from the \(D^{AE}\). To assist the agent in effectively integrating and utilizing this information, the states, actions, and advantage values of the trajectory are treated as distinct modalities. These modalities are then decomposed from the trajectory information, resulting in new, structured trajectory data \({\tau^{\prime}}\), as shown in Eq. (8).

$$\begin{gathered} AE = \left\{ {AE_{1} ,AE_{2} , \ldots AE_{N} } \right\} \hfill \\ S = \left\{ {{\text{s}}_{1} ,s_{2} , \ldots s_{N} } \right\} \hfill \\ A = \left\{ {a_{1} ,a_{2} , \ldots a_{N} } \right\} \hfill \\ \tau ^{\prime } = AE \cup S \cup A \hfill \\ \end{gathered}$$
(8)

where S, A, and AE represent the state, action, and advantage value information, respectively, after the modality decomposition. State-action pairs are separated from advantage values to enhance the adaptability of changing tasks or strategies7 and optimize pre-training. Following this, state S and action A are input into the encoder to comprehend the interactions between modalities, yielding a high-level representation. Subsequently, the decoder harnesses the advantage values for action prediction, as depicted in the modality decomposition module shown in Fig. 2.

Encoder

The encoder mainly comprises intra-modal interactions based on ConvFormer and inter-modal interactions based on a dual Transformer. Relevant research28 shows that intra-modal interactions are relatively less critical than inter-modal interactions. The encoder adopts a step-by-step processing strategy to minimize interference during the representation learning of the most essential interactions. First, intra-modal interaction operations based on ConvFormer are carried out for states and actions. ConvFormer combines the advantages of CNN’s local feature extraction and Transformer’s global modeling to enhance the feature representation of states and actions. Subsequently, inter-modal interactions based on a dual Transformer are performed on states and actions. The cross-attention mechanism of the dual Transformer is utilized to fully capture the complex interaction relationships between states and actions, thereby achieving a more comprehensive and accurate representation of the data.

  1. (1)

    Intra-modal interaction based on convformer


In this paper, we conduct intra-modal interaction of sequential information based on ConvFormer. Unlike the Transformer structure, ConvFormer uses a convolutional module to replace the attention module in the token mixer. This is because the convolutional module can effectively integrate the temporal information between adjacent tokens, which is more conducive to capturing the inherent local dependencies in sequence21. Meanwhile, ConvFormer utilizes the feed-forward neural network (FFN) in Transformer to enable more extensive information transfer and integration of sequential information, which, to some extent, compensates for the limitations of traditional CNNs in long-range relationship modeling and thus enhances the model’s global understanding ability.

In the modality-within interaction based on ConvFormer, to assist agents in better and more accurately understanding unimodal information of states and actions, this paper separately processes states and actions by feeding them into the ConvFormer architecture. Focusing on single modalities ensures that each modality’s features can be fully extracted and comprehended, thereby enhancing the model’s representation capability for each modality. The ConvFormer consists of convolutional and feedforward neural layers, as illustrated in Fig. 3. Taking the state S as an example, \(S = \left\{ {s_{1} ,s_{2} \ldots s_{N} } \right\}\) is first fed into the convolutional layer, which mainly comprises a normalization layer and a convolution block. The normalization layer standardizes S to obtain the state sequence information within this \(X = \left\{ {x_{1} ,x_{2} \ldots x_{N} } \right\}\). The convolution block primarily performs 1D convolution operations on \(X = \left\{ {x_{1} ,x_{2} \ldots x_{N} } \right\}\) to capture the inherent local correlations within the state sequence. Precisely, the convolution block consists of two layers of convolutions, with each layer comprising a 1D convolution (1D Conv) and a ReLU activation function. The 1D convolution employs a filter \(W = \left\{ {w_{1} ,w_{2} \ldots w_{L} } \right\}\) of length L to perform a one-dimensional convolution operation on the sequence \(Y = \left\{ {y_{1} ,y_{2} \ldots y_{N} } \right\}\), as detailed in Eq. (9).

$$y_{i} = \sum\limits_{j = 0}^{L - 1} {w_{j} x_{i + j} }$$
(9)

where, i = 1,2……N represents the length of the sequence S. The sequence X undergoes a convolutional layer and then undergoes residual connection, ultimately obtaining the result \(Z^{1}\). Finally, \(Z^{1}\) passes through a feed—forward layer to obtain an accurate and rich state representation \(H^{S}\). The feed—forward layer consists of a normalization layer (Norm) and a feed—forward neural network (Feed Forward), as shown in Eq. (10).

$$H^{S} = FFN\left( {LN\left( {Z^{1} } \right)} \right) + Z^{1}$$
(10)

where FFN() represents the feed-forward neural network, LN() represents the normalization layer. Meanwhile, Action A is input into the intra-modal interaction module based on Convformer, in a similar process as S, to precisely capture its local information and obtain a more refined representation of the action \(H^{A}\).

Fig. 3
Fig. 3
Full size image

The internal structure of intra-modal interaction based on Convformer.

  1. (2)

    Inter-modal interaction based on a dual Transformer


In multimodal interaction, this paper adopts a dual Transformer architecture equipped with a causal masking attention mechanism, with the specific structure shown in Fig. 4. This architecture utilises two independent Transformer modules to better cope with complex reinforcement learning environments and achieve deep interaction between states, actions, and multimodal information within trajectories. This allows the model to capture the correlations between states and actions from multiple perspectives and levels, aiding the model in more comprehensively understanding the complex relationships between environmental states and actions. Furthermore, causal masking in multi-head cross-attention ensures that the model adheres to causality when processing time series, capturing the temporal dependencies and long-term trends of states and actions.

Fig. 4
Fig. 4
Full size image

The internal structure of inter-modal interaction based on dual Transformer.

In the inter-modal interaction based on a dual Transformer, firstly, sinusoidal positional encoding is applied to the state representation \(H^{S} = \left\{ {H_{1}^{S} , \ldots H_{N}^{S} } \right\}_{{\text{t}}}^{N}\) and action representation \(H^{A} = \left\{ {H_{1}^{A} , \ldots H_{N}^{A} } \right\}_{{\text{t}}}^{N}\), as shown in Eq. (11).

$$\begin{gathered} PE_{{\left( {pos,2i} \right)}} = \sin \left( {pos/1000^{2i/d} } \right) \hfill \\ PE_{{\left( {pos,2{\text{i}} + 1} \right)}} = \cos \left( {pos/1000^{2i/d} } \right) \hfill \\ \end{gathered}$$
(11)

where pos represents the position index, i represents the dimension index, and d represents the dimension of the embedding vector. Then, the representations of states and actions are separately input into the dual Transformer network, which primarily consists of two parallel Transformers. Each Transformer comprises a multi-head attention layer and a feed-forward neural network layer.

Each multi-head attention layer has normalization layers and multi-head attention mechanisms. First, the normalization layers normalize the state representations \(H^{S} = \left\{ {H_{1}^{S} , \ldots H_{N}^{S} } \right\}_{{\text{t}}}^{N}\) and action representations \(H^{A} = \left\{ {H_{1}^{A} , \ldots H_{N}^{A} } \right\}_{{\text{t}}}^{N}\). Then, the state and action representations are interleaved into a common sequence \(\overline{H} = \{ H_{1}^{S} ,H_{1}^{A} \ldots H_{N}^{S} ,H_{N}^{A} \}_{{\text{t}}}^{2*N}\). In multi-head attention, cross-attention operations are performed between the states, actions, and this common sequence \(\overline{H}\). This promotes the states to focus on themselves while deeply capturing action information and simultaneously allows actions to mine state features. This bidirectional interaction breaks down modal barriers, facilitating the high-frequency flow of information between states and actions within the interleaved sequence. This achieves a fine-grained fusion of cross-modal features, preserving the unique information of each modality while more comprehensively capturing the potential correlations between states and actions. This significantly enhances the depth and richness of multimodal information interaction. Specifically, linear transformations are applied to \(H^{S}\) and \(H^{A}\) to obtain the query vectors \(Q^{S}\) and \(Q^{A}\), respectively. The common sequence \(\overline{H}\) is linearly transformed into key vectors K and value vectors V. This is shown in Eq. (12).

$$\begin{gathered} Q^{S} ,Q^{A} = f_{q} \left( {H^{S} } \right),f_{q} \left( {H^{A} } \right) \hfill \\ K,V = f_{k} \left( {\overline{H}} \right),f_{v} \left( {\overline{H}} \right) \hfill \\ \end{gathered}$$
(12)

where \(f_{q} ,f_{k} ,f_{v}\) represent learnable parameter matrices for queries, keys, and values in this context. Then, feed \(Q^{S}\) and \(Q^{A}\)  into the multi-head attention together with the K and V vectors respectively for deep interaction. Specifically, it is defined in Eq. (13).

$$\begin{gathered} X_{S} = Multihead\left( {Q^{S} ,K,V} \right) \hfill \\ X_{A} = Multihead\left( {Q^{A} ,K,V} \right) \hfill \\ \end{gathered}$$
(13)

where Multihead denotes the multi-head attention calculation, and \(X_{S}\) and \(X_{A}\) represent the state and action representations obtained through multi-head cross-attention, respectively. During the process of calculating the attention values, a causal mask is intentionally introduced to ensure that the model does not rely on future information during generation, thereby maintaining the rationality and logic of the calculations. Additionally, to effectively avoid the vanishing gradient problem, a residual connection strategy is employed between the input and output of the multi-head cross-attention. Finally, these representations are sent to feed-forward neural network layers to produce higher-level representations of states and actions, denoted as \(\overline{H}^{S} = \left\{ {\overline{H}_{1}^{S} , \ldots \overline{H}_{N}^{S} } \right\}_{{\text{t}}}^{N}\) and \(\overline{H}^{A} = \left\{ {\overline{H}_{1}^{A} , \ldots \overline{H}_{N}^{A} } \right\}_{{\text{t}}}^{N}\), respectively. The structure of these feed-forward neural network layers is consistent with that used in the intra-modal interaction based on ConvFormer.

Decoder

The decoder combines convolution and a multilayer perceptron (MLP) to efficiently handle modal interactions between advantage values and states, completing action prediction as illustrated in Fig. 5. Specifically, the convolutional module extracts single-modal information from advantage values AE, enhancing the model’s deep understanding of advantage information and strengthening its adaptability and response accuracy in complex decision-making environments. Meanwhile, the MLP performs action prediction based on the representations of advantage values and states. This structural design not only reduces computational complexity and lowers the cost of model training and inference but also maintains efficient feature transformation and prediction performance while ensuring modal interaction processing capabilities, successfully achieving a balance between complexity and effectiveness.

Fig. 5
Fig. 5
Full size image

The internal structure of the decoder.

First, the advantage value AE undergoes feature transformation via the embedding layer and incorporates positional encoding to generate \(A\overline{E}\) with positional information. Next, \(A\overline{E}\) is input into the convolution block, where normalization and convolution operations extract unimodal information—assisting the model in better understanding the advantage details. The convolution block adopts the same structure as the convolution block in ConvFormer-based intra-modal interaction, as shown in Eq. (14):

$$\begin{gathered} A\overline{E} = Embed\left( {AE} \right) + PE\left( {AE} \right), \hfill \\ A \mathop {E} \limits^{\frown } = Conv\left( {LN(A\overline{E})} \right) + A\overline{E} \hfill \\ \end{gathered}$$
(14)

where Embed represents the embedding layer, LN represents the layer normalization, and Conv represents the convolutional block. Then, the state \(\overline{H}^{S}\) and advantage value \(\mathop {A} \limits^{\frown }{E}\) representations are integrated into a single sequence and input into the MLP. The MLP primarily captures the correlation between the advantage values and the high-level state representations obtained from the encoder, making better use of multimodal information to achieve action prediction and obtain the predicted action \(\mathop {A} \limits^{\frown }\) , as shown in Eq. (15).

$${\mathop{A}\limits^{\frown}}= MLP\left( A{{\mathop {E} \limits^{\frown }}_{1} ,\overline{H}_{1}^{S} \ldots , A{\mathop {E} \limits^{\frown }}_{N} ,\overline{H}_{N}^{S} } \right)$$
(15)

CGM

The comprehensive algorithm for offline reinforcement learning, named CGM, which integrates advantage estimation module and MDI, as proposed in this paper, is outlined in Algorithm 1. This algorithm comprises two primary stages: training and testing.

The algorithm’s training phase is structured into two distinct stages: advantage estimation and MDI. During the advantage estimation stage, the training initiates with gradient updates to optimize the value network. Subsequently, the refined value network estimates advantage values derived from state values, which are then applied to re-label the dataset. Following this, the predictor is trained utilizing the advantage values that have been learned from the value network.

Shifting to the MDI stage, the MDI module encoder undergoes an initial pre-training phase, incorporating the forward dynamic prediction, inverse dynamic prediction, and random masking review control techniques proposed by Sun et al.29. This pre-training serves as the foundation for the subsequent training of the decoder, which aims to minimize the action reconstruction loss as defined in Eq. (16).

$$L = E_{{\tau^{\prime} \sim D^{AE} }} \left[ {\sum\limits_{t} {\left( {a_{t} -\mathop {A}\limits^{{\frown}} } \right)}^{2} } \right]$$
(16)

where \(E_{{\tau^{\prime} \sim D^{AE} }}\) represents the calculation of the expected value over the trajectory \({\tau^{\prime}}\) sampled from the dataset \(D^{AE}\), \(a_{t}\) represents the actual action at time step t, \(\mathop {A} \limits^{\frown }\) represents the predicted action value at time step t, where \(\mathop {A} \limits^{\frown } = MDI_{\theta } \left( {h,s_{t} ,AE_{t} } \right)\), \(MDI_{\theta }\) represents for modality decomposition interaction module, h represents historical information, encompassing the sequential data of all states s and actions \(a\) from the initial time step up until the time step t−1, respectively. \(s_{t}\) and \(AE_{t}\) representing the state and the advantage value, respectively, at time step t.

During the testing phase, the predictor is utilized to estimate the current state’s advantage. Subsequently, this advantage estimate, along with historical information and the current state, is fed into the model to predict the appropriate action. The predicted action is then executed in a simulation, and the simulated environment changes the state based on this predicted action. The updated state is then re-evaluated for its advantage, and this process continues in a loop until a termination criterion is met.

Algorithm 1
Algorithm 1
Full size image

CGM

Experiments

We conducted many experiments to evaluate offline reinforcement learning that combines generalized advantage estimation with MDI. The primary objectives of these experiments were to test our method on the D4RL benchmark for offline reinforcement learning and to compare it with the state-of-the-art different offline reinforcement learning approaches. Additionally, we set up two ablation experiments to assess the impact of our method’s key components on overall performance.

Experimental parameter settings

Datasets: This paper primarily evaluates the performance of the proposed method through three different continuous control tasks in the MuJoCo simulation environment under D4RL, namely hopper, halfcheetah, and walk2d, as well as the AntMaze environment task. For each different task in MuJoCo, we used three different levels of historical datasets: (1) the medium dataset, primarily composed of states and actions collected by the agent itself in previous experiments; (2) the medium-expert dataset, mainly consisting of experienced expert-level data, including high-quality states and actions from human experts; (3) the medium-replay dataset, primarily composed of historical data recorded by the agent while performing tasks in the simulation environment. For the AntMaze environment, we adopted two datasets: (1) umaze, which features a more basic environment setting with fewer obstacles and simple path planning requirements; (2) medium, which has a more complex environment setting with more obstacles and intricate path planning needs, increasing the difficulty of the task.

Comparison methods: We compared our approach against the state of art eight different offline reinforcement learning methods, including: (1) Behavioral Cloning (BC), (2) Offline Reinforcement Learning with Implicit Q-Learning (IQL)8, (3) Conservative Q-Learning for Offline Reinforcement Learning (CQL)30, (4) Decision Transformer (DT)10, (5) Trajectory Transformer (TT)31, (6) Decision Transducer (DTd)23, (7) RvS: What is Essential for Offline RL via Supervised Learning32, and (8) ACT: Empowering Decision Transformer with Dynamic Programming via Advantage Conditioning20. Our proposed method is denoted as CGM.

Model parameter settings: In the experiments, the parameter configuration of the advantage estimation module was consistent with that reported in Reference20. The specific architecture of the modality decomposition interaction module is detailed in Table 2, and these parameters were kept consistent across all tasks, enabling parameter sharing. Additionally, the hardware platform used in the experiments was equipped with an Intel(R) Xeon(R) Platinum 8352 V CPU and 16 GB of RAM, along with a vGPU-32 GB version for GPU. On the software side, the experiments were conducted on a 64-bit Linux operating system, with the PyTorch deep learning framework serving as the foundation for model implementation.

Table 2 The specific parameters of the modality decomposition interaction module.

During the specific model training process, a batch size of 64 was adopted, and the context length was set to 20. A dropout rate of 0.1 was introduced to effectively prevent model overfitting. The learning rate was carefully set to 0.0001, and to stabilize gradient changes during the initial training phase, 10,000 learning rate warmup steps were implemented. Furthermore, the gradient norm clip value was set to 0.25, and the weight decay parameter was set to 0.0001 to ensure the stability and convergence of model training. The beta parameters of the Adam optimizer were configured as [0.9, 0.999] to optimize the model training process. The entire training process was conducted for a total of 300 K or 500 K steps, ensuring that the model could fully learn the features and information in the data.

Comparison experiments

In this experiment, we focused on conducting comparative experiments in the MuJoCo and AntMaze domains using the benchmarks of the D4RL dataset. We selected 11 tasks to comprehensively evaluate our proposed CGM method. The core evaluation metric was the normalized scores for MuJoCo and AntMaze tasks, which serve as an objective standard for measuring the performance of various offline reinforcement learning methods.

By comparing with eight other advanced offline reinforcement learning methods, as shown in Table 3, our method demonstrated significant advantages in the experiment. Specifically, our method stood out in seven key tasks such as halfcheetah—medium and walker2d—medium. It achieved the highest normalized scores, proving our method’s superior performance in handling offline reinforcement learning tasks. Although our method did not achieve the optimal values in four tasks: hopper—medium, hopper—medium—replay, walker2d—medium—expert, and walker2d—medium—replay, it still obtained the highest average score in the MuJoCo and AntMaze domains. This further highlights our method’s stability and overall advantage across various offline tasks.

Table 3 Comparison of experiment results on the MuJoCo and AntMaze motion benchmark.

The success of this paper can be attributed to its unique multimodal information processing mechanism. This mechanism strongly supports the agent’s decision-making optimization by deeply exploring the trajectory data’s intra-modal feature correlations and inter-modal complementarity. The excellent performance in tasks like half cheetah—the medium is a direct manifestation of the whole-fledged effect of this mechanism. Moreover, CGM exhibits outstanding task adaptability, maintaining stable high performance across multiple tasks, which further validates the universality and effectiveness of the multimodal information processing mechanism. However, in tasks such as hopper-medium, hopper-medium-replay, walker2d-medium-expert, and walker2d-medium-replay, the algorithm proposed in this paper has not yet reached optimal performance, indicating that there is still room for further improvement in terms of adaptability. Furthermore, the current method has only been evaluated on the D4RL dataset and has not yet been validated in real-world application scenarios. To achieve performance continuity and breakthroughs from simulated environments to real-world scenarios, we need to test and optimize our algorithm in more real-world tasks in the future to ensure its effectiveness and reliability in practical applications.

To more comprehensively evaluate the computational efficiency of our proposed method, we selected ACT, which shares the most structural similarity with our method, and the traditional DT from Transformer-based reinforcement learning as our benchmarks for comparison. On the D4RL dataset, we calculated and recorded the average runtime for these two methods, as well as our proposed method, and presented the comparison results in Table 4. It is noteworthy that, despite the increased model complexity of our method compared to DT, we have successfully achieved optimization in terms of computational efficiency relative to ACT. This achievement not only demonstrates the computational performance advantages of our method but also further validates its effectiveness in handling complex tasks.

Table 4 Comparison of average runtime on D4RL dataset.

Ablation experiments

To thoroughly validate the effectiveness of our approach in offline tasks, we conducted two sets of ablation experiments on the MuJoCo motion benchmark. The objectives of these experiments are to elucidate the significance of advantage estimation and to analyze the impact of the structural design choices incorporated within our methodology.

Experiment I: This experiment investigates the ablation effects of the advantage estimation module introduced in our research. To conclusively validate whether incorporating higher advantage values can significantly enhance the precision and efficiency of trajectory stitching, we specifically designed an ablation experiment framework named “No-AE” (No Advantage Estimation) and applied it to three representative reinforcement learning tasks: hopper, halfcheetah, and walker2d. Under the “No-AE” setting, the utilization of advantage values was completely excluded during training. Instead, the training relied solely on the Return-To-Go (RTG) values tailored for each task as feedback signals. The initial configurations of these RTG values are detailed in Table 5. The ablation study results of Experiment 1 are summarized in Table 6. Across all experiments, our proposed method achieved the best results, underscoring the effectiveness of leveraging advantage values to enhance the overall performance of the model.

Table 5 Initial RTG values are set in depletion experiment one on the MuJoCo task.
Table 6 Results of ablation experiment I: average normalized scores for hopper, halfcheetah, and walker2d tasks across three dataset levels.

Furthermore, we plotted the evaluation of nine tasks over 300 K steps, as shown in Fig. 6. Our study shows that without using advantage values, the model’s performance slightly decreases compared to when advantage values are utilized. This signifies that only suboptimal trajectories affect the model’s trajectory stitching ability in the offline dataset. By employing advantage values, the model can evaluate the value of each state-action pair and determine their superiority or inferiority relative to other choices. This is particularly crucial for stitching discontinuous trajectories because, in the absence of optimal trajectories, the model might need to make correct decisions during the stitching process. However, by utilizing advantage values, the model can better assess the value differences among each state-action pair, thus enabling more accurate trajectory stitching. Therefore, the experimental results of this study indicate that employing advantage values can enhance the model’s trajectory stitching ability in offline datasets and improve the accuracy of agent predictions. This significantly enhances the model’s performance and decision-making capabilities in practical applications.

Fig. 6
Fig. 6
Full size image

Evaluation curves of ablation experiment I on the nine tasks.

Experiment II: This paper conducts ablation experiments and analysis to assess the effectiveness of crucial components within the proposed encoder and decoder methods. In this experiment, six ablation studies are conducted on the structure of MDI:

  1. (a)

    Replacing the intra-modal interaction based on ConvFormer in the encoder with a convolutional module;

  2. (b)

    Substituting the intra-modal interaction module based on ConvFormer in the encoder with a Decision Transformer;

  3. (c)

    Replacing the intra-modal interaction based on Dual Transformer with a Decision Transformer;

  4. (d)

    Removing the intra-modal interaction module based on ConvFormer from the encoder and inputting states and actions directly into the intra-modal interaction based on Dual Transformer;

  5. (e)

    Eliminating the inter-modal interaction module based on Dual Transformer in the encoder;

  6. (f)

    Replacing the MLP layer in the decoder with a Decision Transformer.


In terms of performance comparison, the line chart depicting the ablation experiment results, as shown in Fig. 7, indicates that the proposed method slightly underperforms the ablation setup where the MLP layer in the decoder is replaced with a Decision Transformer only on the walker2d-medium and walker2d-medium-replay datasets. Across the remaining seven datasets, it exhibits superior performance. Data from Table 7 reveals that in the tasks of hopper, halfcheetah, and walk2d, the proposed method achieves the best average performance in the hopper and halfcheetah tasks, falling behind only in the walk2d task compared to ablation experiment (f), which replaces the MLP layer in the decoder with a Decision Transformer. However, in terms of computational efficiency, comparing the average time for the tasks of hopper, halfcheetah, and walk2d, the results show that the efficiency of ablation experiment (f) is lower than that of the proposed method. This thoroughly demonstrates the effectiveness of the proposed method in information extraction, decision optimization, and operational efficiency. Additionally, conducting ablation experiments on the intra-modal interaction based on ConvFormer and the inter-modal interaction based on Dual Transformer highlights the crucial role of our approach in capturing multimodal information interaction. Information from different modalities is complementary and correlated, and comprehensively utilizing these can provide a more comprehensive and accurate basis for decision-making. The optimal results achieved by the proposed method across various tasks further prove its ability to fully extract trajectory information and effectively optimize the agent’s decision-making capabilities.

Fig. 7
Fig. 7
Full size image

Evaluation curves of ablation experiment II on the nine tasks.

Table 7 The results of Ablation Experiment showcase the average normalized scores achieved by the halfcheetah, hopper and walker2d tasks, analyzed across three distinct levels of datasets.

Conclusion and future work

This paper proposes an offline reinforcement learning approach that combines advantage estimation module and MDI to address the deficiencies in trajectory stitching capability and the neglect of interactions between multimodal information within and across modalities in existing Transformer-based offline reinforcement learning methods. Experimental results demonstrate the significant superiority of our approach over other offline reinforcement learning methods on the D4RL benchmark, thereby proving the effectiveness of our approach in offline tasks.

The method proposed in this paper has the following characteristics:

  1. (1)

    It mainly comprises an advantage estimation and MDI modules to improve trajectory stitching capability and adequately extract the interactions between multimodal information in trajectories.

  2. (2)

    In the advantage estimation module, generalized advantage estimation estimates advantage values for states, and the dataset is re-labeled using these advantage values to enhance the model’s trajectory stitching capability.

  3. (3)

    The MDI module consists of an encoder and a decoder. In the encoder, ConvFormer-based intra-modal interaction and dual-Transformer-based inter-modal interaction are employed to fully extract the intra-modal and inter-modal relationships between states and actions, thoroughly learning the high-level representations of states and actions. The decoder further fuses the multi-modal information of advantage values and states to complete action prediction.

Currently, all the experiments on relevant datasets have been conducted in the D4Rl simulation environment. Although the proposed method has effectively optimized the model performance in this simulation environment, it still has certain limitations. In future work, we will apply this method to more real-world scenarios to test its effectiveness and adaptability in the real world.