1 Introduction

Nowadays, there exists an accumulating focus on health, particularly in the wake of the COVID-19 pandemic. COVID-19, also recognized as SARS-CoV-2, is an acute respiratory disease characterized by a substantial initial mortality rate during its outbreak (Wu et al., 2020; Roy et al., 2020; Habib and Kayani, 2023). In the early stages of the pandemic, a notable number of COVID-19 patients experienced clinical deterioration necessitating hospitalization or admission to intensive care units (ICUs), profoundly straining national healthcare systems in a remarkably brief span (McMahon et al. 2020; Buonsenso et al.  2021). To address this challenge, more cutting-edge technologies have been introduced. Decision support systems (DSS) have found application across diverse industries, including business marketing investment (Scheepers and Scheepers, 2008), social media (Sadovykh and Sundaram, 2019; Meske and Bunde, 2023), transportation management (He et al., 2014), and others industries (Zolbanin et al.  2020; Tanergüçlü et al.  2012; Pessoa et al., 2022).

To confront the multifaceted challenges encountered within the healthcare sector, clinical decision support systems (CDSS) have exhibited substantial promise in aiding healthcare professionals to deliver precise and efficient diagnoses and treatments, thereby advancing the principles of precision medicine (Tutun et al., 2022; Huang et al., 2020; Wijnhoven, 2022; Karthikeyan et al., 2021; Murri et al., 2022). In COVID-19, CDSS has proven effective in disease diagnosis and enhancing overall hospital operational efficiency (Mourad et al., 2021; Abbaspour Onari et al., 2021; Govindan et al., 2020). CDSS represents a computer-based tool designed to furnish healthcare professionals with patient-specific information and knowledge, serving as a valuable resource for clinical decision-making (Kawamoto et al., 2005). These systems can be seamlessly integrated with electronic health records (EHRs), offering real-time access to critical patient data, encompassing medical notes, medication history, laboratory test results, and medical imaging studies. Furthermore, Artificial Intelligence (AI) or Machine Learning (ML)-based CDSS benefited from large volumes of EHR data, has been used to enhance healthcare delivery and make more precise medical decisions, showing potential in treating the COVID-19 pandemic (Wijnhoven, 2022; Singh et al., 2021; Qjidaa et al., 2020; Karthikeyan et al., 2021). Therefore, in this study, we present a novel CDSS model applied to EHRs for disease risk predictive modelling aimed at advancing the existing Machine Learning for Decision Support Theory. This theory utilizes machine learning algorithms to analyze historical EHR data to predict patients’ outcomes, thereby aiding decision-making processes.

However, most of the aforementioned CDSSs primarily focus on acute disease risk diagnosis. In the current world, taking America as an example, nearly half (approximately 45%, or 133 million) of all Americans suffer from at least one chronic disease (Raghupathi and Raghupathi, 2018), and they cause 41 million deaths each year, equivalent to 74% of all deaths globally (Organization et al., 2022). This indicates the urgency for current CDSS to include chronic disease diagnosis from patients’ longitudinal hospital visit information. Typical chronic diseases include hypertension, diabetes, stroke, etc. In addition, the majority of current CDSS models (Xu et al., 2018; Choi et al., 2016; Ma et al., 2017) mainly process structured sequential EHRs, neglecting the abundant clinical information embedded in unstructured medical notes. This paper focuses on inventing an AI-based CDSS to implement the typical task of risk prediction in precision medicine, using unstructured longitudinal medical notes collected during multiple hospital visits for the prediction of acute disease risks, chronic disease risks, and mixed disease risks. To achieve this goal, we first need to address two research problems:

Q1) How can a CDSS effectively capture time-varying information from longitudinal EHR data?

Q2) How can a CDSS effectively extract clinically useful information contained in large volumes of unstructured medical notes for predictive model construction?

For the first problem, various CDSS models have been developed. Shang et al. (2019) proposed various recurrent neural networks (RNNs) and attention-based models for learning temporal information. However, despite their predictive power in various prediction tasks such as disease risk prediction, disease diagnosis, and length of hospital stay, the hidden states learned from these RNNs-based models cannot provide a probabilistic interpretation to represent patients’ health states as they are pure data-driven black-box approaches (Krishnan et al., 2017). Hidden Markov Model (HMM)—based disease progression models have been investigated to explicitly model state transitions of patients (Alaa et al., 2017). Based on neural networks and HMM, deep state-space models, parameterising both state transition function and observation function using neural networks, have been developed (Li et al., 2021; Oezyurt et al., 2021; S. Niu et al., 2024; Niu et al., 2024). The main advantages of these models include the formal integration of patients’ latent states and diverse observations, the ability to track the changes of latent states by learning dynamic representations, the construction of a generative model to predict future representations and generate future observations, and the provision of a state transition model considering the impacts of external interventions. Therefore, this paper focuses on developing a deep state-space-based CDSS model for longitudinal patient data. From the summarization of recent CDSS models in Table 1, our model is one of the first attempts to apply deep state-space modelling to unstructured text datasets, different from most of the existing works.

To address the second problem, an increasing number of feature extraction methods are being developed, such as Clinical-BERT (Alsentzer et al., 2019). Among them, attention-based models, such as Recurrent Attentive and Intensive Modelling (RAIM) (Xu et al., 2018) and Multimodal Attentional Neural Network (MNN) (Qiao et al., 2019), are widely used. Recently, several studies (Mullenbach et al., 2018; Wang et al., 2018; Niu et al., 2021a, b, 2024) have adopted a label-dependent approach based on the attention mechanism to integrate the information from both medical notes and disease risk labels for improving prediction performance and achieving interpretability. In these models, the descriptions of labels and words from medical notes are jointly embedded and integrated via the attention mechanism. In this paper, we adopt the label-dependent attention approach into our CDSS to encode unstructured medical notes into latent representations at each time step, where the deep state-space model will update these latent representations.

Table 1 The list of recent CDSS models
Fig. 1
figure 1

An example is provided to illustrate the input and output of our model for predicting a patient’s risks. The input comprises unstructured medical notes gathered from three hospital visits for the patient, along with descriptions of risk labels, where the label descriptions are series of textual phrases. The output is represented as a vector indicating the presence of different types of risks for the patient, denoted by 0/1

Our model is named Deep State-space with Label-dependent Attention Model (DSLAM), which is composed of three modules: the prior module to learn the transition of patients’ latent states, the posterior module to approximate the posterior distribution of latent states via the label-dependent attention mechanism, and the likelihood module to generate predictions from latent states. Figure 1 illustrates the input (medical notes and description of risk labels) and output (disease risk prediction) of our model. The colored arrows indicates that the medical terms in medical notes have semantic relation with the descriptions of risk labels. The main contributions can be summarized as follows:

  • Our work introduces a pioneering CDSS approach for disease risk prediction using longitudinal medical notes, viewing patient risks as outcomes of a state-space transition process.

  • We emphasize the importance of accurately capturing latent states in studying longitudinal EHRs, providing clear guidance for future disease risk research and clinical decision-making, especially for patients with multiple hospital visits.

  • Our model is the first work that is tailored for unstructured longitudinal medical notes, combining probabilistic models with deep neural networks. It enhances predictive efficacy and interpretability through a label-dependent attention mechanism.

  • To substantiate the effectiveness of our model, we conduct rigorous evaluations on two real-world EHR datasets, namely MIMIC-III (Johnson et al., 2016) and N2C2-2014 (Kumar et al., 2015) for both quantitative and qualitative evaluations.

2 Related Work

2.1 AI/ML-based Clinical Decision Support System

AI/ML-based CDSS have become increasingly popular in healthcare due to their potential to improve diagnostic accuracy, treatment planning, and patient outcomes by using different types of EHR data. Park et al. (2022) shows that adopting a CDSS can significantly reduce hospital readmission rates for heart failure, acute myocardial infarction, and pneumonia patients, underscoring the importance of using accurate and comprehensive EHR data. Miotto et al. (2018) investigated different AI-based DSS selections in terms of various EHR data, demonstrating better accuracy for disease diagnosis compared to human-crafted models. Berge et al. (2023) utilized an AI-based DSS to identify and classify allergies of concern for anesthesia and intensive care, resulting in increased detection of patient allergies and improved quality of practice and patient safety during surgery or intensive care unit stays. These studies collectively illustrate that the effectiveness of AI/ML-based CDSS in the medical process is influenced by several key factors: the quality and comprehensiveness of input data, appropriate AI/ML algorithms in CDSS for different EHR types, and the integration of DSS into actual clinical workflows.

2.2 Clinical Decision Support System Across Different EHR Modalities

In terms of processing unstructured medical notes, word-embedding-based deep learning models (Mullenbach et al., 2018; Wang et al., 2018; Qiao et al., 2019) combined with various attention mechanisms (Vaswani et al., 2017) are used to assist disease diagnosis. In addition, pre-trained language models (PLM) have been widely adopted to process different kinds of downstream natural language processing (NLP) tasks, as well as for clinical decision support making. For example, Niu et al. (2021a, 2021b) applied Clinical-BERT (Alsentzer et al., 2019) to extract features from medical notes for disease risk prediction. For laboratory testing results, RNNs-based models are used to process time-series numerical data. For example, RAIM (Xu et al., 2018) used multi-channel RNNs to extract hidden states from laboratory testing results and electrocardiogram (ECG) waveform for ICUs length of stay prediction and patient mortality prediction. Ozyurt et al. (2021) built an attentive deep Markov model based on RNNs for mortality prediction using different laboratory testing measurements. For medical image data, convolutional neural networks (CNNs)-based models are often applied to capture visual information. Arevalo et al. (2016) applied CNNs for the detection of tumours and their classification into benign and malignant. Monshi et al. (2021) built a CNNs-based deep learning model, CovidXrayNet, for the diagnosis of COVID-19. Even though this research works for different types of EHR data achieved comparable accuracy, there are only limited works that notice the longitudinal information hidden in patients’ multiple hospital visits that could be used to improve the clinical decision-making accuracy for prolonged diseases.

2.3 Modelling Longitudinal EHR Data

EHRs have become a valuable resource for researchers and clinicians due to the longitudinal data they contain regarding patients’ medical histories. Deep learning models have been developed to capture and analyze the complex longitudinal information present in EHR data. These models have various applications, such as medication recommendation, health risk prediction, disease risk prediction, and disease progression understanding. One example of a deep learning model used in EHR analysis is RNNs. Shang et al. (2019) developed a model that uses RNNs to recommend medication from successive drug codes across multiple hospital visits of a patient. In addition, Ma et al. (2020) used the time-aware attention mechanism to predict health risks. Luo et al. (2020) built a HiTANet model that uses a time-aware attention-based transformer to predict disease risks. Another model used in EHR analysis is Deep Predictive Clustering (DPC), which clusters RNNs encoded time-series data over time for temporal phenotyping. Lee and Van Der Schaar (2020) used DPC to understand disease progression. Deep state-space models have also been developed for longitudinal data analysis. For instance, Oezyurt et al. (2021) built an attentive Deep Markov Model to trace patients’ latent status and predict patient risks using laboratory testing data. Li et al. (2021) proposed a causal hidden Markov model to learn separate latent representations, in which different supervised tasks, including medical image reconstruction and risk prediction, were used to learn separate representations. Despite these recent advancements, few studies have focused on employing deep state-space models to analyze unstructured longitudinal medical notes in EHR data. The development of such models could lead to new insights into disease progression and improve patient outcomes. Therefore, there is a need for further research in this area.

Fig. 2
figure 2

The workflow of the clinical decision support system involves processing longitudinal medical notes to generate disease risk predictions. Initially, the unstructured medical notes undergo a noise and non-English character removal process. Subsequently, they are sequentially input into our deep state-space-based CDSS model to generate predictions for disease risks at each visit

2.4 Attention Mechanism for Explainable AI

The attention mechanism has gained popularity in interpreting the outputs produced by deep neural networks (Vaswani et al., 2017). In the field of predictive medicine and clinical decision-making, interpretability is a crucial factor that cannot be ignored. Hence, the attention mechanism is often implemented in various ways to provide explainability and reasons for clinical decision-making. For instance, RAIM (Xu et al., 2018) designed a multi-channel attention and recurrent attention mechanism to comprehend the contribution of different input features from EHR data. MNN (Qiao et al., 2019) employed attentional bidirectional RNNs to identify essential features from medical notes and codes. Similarly, Choi et al. (2016); Ma et al. (2017) utilized attention mechanisms with RNNs to determine vital hospital visits of patients and significant input features for disease diagnosis. Additionally, a label-dependent attention mechanism was introduced to identify more phrases related to the target disease from medical records. The label-dependent attention method involves embedding the names or descriptions of the prediction task labels and data features and then integrating their embeddings through the cross-attention mechanism. For example, Akata et al. (2015); Radford et al. (2021) introduced a function to learn image and label embeddings for zero-shot image classification jointly. In Mullenbach et al. (2018), a convolutional attention mechanism was proposed for medical note classification, which utilized label embeddings. Similarly, in Wang et al. (2018), a joint word and label embedding were developed for text data classification. In Niu et al. (2021a), label embeddings were used to guide the integration of medical notes and time-series health status indicators for disease risk prediction. However, the use of label-dependent attention has not been extensively explored in modelling longitudinal medical notes generated during multiple hospital visits.

3 Method

In this section, we introduce our model, DSLAM, which is specifically designed for modelling longitudinal unstructured medical notes in order to predict disease risk. Our model DSLAM builds upon the existing Machine Learning for Decision Support Theory and demonstrates superior performance in risk prediction. The clinical decision-making process is illustrated in Fig. 2, which involves collecting longitudinal medical notes from EHRs, preprocessing the data by eliminating noise from the medical notes, modelling the longitudinal EHRs, and generating disease risk predictions.

In the subsequent sections, we provide the preliminaries of our deep state-space method for modelling longitudinal EHRs, a formal definition of our problem, followed by an overview of our model, and a general description of the deep state-space model. We then provide detailed explanations of individual components, including the prior, posterior, and likelihood modules. Specifically, we will comprehensively explain the label-dependent attention mechanism adopted in the posterior module for text data encoding.

3.1 Preliminary

3.1.1 Variational Autoencoders (VAEs)

Variational Autoencoders (VAEs) (Kingma & Welling 2013) is a type of generative model that combines the autoencoder framework with probabilistic graphical models to learn a probabilistic representation of input data. A typical VAE consists of two crucial components: an encoder and a decoder. The encoder \(q_\phi (\varvec{z}\mid \varvec{x})\) is used to approximate the true posterior distribution \(p_\phi (\varvec{z}\mid \varvec{x})\) by mapping the input data \(\varvec{x}\) into a probabilistic distribution \(\varvec{z}\). The decoder is the likelihood of the process of data generation that results in the data \(\varvec{x}\) from \(\varvec{z}\), denoted as \(p_\theta (\varvec{x}\mid \varvec{z})\). Here, \(\phi \) and \(\theta \) represent learnable parameters through encoder neural networks and decoder neural networks, respectively.

To approximate the posterior distribution of the encoder, we have:

$$\begin{aligned} \log q_\phi (\varvec{z}\mid \varvec{x}) = \log \mathcal {N}(\varvec{z} ; \varvec{\mu }, \varvec{\sigma }^2\textbf{I}), \end{aligned}$$
(1)

where \(\varvec{\mu }\) and \(\varvec{\sigma }\) represent the mean and standard deviation of the approximate posterior, generated through the encoding of the input data via a neural network.

To sample latent variable from the posterior distribution \(\varvec{z}\), we often use the re-parameterization trick (Kingma & Welling 2013) to address the non-differentiable issue of sampling during the training process:

$$\begin{aligned} \varvec{z} = \varvec{\mu } + \epsilon \odot \varvec{\sigma }. \end{aligned}$$
(2)
Fig. 3
figure 3

The structure of deep state-space model, which includes prior networks, posterior networks, and likelihood networks

Here, \(\odot \) represents element-wise multiplication, and \(\epsilon \sim \mathcal {N}(0,\textbf{I})\). Finally, we can obtain the likelihood \(p_\theta (\varvec{x}\mid \varvec{z})\) by maximizing the evidence lower bound (ELBO):

$$\begin{aligned} \mathcal {L}_{ELBO} = - D_{KL}(q_\phi (\varvec{z}\mid \varvec{x}) \mid \mid p_\theta (\varvec{z})) + \mathbb {E}_{q_\phi }[\log p_\theta (\varvec{x} \mid \varvec{z})]. \end{aligned}$$
(3)

Here, \(D_{KL}\) represents the Kullback-Leibler (KL) divergence. The common choice of the prior distribution \(p_\theta (\varvec{z})\) is a standard Gaussian distribution \(\mathcal {N}(0, 1)\).

3.1.2 Deep State-space Model

The deep state-space model utilizes the basic framework of VAEs. Figure 3 illustrates the structure of the deep state-space model. Suppose we have \(\{1,..,t,..,T\}\) states that need to be modeled. The prior distribution \(p_\theta (\varvec{z}_t)\) is generated by a prior network as the transition function for different latent states, denoted as \(p_{\theta }(\varvec{z}_{t} \mid \varvec{z}_{t-1})\).

For the generic deep state-space model (Rangapuram et al., 2018; Li et al., 2021), the posterior distribution of \(q_\phi (\varvec{z}_t \mid \varvec{z}_{t-1}, \varvec{x}_t)\) is generated by an encoding posterior network that encodes and samples from the input data \(\varvec{x}_t\) and the last previous latent state \(\varvec{z}_{t-1}\). In the context of patients’ longitudinal EHRs, the chronic disease diagnosis can often be related to all previous hospital visits. Therefore, the posterior distribution can be represented as \(q_\phi (\varvec{z}_t \mid [\varvec{z}_{1},...,\varvec{z}_{t-1}], \varvec{x}_t)\).

For the likelihood of the deep state-space model, instead of generating the original input \(\varvec{x}_t\), there is often an objective \(\varvec{y}_t\) to be predicted by a likelihood network, denoted as \(p_\theta (\varvec{y}_t\mid \varvec{z}_t) \).

To infer the parameters and latent states, the Bayesian variational learning approach is applied, where the learning objective is defined by the evidence lower bound on the log data likelihood (ELBO) (Krishnan et al., 2017):

$$\begin{aligned} \mathcal {L}_{ELBO}=&\mathbb {E}_{q_\phi (\varvec{z}_1\mid \varvec{x}_1)}[\log p_\theta (\varvec{y}_1\mid \varvec{z}_1)] \nonumber \\&-KL(q_\phi (\varvec{z}_1 \mid \varvec{x}_1) \mid \mid p_\theta (\varvec{z}_1))\nonumber \\ +&\sum _{t=2}^T [\mathbb {E}_{q_\phi (\varvec{z}_t\mid \varvec{Z}_{t-1},\varvec{x}_t)}[\log p_\theta (\varvec{y}_t\mid \varvec{z}_t)]\nonumber \\&-\sum _{t=2}^T KL(q_\phi (\varvec{z}_t\mid \varvec{Z}_{t-1},\varvec{x}_t)\mid \mid p_\theta (\varvec{z}_t\mid \varvec{z}_{t-1}))], \end{aligned}$$
(4)

where \(\varvec{Z}_{t-1}=[\varvec{z}_1,..,\varvec{z}_{t-1}]\), and KL(.) measures the Kullback-Leibler divergence between two distributions. \(p_\theta \) \((\varvec{y}_t\mid \varvec{z}_t)\) is the likelihood of observing \(\varvec{y}_t\) given latent state \(\varvec{z}_t\). When \(t>1\), \(q_\phi (\varvec{z}_t\mid \varvec{Z}_{t-1},\varvec{x}_t)\) and \(p_\theta (\varvec{z}_t\mid \varvec{z}_{t-1})\) are the posterior and the prior of \(\varvec{z}_t\), respectively. For \(\varvec{z}_1\), its prior and posterior are represented as \(p_\theta (\varvec{z}_1)\) and \(q_\phi (\varvec{z}_1 \mid \varvec{x}_1)\). Here, we adopt a Gaussian variational approximation approach such that the distribution of the latent state follows a Gaussian distribution, where the mean and standard deviation are approximated by \(\varvec{x}_t\). In the following subsections, a detailed implementation of neural networks for approximating the prior, posterior, and likelihood will be explained.

3.2 Problem Definition

In the longitudinal study, each patient n is characterized by a sequence of observations: \(\varvec{M}^{(n)} = \{\varvec{M}^{(n)}_1,...,,\varvec{M}^{(n)}_t,..., \) \( \varvec{M}^{(n)}_T\}\), where each element \(\varvec{M}^{(n)}_t\) represents the medical notes containing \(N_t^{(n)}\) words collected at the tth hospital visit, and T is the total number of visits. Suppose each word of \(\varvec{M}^{(n)}_t\) is represented as a one-hot \(\mid V\mid \)-dimensional vector, where \(\mid V \mid \) is the size of the vocabulary V. Let us use \(\varvec{L}_Y =\{\varvec{L}_1,...,\varvec{L}_{v},...,\varvec{L}_{N_y}\}\) to denote the descriptions of all diagnosed disease risk labels, where \(N_y\) is the total number of diagnosed disease risks. Each word from the description of the vth risk (i.e. \(\varvec{L}_{v}\)) is represented as a one-hot vector. To distinctly and conveniently observe the prediction performance of our model on a coarse-grained disease risks, we formulate \(N_y\) disease risks to \(N_Y\) predictive objectives (i.e. acute disease risk, chronic disease risk, and mixed disease risk), where \(N_y > N_Y\). Therefore, we assume \(\varvec{y}^{(n)}=\{\varvec{y}^{(n)}_1,..., \varvec{y}^{(n)}_t,...,\varvec{y}^{(n)}_T\}\) indicate the presence of different types coarse-grained disease risks observed during multiple visits, where each \(N_Y\) dimensional vector \(\varvec{y}^{(n)}_t\) contains 1 or 0 values. Thus, in the problem of disease risk prediction, both \(\varvec{M}^{(n)}\) and \(\varvec{L}_Y\) will be used to predict the values of \(\varvec{y}^{(n)}\). For simplicity, the superscript (n) will be omitted in the following descriptions. All notations to be used in the following subsections are listed in Table 2.

Table 2 Notations and descriptions
Fig. 4
figure 4

The overview of our DSLAM model. It consists of three modules: the prior module generates the prior distribution of the current latent state from the previous time step, the posterior module introduces a label-dependent attention mechanism to approximate the posterior of latent states, and the likelihood module generates predictions of risks. Key variables are described as follows: \(\varvec{M}_{t}\) is the medical notes of patient n at time t; \(\varvec{L}_Y\) is descriptions of all target risks; \(\varvec{z}_{t-1}\) is the patient latent state at the \(t-1\)th hospital visit; \(\varvec{Z}_{t-1}\) contains all latent states from 1 to \(t-1\) for patient n; \({\varvec{\mu }}^p_{\varvec{z}_t}\) and \({\varvec{\sigma }}^p_{\varvec{z}_t}\) are the mean and standard deviation of the prior; \({\varvec{\mu }}^q_{\varvec{z}_t}\) and \({\varvec{\sigma }}^q_{\varvec{z}_t}\) are the mean and standard deviation of the posterior; \(\varvec{\hat{y}}_t\) refers to the predicted risks

3.3 The Overview of Our Deep State-space Model

The approach of sequential Bayesian updating involves utilizing the latest observation to update the current latent state by computing its posterior distribution based on the Bayes rule, incorporating the up-to-date prior from the state transition model and the likelihood. In our study, we have implemented this approach to estimate the health state of patients, denoted by \(\varvec{z}_t\), during each visit t. An overview of the DSLAM model can be seen in Fig. 4, which consists of a prior module that generates the prior distribution of the latent state \(\varvec{z}_t\), a posterior module that uses a label-dependent attention mechanism to encode information from \(\varvec{x}_t=(\varvec{M}_t,\varvec{L}_Y)\) to approximate the posterior distribution of \(\varvec{z}_t\), and a likelihood module that measures the probability of generating the observation \(\varvec{y}_t\). The parameters and latent states of DSLAM are inferred by the ELBO mentioned in Eq. 4.

3.4 The Prior Module

In the prior module, neural networks are used as the transition function to generate the prior distribution of the latent state \(\varvec{z}_t\) from the posterior of \(\varvec{z}_{t-1}\). The prior of \(\varvec{z}_t\) is represented as:

$$\begin{aligned} p_\theta (\varvec{z}_t \mid \varvec{z_{t-1}}) \sim \mathcal {N}(\varvec{\mu }^p_{\varvec{z}_{t}},\varvec{\sigma }^p_{\varvec{z}_{t}}), \end{aligned}$$
(5)

where \(\varvec{\mu }^p_{\varvec{z}_{t}}\) and \(\varvec{\sigma }^p_{\varvec{z}_{t}}\) are the mean and standard deviation of the prior, obtained from the gated recurrent unit (GRU) (Cho et al. 2014) and fully connected networks \(fc_1\) and \(fc_2\):

$$\begin{aligned} \varvec{\mu }^p_{\varvec{z}_{t}}=fc_1(GRU(\varvec{z}_{t-1})) \end{aligned}$$
(6)
$$\begin{aligned} \varvec{\sigma }^p_{\varvec{z}_{t}}=fc_2(GRU(\varvec{z}_{t-1})) \end{aligned}$$
(7)

3.5 The Posterior Module

In the posterior module, the variational approximation of the posterior \(q_\phi (\varvec{z}_{t}\mid \varvec{Z}_{t-1},\varvec{x}_{t})\) is:

$$\begin{aligned} q_\phi (\varvec{z}_t \mid \varvec{Z}_{t-1}, \varvec{x}_{t}) \sim \mathcal {N}(\varvec{\mu }^q_{\varvec{z}_t},\varvec{\sigma }^q_{\varvec{z}_t}), \end{aligned}$$
(8)

where \(\varvec{\mu }^q_{\varvec{z}_t}\) and \(\varvec{\sigma }^q_{\varvec{z}_t}\) denote the mean and standard deviation of the posterior obtained from the label-dependent attention network and fully connected networks using \(\varvec{Z}_{t-1}\) and \(\varvec{x}_{t}\).

The label-dependent attention network encodes \(\varvec{x}_t\) via 1) jointly embedding \(\varvec{M}_t\) and \(\mathcal {L}_Y\), and 2) integrating their embeddings via the cross-attention mechanism. For the embedding step, the text encoder, Clinical-BERT (Alsentzer et al., 2019), is used to transform medical notes and descriptions of labels into latent representations denoted as:

$$\begin{aligned} \varvec{E}^M_t = BERT(\varvec{M}_t) \end{aligned}$$
(9)
$$\begin{aligned} \varvec{E}^Y = AvgPool(BERT(\varvec{L}_Y)), \end{aligned}$$
(10)

For the integrating step, the similarity between \(\varvec{E}^M_t\) and \(\varvec{E}^Y\) are measured first, giving a scaled-dot similarity matrix \(\varvec{G}_t \in \mathbb {R}^{ N_t \times N_y}\):

$$\begin{aligned} \varvec{G}_t = \frac{({fc_3(\varvec{E}_t^M)})^T fc_4(\varvec{E}^Y)}{\sqrt{D}} \end{aligned}$$
(11)

where \(fc_3\) and \(fc_4\) are two fully connected networks, and \((.)^T\) is the matrix transpose operator. To assist the capture of information contained in consecutive words, a max-pooling layer together with the softmax activation is then adopted to generate an attentive embedding vector of medical notes \(\varvec{e}_t \in \mathbb {R}^{D}\):

$$\begin{aligned} \varvec{g}_t = SoftMax(MaxPool(\varvec{G}_t)) \end{aligned}$$
(12)
$$\begin{aligned} \varvec{e}_t = \sum _{n=1}^{N_t}g_{t,n}\varvec{E}_t^M[:,n] \end{aligned}$$
(13)

where \(\varvec{g}_t \in \mathbb {R}^{N_t}\) is the cross-attention score vector, \(\varvec{E}_t^M[:,n]\) is the nth column of \(\varvec{E}_t^M\), \(g_{t,n}\) is the nth element of \(\varvec{g}_t\). With \(\varvec{e}_t\) containing the integrated information from \(\varvec{M}_t\) and \(\mathcal {L}_Y\), the next step is to combine \(\varvec{e}_t\) with \(\varvec{Z}_{t-1}\) to generate \(\varvec{c}_t\) via:

$$\begin{aligned} \varvec{c}_t= fc_5(fg(\varvec{e}_t \oplus BiGRU([\varvec{Z}_{t-1}]))) \end{aligned}$$
(14)

where \(fc_5\) is a fully connected network, fg is the forget gate inspired from LSTM (Hochreiter & Schmidhuber 1997) to filter important latent states, BiGRU is the bidirectional GRU, and \(\oplus \) is the concatenation operator. By passing \(\varvec{c}_t\) into two parallel fully connected networks \(fc_6\) and \(fc_7\), the mean and standard deviation of the posterior of \(\varvec{z}_t\) can be then approximated. We can get a sampled state vector from:

$$\begin{aligned} \varvec{\hat{z}}_t = \varvec{\mu }^q_{\varvec{z}_t} + \epsilon \odot \varvec{\sigma }^q_{\varvec{z}_t}, \end{aligned}$$
(15)

where \(\epsilon \in \mathcal {N}\ (0,\textbf{I})\) is the random noise.

3.6 The Likelihood Module

In the likelihood module, a fully connected network \(fc_8\) is used to get the predictive values of risks \(\varvec{\hat{y}}_t\in \mathbb {R}^{N_Y}\):

$$\begin{aligned} \varvec{\hat{y}}_t = fc_8(\varvec{\hat{z}}_t). \end{aligned}$$
(16)

The log-likelihood can be then defined as:

$$\begin{aligned} \log p_\theta (\varvec{y}_t\mid \varvec{z}_t) =&-\frac{1}{N_Y} \sum _{j=1}^{N_Y} (y_{t,j} \cdot \log (\hat{y}_{t,j})) \nonumber \\&+ (1-y_{t,j})\cdot \log (1-\hat{y}_{t,j})) \end{aligned}$$
(17)

The training procedure to optimize DSLAM by maximizing the ELBO defined in Eq. 4 is shown in Algorithm 1.

Algorithm 1
figure a

Deep state-space modelling with label-dependent attention model.

4 Experiments

4.1 Experimental Dataset

4.1.1 MIMIC-III Dataset

We applied our model to the publicly accessible EHR dataset MIMIC-III (Johnson et al., 2016) to demonstrate its effectiveness. MIMIC-III is a critical care dataset comprising de-identified health data from over 40,000 patients admitted to ICUs at the Beth Israel Deaconess Medical Center between 2001 and 2012. It includes comprehensive clinical information such as vital signs, laboratory tests, medications, diagnoses, and procedures. The dataset has been used to evaluate various prediction models, such as RAIM, GAMENet, MNN, and AttDMM (Xu et al., 2018; Shang et al., 2019; Qiao et al., 2019; Oezyurt et al., 2021). In this study, to assess the impact of using data from multiple visits, we extracted a subset of MIMIC-III containing 9,759 unstructured medical notes from patients with two or more hospital visits, with an average number of visits of 2.61 days. The medical notes include part of the de-identified patients’ demographic information, such as age and gender, and the brief course of patients during hospital visits. The average length of stay for patients in ICUs is 4.9 days, and the mean gap length of two successive visits is 349.5 days. There are a total of 25 different risks; the types of disease risks can be grouped into three types of risk, i.e., acute, chronic, and mixed risks. To provide an intuitive and comprehensive evaluation result on three types of disease risks, we streamline 25 disease risk prediction objectives into three prediction objectives, i.e., acute, chronic, and mixed risks. We used the tool described in Harutyunyan et al. (2019) to extract medical notes and risk indicators from the MIMIC-III data and removed stop-words and non-alphabet characters from the notes. We used the same data splitting strategy as Harutyunyan et al. (2019) to obtain training and test datasets with a ratio of 4:1 for performance evaluation.

4.1.2 N2C2-2014 Dataset

The N2C2-2014 (Kumar et al., 2015) dataset is utilized to focus on the de-identification of longitudinal medical records. This dataset de-identified a set of 1,304 longitudinal medical records belonging to 296 patients, with an average count of hospital visits amounting to 4.41. The primary data type within the N2C2-2014 dataset consists of unstructured medical notes, which include a few de-identified patients’ demographic information and a hospital brief course. To enhance data quality, we implemented various data pre-processing techniques aimed at removing noise, stop-words, and other irrelevant elements. The target objective diseases include hyperlipidemia, hypertension, coronary artery disease, and diabetes. The entire dataset is partitioned into a 4:1 ratio for training and testing purposes.

4.2 Baseline Methods

To enhance the precision of the proposed method, we performed a comparative analysis of DSLAM against baseline methods from two primary categories. Class 1 methods solely depend on the data collected during the current hospital visit, while Class 2 methods combine information from previous hospital visits with the current visit.

The Class1 baseline methods are listed as follows:

  • XGBOOST and SVM: Two conventional machine learning methods XGBOOST and SVM are applied to construct predictive models using the word2vec representations of medical notes.

  • CAML: It adopts a cross-attention mechanism to encode information contained in medical notes and labels (Mullenbach et al., 2018).

  • CAML\(^-\): The cross-attention module in CAML is replace by the self-attention (Vaswani et al., 2017) .

  • CAML+\(\varvec{\mathcal {B}}\): The word2vec layer in CAML is replaced by Clinical-BERT.

  • CAML\(^-\)+\(\varvec{\mathcal {B}}\): The cross-attention module in CAML+\(\mathcal {B}\) is replace by the self-attention.

The Class2 baseline methods are listed as follows:

  • \(\varvec{\mathcal {G}}\)+CAML+\(\varvec{\mathcal {B}}\): GRU is added to CAML+\(\mathcal {B}\) for modelling longitudinal medical notes.

  • \(\varvec{\mathcal {G}}\)+CAML\(^-\)+\(\varvec{\mathcal {B}}\): Based on CAMLL\(^-\)+\(\mathcal {B}\), GRU is added for modelling longitudinal medical notes.

  • RETAIN: RETAIN (Choi et al., 2016) is a representative RNNS-based disease risk prediction model for longitudinal EHRs by employing a reverse time-aware attention module.

  • RETAIN+\(\varvec{\mathcal {B}}\): The encoder layer of RETAIN is replaced with Clinical-BERT.

  • DIPOLE: DIPOLE (Ma et al., 2017) upgrades the performance of longitudinal EHRs modelling by utilizing the Bi-directional RNNs with a dual time-aware attention mechanism to replace the reverse time-aware attention mechanism RETAIN.

  • DIPOLE+\(\varvec{\mathcal {B}}\): In the interest of fairness in comparisons, we substitute the encoder layer of DIPOLE with Clinical-BERT.

  • DSLAM: This is the deep state-space model with label-dependent attention proposed in this paper.

  • DSLAM\(^-\): The label-dependent attention module of our DSLAM model is replaced by self-attention.

The comparative models used the following hyperparameters: a learning rate of \(1e^{-5}\), a token length of 300, an embedding size of 768 for Clinical-BERT, and a latent state size of 384. The ADAM optimizer was selected for model training. Additionally, a dropout strategy with a dropout rate of 0.3 was employed. To ensure robustness, all comparative models were trained five times with a fixed set of five different seeds, and the average indicator performance was reported. The implementation of all models was done using PyTorch on an NVIDIA TESLA V100 GPU. The source code of our model is publicly availableFootnote 1.

4.3 Evaluation Metrics

The performance of risk prediction is represented by the precision, recall, F1 scores, and ROCAUC scores:

  • Precision: Precision is a metric that measures how many of the positive predictions a model makes are correct. It is defined as the ratio of true positives (TP) to the sum of true positives and false positives (FP). The formula for precision is:

    $$\begin{aligned} Precision = \frac{TP}{TP+FP} \end{aligned}$$
    (18)
  • Recall: Recall is a metric that measures how many of the actual positive instances in a dataset are correctly identified by a model. It is defined as the ratio of TP to the sum of true positives and false negatives (FN). The formula for the recall is:

    $$\begin{aligned} Recall = \frac{TP}{TP+FN} \end{aligned}$$
    (19)
  • F1 score: The F1 score is a metric that balances precision and recall by taking their harmonic mean. It is defined as:

    $$\begin{aligned} F1 = \frac{2*Precision*Recall}{Precision+Recall} \end{aligned}$$
    (20)
  • ROCAUC score: ROCAUC score is a popular metric used to evaluate the performance of binary classification models. ROC stands for Receiver Operating Characteristic, which is a plot of the true positive rate (TPR) against the false positive rate (FPR) at different classification thresholds. AUC stands for Area Under the Curve, which represents the area under the ROC curve. The ROCAUC score can range from 0 to 1, with a score of 0.5 indicating that the model performs no better than random chance, and a score of 1 indicating perfect classification performance.

In the case of multi-class risk prediction, we adopt micro and macro averaging methods for precision, recall, F1 score, and ROCAUC score calculation to obtain comprehensive evaluation results of our model:

$$\begin{aligned} \begin{aligned}&Micro\ Precision = \frac{\sum _i TP_i }{\sum _i TP_i + \sum _i FP_i},\\&Micro\ Recall = \frac{\sum _i TP_i }{\sum _i TP_i + \sum _i FN_i},\\&Micro\ F1 = \frac{2*Micro\ Precision*Micro\ Recall}{Micro\ Precision+Micro\ Recall},\\&Macro\ Precision = \sum _i Precision_i/L,\\&Macro\ Recall = \sum _i Recall_i/L,\\&Macro\ F1 = \frac{2*Macro\ Precision*Macro\ Recall}{Macro\ Precision+Macro\ Recall}, \end{aligned} \end{aligned}$$
(21)

where i indicates the class index and L is the number of classes.

The Micro ROCAUC score computes the overall ROCAUC score by taking into account all the TPR and FPR of all classes and their corresponding weights. It is useful when the classes are imbalanced, and it is important to evaluate the overall performance of the model.

For the Macro ROCAUC score, it computes the ROCAUC score for each class separately and then averages them. This score is useful when each class is equally important, and we want to evaluate the performance of the model for each class.

Table 3 We conducted a comparative analysis on the MIMIC-III dataset by assessing our DSLAM model against all baseline models and ablated models, utilizing measures of Micro and Macro Precision, Recall, F1 score, and ROCAUC score
Table 4 We conducted a comparative analysis on N2C2-2014 dataset by assessing our DSLAM model against all baseline models and ablated models, utilizing measures of Micro and Macro Precision, Recall, F1 score, and ROCAUC score
Fig. 5
figure 5

ROCAUC score for acute disease risk, chronic disease risk, and mixed disease risk on MIMIC-III dataset

4.4 The Performance of Risk Prediction

The performance of all comparative models in the risk prediction task was evaluated based on several metrics, including Precision, Recall, F1 score, and ROCAUC score, which are presented in Tables 3, 4, and Fig. 5. From the results, several observations can be made.

  • Effectiveness of utilizing longitudinal models: Longitudinal models have been shown to be effective in improving disease risk predictions by incorporating information from multiple visits over time. Among the different types of models, deep neural network-based models have demonstrated superior performance when compared to conventional machine learning models. However, Class 1 models such as XGBOOST and SVM tend to have higher recall values at the expense of precision, resulting in lower F1 and ROCAUC scores. On the other hand, Class 2 models that take into account longitudinal information generally achieve higher ROCAUC scores and a more balanced precision and recall. The use of multiple visits medical notes data can provide important contextual information that can help improve the accuracy of disease risk predictions. This observation suggests that the effectiveness of utilizing longitudinal models for disease risk prediction is evident, as they provide a more comprehensive understanding of the disease progression over time.

  • Comparing deep state-space models with other RNNs-based models: From Tables 3 and 4, our findings suggest that our deep state-space models DSLAM outperform RNNs-based models, such as RETAIN, DIPOLE, and CAML with GRU, in terms of disease risk prediction accuracy. The superior performance of DSLAM is evident from the higher F1 score and ROCAUC score obtained for both micro and macro versions on the MIMIC-III dataset and the N2C2-2014 dataset. Furthermore, we also compared our model DSLAM with DIPOLE+\(\mathcal {B}\) on the MIMIC-III dataset for different types of disease risks, respectively. As shown in Fig. 5, our model exhibits superior performance to DIPOLE+\(\mathcal {B}\), as evidenced by occupying a larger area under the ROC Curve and achieving a higher ROCAUC score for all acute, chronic, and mixed disease risks. This observation highlights the effectiveness of deep state-space models over conventional RNNs-based approaches for modelling longitudinal medical notes. Regarding the N2C2-2014 dataset, it should be noted that it only includes records for chronic disease risks. As a result, there would not be any separate analyses conducted for different types of disease risks within this dataset.

  • Effectiveness of the label-dependent attention: In Tables 3 and 4, we can observe that the label-dependent attention module, such as CAML, CAML+\(\mathcal {B}\), and \(\mathcal {G}\)+CAML+\(\mathcal {B}\), outperformed their counterparts with the module replaced (CAML\(^-\), CAML\(^-\)+\(\mathcal {B}\)), and \(\mathcal {G}\)+CAML\(^-\)+\(\mathcal {B}\) in terms of F1 and ROCAUC values, respectively. This observation suggests that the label-dependent attention module can help improve a model’s ability to identify risks and increase its overall predictive power. Furthermore, our evaluation of the DSLAM model equipped with the label-dependent attention module showed higher ROCAUC values and comparable F1 values when compared to DSLAM\(^-\) without the module on both the MIMIC-III dataset and the N2C2-2014 dataset. Our findings suggest that the label-dependent attention module can be a valuable tool in improving a model’s overall evaluation performance and, thus, its ability to detect risks accurately.

Fig. 6
figure 6

This case study demonstrates the interpretable results produced by DSLAM. The medical notes collected from four hospital visits of a randomly selected patient and the associated observed risks are presented in the figure. The top 30% words, ranked by their attention scores generated from the label-dependent attention, are highlighted in yellow. Common words that appeared in multiple visits are highlighted in red

To conclude, our deep state-space-based model DSLAM outperformed all baseline models for the disease risk prediction task. The label-dependent attention module was found to be a valuable tool in improving a model’s ability to identify risks and increase its overall predictive power, with higher recall values for risk prediction tasks in the healthcare domain.

5 Discussion

Our DSLAM model offers a high degree of interpretability, which is crucial for clinical decision-making. To illustrate this interpretability, we randomly selected a patient from the evaluation dataset who had multiple hospital visits and used our model to analyze their medical records. The label-dependent cross-attention module learned in the risk prediction task generates attention scores that indicate the importance of different words. In Fig. 6, we highlight the top 30% words ranked by their attention scores in fragments of medical notes using yellow colour. In addition to medical notes, Fig. 6 also shows the observed risks at each visit for the selected patient, which can help investigate whether the top words used in the risk prediction are clinically relevant to the observed risks. For instance, in the first hospital visit, the patient had the disease risks of acute and chronic. The specific diagnosed diseases are “respiratory disease", “respiratory failure", “coronary atherosclerosis", and “diabetes mellitus". Regarding the medical notes, our model assigns higher attention scores to the patient’s symptoms such as “pulmonary", “fever", “cough", and “sputum", which are primary symptoms of respiratory diseases (Farzan, 1990). Furthermore, medications such as “insulin", “metoprolol", and “simvastatin" are also highlighted, of which “insulin" is a widely adopted treatment for diabetes (Wilcox, 2005), “metoprolol" is a standard \(\beta \)-blocker for treating coronary atherosclerosis (Joseph et al., 2019), and “simvastatin" is commonly used for treating heart diseases (Niazi et al., 2020). All these diseases are recorded in the patient’s EHRs. Similar observations can be made for other hospital visits. By examining the disease risks recorded in multiple hospital visits, we find that this patient has persistent problems, including “respiratory disease" and “diabetes mellitus". We highlight the common words from the yellow highlighted fields that appeared in more than one visit using red color. “Pulmonary" and “insulin", which are associated with persistent diseases, are found to be important across almost all hospital visits. Based on these findings, we can conclude that our model can produce clinically interpretable results.

Fig. 7
figure 7

The workflow illustrates how to use our model DSLAM for disease risk prediction

6 Limitation and Future Work

In the paper, the theoretical framework is based on the Machine Learning for Decision Support Theory, which leverages machine learning algorithms to analyze historical EHRs for predicting patient outcomes and aiding in decision-making processes. We propose a novel algorithm applied to EHRs aimed at advancing existing theories in this domain. The study’s conceptual framework revolves around developing a CDSS that utilizes a deep state-space model to process longitudinal EHR data effectively. This involves:

1. Capturing time-varying information from the longitudinal EHR data.

2.Extract clinically useful information in unstructured medical notes for predictive model construction.

3. Addressing the absence of specialized longitudinal CDSS capable of effectively tracing patients’ latent health states across multiple hospital visits for predicting various types of disease risks.

Furthermore, we also emphasize the model’s capability to offer interpretable insights into the factors driving disease risk predictions at a fine-grained level by leveraging a label-dependent attention network. However, a limitation is acknowledged where the latent states captured by the model are challenging to interpret directly and connect with real-world medical domain concepts, particularly for supporting clinical decisions at a broader, more coarse-grained level.

For future work, we plan to focus on the development of a deep and interpretable state transition model that aims to unveil the underlying medical attributes associated with the latent states. This would enhance the comprehensibility of disease risk prediction results and their practical application in the medical domain. Additionally, as EHRs encompass abundant multimodal data-not just medical notes but also time series lab test results, patient demographic information, Chest X-rays (CXR) images, etc., we intend to integrate more types of EHR data. This integration will consider the joint effects of diagnoses made by doctors from different departments or perspectives, leading to a more comprehensive evaluation of patient’s health states and generating more accurate clinical decision-making.

7 Conclusion

The COVID-19 pandemic has brought about significant challenges for the healthcare industry, including an increase in demand for medical resources, the need for rapid and accurate diagnosis, and the development of effective treatment strategies. To improve clinical decision-making processes and patient outcomes, there is growing interest in leveraging emerging technologies, such as the CDSS.

This paper proposes a deep state-space-based label-dependent attention model for clinical decision-making. Our proposed CDSS model uses longitudinal unstructured medical notes and disease risk label descriptions from EHRs to predict disease risks. It combines the representation power of deep neural networks with the structured representations of probabilistic models to model unstructured medical notes generated from multiple hospital visits effectively. The model also uses a label-dependent attention mechanism to improve predictive performance and generate interpretable results.

Compared to traditional CDSS models, our proposed model offers several advantages for clinical decision-making processes. Firstly, it can analyze large volumes of patient data and trace latent health state trajectories across multiple hospital visits. Secondly, the system can provide trustworthy decision support to clinicians, enabling them to make more informed decisions with explainable evidence, ultimately improving patient outcomes.

We evaluated the effectiveness of our proposed CDSS model using two real-world EHR datasets, MIMIC-III and N2C2-2014, and the results demonstrate its strong predictive power. The model also shows the ability to identify medical words or phrases with large attention scores, which are clinically meaningful and provide interpretable results. This feature is particularly important in the post-COVID-19 era, as clinicians must make more complex decisions based on a greater volume of patient data.

In summary, the proposed deep state-space-based label-dependent attention model presents a promising solution to the challenges facing the healthcare industry in the post-COVID-19 era. By leveraging the power of emerging technologies such as AI, this model can analyze large volumes of patient data and provide trustworthy decision support to clinicians. The model’s ability to effectively model unstructured medical notes generated from multiple hospital visits and provide interpretable results enables clinicians to make more informed decisions, ultimately improving patient outcomes. The use of CDSS in clinical decision-making is likely to become more prevalent in the coming years, and our proposed model offers insight into the future development of CDSS that can address the healthcare industry’s most pressing challenges.