Introduction

Trajectory prediction is increasingly being widely applied in the field of autonomous driving. It aims to predict the future driving trajectory of a traffic agent based on its historical behavior and the surrounding environmental information. Improving the accuracy of trajectory prediction is of great significance for assisting autonomous driving decision-making and enhancing traffic efficiency. However, predicting traffic agent trajectories presents significant challenges due to the intricate spatio-temporal features inherent in traffic agent trajectory information and the complex nature of the road environment [1].

Early trajectory prediction methods were primarily based on physics-based approaches [2,3,4,5,6]. They utilized vehicle dynamics models or kinematics models for prediction but suffered from lower prediction accuracy due to the lack of consideration for road-related factors. Subsequently, researchers turned to machine learning models [7,8,9,10] for traffic agent trajectory prediction, which provided new insights into trajectory prediction. However, machine learning methods often require pre-provided or identified features and struggle to capture abstract and implicit cross-modal features. To effectively capture the rich spatio-temporal relationships inherent in traffic trajectories, most contemporary trajectory prediction methods are based on deep learning models. These methods can simultaneously consider interaction factors between traffic agents as well as between traffic agents and roads.

In the spatial dimension, early researchers utilized convolutional neural networks (CNN) to understand the spatial relationships in scenes [11, 12] and achieved notable success in extracting Euclidean spatial features. However, the spatial relationships among traffic agents are non-Euclidean. With the development of graph neural networks (GNN), researchers attempted to model the spatial interaction relationships among traffic agents as a graph [13,14,15,16] and used GNN to learn their dependencies. Li et al. [15] captured the local spatial information between traffic agents by constructing a distance adjacency matrix and a spatial connectivity matrix using graph convolutional networks (GCN). Zhang et al. [16] employed graph attention networks (GAT) to consider the most noteworthy interactions from surrounding traffic agents. GAT introduces an attention mechanism, allowing target traffic agents to assign different attention weights to their neighbor traffic agents. However, this inclusion also introduces complexity and computational costs to the model. In contrast, GCN typically possesses a simpler structure and computational process, making it more computationally efficient when handling large-scale traffic graph data. Whether employing GCN or GAT, both methods enable traffic agents to learn the local spatial information of their neighboring traffic agents. However, they do not encompass the learning of the global spatial features of traffic agents. The development of attention mechanisms has provided new momentum for trajectory prediction. Researchers have used attention mechanisms to adaptively learn the global spatial features among traffic agents. References [17, 18] incorporated attention mechanisms to simultaneously consider both the local and global spatial features, allowing traffic agents to adaptively learn the spatial features of all surrounding traffic agents at each time step. However, they lack exploration of the full-process spatial features over the entire historical time sequence and fail to effectively fuse spatial features across different dimensions.

Fig. 1
figure 1

Schematic diagram of the local, global, and full-process spatial modeling. The local and global spatial modeling represents learning the spatial relationships between traffic agents at time t, while the full-process spatial modeling represents the feature map obtained by fusing the historical trajectory features of traffic agents through LSTM, and then learning the spatial relationships between traffic agents

In the temporal dimension, early researchers employed recurrent neural networks (RNN) for trajectory prediction. Alahi et al. [19] designed a social pooling mechanism to capture pedestrian neighbor information and used LSTM for trajectory prediction. Although RNNs can capture short-term temporal dependencies in trajectories, they struggle to effectively capture long-term temporal dependencies in trajectory sequences. The introduction of temporal convolutional networks (TCN) has proved beneficial in addressing this limitation. TCN utilizes stacked dilated causal convolutions to better comprehend the long-term temporal dependencies present in traffic agent trajectory sequences. Tang et al. [20] set different dilation rates for the dilated causal convolutions of TCN’s different odd and even layers to explore the long-term temporal dependencies of trajectory sequences, reducing the time cost of trajectory prediction and improving prediction accuracy. Later, transformers, the aid of positional encoding and self-attention mechanism, can easily capture long-term dependencies in sequences without the need for convolutional layers to gradually expand the receptive field, achieving significant success in learning long-term temporal dependencies. Chen et al. [17] utilized transformers to uncover the long-term temporal dependencies in historical trajectories of traffic agents, significantly improving trajectory prediction accuracy. In [21], LSTM’s sequential step-by-step model was compared with only-attention-based models, including transformer (TF) and larger bidirectional transformer (BERT), to predict the future trajectory of pedestrians. The results indicate that TF-based models exhibit superior performance, particularly in long-term prediction. These findings highlight the effectiveness of TF-based models for capturing both long and short-term temporal dependencies in traffic trajectories across various temporal dimensions. However, it is worth noting that these methods do not integrate the temporal features from different dimensions.

To comprehensively explore and integrate the multi-dimensional spatio-temporal features within trajectory sequences, this paper proposes a multi-dimensional spatio-temporal feature fusion trajectory prediction model (MDSTF). The proposed model addresses the trajectory prediction problem of traffic agents in mixed traffic scenarios. In the spatial dimension, MDSTF utilizes GCN to construct a distance-based adjacency matrix and a heading-angle-based adjacency matrix, which enables the target traffic agent to learn its spatial information from its neighboring traffic agents, facilitating the local spatial modeling. However, it cannot effectively capture the global spatial features of distant traffic agents. To address this, spatial attention is employed to better capture the global spatial features of traffic agents, contributing to the global spatial modeling. In addition, given the complex spatial relationships of traffic agents throughout the historical time series, the historical trajectory sequences is input LSTM. The hidden state output by LSTM at the last time as the feature representation of traffic agents throughout the historical time series. This complex spatial relationship is further learned through spatial attention, forming the full-process spatial modeling. The schematic diagrams of the local, global, and full-process spatial modeling are shown in Fig. 1. In the temporal dimension, MDSTF adopts the Transformer’s encoder to model long-term temporal dependencies between different time steps using position encoding and attention mechanisms. This effectively resolves the ambiguity of trajectory transitions in some extreme traffic scenarios. Simultaneously, considering that the behavior of traffic agents is largely influenced by short-term behavior, we fuse the short-term temporal dependencies obtained from the output of all time steps through LSTM in LSAN with the long-term temporal dependencies captured by the Transformer. Considering the temporal dependence of the output future trajectory, MDSTF utilizes TCN to generate future trajectories. The main contributions of this paper are summarized as follows:

  • We present the MDSTF, which effectively captures the spatio-temporal dependencies of different dimensions and improves the prediction accuracy of future trajectories for traffic agents.

  • We extract and fuse multi-dimensional spatio-temporal features. In the spatial dimension, we model the local, global, and full-process spatial features of traffic agent trajectory sequences through the entire historical time series. In the temporal dimension, we also model the long-term and short-term temporal dependencies of traffic agent trajectory sequences.

  • We conduct experiments on the complex mixed traffic flow dataset, Apolloscape, and the results demonstrate the superiority of MDSTF in predicting long-term trajectories in extreme traffic scenarios or mixed traffic scenarios.

The remaining sections of this paper are organized as follows: Sect.  “Related work” provides an overview of related work in trajectory prediction. Section “Problem formulation” defines the problem of trajectory prediction. Section “Methodology” delves into the details of our MDSTF. Section “Experiment” focuses on evaluating the performance of the MDSTF model. Finally, Sect. “Conclusion” concludes the paper.

Related work

This section provides an overview of three major types of trajectory prediction methods for autonomous driving that have been proposed in the past two decades. These encompass various methodologies, including physics-based methods, machine learning-based approaches, and the currently prevalent deep learning methods. These methods represent the development and evolution of autonomous driving technology in the field of trajectory prediction, showcasing diverse theoretical foundations and practical application outcomes.

Trajectory prediction based on physics

The physics-based method involves utilizing physical models, such as vehicle dynamics or kinematics models, to predict the trajectory of traffic agents. Kinematics models, including constant velocity or constant acceleration models, are commonly employed due to their relatively simple structure. Physics-based methods typically encompass single-trajectory methods, Kalman filtering methods, and Monte Carlo methods, among others [2,3,4,5,6]. Single-trajectory methods directly apply the current state of the traffic agent to the vehicle’s dynamics or kinematics model, making them simple and suitable for scenarios where trajectory information has no noise. Lytrivis et al. [2] and Miller et al. [3] used the constant turn rate and acceleration model and constant-acceleration model respectively to predict future trajectories. Single-trajectory methods assume that the state of the vehicle is fully known and noise-free. In contrast, the Kalman Filter (KF) method is capable of handling noise. It models the uncertainty or noise in the current vehicle state and its physical model using a Gaussian distribution, taking into account the uncertainty in trajectory prediction. Zhang et al. [4] proposed an approach based on vehicle-to-vehicle communication and Kalman filtering, enabling the target vehicle to avoid obstacles by predicting the trajectories of distant vehicles. Lefkopoulos et al. [5] introduced the interacting multiple model Kalman filter (IMM-KF) method, which fully considered the interactive factors among traffic agents and used physical models to predict the trajectories of traffic participants in the next few seconds. Unlike single-trajectory methods and the Kalman Filter methods, the Monte Carlo method can approximate the state distribution of traffic participants. It randomly samples the state information of traffic participants and applies physical models to generate potential future trajectories. A representative work includes Wang et al. [6], who used the Monte Carlo method to predict trajectories and optimized the reference trajectory using model predictive control (MPC).

Physics-based methods have advantages in terms of high computational efficiency and simplicity. However, they lack consideration for factors associated with the road, resulting in lower accuracy and larger errors in trajectory prediction. These methods are suitable for simple road scenarios and short-term predictions.

Trajectory prediction based on machine learning

Machine learning-based trajectory prediction methods commonly leverage algorithms such as support vector machine (SVM), hidden Markov model (HMM), and dynamic Bayesian network (DBN) to forecast future trajectories. SVM works by transforming the input data of driving behavior into a high-dimensional space and performing linear classification in this space to find driving behavior categories (such as straight, left turn, right turn, etc.), thereby predicting future driving trajectories. Mandalia et al. [7] achieved good prediction results by utilizing features such as steering wheel angle, position, and acceleration, and applying SVM to identify driving actions. However, SVM requires predefining the driver’s actions, which may impact the final trajectory prediction results. HMM constructs a Markov chain by representing the state at the current moment as being dependent only on the state of the previous moment. In trajectory prediction, HMM utilizes the historical states of traffic participants as the observed sequence to solve for the most likely future observation sequence. To predict driver control actions, Berndt et al. [8] used steering wheel angle and global coordinates as inputs to HMM. Qiao et al. [9] proposed an algorithm named HMTP* based on HMM, which adaptively selects parameters to simulate real-world scenarios with dynamically changing velocities. However, HMM fails to consider interactions among traffic agents. DBN models the interaction between vehicle states and traffic participants using Bayesian networks. This approach combines the historical sequence information of traffic participants and employs inference and learning methods to predict future trajectories. Schreier et al. [10] employed DBN to identify driving actions and predicted vehicle trajectories using kinematic models corresponding to each actions. However, most DBN-based methods can only distinguish a limited number of actions such as straight driving or lane changing, and their generalization capability is limited.

Machine learning-based trajectory prediction methods primarily rely on manual feature extraction, which is not only time-consuming and labor-intensive but also challenging to capture abstract and implicit features.

Trajectory prediction based on deep learning

Deep learning-based trajectory prediction methods not only effectively learn the relevant information of roads and traffic agents themselves but also fully utilize the interaction information among traffic agents to adapt to complex traffic scenarios. Most deep learning models for trajectory prediction primarily model the spatio-temporal features of traffic agent trajectories, uncovering underlying spatio-temporal dependencies.

In the spatial dimension, early researchers often employed convolutional neural networks (CNN) to model spatial relationships. They treated traffic scenarios as bird’s-eye views and used them as inputs for CNNs, applying the networks to understand spatial dependencies within the scenes. For instance, Djuric et al. [11] encoded the surrounding environment of vehicles as rasterized images and used them as inputs to a CNN for predicting short-term vehicle trajectories. However, the method tended to produce relatively uniform and single-path trajectory predictions. Cui et al. [12] considered the uncertainty of traffic behavior and the potential situations vehicles may encounter on the road. They proposed an automatic driving multimodal trajectory prediction method based on deep convolutional networks. Experimental results demonstrated that multimodal prediction outperformed single-modal prediction, especially for longer prediction horizons. While CNN has achieved great success in extracting Euclidean spatial features, traffic trajectory data mostly exist in non-Euclidean spaces. With the development of graph neural networks (GNN), researchers began to view traffic agents in the scene as nodes in a graph, where the interaction relationships between traffic agents are regarded as edges between graph nodes. GNN was then employed to learn the dependencies between traffic agents. The most commonly used GNNs are graph convolutional networks (GCN) and graph attention networks (GAT). GCN first represents each node in the graph structure as a vector and then combines these vectors through convolution operations to generate new features. Sheng et al. [13] represented the traffic scene as a graph and established a distance-based adjacency matrix. In this matrix, the element values represent the reciprocal of the distance between two nodes, indicating that a shorter distance between two traffic agents corresponds to a stronger relationship. Subsequently, they employed a GCN to aggregate information from adjacent nodes. However, relying solely on a distance-based adjacency matrix may not fully capture the complex spatial relationships within traffic scenes. Xu et al. [14] addressed this issue by building a multi-view logical network, including logical associations between traffic agent categories, simplified trajectories, and heading angles. They used graph convolutional modules to extract multi-view logical features and uncovered both micro-level logical-physical features and macro-level global logical-physical features. Due to the fact that the weights of neighboring nodes in GCN to target nodes are obtained through normalized adjacency matrix, while GAT dynamically learns different weights of neighboring nodes by introducing attention mechanism, improving the model’s attention to different neighboring nodes. Zhang et al. [16] used GAT to model interactions between agents, as well as between agents and infrastructure, by constructing a distance adjacency matrix. This approach improved trajectory prediction performance. However, most spatial modeling using graph neural networks primarily considers the local spatial relationships and overlooks the global spatial relationships. The attention mechanism addresses the inability of graph neural networks to learn information from distant nodes, leading researchers to increasingly focus on modeling the global spatial context. Chen et al. [17] employed a spatial attention mechanism to adaptively learn the global spatial features. Furthermore, Chen et al. [18] not only utilized attention mechanisms to model the global social interactions at time t but also captured the relationships between representations of social interactions across at different time steps. This model considers a more comprehensive spatial relationship.

In the temporal dimension, recurrent neural networks (RNNs), along with their variants such as long short-term memory (LSTM) and gated recurrent unit (GRU), are commonly used for extracting temporal features. Alahi et al. [19] designed a social pooling mechanism to capture information about surrounding pedestrians and used LSTM to extract pedestrian trajectory features. However, since a single RNN network may struggle to extract complex traffic trajectory sequences, multi-RNN architectures have been widely adopted. Dai et al. [22] employed a combination of two LSTM networks to predict the trajectory of a target vehicle. One network was used to simulate the trajectories of surrounding vehicles, while the other to capture the interactions among these vehicles. Similarly, Xin et al. [23] used two LSTM networks, one to predict the target lane for the target vehicle and the other to predict the trajectory based on the vehicle’s state and the predicted target lane. However, LSTM and GRU have limitations in processing time series data in parallel and may not be suitable for handling long-term sequence problems. Transformer networks introduced self-attention mechanisms, which assist the model in capturing long-range dependencies within sequences. Giuliari et al. [21] were the first to utilize the Transformer network for pedestrian trajectory prediction, achieving good results without the need for any complex interaction terms. Subsequently, Chen et al. [17] combined the use of TCN to capture temporal features of traffic trajectories with the Transformer for trajectory prediction, achieving state-of-the-art results on the ApolloScape trajectory dataset in 2021. TCN employs multiple dilated causal convolutional layers to simultaneously process information from different time steps, enabling high parallelism. The receptive field of the convolutional kernels in TCN is typically adjustable, allowing the model to learn dependency relationships across varying temporal ranges. Li et al. [15] directly used 2D convolutions with (\(1\times 3\)) kernels to capture the temporal dependencies of trajectory sequences, achieving the best accuracy on the 2019 ApolloScape dataset. To further improve the computation time and prediction accuracy of TCN, Tang et al. [20] set the dilation rates of 1 and 2 for the odd and even layers, respectively, of the dilated causal convolutional layers in TCN. Their model achieved faster training time compared to Li et al. [15]. TCN can not only extract temporal features but can also be used for spatiotemporal information aggregation. Yang et al. [24] utilized TCN to directly aggregate spatial and temporal interactions, achieving good trajectory prediction performance.

In summary, despite numerous efforts in spatial-temporal modeling for traffic agent trajectory prediction, there is still a lack of in-depth exploration of spatial relationships across the entire historical time sequence. Furthermore, research on the fusion of spatio-temporal features from different dimensions is not sufficient.

Problem formulation

The trajectory prediction problem can be formulated as predicting the future trajectory of a traffic agent based on its historical trajectory sequence and the surrounding environmental data. Specifically, let X represent the historical trajectories of all traffic agents in the traffic scenario, and \(X_i\) represent the historical trajectory sequence of the traffic agent i:

$$\begin{aligned} X_{i}=\left\{ (x_{i}^{t}, y_{i}^{t},o_{i}^{t})|t=1,\ldots ,T_{_{obs} } \right\} \end{aligned}$$
(1)

where \(x_{i}^{t}\) and \(y_{i}^{t}\) represent the x-axis and y-axis coordinates of traffic agent i at time t. \(o_{i}^{t}\) represents other information about the traffic agent i, such as its type, length, and width. \(T_{obs}\) denotes the length of the historical trajectory sequence. The goal of trajectory prediction is to predict the future traffic agent trajectory \({\hat{Y}} _{i} =\left\{ ({\hat{x}}_{i}^{t},{\hat{y}}_{i}^{t} )|t=T_{obs}+1,\ldots ,T_{obs}+T_{pred} \right\} \) within a certain time range, where \(T_{pred} \) represents the length of the predicted future trajectory sequence. The actual future trajectory can be represented as \(Y_{i}\).

Methodology

To capture the complex spatio-temporal features among traffic agents in mixed traffic scenes, MDSTF models them in different spatio-temporal dimensions. The framework of the MDSTF model, as shown in Fig. 2, mainly consists of three modules: spatial module, temporal module, and trajectory prediction module. Firstly, the historical trajectories of individual traffic agents and the traffic graphs are extracted from the historical traffic scene as inputs to the spatial module. In the spatial module, a GCN is used to locally model spatial features using the traffic graph. Additionally, a spatial attention mechanism (GSA) is introduced for the global spatial feature modeling. Within the LSTM-spatial attention network (LSAN), we combine the LSTM and spatial attention to model the full-process spatial features over the entire historical time sequence. A spatial gate fusion mechanism is then used for spatial feature fusion. In the temporal module, the transformer encoder is employed to capture long-term temporal dependencies and integrate them with the short-term temporal dependencies obtained from the LSTM output of the LSAN. Finally, in the trajectory prediction module, two TCNs with different kernel sizes are used to generate future trajectories for traffic agents.

Fig. 2
figure 2

MDSTF trajectory prediction model

Model input

We predict the speed of traffic agents by using \(X_{i}^{t} =(x_{i}^{t}-x_{i}^{t-1},y_{i}^{t}-y_{i}^{t-1},o_{i}^{t})\) as the input to MDSTF, and then obtains the final coordinates of the traffic agents based on the predicted speed. The advantages of predicting speed are primarily as follows: (1) It can eliminate the impact of geographical location differences, facilitating the use of data from different locations to train the model [15]; (2) By using \(x_{i}^{t}-x_{i}^{t-1} \) and \(y_{i}^{t}-y_{i}^{t-1} \) as inputs, the model can achieve better convergence while reducing numerical space [25].

To represent complex spatio-temporal dependencies, we employ a fully connected layer to transform the input feuture \(X_{i}^{t} \) of the traffic agent i at time t into an embedding representation:

$$\begin{aligned} H_{i}^{t} =FC(X_{i}^{t}) \end{aligned}$$
(2)

where \(H_{i}^{t}\in R^{D} \) is the embedding feature vector, D represents the embedding dimension, and FC denotes the fully connected layer. So, the input trajectory feature of the spatio-temporal layer in MDSTF is \(H \in R^{T_{obs}\times N \times D}\), where N represents the number of traffic agents.

Traffic graph construction: for each time t, we consider each traffic agent as a node in the graph, and the relationships between traffic agents represent the edges of the graph, resulting in a traffic graph \(G^{t}=(V^{t}, E^{t})\). Here, the node set \(V^{t}\) represents all the traffic agents at time t, and \(E^{t}\) represents the relationships between traffic agents at that time. We use an adjacency matrix \(A^{t}=R^{N\times N}\) to represent each graph, with N as the number of nodes. Each element \(A_{ij}^{t} \) in \(A^{t}\) indicates whether the nodes in the graph are adjacent: if \(v_{i}^{t}, v_{j}^{t}\in V^{t} \) and \((v_{i}^{t}, v_{j}^{t})\in E^{t} \), then \(A_{ij}^{t}=1 \), otherwise \(A_{ij}^{t}=0 \). Here, \(v_{i}^{t} \) and \(v_{j}^{t} \) represent traffic agents i and j at time t. We establish multiple traffic graphs for each time t, including the distance adjacency matrix \(A_{D}^{t}\) and the heading angle adjacency matrix \(A_{H}^{t}\).

  • Distance adjacency matrix \(A_{D}^{t}\): Considering the impact of other traffic agents within a certain range on the target traffic agent, we construct an adjacency matrix \(A_{D}^{t}\) based on the Euclidean distance at each time t:

    $$\begin{aligned} (A_{D}^{t})_{ij}={\left\{ \begin{array}{ll}1, &{} \text{ d}_{ij}^{t}<D_{threshold} \\ 0, &{} \text{ o }therwise\end{array}\right. } \end{aligned}$$
    (3)

    where, \((A_{D}^{t})_{ij}\) represents the element of the adjacency matrix \(A_{D}^{t} \), \(d_{ij}^{t} \) denotes the Euclidean distance between traffic agent i and traffic agent j at time t, and \(D_{threshold} \) is the Euclidean distance threshold. If the distance between two traffic agents is less than the threshold, they are considered to have an interaction relationship. Considering the inconsistency in the influence range of different categories of traffic agents in real-life scenarios, we set different distance thresholds for different categories of traffic agents. In the experimental section of this paper, we conducted comparative experiments on the parameter setting of \(D_{threshold}\).

  • Heading angle adjacency matrix \(A_{H}^{t}\): In a traffic scenario, two traffic agents with the same heading angle often exhibit similar traffic behaviors. To capture this relationship, an adjacency matrix \(A_{H}^{t}\) based on the heading angles. Considering the possibility of inaccuracies and noise in the dataset used in this paper, we set the difference in heading angles between two traffic agents within a certain threshold:

    $$\begin{aligned} (A_{H}^{t})_{ij}={\left\{ \begin{array}{ll}1, &{} {\theta } _{i}-\theta _{j} <\theta _{threshold}\\ 0, &{} \text{ o }therwise\end{array}\right. } \end{aligned}$$
    (4)

    where \(\theta _{i}\) and \(\theta _{j}\) represent the heading angles of traffic agent i and j at time t, and \(\theta _{threshold} \) is the heading angle threshold.

Spatial module

In traffic scenes, the motion of traffic agents is heavily influenced by the movement of surrounding objects. Therefore, modeling the spatial features of traffic agents is crucial. Most researchers have focused on the local and global spatial dependencies, neglecting the full-process spatial dependencies of traffic agents across the entire historical time series. To fully exploit the spatial dependencies among traffic agents, this study proposes a model based on multidimensional spatial feature modeling. Specifically, a graph convolutional network is employed to capture the local spatial features, while a spatial attention mechanism is utilized to capture the global spatial features for distant traffic agents. The combination of LSTM and spatial attention mechanism is used to model the full-process spatial features of the entire historical time sequence. Finally, a gating mechanism is employed to fuse all the spatial features.

Graph convolution network (GCN)

The behavior of a traffic agent is greatly influenced by other nearby traffic agents within a short distance range. To capture the local spatial relationships among traffic agents, we employ a graph convolutional network. Kipf [26] proposed a first-order approximation to simplify the Chebyshev spectral filtering, making standard convolutional operations on graphs [27] feasible in practice. Li et al. [28] proposed diffusion convolutional layers for spatiotemporal modeling, and later Wu et al. [29] further generalized the concept of diffusion convolution to the form of Eq. 5. So, in our model, the diffusion convolution at time t can be represented as:

$$\begin{aligned} L^{t}=\sum _{k=0}^{K}P_{D}^{k} H^{t}W_{k1}+ P_{H}^{k} H^{t}W_{k2} \end{aligned}$$
(5)

where \(H^{t}\in R^{N \times D}\) represents the trajectory features input at time t, \(L^{t}\in R^{N \times D}\) stands for the graph features output after processing at time t, \(P^{k}\) represents the power series of the transition matrix, \(W_{k1}\) and \(W_{k2}\) are learnable parameters, and K is the diffusion step. We use the distance adjacency matrix \(A_{D}^{t}\) and heading angle adjacency matrix \(A_{H}^{t}\) defined in Sect. “Model input” as inputs. \(P_{D}=A_{D}^{t}/rowsum(A_{D}^{t}),P_{H}=A_{H}^{t}/rowsum(A_{H}^{t}) \), rowsum() represents the sum operation of matrix row values. The graph convolution in Eq. 5 is actually used by the target traffic agent to aggregate feature information from traffic agents in different neighborhood orders.

Global spatial attention (GSA)

The behavior of traffic agents is not only influenced by the behaviors of nearby traffic agents, but also by distant traffic agents. While GCNs can effectively aggregate information from neighboring nodes, they often lack aggregation of information from distant nodes. To address this issue, we use multi-head self-attention mechanism to capture the global information.

Given the input trajectory features \(H^{t} \in R^{ N \times D}\) at time t, subspaces \(Q_{m}^{G}\in R^{N \times d_{q} }\), \(K_{m}^{G}\in R^{N \times d_{k} }\), and \(V_{m}^{G}\in R^{N \times d_{v} }\) are generated through linear transformation in the following forms:

$$\begin{aligned} Q_{m}^{G}=H^{t} \cdot W_{qm}^{G} ,K_{m}^{G}=H^{t} \cdot W_{km}^{G},V_{m}^{G}=H^{t} \cdot W_{vm}^{G} \end{aligned}$$
(6)

where \(W_{qm}^{G}\in R^{D\times d_{q} }, W_{km}^{G}\in R^{D\times d_{k} }, W_{vm}^{G} \in R^{D\times d_{v} }\) are learnable parameters.

Subsequently, the attention scores of the traffic agent i and traffic agent j are obtained through multi-head scaled dot-product attention. Finally, based on the obtained attention weights, we aggregate the features of adjacent traffic agents to the target traffic agent, which can be represented as:

$$\begin{aligned} G^{t}= Concat(head_{1},head_{2},\ldots ,head_{n} )W_{G} \end{aligned}$$
(7)

where,

$$\begin{aligned} head_{m}=softmax\left( \frac{Q_{m}^{G}(K_{m}^{G})^{T} }{\sqrt{d_{k} } } \right) V_{m}^{G} \end{aligned}$$
(8)

where n represents the number of attention heads, m represents the m-th attention head, Concat represents concatenation operation, softmax represents the activation function, and \(W_{G}\) denotes the weight matrix. \(G^{t}\) represents the features of each traffic agent after integrating global traffic agents at time t.

Fig. 3
figure 3

Traffic agents driving on the road diagram

LSTM-spatial attention network (LSAN)

Although existing methods have conducted more detailed research on the spatial relationships of traffic agents, they are limited to the local and global relationships, and ignore the full-process spatial relationships of traffic agents across the entire historical time series. Taking the scenario illustrated in Fig. 3a as an example, if the orange and yellow vehicles maintain similar heading angles and driving speeds throughout the entire historical time series, the yellow vehicle would exert a greater influence and exhibit higher similarity to the orange vehicle. Conversely, in the scenario depicted in Fig. 3b, if the yellow vehicle enters a right-turn lane in the middle of the process, its influence and similarity to the orange vehicle would be relatively diminished compared to the scenario in Fig. 3a. To capture this spatial relationship, we propose a method that combines LSTM with a multi-head attention mechanism to model the full-process spatial features of traffic agents throughout the entire historical time sequence.

We input the trajectory feature \(H\in R^{T_{obs} \times N\times D} \) into LSTM and use the last hidden state \(T_{H}\in R^{N\times D} \) of LSTM as the temporal feature for each traffic agent after fusing the entire historical time sequence:

$$\begin{aligned} T_{S} ,T_{H}=LSTM(H) \end{aligned}$$
(9)

where \(T_{S}\in R^{T_{obs} \times N\times D}\) represents all hidden states output by LSTM. To capture the full-process spatial features of traffic agents across the entire historical time sequence, we adopt a multi-head attention mechanism to capture these intricate spatial features. Given the temporal features \(T_{H}\in R^{N\times D} \), subspaces \(Q_{m}^{F}\in R^{D \times d_{q} } \), \(K_{m}^{F}\in R^{D \times d_{k} } \), and \(V_{m}^{F}\in R^{D \times d_{v} } \) are generated through linear transformation in the following forms:

$$\begin{aligned} Q_{m}^{F}=T_{H} \cdot W_{qm}^{F} ,K_{m}^{F}=T_{H} \cdot W_{km}^{F},V_{m}^{F}=T_{H} \cdot W_{vm}^{F} \end{aligned}$$
(10)

where \(W_{qm}^{F}\in R^{D\times d_{q} }, W_{km}^{F}\in R^{D\times d_{k} }, W_{vm}^{F} \in R^{D\times d_{v} }\) are learnable parameters. Then, obtaining attention scores of the traffic agents i and j through multi-head scaled dot-product, and finally, based on the acquired attention weights, aggregating the features of neighboring traffic agents into the target traffic agents. This process can be expressed as:

$$\begin{aligned} F = Concat\left( head_{1},head_{2},\ldots ,head_{n}\right) W_{F} \end{aligned}$$
(11)

where,

$$\begin{aligned} head_{m}=softmax\left( \frac{Q_{m}^{F}(K_{m}^{F})^{T} }{\sqrt{d_{k} } }\right) V_{m}^{F} \end{aligned}$$
(12)

where \(W_{F}\) is a learnable parameter and F represents the full-process spatial features captured through LSTM and multi-head attention mechanisms.

Spatial gated fusion

To integrate the local spatial features L, the global spatial features G, and the full-process spatial features F, we design a spatial gated fusion mechanism with a sigmoid activation function. This mechanism maps the importance of input features to weight values between 0 and 1 using the sigmoid activation function. Subsequently, these weight values are multiplied with the corresponding features to achieve a weighted fusion of different input features. Initially, we fuse the local and global features using this spatial gated fusion mechanism as follows:

$$\begin{aligned}{} & {} z_{1} = \sigma (L \cdot W_{11} +G \cdot W_{12}+b_{g1}) \end{aligned}$$
(13)
$$\begin{aligned}{} & {} GL = z_{1} \odot L+(1-z_{1} ) \odot G \end{aligned}$$
(14)

In the above equation, \(\sigma \) represents the sigmoid activation function, \(W_{11} \) and \(W_{12} \) are learnable weights, \(b_{g1}\) is the bias term, and \(\odot \) denotes element-wise multiplication. Following this, we use the same gated fusion mechanism to integrate the full-process spatial features:

$$\begin{aligned}{} & {} z_{2} =\sigma (GL\cdot W_{21} +F\cdot W_{22}+b_{g2}) \end{aligned}$$
(15)
$$\begin{aligned}{} & {} GLF = z_{2} \odot GL+(1-z_{2} ) \odot F \end{aligned}$$
(16)

GLF encompasses complex spatial features from the historical trajectories of traffic agents.

Temporal module

The trajectory sequence contains rich temporal features, and the transformer has shown excellent performance in handling long-time series problems. The transformer network primarily consists of positional encoding, self-attention mechanisms and feedforward network. Positional encoding is introduced to provide positional information for each position in the input sequence. The self-attention mechanism enables the transformer network to calculate the correlations between each input position and other positions when processing sequential data. This enables the transformer network to directly capture dependencies between different positions in the input sequence without iteratively processing the sequence like RNNs. In the MDSTF model, we use the encoder of the transformer network to capture the long-term temporal dependencies in traffic trajectory sequences.

Fig. 4
figure 4

Residual network framework with dilated causal convolution

Positional encoding

Since traffic trajectory data is a classic form of time series data, and analyzing the temporal relationships between data can enhance the performance of traffic trajectory prediction. To capture the positional information of time series data, the transformer model incorporates sine and cosine positional encoding. The positional encoding is defined as follows:

$$\begin{aligned} \begin{aligned}&PE(pos,2k)=sin\left( \frac{pos}{10,000^{2k/d} } \right) \\&PE(pos,2k+1)=cos\left( \frac{pos}{10,000^{2k/d} } \right) \end{aligned} \end{aligned}$$
(17)

where pos represents the position in the sequence, k denotes the length of the sequence, and d is the dimension of the sequence. 2k represents even dimensions while \(2k+1\) represents odd dimensions. For each element in the sequence data, we generate embeddings by applying sine and cosine functions with different frequencies.

Temporal information extraction

We use the encoder of transformer to capture the long-term temporal relationships in the traffic trajectory sequence. The encoder is primarily composed of Temporal Attention, FeedForward, Dropout, and Layer Normalization. The abstract representation of the Transformer encoder is as follows:

$$\begin{aligned} ST_{L}=LN(DP(FeedForward(Res))+Res) \end{aligned}$$
(18)

where \(ST_{L}\) represents the output of the transformer encoder, LN stands for layer normalization, DP represents dropout operation, FeedForward is a feedforward network consisting of two fully connected layers, and ReLU is the activation function for the first layer and no activation function for the second layer. Res represents the intermediate feature of the encoder and can be expressed as follows:

$$\begin{aligned}{} & {} FeedForward(Res)= ReLU(ResW_{1}+b_{1} )W_{2} +b_{2}, \nonumber \\ \end{aligned}$$
(19)
$$\begin{aligned}{} & {} Res=LN(DP(Temporal Attention(X)) +X). \end{aligned}$$
(20)

Here, \(W_1,W_2,b_1,b_2\) are the corresponding learnable parameters. TemporalAttention denotes multi-head self-attention, and its process is simplified as follows:

$$\begin{aligned}{} & {} Temporal Attention(X) \nonumber \\{} & {} \quad =Concat(head_{1},head_{2},\ldots ,head_{n})W_{T} \end{aligned}$$
(21)
$$\begin{aligned}{} & {} head_{m} =softmax\left( \frac{XW_{qm}^{T}(XW_{km}^{T})^{T} }{\sqrt{d_{k} } } \right) XW_{vm}^{T} \end{aligned}$$
(22)

where \(X=GLF+PE\), PE represents Positional Encoding. \(W_{qm}^{T}\in R^{D\times d_{q} }, W_{km}^{T}\in R^{D\times d_{k} }, W_{vm}^{T} \in R^{D\times d_{v} }, W_{T}\) are learnable parameters. Temporal Transformer, through positional encoding and attention mechanism, can learn bidirectional and long-term temporal dependencies in the traffic trajectory sequence.

Table 1 Experiment on the parameter value of the method in this paper

Finally, we integrate the short-term temporal dependencies \(T_S\) obtained from LSTM in LSAN:

$$\begin{aligned} ST_{out}=ST_{L} + T_{S}, \end{aligned}$$
(23)

\(ST_{out}\) represents the multi-dimensional spatio-temporal features captured by the traffic agent’s trajectory sequence after passing through the spatial module and temporary module. To prevent information loss, we use residual connections after each spatial-temporal layer.

Trajectory prediction module

Most researchers utilize Seq2Seq structures for the final trajectory prediction. However, Seq2Seq structures excessively rely on the context vectors generated by the encoder, and the loss of contextual information can lead to inaccurate prediction results. TCN based on CNN has the characteristics of translation invariance and parallel computing, which are not limited by context windows. It has demonstrated effectiveness in capturing temporal relationships within sequence information. Therefore, we attempted to use TCN for final trajectory prediction. As shown in Fig. 4, TCN consists of a series of stacked residual blocks that implement dilated causal convolution layers. Each layer of dilated causal convolution uses the same size of convolutional kernels with different dilation rates. In our scenario, for a given input time series \(ST_{out}=(x_{1},x_{2},\ldots ,x_{T_{obs}} )\) and filters \( \Gamma =(f_{1},f_{2},\ldots ,f_{U} )\), the dilated causal convolution at time t is defined as:

$$\begin{aligned} ST_{out}*\Gamma (t)=\sum _{u=1}^{U}\Gamma (u)x_{t-(U-1)d} \end{aligned}$$
(24)

where d is the dilation rate, indicating the distance between convolutional kernels. Each residual block primarily consists of two-dimensional convolutions, WeightNorm normalization, ReLU activation function, and Dropout. TCN residual blocks capture information at different time scales by stacking multiple convolutional layers. They accelerate training and address the vanishing gradient problem through residual connections, and they utilize dilated convolutions to enlarge the receptive field. These designs enable TCN to handle long-term dependencies.

The dilation rate in TCN is usually set as an exponential form of 2. While this significantly improves computational speed, it may result in the loss of the local information and a decrease in prediction accuracy. To address this issue and improve accuracy, we employ two distinct TCNs to extract time dependencies and set different sizes of convolutional kernels (\(k_1\) and \(k_2\)) and dilation rates for each TCN network. Finally, the results from the two different TCNs are then concatenated and subsequently passed through two fully connected layers to obtain the final output. Given \(ST_{out}\), our trajectory prediction module is as follows:

$$\begin{aligned} {\hat{Y}} =FC(FC(Concat(\Theta _{1} (ST_{out}),\Theta _{2} (ST_{out}) ))) \end{aligned}$$
(25)

where \(\Theta _{1},\Theta _{2} \) are two independent TCN networks. Using TCN, we obtain the final predicted future trajectory \({\hat{Y}} \) from time \(T_{obs}+1 \) to \(T_{obs}+T_{pred} \).

We use L2-loss as the loss function for the MDSTF model:

$$\begin{aligned} Loss=\sum _{t=T_{obs}+1}^{T_{obs}+T_{pred}} |{\hat{Y}}^{t}-Y^{t} | \end{aligned}$$
(26)

where \({\hat{Y}}^{t} \) is the predicted value of the MDSTF model, and \(Y^{t} \) is the ground truth.

Experiment

Dataset

We evaluate our model on the ApolloScape Trajectory dataset [30]. The dataset comprises images, point clouds, and manually annotated trajectories. It includes 53 min training trajectory sequence captured at a rate of 2 frames per second and 50 min testing trajectory sequence. The dataset primarily consists of complex traffic flows, with a mixture of vehicles (small and big vehicles), pedestrians, and cyclists (motorcyclists and bicyclists) in urban settings. Therefore, it holds significant relevance for studying mixed traffic scenarios.

Evaluation metrics

We use two metrics, the average displacement error (ADE) and final displacement error (FDE) metrics, to evaluate the performance of our model. ADE refers to the average Euclidean distance between all the predicted positions and the corresponding ground true positions, while FDE refers to the average Euclidean distance between the position at the final time step and the corresponding ground truth position. According to the description provided by the official dataset, as the heterogeneous traffic agents differ in scale, Weighted sum of average displacement error (WSADE) and weighted sum of final displacement error (WSFDE) were used as the evaluation metrics.

$$\begin{aligned} \begin{aligned} WSADE&=D_{v}\cdot ADE_{v}+D_{p}\cdot ADE_{p}+D_{b}\cdot ADE_{b} \\ WSFDE&=D_{v}\cdot FDE_{v}+D_{p}\cdot FDE_{p}+D_{b}\cdot FDE_{b}. \end{aligned} \end{aligned}$$
(27)

Among them, \(D_v,D_p\) and \(D_b\) are correlated with the reciprocal of the average speed of vehicles, pedestrians, and cyclists in the dataset, with the value of 0.20, 0.58, and 0.22, respectively.

Fig. 5
figure 5

WSADE and WSFDE values under different hyperparameters

Parameter comparison test

We conducted comparative experiments on the parameters settings involved in our model, including the distance thresholds (\(D_{threshold}\)), heading angle threshold (\(\theta _{threshold}\)), TCN convolution kernel sizes (\(k_1\) and \(k_2\)). The experimental results are shown in Table 1. Here, "Parameter" represents the name of the parameter studied in the experiment, "Object Type" indicates the type of traffic agent, and "Parameter Value" denotes the reasonable value we set for the parameter in the experiment. In addition, we also explored the embedding dimension (D), the number of stacked spatio-temporal layers (L), and the number of multi-head attention heads (n) as mentioned in the paper. The experimental results for these parameters are shown in Fig. 5.

When constructing the distance adjacency matrix, we set different distance thresholds for different categories of traffic agents. Considering that the range affecting the movement of pedestrians, cyclists, and vehicles may vary in the real world, we set a smaller distance threshold for pedestrians compared to cyclists, and a smaller distance threshold for cyclists compared to vehicles.

According to Table 1, it is observed that the best experimental results are achieved when the distance thresholds \(D_{threshold}\) for pedestrians, cyclists, and vehicles are set to 10, 15, and 20, respectively. When adding 5 to their respective distance thresholds \(D_{threshold}\), both WSADE and WSFDE increase, indicating a deterioration in the experimental performance. This suggests that increasing the distance threshold leads to the inclusion of excessive traffic agents as invalid neighboring nodes in the traffic graph, thereby interfering with the accuracy of the model. Similarly, when subtracting 5 from their respective distance thresholds, WSADE and WSFDE also increase, resulting in a decrease in experimental performance. This highlights that excessively small distance thresholds cause the traffic graph to lack effective neighboring nodes, reducing the accuracy of the model.

To validate the rationale behind setting different distance thresholds for different traffic agent types, we conducted a comparative experiment where we set the same distance threshold for all traffic agent types. As shown in Table 1, when the distance threshold \(D_{threshold}\) was set to 20 for all agent types, the experimental results were optimal. Decreasing or increasing the distance threshold showed signs of increased WSADE and WSFDE, but there was still a noticeable difference compared to setting different distance thresholds for different agent types. This further validates the soundness of our approach in setting different distance thresholds for different traffic agent types.

Regarding the setting of the heading angle threshold parameter (\(\theta _{threshold}\)), we do not differentiate between different types of traffic agents. When \(\theta _{threshold}\) increased from \(\pi /12\) to \(\pi /4\), WSADE and WSFDE initially decreased and then increased, indicating a deterioration in experimental performance. This could be attributed to potential deviations or noise in the recorded heading angles in the dataset. A smaller heading angle threshold may result in the exclusion of some traffic agents that may have similar states. On the other hand, setting a larger heading angle threshold may include traffic agents with different traffic states as neighboring nodes, thereby impacting the accuracy of the model. The best experimental results were achieved when \(\theta _{threshold}=\pi /6\).

For the convolution kernel sizes \(k_1\) and \(k_2\) in two TCNs, we considered the historical trajectory step size of 6 in the ApolloScape dataset. Therefore, we select three convolution kernel sizes of 2, 3, and 5, and performed pairwise combinations. From Table 1, it can be observed that the different combinations yield similar results. However, the best experimental performance was achieved when \(k_1=2\) and \(k_2=3\).

Figure 5 illustrates the WSADE and WSFDE prediction results of MDSTF with different hyperparameter settings. When adjusting one parameter, the other parameters were set to the optimal value by default. From Fig. 5, it can be observed that appropriately increasing the number of spatio-temporal layers and attention heads can improve the model’s performance. The best results were obtained when the embedding dimension \(D=64\) and the number of attention heads \(n=8\). As for the spatio-temporal layers, increasing the number of layers enhances the experimental performance. However, when \(L=5\), the model achieves the highest performance, but having more layers may introduce excessive complexity and reduce computational efficiency.

In summary, we set the distance thresholds \(D_{threshold}\) for vehicles, cyclists, and pedestrians to 20, 15, and 10, respectively. The heading angle threshold \(\theta _{threshold}\) was set to \(\pi /6\). The embedding feature dimension is \(D=64\), the Spatial-Temporal Layer \(L=5\), and the number of multi-head attention heads is \(n=8\). For the TCN layers, one convolution kernel size of \(k_1=2\), while that of the other is \(k_2=3\). The Dropout is uniformly set to 0.2.

We trained MDSTF using PyTorch on an NVIDIA GeForce 3080Ti GPU. The optimizer utilized was Adam, with an initial learning rate of 0.001. Throughout the training process, a batch size of 32 was employed.

Table 2 Ablation experiment of the model of this paper

Ablation experiment

To verify the effectiveness of each component of the model, we conducted several ablation experiments in this section to observe the performance variations. The experimental results are shown in Table 2.

The modules name in Table 2 is defined as follows: LSAN represents the LSTM-spatial attention network module, GSA stands for the global spatial modeling module, DA indicates the distance adjacency matrix, HA represents the heading angle adjacency matrix, Single-TCN refers to the use of a single TCN, Double-TCN represents using two TCNs, PE represents the positional encoding of the transformer, and S-LSTM represents the presence or absence of fused short-term temporal dependencies in the Transformer encoder output. We use "\(\surd \)" to indicate whether MDSTF includes the aforementioned modules.

From Table 2, we can draw the following conclusions:

Comparing A and B, the model with the full-process spatial modeling shows a significant improvement in accuracy compared to without it, with a decrease of 3.34% and 3.68% in WSADE and WSFDE, respectively. This validates that the local and global spatial modeling alone cannot fully capture the spatial relationships between traffic agents, while the full-process spatial modeling provides more comprehensive modeling of spatial dependencies.

Comparing A and C, the global spatial modeling also contributes to the model’s accuracy improvement, with decreases in both WSADE and WSFDE. This indicates that the status of traffic agents is not only influenced by nearby traffic agents but also affected by distant ones, thus validating the rationality of considering the global spatial features.

Comparing A, D, and E, the distance adjacency matrix (DA) and the heading angle adjacency matrix (HA) improve prediction accuracy and enhance features. The DA has a greater impact on the model’s improvement, while the HA has a smaller effect. This suggests that the future trajectory of traffic agents is more influenced by nearby agents.

Comparing A and F, we can see that using two TCNs with different kernel sizes improves the model’s prediction accuracy. This indicates the potential loss of the local information in TCN and validates that using two TCNs with different kernel sizes can complement each other in capturing temporal relationships.

Comparing A and G reveals that positional encoding in the temporal Transformer enhances prediction accuracy. Since the historical trajectory is a time series that carries temporal ordering, this demonstrates that position encoding allows the model to learn the temporal relationships within the trajectory sequence. Meanwhile, comparing A and H suggests that the Transformer alone cannot effectively capture the short-term temporal dependencies in traffic trajectories, indicating that traffic trajectories are more influenced by the short-term states of traffic agents.

In addition, to verify the rationality of using TCN as the trajectory prediction module, we compared the results of TCN with Seq2Seq with different encoding and decoding methods, as shown in Table 3. The results indicate that utilizing TCN for final trajectory prediction outperforms the Seq2Seq structure.

Table 3 Comparison results of different trajectory prediction modules

Comparison baselines

To evaluate the performance of the MDSTF model, we compare it with representative state-of-the-art models that have been officially released by Apolloscape. The baseline models involved in the comparison are as follows:

  • TrafficPredict [30]: A real-time traffic prediction algorithm based on LSTM networks. It serves as the baseline model for the ApolloScape trajectory dataset.

  • Social LSTM (S-LSTM) [19]: It incorporates a social pooling mechanism to capture neighbor information and utilizes LSTM to extract pedestrian trajectory features.

  • Social GAN (S-GAN) [31]: The model uses a conditional GAN to predict socially reasonable future trajectories.

  • StarNet [32]: It predicts future trajectories by building a network of observed pedestrian trajectories and uses the star-planet topology to represent the interactions between pedestrians.

  • Transformer [21]:This model applies a standard Transformer to model pedestrian trajectories without any complex interaction terms.

  • TPNet [33]: Firstly, it generates a set of candidate trajectories for future predictions. Then, the final prediction results are obtained by classifying and refining these candidate trajectory sets.

  • GRIP++ [15]: This model represents the interaction between traffic agents using graphs and utilizes multiple graph convolutional blocks to aggregate features. Subsequently, an encoder-decoder LSTM model is employed to predict future traffic behavior.

  • MVHGN [14]: It constructs a multi-view logical network for multi-view logical feature extraction. It combines an adaptive spatial topology network and a macro-level region clustering network to capture micro-level logical-physical features and global logical-physical features, respectively. The prediction is performed using a Seq2Seq model with GRU.

  • S2TNet [17]: It captures spatiotemporal dependencies through spatial self-attention mechanisms and TCN (Temporal Convolutional Networks). The future trajectory is then predicted using a Transformer network.

The comparison results between MDSTF and the aforementioned baseline models are shown in Table 4.

Firstly, based on the experimental results, MDSTF outperforms all the baselines. Compared to the state-of-the-art solution (S2TNet), MDSTF achieves a reduction of 4.37% and 6.23% in the two most important metrics, WSADE and WSFDE, respectively. In terms of predictions on the three traffic entities (vehicles, pedestrians, and cyclists) in the ApolloScape dataset, MDSTF reduces ADE by 5.98%, 1.37%, and 5.82%, respectively, while achieving reductions of 6.84%, 2.52%, and 9.58% in FDE. Secondly, MDSTF exhibits a significant improvement in FDE compared to the S2TNet model, indicating its enhanced accuracy and robustness for long-term trajectory prediction. Lastly, when compared to S2TNet, which considers the global spatial features and long-term temporal dependencies, MDSTF demonstrates better accuracy. In comparison to MVHGN, which constructs multiple adjacency matrices for GCN, MDSTF significantly outperforms MVHGN in the WSADE and WSFDE metrics.

The computational complexity of MDSTF is primarily influenced by LSTM, GCN, three self-attentions, and two TCNs. Compared to S2TNet, MDSTF mainly increases in the time complexity of LSTM (approximately \(O(T_{obs}D^{2}) \)) and the space complexity (approximately \(O(D^{2}) \)), as well as the time complexity of GCN (approximately \(O(N+E) \)) and the space complexity (approximately O(ND) ), where E represents the number of edges in the constructed traffic graph.In our future work, we will continue to optimize our model, balancing the model’s computational complexity with its accuracy.

Table 4 Comparative results on apolloscape dataset

Visualization and analyses

This paper primarily investigates the problem of future trajectory prediction for traffic agents in mixed traffic scenarios. We visualize several prediction results from the ApolloScape Trajectory dataset, as depicted in Figs. 6, 7, 8 and 9.

Single trajectory prediction

Figure 6 displays the trajectories of individual traffic agents from four different categories compared to S2TNet. MDSTF accurately predicts the trajectories of various traffic agent categories within a 3-second future range, consistent with our experimental results in terms of WSADE. As the prediction horizon increases, MDSTF exhibits a closer alignment with the actual final time position compared to S2TNet, consistent with our experimental results in terms of WSFDE. Additionally, MDSTF demonstrates better cumulative error performance than S2TNet.

Fig. 6
figure 6

Comparison results with S2TNet

Single trajectory prediction in extreme scenarios

In Fig. 7, we present the predicted trajectories of traffic agents in some extreme or sharp-turn situations and compare them with the visualization results of S2TNet. In such traffic scenarios, both observed and future trajectories exhibit irregular patterns. From Fig. 7, we can observe that our trajectory predictions also yield satisfactory results compared to S2TNet in these extreme scenarios. For instance, in Fig. 7a, it can be seen that the vehicle has been nearly stationary in the past, but its future trajectory suddenly starts moving. MDSTF manages to approximate its future trajectory reasonably well, while S2TNet predicts that it remains stationary. MDSTF strives to fit the real future trajectories of traffic agents as accurately as possible.

Fig. 7
figure 7

Comparison results with S2TNet in extreme scenarios

Fig. 8
figure 8

Comparison results with S2TNet in mixed traffic scenarios

Fig. 9
figure 9

Comparison results with S2TNet on DE metrics

Trajectory prediction results in mixed traffic scenarios

In Fig. 8, we present the visualizations of trajectory predictions in mixed scenarios for both fewer and multiple traffic agents compared to S2TNet. In Fig. 8a, for the closely spaced traffic agents A and B, their interaction is significant. The future trajectory of traffic agent A exhibits a smaller distance interval compared to traffic agent B, indicating that traffic agent A is strongly influenced by traffic agent B. This aligns with the concept of the distance adjacency matrix we constructed. Furthermore, traffic agents A and C have similar heading angles, and their future trajectories demonstrate a high degree of similarity, suggesting the validity of the heading angle adjacency matrix we constructed. In Fig. 8b, here there are multiple traffic agents, the visualization shows that MDSTF performs well in predicting future trajectories of a larger number of traffic agents compared to S2TNet. Specifically, for traffic agents D and E, whose historical trajectories only consist of the last time step, but exhibit significant movement in future trajectories, S2TNet predicts that they remain stationary. However, MDSTF accurately predicts their future trajectories. These observations indicate that MDSTF performs well in predicting future trajectories of multiple traffic agents in mixed traffic flow scenarios.

DE metrics

In addition to comparing WSADE and WSFDE metrics on the Apolloscape dataset, we also evaluated the DE metrics across all prediction steps, where DE refers to the Displacement Error between the predicted and ground truth positions at time t. The experimental results are shown in Fig. 9. In addition to the aforementioned categories of traffic agents, another category called "others" representing miscellaneous agents is also included. Regardless of whether it is vehicles, pedestrians, cyclists (motorcyclists and bicyclists), or "others", the performance of the initial prediction steps is similar between MDSTF and S2TNet. However, in the later time steps, the DE metric for vehicles and cyclists is significantly lower for MDSTF compared to S2TNet. On the other hand, the improvement in DE metric for pedestrians and "others" is relatively modest, indicating that MDSTF has a significant advantage over S2TNet in long-term predictions.

Conclusion

In this paper, we present MDSTF, a trajectory prediction model that incorporates multi-dimensional spatio-temporal features to predict the future trajectories of traffic agents. MDSTF effectively captures the local, global, and full-process spatial features from the historical trajectories of traffic agents, while also accounting for long-term and short-term temporal dependencies. Compared with the best baseline model (S2TNet), MDSTF demonstrates superior capability in capturing spatio-temporal features within trajectory sequences and exhibits enhanced performance. It achieves a notable reduction of 4.37% in the WSADE metric and 6.23% in the WSFDE metric, which are both significant evaluation measures. Moreover, in challenging scenarios characterized by extreme or mixed traffic conditions, MDSTF excels in capturing the complex spatio-temporal dependencies among traffic agents, enabling accurate predictions of future trajectories. It is important to note that trajectory prediction for traffic agents not only relies on historical trajectory information but also greatly depends on road-related factors in the traffic environment, such as traffic lights and pedestrian crossings. In future research, we aim to incorporate road information to further explore potential spatio-temporal features in mixed traffic environments and enhance the accuracy of trajectory prediction.