Abstract
Predicting the severity of an illness is crucial in intensive care units (ICUs) if a patient‘s life is to be saved. The existing prediction methods often fail to provide sufficient evidence for time-critical decisions required in dynamic and changing ICU environments. In this research, a new method called MM-RNN (multi-task memory-fused recurrent neural network) was developed to predict the severity of illnesses in intensive care units (ICUs). MM-RNN aims to address this issue by not only predicting illness severity but also generating an evidence-based explanation of how the prediction was made. The architecture of MM-RNN consists of task-specific phased LSTMs and a delta memory network that captures asynchronous feature correlations within and between multiple organ systems. The multi-task nature of MM-RNN allows it to provide an evidence-based explanation of its predictions, along with illness severity scores and a heatmap of the patient’s changing condition. The results of comparison with state-of-the-art methods on real-world clinical data show that MM-RNN delivers more accurate predictions of illness severity with the added benefit of providing evidence-based justifications.
Similar content being viewed by others
Avoid common mistakes on your manuscript.
1 Introduction
The exponential growth of electronic health records (EHRs) has drawn significant interest from the machine learning and data mining communities. With a wealth of information from multiple sources and formats, these EHRs offer a vast dataset for developing evidence-based clinical decision-making tools. For example, the My Health Record SystemFootnote 1 stores over 41.9 million EHRs on over 6.4 million patients. Despite criticisms by some that EHRs are often vendor-specific and sometimes limited in scope [1], their sheer size and diversity make them a valuable resource for deep learning technology. This is especially true in the intensive care Unit (ICU), where critical decisions are driven by forecasts of patient outcomes based on pathological and physiological values [1]. As such, deep learning technology has been broadly applied to advance research in ICU decision support, particularly mortality estimation [2] and phenotype analysis [3]. In general, clinical decisions in ICUs are time-critical and highly dependent on physiological data. However, making accurate and rapid decisions in these fast-changing environments without enough real-time information on the severity of a patient‘s illness can be very challenging for clinicians. As a result, numerous scoring systems have been developed and progressively refined to assist with rapid patient assessment. Examples include the sequential organ failure assessment score (SOFA) [4], APACHE II [5], and SAPS II [6]. The produced scores reflect the current clinical condition of a patient based on a set of basic physiological indicators.
1.1 Motivation
These scoring systems serve as a simple calculation of a patient’s vital signs at various times but do not provide real-time information for critical decision-making in the ICU. The longer the time between updated information, the less opportunity there is to respond to a deteriorating patient, which is why a continuous monitoring of key indicators such as heart rate is essential. According to Bouch and Thompson [7], an instantaneous scoring system that covers a wider range of indicators is urgently needed to support better decision-making in the ICU. To demonstrate the potential impact of such a system, we present an example of two ICU patients, whose SOFA scores are charted over time to show changes in their condition (as illustrated in Figure 1). If linked to medical interventions, these high-frequency SOFA measures could provide insight into the effectiveness of each treatment. This example highlights both the potential and the need for continuous prediction of illness severity scores as a new tool for patient monitoring.
Over the years, recurrent neural networks (RNNs) and their variants [8,9,10,11,12], have been explored as deep models for handling time series data, and many have achieved significant results with clinical prediction tasks like mortality risk. Given a sequence of multivariate features, the typical outlook of mortality risk with the prediction techniques of today is about 24 hours -barely enough time for clinicians to intervene. More importantly, short-term mortality risk predictions may have ethical implications. For example, the mortality risk to a patient over the next week may be, say, 80% but the prediction for the next 24 hours may only be 5%. If faced with an unaffordable treatment, many patients and caregivers may choose not to continue with clinical services unbeknownst to the consequences of that decision beyond tomorrow. Thus, continuously predicting the medical trajectory not only offers more detailed information at a finer time granularity but could also help caregivers concentrate on planning effective treatments with better consideration of an illness’s true severity.
Despite their solid results to date, learning models have some deficiencies. For instance, they normally treat all multivariate time-series variables as an entire input stream without considering the correlations between the physiological variables. However, human organs are highly correlated to each other and to a patient‘s deterioration. When one or two organs start to malfunction, others tend to follow over a short period. For example, systolic blood pressure is positively correlated with diastolic blood pressure and pulse pressure, whereas diastolic blood pressure is inversely correlated with pulse pressure. Also, a deterioration in the fraction of inspired oxygen can asynchronously affect cerebral blood flow. Thus, exploiting correlations between medical time-series variables can further improve classification performance for ICU prediction tasks. There are few research works that have investigated the aforementioned feature correlations, which are very common and important in the healthcare domain.
1.2 Solution
In search of a better approach that does not suffer from these shortcomings, we designed a novel learning framework to monitor the severity of illness in ICU patients in real-time. Within the framework, the medical time-series variables most frequently monitored in ICUs are assigned to different organ systems according to their physiological functions. For example, heart rates belong to the cardiovascular system, while FiO2 is common to the heart and the lungs and is therefore treated as a shared feature across the cardio and pulmonary systems. Given that ICU multivariate time series data is often characterized as sparse, irregular, and noisy, the features of each organ system are learned by controlling memory updates to a standard LSTM cell as a way to yield better performance. The framework has a multi-task architecture to capture feature correlations that exist both within one organ system and across multiple systems. Once each individual system has been modelled, the memories of all the separate LSTMs are fused with an attention mechanism - hence, the name multi-task memory-fused RNN (MM-RNN). Most importantly, using a multi-task architecture to learn each organ system concurrently means that MM-RNN is able to visualize the evidence from all the single-task LSTMs upon which the model it based its predictions. The interpretability this provides to clinicians is priceless.
1.3 Contribution
In summary, the contributions of this research include:
-
a novel multi-task delta memory-fused RNN model (MM-RNN). The multi-task framework concurrently learns behaviours in individual human organ systems, while the delta memory captures asynchronous feature correlations within and across tasks over a longer temporal range.
-
a novel approach to identifying correlated temporal features in time-series data. The memories of all the separate LSTMs are fused to capture important correlations between features over time by integrating both task-specific and cross-task interactions in time series EHRs.
-
a method of providing practical and visual explanations for illness severity predictions based on outputs from the memory fusion network. The combination of an accurate prediction coupled with an interpretable explanation provides strong evidence for accurately forecasting problematic organ systems and a patient‘s likely future condition.
-
comparative experiments with a real-world dataset, which show that MM-RNN delivers state-of-the-art performance with respect to continuous illness severity predictions.
The remainder of this paper is organized as follows. Related works are reviewed in Section 2. The proposed method is discussed in detail in Section 3. The experiment evaluations and related discussions are provided in Section 4. Section 5 presents our conclusions, and future work is discussed.
2 Related work
As governments commit to national EHR systems, the use of EHRs has been widely adopted in both hospital and health-care settings [13, 14]. EHR systems provide a wealth of information on a patient‘s medical history, complications, and medicine use. Deep learning encompasses a wide variety of techniques. In this section, we review pertinent studies on both mining multivariate time series from EHRs and classifying medical data with deep learning techniques. EHRs in ICUs mainly consist of multivariate time series data that span a range of medical variables. In practice, time-series deep models have been widely applied to many analytics tasks, including mortality risk estimation, phenotyping analysis, and disease modelling.
Despite the diversity in applications, most models are evaluated with a standard set of metrics, such as AUC (risk event prediction [15, 16], diagnosis classification [17,18,19], disease risk identification [9, 10]), accuracy, and precision, recall, and F1 scores.
As the amount and availability of detailed clinical health records has exploded in recent years, there is great opportunity for revisiting and refining the definitions and boundaries of illnesses and diagnoses. Diseases are traditionally defined by a set of manual clinical descriptions, whereas computational phenotyping seeks to derive richer, data-driven descriptions of illnesses [17, 20]. With machine learning and data mining techniques, it is possible to discover natural clusters of clinical descriptors that lend themselves to a more fine-grained description of a disease. For example, detailed phenotypes represent a big step towards the eventual goal of personalised precision healthcare. Computational phenotypes can be seen as an archetypal clinical application of deep learning principles. Grounded in the philosophy of “letting the data speak”, the idea is to discover the latent relationships and hierarchical concepts in the raw data without any human supervision or prior bias. With the growing availability of huge amounts of clinical data, many researchers are using deep learning techniques to explore computational phenotypes.
2.1 Mortality prediction
As one type of clinical prediction tasks in an ICU [21], mortality prediction aims to identify a patient‘s risk of death over the next 24-48 hours from their available medical data. This prediction task can be formulated as a binary classification problem, where the outcome of the classifier is literally life or death for a patient. Further, mortality prediction tasks can be divided into more specific subtasks based on other factors, such as whether the patient is being treated in hospital, short versus long term outlooks, and so on. Many studies have explored estimating mortality risk by modelling medical time-series data with neural networks and logistic regression models [22, 23]; in many cases, with phenomenal performance that outstrips traditional models. In Che et al. [11] used prior knowledge to regularize parameters in a standard deep model and then train a collection of neural networks for clinical prediction tasks with a scalable procedure. Aczon et al. [10] implemented a standard recurrent neural network and generated temporally-dynamic predictions for in-ICU mortality by feeding the network with multiple data sources, such as physiological observations, laboratory results, administered drugs, and interventions. The trained model was deployed in a pediatric ICU and achieved better performance than the conventional models. To address data quality issues, [9] proposed a novel variant of a gated recurrent unit, called GRU-D, that incorporates patterns of masked or missing time interval data into an end-to-end framework.
The advances deep learning models have made in mortality prediction are undoubtedly outstanding, but there is more to be done. A prediction for the next 24-48 hours does not provide critical care experts with enough information to make timely treatment decisions in a dynamically changing environment. Additionally, the black-box nature of deep neural networks results in quantitative predictions. How and why the model arrived at those predictions is not explained, which undermines trust in the results. When a patient‘s life or death is at stake, practitioners must have full confidence that the reasoning behind the prediction is sound.
2.2 Multivariate time-series deep modeling
Multivariate time-series data contains more than one time-dependent variable. Many sophisticated models have been developed to tackle the tasks and challenges. Wang et al. [17] investigated disease-specific features embedded in descriptive data to analyze the phenotypes of patients. To improve the model performance, Zhou et al. [24] and Nie et al. [25] leveraged consistencies between multiple modalities and task-specific features. These models are adequate for their purpose, but performance is not optimal because they neglect the temporal information embedded in the data. Multivariate time-series features properly reflect the severity of a patient‘s illness. Hence, several attempts have been made to combine temporal features with RNNs. Pham et al. [26] combined the temporal information embedded in time-series data with an RNN for estimating the severity of an ICU patient‘s illness. To improve performance, Che et al. [9], Lipton et al. [3], and Chen et al. [16] considered time intervals while integrating RNNs to investigate irregular EHR time-series data. The downside is that all these approaches process temporal features in a heuristic manner using a monotonically decreasing function. Therefore, they tend to suffer from over- and/or under-parameterization.
2.3 Multi-task learnning
The aim of multi-task learning is to leverage useful information from similar learning tasks to improve model performance. It is most widely used in computer vision [27, 28]. Abdulnabi et al. [29] proposed a multi-task convolution neural network (CNN) for attribute prediction, in which the tasks from multiple CNNs are used to learn the features of image segments. To identify the most useful information, a common layer determines the spatial correlations between image features. Instead of a multi-task CNN, Chen et al. [30] proposed a multi-task RNN framework to exploit the correlations between sub-tasks by jointly learning EEG signals for intention recognition tasks. To analyze multivariate time-series data, Chen et al. [16] proposed a multi-task learning method that captures the temporal correlations embedded in the data. This is a key feature of a multi-task analytic framework - capturing the intrinsic relationships between different learning tasks and using the significance of the common features to improve performance [31]. Harutyunyan et al. [32] extended the success of heterogeneous models for time-series learning. However, similar to the previous set of deep models, performance suffers from over-parameterization.
2.4 Sequence modeling
Many non-time series models that incorporate recurrent architecture have recently been tasked to sequence modelling problems. Gehring et al. [33] incorporated an attention mechanism into a CNN, achieving competitive results with a sequence mining task. Extending the basic encoder-decoder architecture along with attention mechanisms, Rocktäschel et al. [34] and Verga et al. [35] proposed a multi-head self-attention mechanism, which has demonstrated satisfactory performance in NLP tasks. Sharma et al. [36] applied a location-based softmax function to the hidden states of the LSTM layers to identify the more valuable elements in sequential inputs for action recognition. The success of attention mechanisms at identifying and selecting valuable features motivated our decision to incorporate attention into our memory fusion framework.
3 Methodology
This section outlines the MM-RNN framework as seen in Figure 2, designed to predict the severity of an ICU patient‘s illness. We begin by explaining the pre-processing procedure, which includes cohort selection, data extraction, data cleaning, and feature extraction. Details of the architecture follow, along with a description of how the model is able to both learn the distinctive features in different organ systems from time-series EHRs and exploit the temporal correlations between those features across all organ systems. The section concludes with a description of the memory fusion method, which identifies and selects the features for the embeddings to produce more descriptive representations, in turn improving the final results.
3.1 Data pre-processing
The pre-processing procedure is best explained through an example with real-world data. For this purpose, we selected MIMIC-III V1.3 [37], which contains over 60,000 de-identified adult patients from the Beth Israel Deaconess Medical Center from 2001 to 2012. Following the cohort selection procedure [21], we excluded patients under the age of 15 and removed all ICU stays of less than 24 hours. Patient’s of an ICU stay was then treated as an independent data observation. The resulting root cohort consisted of 45,321 records. Figure 3 illustrates the age distribution in the selected cohort.
From these data, we extracted 41 physiological variables and assigned each to six different organ systems according to professional suggestions. More details are provided in Tables 1 and 2. It is worth noting that some variables pertain to multiple organ systems. For example, fraction of inspiration O2 (FiO2, index = 16), is a commonly shared physiological feature belonging to the respiratory, cardiovascular, and coagulation systems. Some generic features, such as body temperature (index = 38), are shared by all organ systems. In addition to the above 41 physiological variables, we also extracted body weight and age as input features. We also followed the pre-processing methods in [21]. to improve data quality, e.g., data cleaning and imputation. ICU data are usually low in quality due to missing values and irregular sampling routines. To address this issue, we applied a forward-fill imputation strategy as follows. Consider missing values in the d-th variable at t time-step:
-
If there is at least one valid observation at time \(t^\prime \), where \(t^\prime < t\), then \(x_{t,d}:=x_{t^\prime ,d}\).
-
If there are no previous observations, then the missing value is replaced by the median value of all measurements.
The inspiration behind this strategy is that measurements are recorded at intervals proportional to the rate at which the values are believed or observed to change [16]. Before feeding the pre-processed data into the model, each time series event is converted into a matrix with a variable number of rows, as shown in Figure 4. D denotes the number of features, while n represents the number of ICU stay records (i.e., ICU stays). \(t_i\) denotes the maximum stay time for the i-th data sample, \(i=1,\cdots , n\). Thus, the data samples can be represented by \(\varvec{X}=\{\varvec{x}_1, \varvec{x}_2, \cdots , \varvec{x}_n\}\), \(\varvec{x}_i\in \mathbb {R}^{t_{i} \times D}\).
3.2 Model description
The recurrent neural networks (RNNs) [38] have sufficient ability to process arbitrary sequential inputs by recursively applying a transaction function to a hidden vector \(\textbf{h}_{t}\). The activation function f of the current hidden state \(\textbf{h}_{t}\) at t time step can be computed as followings:
where \(x_{t}\) is the current state input, and \(h_{t-1}\) is the previous hidden state. However, RNNs with a transition function of this form have difficulty learning long-range dependencies because the components of the gradient vector can vanish or explode exponentially over a long sequence. We propose integrating gating functions into an LSTM network [39] to address this vanishing gradient problem by incorporating gating functions. At each time step, an LSTM maintains a hidden vector h and a memory vector m to control state updates and outputs [40]. The LSTM unit at each time step t is defined as a collection of vectors in \( R^d\). Each unit includes i, f, o, c and h, which are the input gate, forget gate, output gate, memory cell, and hidden state respectively. The forget gate controls the amount of memory in each unit to be “forgotten”, the input gate rules the update of each unit, and the output gate checks the exposure of each internal memory state. The LSTM transition equations are defined as follows:
where \(x_{t}\) is the input at time t \((t\in \{1,2,\cdots ,T\})\), \(\varvec{W}\)s (e.g. \(W_{xi}\) corresponds to the weight matrix for the hidden-input gate) are the weights, \(\varvec{b}\)s (e.g. \(b_i\) is the bias term for the input gate) are bias terms. \(\sigma (\cdot )\) denotes the logistic sigmoid function. As the basic LSTMs are ill-suited to process irregularly sampled data, we have opted for a phased LSTM model [41] to extend the standard cell by adding a new time gate \(k_t\):
where \(\phi _{t} = \frac{(t-s)\, mod\, {T}}{T}\) and \(r_{on}\) is the ratio of the open period to the total period of the time gate, \(k_t\). Another two parameters s and T represent the open period of the time gate \(k_t\) and the length of input sequence, respectively. Based on (3), the update equations of the state output (\(\widetilde{c}_{t}\)) and hidden output (\(\widetilde{h}_{t}\)) at time t in a cell of phased LSTM can be rewritten as:
In our multi-task architecture, the medical time series data for each organ system is learned by one phased LSTM, and the learning procedures for each LSTM are independent. The complete update equations for j-th task can be written as:
Unlike the phased LSTM in [41], we further concatenate the hidden states of all the tasks as an intermediate hidden output of MM-RNNs (learned tasks).
where J is the number of total tasks, and T is the length of the data sequence. This intermediate hidden state contains information about all the organ systems, which benefits the feature correlations within each organ system.
To exploit feature correlations that exist across multiple systems (cross-tasks), we need to fuse all the memories of the multi-task LSTMs in (5) and capture the cross-tasks interactions. Inspired by [42, 43], we firstly concatenate all the memories of the LSTM with respect to each task at time t into a cross-task vector:
As some feature correlations between multiple organ systems are asynchronous (e.g., a deterioration in the fraction of inspired oxygen can asynchronously affect cerebral blood flow), the memory fusion mechanism contains a parameter \(\Delta \), that the length of the time span to be fused. The greater the value of \(\Delta \) is, the more temporal information that will be considered. The new concatenation of memories across tasks is represented by
It is worth noting that the proposed memory concatenation has more flexibility than the one in DMAN [42], which only consider two consecutive time-steps. Our memory fusion, \(\varvec{c}_{t\pm \Delta }\), has a variable length and thus can exploit asynchronous feature correlation across tasks over a longer temporal range. Consequently, \(\varvec{c}_{t\pm \Delta }\) will be fed into a neural network \(\mathcal {N}_a : \mathbb {R}^{\Delta * d} \mapsto \mathbb {R}^{\Delta * d}\) (\(d=\sum \limits _{j=1}^J d_j\), \(d_j\) is the dimension of j-th task’s memory) to highlight high-impact coefficients over the new concatenation in (8). Subsequently, the attended memories of the LSTMs can be calculated as:
where \(\odot \) denotes the element-wise product.
In this way, \(\mathcal {N}_a\) capture synchronous and asynchronous interactions between organ systems over a longer period. Using the same strategy of exploiting correlations in fused memories in [42], a neural network \(\mathcal {N}_u : \mathbb {R}^{\Delta * d} \mapsto \mathbb {R}^{d_{mem}}\) is used to transform \(\hat{\varvec{c}}_{t\pm \Delta }\) into one unified memory across the tasks at time t: \(\hat{u}_t = \mathcal {N}_u(\varvec{c}_{t\pm \Delta })\). Then, the output of the multi-task gated memory is updated with the following rule:
where \(\eta _{1}\) and \(\eta _{2}\) control how much of the current state to retain and how much to forget in the update at time t given \(\hat{u}_{t}\).
To exploit the feature correlations possible in an end-to-end multi-task deep framework, we derive the final output of MM-RNNs as a concatenation of \(\varvec{h}_{T}\) in (6) and \(\varvec{u}_T\) in (10) as follows:
As a matter of course, the attention mechanism highlights influential coefficients. These can be used to visualize the corresponding coefficients to help explain the prediction results. Figure 5 shows some synchronous and asynchronous correlations between organ systems.
The extensive use of attention mechanisms brings several advantages. The first is the ability to visualize decomposition of organ malfunction into a series of steps instead of a single forward pass. Second, the attention over time series data provides additional insights whenever the model fails to estimate the outcome of an ICU patient. The advantages of visulization are two-fold:
-
able to present result in a simple and easy-to-digest graphic forms,
-
making prediction result recognizable, summarizing huge amount of data into insightful signs and communicating findings in more appealing way.
3.3 Complexity
In this section, we discuss the asymptotic complexity of our framework and how it offers a higher degree of parallelism than frameworks that only use single-task LSTMs. Assume that all hidden dimensions are d, and T is the length of the input sequence. Since MM-RNN is implemented in parallel, we have only focused on the most complex task, i.e., feature correlation with attention in a multitask framework. According to [44], the time complexity of the LSTM is \(O(T*d^2)\). In addition, the complexity of a dot-product-based attention mechanism is \(O(T^2d)\) [45]. In terms of computational complexity, attention layers are faster than recurrent layers when the sequence length T is smaller than the representation dimensionality d, since the attention considering only considering nearby time steps \(\pm \Delta \). Thus the overall complexity is \(O(Td^2+T^2d)\). Thus, the overall complexity of our framework can be simplified as \(O(d^2)\). Therefore, the complexity of our model is identical to a basic LSTM model.
4 Experiments
4.1 Data description and experiment design
We conducted extensive experiments to evaluate the performance of MM-RNNs using the publicly available benchmark dataset MIMIC III.Footnote 2 Over several experiments, we compared our model to the state-of-the-art algorithms and several baselines. We also investigated the influence of using a memory fusion method with a multi-task framework.
4.2 Details of dataset and settings
All methods were tested on version 1.3 of the MIMIC-III dataset [37]. Only patients aged 16 or older who stayed in the ICU for more than 24 hours were included in the sample. For patients with multiple ICU admissions, we treated each as an independent data sample. We randomly selected 80% of the 45,321 ICU stays as the training data; the other 20% were used for testing and validation. To select the best parameters, we employed a 10-fold cross-validation schema, and all experiments were repeated ten times.
All neural networks were implemented with Tensorflow and Keras frameworks and trained on two Nvidia 1080 Ti GPUs from scratch in a fully-supervised manner. To minimize cross-entropy loss, we employed stochastic gradient descent with an ADAM optimizer [46]. The network parameters were optimized with a learning rate of \(10^{-4}\). The keep probability of the dropout operation was set to 0.5. The number of neurons in the MM-RNNs’ input and output layers was fixed at 41, and \(\lambda \) was \(4 \times 10^{-4}\). All implementation code is available from GithubFootnote 3
4.3 Evaluation metrics
We designed a three-stage evaluation strategy to evaluate the proposed MM-RNNs.
In the first set of experiments, we evaluated the ability of algorithms to identify the status of ICU patients, regardless of their disease category. This can be understood as testing the algorithm’s ability to predict mortality risk. All predicted cases were regarded as positive-cases and the true illness severity level was regarded as (true) positive-cases. We used measures that focus on positive predictions, namely Precision, Recall/sensitivity, and F1-score. Precision measures how correct the positive predictions of an algorithm are. Recall is the model’s ability to identify positive cases. F1-scores reflect the harmonic mean between precision and recall.
In the second set of experiments, we examined the influence of the different aspects of the framework, i.e., its multi-task architecture, and the memory fusion mechanism. To do this, we compared different integrations of the proposed algorithms with several baseline models. The evaluation metrics were precision, recall, F1 score, and AUC. The criteria for individual classes was calculated with a one-versus-all method.
Lastly, we investigated the interpretability of the results by randomly selecting several cases and manually analyzing the visualizations against their corresponding medical reports.
4.4 Comparison methods
We chose the following state-of-the-art methods and baselines for comparison.
-
1.
GRU-ATT: Nguyen et al. [47], which is a GRU-based attention network for estimating mortality risk.
-
2.
HMT-RNN: Hharutyunyan et al. [32], an RNN for predicting in-hospital mortality.
-
3.
pRNN: Aczonet al. [10], an RNN-based model for mortality prediction based on EHRs.
-
4.
RNN: A standard version of single-task RNN with the hyperparameters suggested by [47]
-
5.
MT-RNN: A multi-task RNN without the MFN with the parameter settings suggested in [16].
-
6.
RNN-MFN: : A single-task phased RNN with an MFN.
Beyond these state-of-the-art approaches, we also compared MM-RNNs with some representative classification baselines: support vector machines (SVMs), decision tree (DT), linear discriminant analysis (LDA), random forest (RF), and XGboost. All the parameters were fine-tuned using a grid-search scheme and the best results with the optimal parameters are reported.
We chose to use SOFA scores as the method for monitoring changes in the condition of patients. However, any scoring system could be used. The estimated mortality risk was based on the highest SOFA score during the patient‘s ICU stay, as shown in Table 3. The class settings were derived from categories developed by critical care experts [16].
4.5 Evaluation
Table 4 reports the classification results of all the methods in terms of accuracy. MM-RNN consistently out-performed the state-of-the-art methods and baseline models. Further, the MT-RNN and RNN-MFN baselines of our model also achieved competitive results. These results imply that a multi-task structure and the temporal correlation information do improve the performance of the prediction model. MM-RNN’s performance far surpassed GRU-ATT [47], HMT-RNN [32], and pRNN [10]. In addition, MM-RNN outperformed the baseline methods by nearly 28% in terms of accuracy. The differences in these results highlight the value of exploiting the temporal correlations between different organ systems through the memory fusion network.
Next, we evaluated the effectiveness of the data imputation strategy and the impact of modelling missingness on both raw data and processed data. A robust model should exploit informative missingness properly and avoid introducing non-existent relations between missingness and predictions. Table 5 reports the classification results of all the methods. It is clear that the imputation strategy improves data quality and, in turn, the performance of all the models. MM-RNN benefitted from using missingness, resulting in better performance as the correlations increased. In other words, MM-RNN yielded the best results, demonstrating that it handles missing values in multivariate time-series data effectively.
Deep neural networks usually need to be well-trained to cope with large datasets. Therefore, we tested the performance of a selection of models when trained with 30%, 60%, and 90% of the dataset. Each training set was constructed via randomly sampling; however, we did ensure the class distributions were the same across all three sets. The models compared were MM-RNN and all its baselines, plus GRU-ATT and HMT-RNN, which were the second and third best overall performers in classification on a one-hour time-window dataset. The results, shown in Figure 6, reveal that all methods achieved better performance with a greater number of training samples. Even at a very low training proportion (30% training proportion), MM-RNN achieved 68% accuracy. This implies that our model is robust and is less dependent on the training data size. However, the improvements in prediction accuracy and ROCs for the baseline methods was limited compared to the deep learning methods. MM-RNN delivered the best performance over all the set sizes, and the performance gap between MM-RNN and the baselines grew as more data became available.
Receiver operating curves (ROC) demonstrate the discriminative capability of a classifier by plotting the true positive rate against the false positive rate against a range of threshold values. Figure 7 shows that the ROC for all categories was very far from a 45 degree diagonal and close to the upper left corner of the ROC space. The area under each of these six ROCs (AUC) is shown in Table 6. The average value was around 96.7%, which represents excellent performance. Interestingly, from these results, we also observed that MM-RNN was very sensitive to Class 1 and Class 6. These are two important classes in ICU assessments. Class 6 represents critical conditions; Class 1 represents the lowest threshold for admission to the ICU.
Window size is another important parameter that impacts classification performance. Therefore, we evaluated all methods and baselines with respect to 1-hour, 3-hour, and 6-hour time windows. The results appear in Figures 7 and 8. These results are expressed in terms of precision and recall, and clearly show the superior performance of MM-RNN. Notably, however, performance dropped slightly as the time window grew longer. This may be because medical conditions can change dramatically, for better or worse, over a longer period of time. Also, MM-RNN outperformed the baseline (Table 7).
To investigate the impact of \(\Delta \), which controls the length of the time span for the fused memories, we fixed the length of the time window to 1-hour and tested the variation in performance with respect to different \(\Delta \)s. The curve in Figure 8 shows that performance did indeed vary with different values of \(\Delta \). Precision peaked at 87.42% when \(\Delta \) = 4. In other words, choosing an appropriate length of time is advisable to leverage asynchronous correlations. We see the same trends when testing different time windows (Table 8).
To examine the influence of the multi-task architecture and the memory fusion mechanic, we built three baseline models and tested their classification performance. The criteria for individual classes were calculated with a one-versus-all method. The results are reported in Table 6. We observed that MM-RNN consistently outperformed the other baselines in all measurements. In comparing MT-RNN, RNN-MFN, and a simple RNN, the simple RNN could not compete. This indicates that the temporal correlations and the multi-task structure do improve the predictions. The performance improvement with MM-RNN was consistent compared to MT-RNN and RNN-MFN in all evaluation measures. We also noticed that capturing important correlations between features over time by integrating both task-specific and cross-task interactions could significantly enhance classification performance, which suggests that including temporal information can bring many benefits to illness severity recognition tasks. Overall MM-RNN achieved 89.98% accuracy.
4.6 Explanation for illness severity prediction
To demonstrate the interpretability of MM-RNN, we conducted a small study on two patients. The results are shown in Figure 9. The output scores from the MFN show the attention variations for the two patients over time. A darker colour means a worse condition for the system. In the first 48 hours of Patient A’s stay, MM-RNN paid the most attention to malfunctions in the kidney and cardiovascular systems. This corresponds to a note by the doctor in the radiology report made 72 hours later (noteevents.row_id=1200211): “evaluate for obstruction causing acute renal failure”. With Patient B, MM-RNN paid the most attention to the respiratory system in the early days of admission. After 72 hours, the doctor diagnosed “respiratory failure” in the medical report (noteevents.row_id=105617). Attention to the respiratory system dramatically increased post 80 hours, which is confirmed in the medical report as “Large right pleural effusion is increasing” after 144 hours. This simple example shows MM-RNN’s potential as a highly accurate early warning system to clinicians. Beyond a prediction, the system provides an explanation and an indication of which organ systems are causing problems well before traditional forms of monitoring signal an issue. In this way, caregivers are given more insight into which patients are at-risk. This comfort is not something most conventional deep models offer.
In summary, the MM-RNN not only improves prediction performance, it also offers clinicians visual explanations of the prediction results. A correct clinical decision may mean the difference between life and death. Therefore, practitioners must be able to understand how a prediction made by any diagnostic tool was reached if they are to trust the results.
5 Conclusion
MM-RNN outperformed the other methods on the benchmark MIMIC dataset, but this does not mean the model is necessarily robust. In future, we will look to extending the model to consider data uncertainty and undertake a deeper examination of the issue of missing values. A starting point for missing value imputation may be to incorporate a mask of missing data to indicate the placement of imputation values or missing values. The model could then not only capture the long-term temporal dependencies of time-series observations but also utilize the missing patterns to further improve the prediction results.Footnote 4
Availability of data and materials
The data used in this project is publicly available on PhysioNet.
References
Binder, H., Blettner, M.: Big data in medical science–a biostatistical view: Part 21 of a series on evaluation of scientific publications. Dtsch. Ärztebl. Int 112(9), 137 (2015)
Shann, F., Pearson, G., Slater, A., Wilkinson, K.: Paediatric index of mortality (pim): a mortality prediction model for children in intensive care. Intensive Care Med 23(2), 201–207 (1997)
Lipton, Z.C., Kale, D.C., Wetzel, R.:Modeling missing data in clinical time series with rnns. Mach Learn Healthcare (2016)
Vincent, J., et al.: The sofa score to describe organ dysfunction/failure. on behalf of the working group on sepsis-related problems of the european society of intensive care medicine. Intensive Care Med 22(7), 707–710 (1996)
Knaus, W.A., et al.: The apache iii prognostic system: risk prediction of hospital mortality for critically iii hospitalized adults. Chest 100(6), 1619–1636 (1991)
Le Gall, J.-R., Lemeshow, S., Saulnier, F.: A new simplified acute physiology score (saps ii) based on a european/north american multicenter study. Jama 270(24), 2957–2963 (1993)
Bouch, D.C., Thompson, J.P.: Severity scoring systems in the critically ill. Continuing Education in Anaesth Crit Care Pain Med 8(5), 181–185 (2008)
Shen, S., Xu, M., Yue, L., Boots, R., Chen, W.: In: Li, B., et al. (eds.) Death comes but why: An interpretable illness severity predictions in icu. Springer Nature Switzerland, Cham (2023)
Che, Z., Purushotham, S., Cho, K., Sontag, D., Liu, Y.: Recurrent neural networks for multivariate time series with missing values. Sci Rep 8(1), 6085 (2018)
Aczon, M., et al.: Dynamic mortality risk predictions in pediatric critical care using recurrent neural networks. arXiv:1701.06675
Che, Z., Kale, D., Li, W., Bahadori, M.T., Liu, Y.: Deep computational phenotyping, 507–516, (2015). ACM
Wu, X., Zhu, X., Wu, G.-Q., Ding, W.: Data mining with big data. IEEE Trans Knowl Data Eng 26(1), 97–107 (2013)
Morrison, Z., Robertson, A., Cresswell, K., Crowe, S., Sheikh, A.: Understanding contrasting approaches to nationwide implementations of electronic health record systems: England, the usa and australia. J Healthcare Eng 2(1), 25–41 (2011)
Baumann, L.A., Baker, J., Elshaug, A.G.: The impact of electronic health record systems on clinical documentation times: A systematic review. Health Policy 122(8), 827–836 (2018)
Choi, E., Bahadori, M.T., Schuetz, A., Stewart, W.F., Sun, J.: Doctor ai: Predicting clinical events via recurrent neural networks, 301–318 (2016)
Chen, W., et al.: Dynamic illness severity prediction via multi-task rnns for intensive care unit, 917–922 (2018). IEEE
Wang, S., et al.: Diagnosis code assignment using sparsity-based disease correlation embedding. IEEE Trans Knowl Data Eng 28(12), 3191–3202 (2016)
Chen, L., et al.: Mining health examination records—a graph-based approach. IEEE Trans Knowl Data Eng 28(9), 2423–2437 (2016)
Loo, C.K., Rao, M.: Accurate and reliable diagnosis and classification using probabilistic ensemble simplified fuzzy artmap. IEEE Trans Knowl Data eng 17(11), 1589–1593 (2005)
Li, Y., et al.: Ifflc: An integrated framework of feature learning and classification for multiple diagnosis codes assignment. IEEE Access 7, 36810–36818 (2019)
Purushotham, S., Meng, C., Che, Z., Liu, Y.: Benchmark of deep learning models on large healthcare mimic datasets (2017). arXiv:1710.08531
Caruana, R., Baluja, S., Mitchell, T.: Using the future to" sort out" the present: Rankprop and multitask learning for medical risk evaluation 959–965 (1996)
Clermont, G., Angus, D.C., DiRusso, S.M., Griffin, M., Linde-Zwirble, W.T.: Predicting hospital mortality for patients in the intensive care unit: a comparison of artificial neural networks with logistic regression models. Critical Care Medicine 29(2), 291–296 (2001)
Zhou, J., Liu, J., Narayan, V. A., Ye, J.: Modeling disease progression via fused sparse group lasso 1095–1103 (2012). ACM
Nie, L., et al.: Beyond doctors: Future health prediction from multimedia and multimodal observations 591–600 (2015). ACM
Pham, T., Tran, T., Phung, D., Venkatesh, S.: Deepcare: A deep dynamic memory model for predictive medicine 30–41 (2016). Springer
Yim, J., et al.: Rotating your face using multi-task deep neural network 676–684 (2015)
Zhang, T., Ghanem, B., Liu, S., Ahuja, N.: Robust visual tracking via structured multi-task sparse learning. Int J Comput Vis 101(2), 367–383 (2013)
Abdulnabi, A.H., Wang, G., Lu, J., Jia, K.: Multi-task cnn model for attribute prediction. IEEE Trans Multimed 17(11), 1949–1959 (2015)
Chen, W., et al.: Eeg-based motion intention recognition via multi-task rnns 279–287 (2018). SIAM
Zhou, J., Yuan, L., Liu, J., Ye, J.: A multi-task learning formulation for predicting disease progression 814–822 (2011). ACM
Harutyunyan, H., Khachatrian, H., Kale, D. C., Steeg, G. V., Galstyan, A.: Multitask learning and benchmarking with clinical time series data (2017). arXiv:1703.07771
Gehring, J., Auli, M., Grangier, D., Yarats, D., Dauphin, Y. N.: Convolutional sequence to sequence learning 1243–1252 (2017). JMLR. org
Rocktäschel, T., Grefenstette, E., Hermann, K. M., Kočiskỳ, T., Blunsom, P.: Reasoning about entailment with neural attention. arXiv:1509.06664 (2015)
Verga, P., Strubell, E., McCallum, A.: Simultaneously self-attending to all mentions for full-abstract biological relation extraction (2018). arXiv:1802.10569
Sharma, S., Kiros, R., Salakhutdinov, R.: Action recognition using visual attention (2015). arXiv:1511.04119
Johnson, A.E., et al.: Mimic-iii, a freely accessible critical care database. Sci Data 3, 160035 (2023)
Elman, J.L.: Finding structure in time. Cognit Sci 14(2), 179–211 (1990)
Hochreiter, S., Schmidhuber, J.: Long short-term memory. Neural Comput (1997)
Graves, A.: Generating sequences with recurrent neural networks (2013). arXiv:1308.0850
Neil, D., Pfeiffer, M., Liu, S.-C.: Phased lstm: Accelerating recurrent network training for long or event-based sequences 3882–3890 (2016)
Zadeh, A., et al.: Memory fusion network for multi-view sequential learning (2018)
Liu, F., Perez, J.: Gated end-to-end memory networks 1–10 (2017)
Parikh, A.P., Täckström, O., Das, D., Uszkoreit, J.: A decomposable attention model for natural language inference (2016). arXiv:1606.01933
Vaswani, A., et al.: Attention is all you need 5998–6008 (2017)
Kingma, D. P., Ba, J.: Adam: A method for stochastic optimization (2014). arXiv:1412.6980
Nguyen, P., Tran, T., Venkatesh, S.: Deep learning to attend to risk in icu (2017). arXiv:1707.05010
Funding
Open Access funding enabled and organized by CAUL and its Member Institutions. This project is financially supported by Chen Start-up (15131570) from The University of Adelaide, and the 2022 UQAI ECR Seed Fund and Cyber Security research grant (4018207-01-299-21-619007) from The University of Queensland.
Author information
Authors and Affiliations
Contributions
All authors contributed to the main manuscript text, and figure design using overleaf. Weitong Chen conducted experiments and All authors reviewed the manuscript.
Corresponding author
Ethics declarations
Ethical approval
This project was approved by the Royal Brisbane & Women’s Hospital Human Research Ethics Committee (RBWH HREC) and the Institutional Review Boards of Beth Israel Deaconess Medical Center (Boston, MA) and the Massachusetts Institute of Technology (Cambridge, MA) [37]. Informed consent was waived by the RBWH HREC as the study used de-identified patient data.
Competing interest
The authors declare no competing interests.
Additional information
Publisher's Note
Springer Nature remains neutral with regard to jurisdictional claims in published maps and institutional affiliations.
This article belongs to the Topical Collection: APWeb-WAIM 2022
Guest editors: Calvanese Diego, Toshiyuki Amagasa and Bohan Li.
Rights and permissions
Open Access This article is licensed under a Creative Commons Attribution 4.0 International License, which permits use, sharing, adaptation, distribution and reproduction in any medium or format, as long as you give appropriate credit to the original author(s) and the source, provide a link to the Creative Commons licence, and indicate if changes were made. The images or other third party material in this article are included in the article’s Creative Commons licence, unless indicated otherwise in a credit line to the material. If material is not included in the article’s Creative Commons licence and your intended use is not permitted by statutory regulation or exceeds the permitted use, you will need to obtain permission directly from the copyright holder. To view a copy of this licence, visit http://creativecommons.org/licenses/by/4.0/.
About this article
Cite this article
Chen, W., Zhang, W.E. & Yue, L. Death comes but why: A multi-task memory-fused prediction for accurate and explainable illness severity in ICUs. World Wide Web 26, 4025–4045 (2023). https://doi.org/10.1007/s11280-023-01211-w
Received:
Revised:
Accepted:
Published:
Issue Date:
DOI: https://doi.org/10.1007/s11280-023-01211-w