1 Introduction

The exponential growth of electronic health records (EHRs) has drawn significant interest from the machine learning and data mining communities. With a wealth of information from multiple sources and formats, these EHRs offer a vast dataset for developing evidence-based clinical decision-making tools. For example, the My Health Record SystemFootnote 1 stores over 41.9 million EHRs on over 6.4 million patients. Despite criticisms by some that EHRs are often vendor-specific and sometimes limited in scope [1], their sheer size and diversity make them a valuable resource for deep learning technology. This is especially true in the intensive care Unit (ICU), where critical decisions are driven by forecasts of patient outcomes based on pathological and physiological values [1]. As such, deep learning technology has been broadly applied to advance research in ICU decision support, particularly mortality estimation [2] and phenotype analysis [3]. In general, clinical decisions in ICUs are time-critical and highly dependent on physiological data. However, making accurate and rapid decisions in these fast-changing environments without enough real-time information on the severity of a patient‘s illness can be very challenging for clinicians. As a result, numerous scoring systems have been developed and progressively refined to assist with rapid patient assessment. Examples include the sequential organ failure assessment score (SOFA) [4], APACHE II [5], and SAPS II [6]. The produced scores reflect the current clinical condition of a patient based on a set of basic physiological indicators.

Fig. 1
figure 1

Two different SOFA trajectories for two ICU patients. According to their SOFA score, Patient ID: 80030 (red) was initially in critical condition but gradually improved and was eventually discharged. In contrast, the condition of Patient ID: 45767 (blue) deteriorated, and they ultimately passed away

1.1 Motivation

These scoring systems serve as a simple calculation of a patient’s vital signs at various times but do not provide real-time information for critical decision-making in the ICU. The longer the time between updated information, the less opportunity there is to respond to a deteriorating patient, which is why a continuous monitoring of key indicators such as heart rate is essential. According to Bouch and Thompson [7], an instantaneous scoring system that covers a wider range of indicators is urgently needed to support better decision-making in the ICU. To demonstrate the potential impact of such a system, we present an example of two ICU patients, whose SOFA scores are charted over time to show changes in their condition (as illustrated in Figure 1). If linked to medical interventions, these high-frequency SOFA measures could provide insight into the effectiveness of each treatment. This example highlights both the potential and the need for continuous prediction of illness severity scores as a new tool for patient monitoring.

Over the years, recurrent neural networks (RNNs) and their variants [8,9,10,11,12], have been explored as deep models for handling time series data, and many have achieved significant results with clinical prediction tasks like mortality risk. Given a sequence of multivariate features, the typical outlook of mortality risk with the prediction techniques of today is about 24 hours -barely enough time for clinicians to intervene. More importantly, short-term mortality risk predictions may have ethical implications. For example, the mortality risk to a patient over the next week may be, say, 80% but the prediction for the next 24 hours may only be 5%. If faced with an unaffordable treatment, many patients and caregivers may choose not to continue with clinical services unbeknownst to the consequences of that decision beyond tomorrow. Thus, continuously predicting the medical trajectory not only offers more detailed information at a finer time granularity but could also help caregivers concentrate on planning effective treatments with better consideration of an illness’s true severity.

Despite their solid results to date, learning models have some deficiencies. For instance, they normally treat all multivariate time-series variables as an entire input stream without considering the correlations between the physiological variables. However, human organs are highly correlated to each other and to a patient‘s deterioration. When one or two organs start to malfunction, others tend to follow over a short period. For example, systolic blood pressure is positively correlated with diastolic blood pressure and pulse pressure, whereas diastolic blood pressure is inversely correlated with pulse pressure. Also, a deterioration in the fraction of inspired oxygen can asynchronously affect cerebral blood flow. Thus, exploiting correlations between medical time-series variables can further improve classification performance for ICU prediction tasks. There are few research works that have investigated the aforementioned feature correlations, which are very common and important in the healthcare domain.

1.2 Solution

In search of a better approach that does not suffer from these shortcomings, we designed a novel learning framework to monitor the severity of illness in ICU patients in real-time. Within the framework, the medical time-series variables most frequently monitored in ICUs are assigned to different organ systems according to their physiological functions. For example, heart rates belong to the cardiovascular system, while FiO2 is common to the heart and the lungs and is therefore treated as a shared feature across the cardio and pulmonary systems. Given that ICU multivariate time series data is often characterized as sparse, irregular, and noisy, the features of each organ system are learned by controlling memory updates to a standard LSTM cell as a way to yield better performance. The framework has a multi-task architecture to capture feature correlations that exist both within one organ system and across multiple systems. Once each individual system has been modelled, the memories of all the separate LSTMs are fused with an attention mechanism - hence, the name multi-task memory-fused RNN (MM-RNN). Most importantly, using a multi-task architecture to learn each organ system concurrently means that MM-RNN is able to visualize the evidence from all the single-task LSTMs upon which the model it based its predictions. The interpretability this provides to clinicians is priceless.

1.3 Contribution

In summary, the contributions of this research include:

  • a novel multi-task delta memory-fused RNN model (MM-RNN). The multi-task framework concurrently learns behaviours in individual human organ systems, while the delta memory captures asynchronous feature correlations within and across tasks over a longer temporal range.

  • a novel approach to identifying correlated temporal features in time-series data. The memories of all the separate LSTMs are fused to capture important correlations between features over time by integrating both task-specific and cross-task interactions in time series EHRs.

  • a method of providing practical and visual explanations for illness severity predictions based on outputs from the memory fusion network. The combination of an accurate prediction coupled with an interpretable explanation provides strong evidence for accurately forecasting problematic organ systems and a patient‘s likely future condition.

  • comparative experiments with a real-world dataset, which show that MM-RNN delivers state-of-the-art performance with respect to continuous illness severity predictions.

The remainder of this paper is organized as follows. Related works are reviewed in Section 2. The proposed method is discussed in detail in Section 3. The experiment evaluations and related discussions are provided in Section 4. Section 5 presents our conclusions, and future work is discussed.

2 Related work

As governments commit to national EHR systems, the use of EHRs has been widely adopted in both hospital and health-care settings [13, 14]. EHR systems provide a wealth of information on a patient‘s medical history, complications, and medicine use. Deep learning encompasses a wide variety of techniques. In this section, we review pertinent studies on both mining multivariate time series from EHRs and classifying medical data with deep learning techniques. EHRs in ICUs mainly consist of multivariate time series data that span a range of medical variables. In practice, time-series deep models have been widely applied to many analytics tasks, including mortality risk estimation, phenotyping analysis, and disease modelling.

Despite the diversity in applications, most models are evaluated with a standard set of metrics, such as AUC (risk event prediction [15, 16], diagnosis classification [17,18,19], disease risk identification [9, 10]), accuracy, and precision, recall, and F1 scores.

As the amount and availability of detailed clinical health records has exploded in recent years, there is great opportunity for revisiting and refining the definitions and boundaries of illnesses and diagnoses. Diseases are traditionally defined by a set of manual clinical descriptions, whereas computational phenotyping seeks to derive richer, data-driven descriptions of illnesses [17, 20]. With machine learning and data mining techniques, it is possible to discover natural clusters of clinical descriptors that lend themselves to a more fine-grained description of a disease. For example, detailed phenotypes represent a big step towards the eventual goal of personalised precision healthcare. Computational phenotypes can be seen as an archetypal clinical application of deep learning principles. Grounded in the philosophy of “letting the data speak”, the idea is to discover the latent relationships and hierarchical concepts in the raw data without any human supervision or prior bias. With the growing availability of huge amounts of clinical data, many researchers are using deep learning techniques to explore computational phenotypes.

2.1 Mortality prediction

As one type of clinical prediction tasks in an ICU [21], mortality prediction aims to identify a patient‘s risk of death over the next 24-48 hours from their available medical data. This prediction task can be formulated as a binary classification problem, where the outcome of the classifier is literally life or death for a patient. Further, mortality prediction tasks can be divided into more specific subtasks based on other factors, such as whether the patient is being treated in hospital, short versus long term outlooks, and so on. Many studies have explored estimating mortality risk by modelling medical time-series data with neural networks and logistic regression models [22, 23]; in many cases, with phenomenal performance that outstrips traditional models. In Che et al. [11] used prior knowledge to regularize parameters in a standard deep model and then train a collection of neural networks for clinical prediction tasks with a scalable procedure. Aczon et al. [10] implemented a standard recurrent neural network and generated temporally-dynamic predictions for in-ICU mortality by feeding the network with multiple data sources, such as physiological observations, laboratory results, administered drugs, and interventions. The trained model was deployed in a pediatric ICU and achieved better performance than the conventional models. To address data quality issues, [9] proposed a novel variant of a gated recurrent unit, called GRU-D, that incorporates patterns of masked or missing time interval data into an end-to-end framework.

The advances deep learning models have made in mortality prediction are undoubtedly outstanding, but there is more to be done. A prediction for the next 24-48 hours does not provide critical care experts with enough information to make timely treatment decisions in a dynamically changing environment. Additionally, the black-box nature of deep neural networks results in quantitative predictions. How and why the model arrived at those predictions is not explained, which undermines trust in the results. When a patient‘s life or death is at stake, practitioners must have full confidence that the reasoning behind the prediction is sound.

2.2 Multivariate time-series deep modeling

Multivariate time-series data contains more than one time-dependent variable. Many sophisticated models have been developed to tackle the tasks and challenges. Wang et al. [17] investigated disease-specific features embedded in descriptive data to analyze the phenotypes of patients. To improve the model performance, Zhou et al. [24] and Nie et al. [25] leveraged consistencies between multiple modalities and task-specific features. These models are adequate for their purpose, but performance is not optimal because they neglect the temporal information embedded in the data. Multivariate time-series features properly reflect the severity of a patient‘s illness. Hence, several attempts have been made to combine temporal features with RNNs. Pham et al. [26] combined the temporal information embedded in time-series data with an RNN for estimating the severity of an ICU patient‘s illness. To improve performance, Che et al. [9], Lipton et al. [3], and Chen et al. [16] considered time intervals while integrating RNNs to investigate irregular EHR time-series data. The downside is that all these approaches process temporal features in a heuristic manner using a monotonically decreasing function. Therefore, they tend to suffer from over- and/or under-parameterization.

2.3 Multi-task learnning

The aim of multi-task learning is to leverage useful information from similar learning tasks to improve model performance. It is most widely used in computer vision [27, 28]. Abdulnabi et al. [29] proposed a multi-task convolution neural network (CNN) for attribute prediction, in which the tasks from multiple CNNs are used to learn the features of image segments. To identify the most useful information, a common layer determines the spatial correlations between image features. Instead of a multi-task CNN, Chen et al. [30] proposed a multi-task RNN framework to exploit the correlations between sub-tasks by jointly learning EEG signals for intention recognition tasks. To analyze multivariate time-series data, Chen et al. [16] proposed a multi-task learning method that captures the temporal correlations embedded in the data. This is a key feature of a multi-task analytic framework - capturing the intrinsic relationships between different learning tasks and using the significance of the common features to improve performance [31]. Harutyunyan et al. [32] extended the success of heterogeneous models for time-series learning. However, similar to the previous set of deep models, performance suffers from over-parameterization.

2.4 Sequence modeling

Many non-time series models that incorporate recurrent architecture have recently been tasked to sequence modelling problems. Gehring et al. [33] incorporated an attention mechanism into a CNN, achieving competitive results with a sequence mining task. Extending the basic encoder-decoder architecture along with attention mechanisms, Rocktäschel et al. [34] and Verga et al. [35] proposed a multi-head self-attention mechanism, which has demonstrated satisfactory performance in NLP tasks. Sharma et al. [36] applied a location-based softmax function to the hidden states of the LSTM layers to identify the more valuable elements in sequential inputs for action recognition. The success of attention mechanisms at identifying and selecting valuable features motivated our decision to incorporate attention into our memory fusion framework.

3 Methodology

This section outlines the MM-RNN framework as seen in Figure 2, designed to predict the severity of an ICU patient‘s illness. We begin by explaining the pre-processing procedure, which includes cohort selection, data extraction, data cleaning, and feature extraction. Details of the architecture follow, along with a description of how the model is able to both learn the distinctive features in different organ systems from time-series EHRs and exploit the temporal correlations between those features across all organ systems. The section concludes with a description of the memory fusion method, which identifies and selects the features for the embeddings to produce more descriptive representations, in turn improving the final results.

3.1 Data pre-processing

The pre-processing procedure is best explained through an example with real-world data. For this purpose, we selected MIMIC-III V1.3 [37], which contains over 60,000 de-identified adult patients from the Beth Israel Deaconess Medical Center from 2001 to 2012. Following the cohort selection procedure [21], we excluded patients under the age of 15 and removed all ICU stays of less than 24 hours. Patient’s of an ICU stay was then treated as an independent data observation. The resulting root cohort consisted of 45,321 records. Figure 3 illustrates the age distribution in the selected cohort.

Fig. 2
figure 2

The proposed framework

Fig. 3
figure 3

Age distribution in the selected cohort

Table 1 The 41 features extracted from MIMIC-III
Table 2 List of the organ systems and assigned features

From these data, we extracted 41 physiological variables and assigned each to six different organ systems according to professional suggestions. More details are provided in Tables 1 and 2. It is worth noting that some variables pertain to multiple organ systems. For example, fraction of inspiration O2 (FiO2, index = 16), is a commonly shared physiological feature belonging to the respiratory, cardiovascular, and coagulation systems. Some generic features, such as body temperature (index = 38), are shared by all organ systems. In addition to the above 41 physiological variables, we also extracted body weight and age as input features. We also followed the pre-processing methods in [21]. to improve data quality, e.g., data cleaning and imputation. ICU data are usually low in quality due to missing values and irregular sampling routines. To address this issue, we applied a forward-fill imputation strategy as follows. Consider missing values in the d-th variable at t time-step:

  • If there is at least one valid observation at time \(t^\prime \), where \(t^\prime < t\), then \(x_{t,d}:=x_{t^\prime ,d}\).

  • If there are no previous observations, then the missing value is replaced by the median value of all measurements.

The inspiration behind this strategy is that measurements are recorded at intervals proportional to the rate at which the values are believed or observed to change [16]. Before feeding the pre-processed data into the model, each time series event is converted into a matrix with a variable number of rows, as shown in Figure 4. D denotes the number of features, while n represents the number of ICU stay records (i.e., ICU stays). \(t_i\) denotes the maximum stay time for the i-th data sample, \(i=1,\cdots , n\). Thus, the data samples can be represented by \(\varvec{X}=\{\varvec{x}_1, \varvec{x}_2, \cdots , \varvec{x}_n\}\), \(\varvec{x}_i\in \mathbb {R}^{t_{i} \times D}\).

Fig. 4
figure 4

The structure of data organization

3.2 Model description

The recurrent neural networks (RNNs) [38] have sufficient ability to process arbitrary sequential inputs by recursively applying a transaction function to a hidden vector \(\textbf{h}_{t}\). The activation function f of the current hidden state \(\textbf{h}_{t}\) at t time step can be computed as followings:

$$\begin{aligned} \textbf{h}_{t} = \left\{ \begin{array}{ll} 0 &{} t=0\\ \int \left( \textbf{h}_{t-1}, x_{t}\right) &{} otherwise \end{array} \right. \end{aligned}$$

where \(x_{t}\) is the current state input, and \(h_{t-1}\) is the previous hidden state. However, RNNs with a transition function of this form have difficulty learning long-range dependencies because the components of the gradient vector can vanish or explode exponentially over a long sequence. We propose integrating gating functions into an LSTM network [39] to address this vanishing gradient problem by incorporating gating functions. At each time step, an LSTM maintains a hidden vector h and a memory vector m to control state updates and outputs [40]. The LSTM unit at each time step t is defined as a collection of vectors in \( R^d\). Each unit includes i, f, o, c and h, which are the input gate, forget gate, output gate, memory cell, and hidden state respectively. The forget gate controls the amount of memory in each unit to be “forgotten”, the input gate rules the update of each unit, and the output gate checks the exposure of each internal memory state. The LSTM transition equations are defined as follows:

$$\begin{aligned}&i_t = \sigma ( W_{xi} x_t + W_{hi} h_{t-1} + W_{ci} c_{t-1}+ b_i ),\nonumber \\&f_t = \sigma ( W_{xf} x_t + W_{hf} h_{t-1} + W_{cf} c_{t-1}+ b_f ),\nonumber \\&o_t = \sigma ( W_{xo} x_t + W_{ho} h_{t-1} + W_{co} c_{t-1}+ b_o ),\nonumber \\&c_t = f_tc_{t-1} + i_t \tanh ( W_{xc}x_t + W_{hc}h_{t-1}+ b_c ),\\&h_t = o_t \tanh ( c_t),\nonumber \end{aligned}$$

where \(x_{t}\) is the input at time t \((t\in \{1,2,\cdots ,T\})\), \(\varvec{W}\)s (e.g. \(W_{xi}\) corresponds to the weight matrix for the hidden-input gate) are the weights, \(\varvec{b}\)s (e.g. \(b_i\) is the bias term for the input gate) are bias terms. \(\sigma (\cdot )\) denotes the logistic sigmoid function. As the basic LSTMs are ill-suited to process irregularly sampled data, we have opted for a phased LSTM model [41] to extend the standard cell by adding a new time gate \(k_t\):

$$\begin{aligned}&k_{t} = \left\{ \begin{array}{ll} \frac{2\phi _{t}}{r_{on}},\quad &{} if \phi _{t}< \frac{1}{2}r_{on},\\ 2 - \frac{2\phi _{t}}{r_{on}},\quad &{} if \frac{1}{2}r_{on}<\phi _{t} < r_{on},\\ \alpha \phi _{t},\quad &{} otherwise, \end{array} \right. \end{aligned}$$

where \(\phi _{t} = \frac{(t-s)\, mod\, {T}}{T}\) and \(r_{on}\) is the ratio of the open period to the total period of the time gate, \(k_t\). Another two parameters s and T represent the open period of the time gate \(k_t\) and the length of input sequence, respectively. Based on (3), the update equations of the state output (\(\widetilde{c}_{t}\)) and hidden output (\(\widetilde{h}_{t}\)) at time t in a cell of phased LSTM can be rewritten as:

$$\begin{aligned}&\widetilde{c}_{t} = k_{t}{c}_{t} + (1-k_t)c_{t-1},\nonumber \\&\widetilde{h}_{t} = k_{t}{h}_{t} +(1-k_{t})h_{t-1} \end{aligned}$$

In our multi-task architecture, the medical time series data for each organ system is learned by one phased LSTM, and the learning procedures for each LSTM are independent. The complete update equations for j-th task can be written as:

$$\begin{aligned}&i_t^{(j)} = \sigma ( W_{xi}^{(j)} x_t^{(j)} + W_{hi}^{(j)} h_{t-1}^{(j)} + W_{ci}^{(j)} c_{t-1}^{(j)}+ b_i^{(j)} ),\nonumber \\&f_t^{(j)} = \sigma ( W_{xf}^{(j)} x_t^{(j)} + W_{hf}^{(j)} h_{t-1}^{(j)} + W_{cf}^{(j)} c_{t-1}^{(j)}+ b_f^{(j)} ),\nonumber \\&o_t^{(j)} = \sigma ( W_{xo}^{(j)} x_t^{(j)} + W_{ho}^{(j)} h_{t-1}^{(j)} + W_{co}^{(j)} c_{t-1}^{(j)}+ b_o^{(j)} ),\\&c_{t}^{(j)} = k_{t}^{(j)}{c}_{t}^{(j)} + (1-k_t^{(j)})c_{t-1}^{(j)},\nonumber \\&h_{t}^{(j)} = k_{t}^{(j)}{h}_{t}^{(j)} +(1-k_{t}^{(j)})h_{t-1}^{(j)}\nonumber \end{aligned}$$

Unlike the phased LSTM in [41], we further concatenate the hidden states of all the tasks as an intermediate hidden output of MM-RNNs (learned tasks).

$$\begin{aligned} \varvec{h}_{T} = \{h_{T}^{1}, h_{T}^{2}, \cdots , h_{T}^{J}\} \end{aligned}$$

where J is the number of total tasks, and T is the length of the data sequence. This intermediate hidden state contains information about all the organ systems, which benefits the feature correlations within each organ system.

To exploit feature correlations that exist across multiple systems (cross-tasks), we need to fuse all the memories of the multi-task LSTMs in (5) and capture the cross-tasks interactions. Inspired by [42, 43], we firstly concatenate all the memories of the LSTM with respect to each task at time t into a cross-task vector:

$$\begin{aligned} \varvec{c}_{t} = \{c_{t}^{1}, c_{t}^{2}, \cdots , c_{t}^{J}\} \end{aligned}$$

As some feature correlations between multiple organ systems are asynchronous (e.g., a deterioration in the fraction of inspired oxygen can asynchronously affect cerebral blood flow), the memory fusion mechanism contains a parameter \(\Delta \), that the length of the time span to be fused. The greater the value of \(\Delta \) is, the more temporal information that will be considered. The new concatenation of memories across tasks is represented by

$$\begin{aligned} \varvec{c}_{t\pm \Delta } = \{c_{t-\Delta }^{1}, \cdots , c_{t-\Delta }^{J}, \cdots , c_{t}^{1}, \cdots , c_{t}^{J}, \cdots , c_{t+\Delta }^{1}, \cdots , c_{t+\Delta }^{J}\} \end{aligned}$$

It is worth noting that the proposed memory concatenation has more flexibility than the one in DMAN [42], which only consider two consecutive time-steps. Our memory fusion, \(\varvec{c}_{t\pm \Delta }\), has a variable length and thus can exploit asynchronous feature correlation across tasks over a longer temporal range. Consequently, \(\varvec{c}_{t\pm \Delta }\) will be fed into a neural network \(\mathcal {N}_a : \mathbb {R}^{\Delta * d} \mapsto \mathbb {R}^{\Delta * d}\) (\(d=\sum \limits _{j=1}^J d_j\), \(d_j\) is the dimension of j-th task’s memory) to highlight high-impact coefficients over the new concatenation in (8). Subsequently, the attended memories of the LSTMs can be calculated as:

$$\begin{aligned} \hat{\varvec{c}}_{t\pm \Delta } = \mathcal {N}_a(\varvec{c}_{t\pm \Delta }) \odot \varvec{c}_{t\pm \Delta }, \end{aligned}$$

where \(\odot \) denotes the element-wise product.

In this way, \(\mathcal {N}_a\) capture synchronous and asynchronous interactions between organ systems over a longer period. Using the same strategy of exploiting correlations in fused memories in [42], a neural network \(\mathcal {N}_u : \mathbb {R}^{\Delta * d} \mapsto \mathbb {R}^{d_{mem}}\) is used to transform \(\hat{\varvec{c}}_{t\pm \Delta }\) into one unified memory across the tasks at time t: \(\hat{u}_t = \mathcal {N}_u(\varvec{c}_{t\pm \Delta })\). Then, the output of the multi-task gated memory is updated with the following rule:

$$\begin{aligned} \varvec{u}_{t} = \varvec{\eta }_{1} \odot \varvec{u}_{t-1} + \varvec{\eta }_{2} \odot tanh(\hat{\varvec{u}}_{t}), \end{aligned}$$

where \(\eta _{1}\) and \(\eta _{2}\) control how much of the current state to retain and how much to forget in the update at time t given \(\hat{u}_{t}\).

To exploit the feature correlations possible in an end-to-end multi-task deep framework, we derive the final output of MM-RNNs as a concatenation of \(\varvec{h}_{T}\) in (6) and \(\varvec{u}_T\) in (10) as follows:

$$\begin{aligned} \varvec{h}^{*} = \{\varvec{h}_{T},\varvec{u}_{T}\} \end{aligned}$$

As a matter of course, the attention mechanism highlights influential coefficients. These can be used to visualize the corresponding coefficients to help explain the prediction results. Figure 5 shows some synchronous and asynchronous correlations between organ systems.

Fig. 5
figure 5

Two visualizations of coefficients of the fused memory network. The X-axis represents the length of stay in hours while each row in the y-axis denotes the values of coefficents with respect to each organ system

The extensive use of attention mechanisms brings several advantages. The first is the ability to visualize decomposition of organ malfunction into a series of steps instead of a single forward pass. Second, the attention over time series data provides additional insights whenever the model fails to estimate the outcome of an ICU patient. The advantages of visulization are two-fold:

  • able to present result in a simple and easy-to-digest graphic forms,

  • making prediction result recognizable, summarizing huge amount of data into insightful signs and communicating findings in more appealing way.

3.3 Complexity

In this section, we discuss the asymptotic complexity of our framework and how it offers a higher degree of parallelism than frameworks that only use single-task LSTMs. Assume that all hidden dimensions are d, and T is the length of the input sequence. Since MM-RNN is implemented in parallel, we have only focused on the most complex task, i.e., feature correlation with attention in a multitask framework. According to [44], the time complexity of the LSTM is \(O(T*d^2)\). In addition, the complexity of a dot-product-based attention mechanism is \(O(T^2d)\) [45]. In terms of computational complexity, attention layers are faster than recurrent layers when the sequence length T is smaller than the representation dimensionality d, since the attention considering only considering nearby time steps \(\pm \Delta \). Thus the overall complexity is \(O(Td^2+T^2d)\). Thus, the overall complexity of our framework can be simplified as \(O(d^2)\). Therefore, the complexity of our model is identical to a basic LSTM model.

4 Experiments

4.1 Data description and experiment design

We conducted extensive experiments to evaluate the performance of MM-RNNs using the publicly available benchmark dataset MIMIC III.Footnote 2 Over several experiments, we compared our model to the state-of-the-art algorithms and several baselines. We also investigated the influence of using a memory fusion method with a multi-task framework.

4.2 Details of dataset and settings

All methods were tested on version 1.3 of the MIMIC-III dataset [37]. Only patients aged 16 or older who stayed in the ICU for more than 24 hours were included in the sample. For patients with multiple ICU admissions, we treated each as an independent data sample. We randomly selected 80% of the 45,321 ICU stays as the training data; the other 20% were used for testing and validation. To select the best parameters, we employed a 10-fold cross-validation schema, and all experiments were repeated ten times.

All neural networks were implemented with Tensorflow and Keras frameworks and trained on two Nvidia 1080 Ti GPUs from scratch in a fully-supervised manner. To minimize cross-entropy loss, we employed stochastic gradient descent with an ADAM optimizer [46]. The network parameters were optimized with a learning rate of \(10^{-4}\). The keep probability of the dropout operation was set to 0.5. The number of neurons in the MM-RNNs’ input and output layers was fixed at 41, and \(\lambda \) was \(4 \times 10^{-4}\). All implementation code is available from GithubFootnote 3

4.3 Evaluation metrics

We designed a three-stage evaluation strategy to evaluate the proposed MM-RNNs.

Table 3 SOFA score ranges and their corresponding labels

In the first set of experiments, we evaluated the ability of algorithms to identify the status of ICU patients, regardless of their disease category. This can be understood as testing the algorithm’s ability to predict mortality risk. All predicted cases were regarded as positive-cases and the true illness severity level was regarded as (true) positive-cases. We used measures that focus on positive predictions, namely Precision, Recall/sensitivity, and F1-score. Precision measures how correct the positive predictions of an algorithm are. Recall is the model’s ability to identify positive cases. F1-scores reflect the harmonic mean between precision and recall.

In the second set of experiments, we examined the influence of the different aspects of the framework, i.e., its multi-task architecture, and the memory fusion mechanism. To do this, we compared different integrations of the proposed algorithms with several baseline models. The evaluation metrics were precision, recall, F1 score, and AUC. The criteria for individual classes was calculated with a one-versus-all method.

Lastly, we investigated the interpretability of the results by randomly selecting several cases and manually analyzing the visualizations against their corresponding medical reports.

4.4 Comparison methods

We chose the following state-of-the-art methods and baselines for comparison.

  1. 1.

    GRU-ATT: Nguyen et al. [47], which is a GRU-based attention network for estimating mortality risk.

  2. 2.

    HMT-RNN: Hharutyunyan et al. [32], an RNN for predicting in-hospital mortality.

  3. 3.

    pRNN: Aczonet al. [10], an RNN-based model for mortality prediction based on EHRs.

  4. 4.

    RNN: A standard version of single-task RNN with the hyperparameters suggested by [47]

  5. 5.

    MT-RNN: A multi-task RNN without the MFN with the parameter settings suggested in [16].

  6. 6.

    RNN-MFN: : A single-task phased RNN with an MFN.

Beyond these state-of-the-art approaches, we also compared MM-RNNs with some representative classification baselines: support vector machines (SVMs), decision tree (DT), linear discriminant analysis (LDA), random forest (RF), and XGboost. All the parameters were fine-tuned using a grid-search scheme and the best results with the optimal parameters are reported.

We chose to use SOFA scores as the method for monitoring changes in the condition of patients. However, any scoring system could be used. The estimated mortality risk was based on the highest SOFA score during the patient‘s ICU stay, as shown in Table 3. The class settings were derived from categories developed by critical care experts [16].

4.5 Evaluation

Table 4 reports the classification results of all the methods in terms of accuracy. MM-RNN consistently out-performed the state-of-the-art methods and baseline models. Further, the MT-RNN and RNN-MFN baselines of our model also achieved competitive results. These results imply that a multi-task structure and the temporal correlation information do improve the performance of the prediction model. MM-RNN’s performance far surpassed GRU-ATT [47], HMT-RNN [32], and pRNN [10]. In addition, MM-RNN outperformed the baseline methods by nearly 28% in terms of accuracy. The differences in these results highlight the value of exploiting the temporal correlations between different organ systems through the memory fusion network.

Table 4 Overall performance comparison (Accuracy)
Table 5 Classification accuracy with raw ICU data, data with missing values, and imputed data

Next, we evaluated the effectiveness of the data imputation strategy and the impact of modelling missingness on both raw data and processed data. A robust model should exploit informative missingness properly and avoid introducing non-existent relations between missingness and predictions. Table 5 reports the classification results of all the methods. It is clear that the imputation strategy improves data quality and, in turn, the performance of all the models. MM-RNN benefitted from using missingness, resulting in better performance as the correlations increased. In other words, MM-RNN yielded the best results, demonstrating that it handles missing values in multivariate time-series data effectively.

Deep neural networks usually need to be well-trained to cope with large datasets. Therefore, we tested the performance of a selection of models when trained with 30%, 60%, and 90% of the dataset. Each training set was constructed via randomly sampling; however, we did ensure the class distributions were the same across all three sets. The models compared were MM-RNN and all its baselines, plus GRU-ATT and HMT-RNN, which were the second and third best overall performers in classification on a one-hour time-window dataset. The results, shown in Figure 6, reveal that all methods achieved better performance with a greater number of training samples. Even at a very low training proportion (30% training proportion), MM-RNN achieved 68% accuracy. This implies that our model is robust and is less dependent on the training data size. However, the improvements in prediction accuracy and ROCs for the baseline methods was limited compared to the deep learning methods. MM-RNN delivered the best performance over all the set sizes, and the performance gap between MM-RNN and the baselines grew as more data became available.

Fig. 6
figure 6

Accuracy vs data proportion: prediction accuracy with different sizes of training set. X-axis = sub-sampled dataset size; y-axis = accuracy

Receiver operating curves (ROC) demonstrate the discriminative capability of a classifier by plotting the true positive rate against the false positive rate against a range of threshold values. Figure 7 shows that the ROC for all categories was very far from a 45 degree diagonal and close to the upper left corner of the ROC space. The area under each of these six ROCs (AUC) is shown in Table 6. The average value was around 96.7%, which represents excellent performance. Interestingly, from these results, we also observed that MM-RNN was very sensitive to Class 1 and Class 6. These are two important classes in ICU assessments. Class 6 represents critical conditions; Class 1 represents the lowest threshold for admission to the ICU.

Fig. 7
figure 7

A ROC analysis for the MM-RNN, and the ROC analysis shows the discriminative capability of the classifier. X-axis = true positive; y-axis = false positive

Table 6 Comparison between the baselines and the proposed multi-task framework

Window size is another important parameter that impacts classification performance. Therefore, we evaluated all methods and baselines with respect to 1-hour, 3-hour, and 6-hour time windows. The results appear in Figures 7 and 8. These results are expressed in terms of precision and recall, and clearly show the superior performance of MM-RNN. Notably, however, performance dropped slightly as the time window grew longer. This may be because medical conditions can change dramatically, for better or worse, over a longer period of time. Also, MM-RNN outperformed the baseline (Table 7).

Fig. 8
figure 8

Variations in Precision with different \(\Delta \)s. The x-axis charts \(\Delta \), and Precision is measured on the y-axis

Table 7 Performance comparisons (Precision) with respect to different time-window lengths
Table 8 Performance comparisons (Recall) with respect to different time-window lengths

To investigate the impact of \(\Delta \), which controls the length of the time span for the fused memories, we fixed the length of the time window to 1-hour and tested the variation in performance with respect to different \(\Delta \)s. The curve in Figure 8 shows that performance did indeed vary with different values of \(\Delta \). Precision peaked at 87.42% when \(\Delta \) = 4. In other words, choosing an appropriate length of time is advisable to leverage asynchronous correlations. We see the same trends when testing different time windows (Table 8).

To examine the influence of the multi-task architecture and the memory fusion mechanic, we built three baseline models and tested their classification performance. The criteria for individual classes were calculated with a one-versus-all method. The results are reported in Table 6. We observed that MM-RNN consistently outperformed the other baselines in all measurements. In comparing MT-RNN, RNN-MFN, and a simple RNN, the simple RNN could not compete. This indicates that the temporal correlations and the multi-task structure do improve the predictions. The performance improvement with MM-RNN was consistent compared to MT-RNN and RNN-MFN in all evaluation measures. We also noticed that capturing important correlations between features over time by integrating both task-specific and cross-task interactions could significantly enhance classification performance, which suggests that including temporal information can bring many benefits to illness severity recognition tasks. Overall MM-RNN achieved 89.98% accuracy.

Fig. 9
figure 9

Fused memories from all single-task LSTMs. The length of stay in hours is plotted on the x-axis; the y-axis denotes the coefficients in fused memory w.r.t. each organ system. The heatmaps indicate the severity of the organ system’s condition. The two panels show the different journeys of Patient A (ID: 80030) (top) and Patient B (ID: 45767) (bottom)

4.6 Explanation for illness severity prediction

To demonstrate the interpretability of MM-RNN, we conducted a small study on two patients. The results are shown in Figure 9. The output scores from the MFN show the attention variations for the two patients over time. A darker colour means a worse condition for the system. In the first 48 hours of Patient A’s stay, MM-RNN paid the most attention to malfunctions in the kidney and cardiovascular systems. This corresponds to a note by the doctor in the radiology report made 72 hours later (noteevents.row_id=1200211): “evaluate for obstruction causing acute renal failure”. With Patient B, MM-RNN paid the most attention to the respiratory system in the early days of admission. After 72 hours, the doctor diagnosed “respiratory failure” in the medical report (noteevents.row_id=105617). Attention to the respiratory system dramatically increased post 80 hours, which is confirmed in the medical report as “Large right pleural effusion is increasing” after 144 hours. This simple example shows MM-RNN’s potential as a highly accurate early warning system to clinicians. Beyond a prediction, the system provides an explanation and an indication of which organ systems are causing problems well before traditional forms of monitoring signal an issue. In this way, caregivers are given more insight into which patients are at-risk. This comfort is not something most conventional deep models offer.

In summary, the MM-RNN not only improves prediction performance, it also offers clinicians visual explanations of the prediction results. A correct clinical decision may mean the difference between life and death. Therefore, practitioners must be able to understand how a prediction made by any diagnostic tool was reached if they are to trust the results.

5 Conclusion

MM-RNN outperformed the other methods on the benchmark MIMIC dataset, but this does not mean the model is necessarily robust. In future, we will look to extending the model to consider data uncertainty and undertake a deeper examination of the issue of missing values. A starting point for missing value imputation may be to incorporate a mask of missing data to indicate the placement of imputation values or missing values. The model could then not only capture the long-term temporal dependencies of time-series observations but also utilize the missing patterns to further improve the prediction results.Footnote 4