Abstract
Arrhythmia is a fatal cardiac clinical condition that risks the lives of millions every year. It has multiple classes with variable prevalence rates. Some rare arrhythmia classes are equally critical as common ones, yet are very hard to detect due to limited training samples. While several methods accurately detect Arrhythmia's multi-class, minority class accuracy remains low and these methods are resource-intensive. Therefore, most of the existing detection systems ignore minority classes in their classification or focus on binary classification. In this study, we introduce RL-ECGNet, a resource-efficient reinforcement learning-based optimization for multi-class arrhythmia detection, encompassing minority classes, through ECG signal analysis. RL-ECGNet uses raw ECG signals, processes them to extract the temporal ECG features, and utilizes Reinforcement Learning (RL) to optimize the training and network hyperparameters of the Deep Learning (DL) models while reducing resource consumption. For evaluation, four DL models, namely, MLP, CNN, LSTM, and GRU, are trained and optimized. Moreover, time and memory usage are minimized to optimize resource consumption. Throughout the evaluation of the four DL models, the proposed RL model achieved accuracies ranging from 88.45% to 96.41% for all 9 arrhythmia classes, including minority classes. In addition, the proposed RL method improved performance by a factor ranging from 1.28 to 1.39 in terms of accuracy. Moreover, the optimized DL models had reduced training time, as well as minimized memory usage. The proposed method achieved resource consumption reduction ranging from 1.36 to 1.925 times for training time, and from 1.179 to 1.815 times for memory usage.
Similar content being viewed by others
Avoid common mistakes on your manuscript.
1 Introduction
Arrhythmia is a cardiac disorder that results in irregular heart rhythm due to abnormal electrical activity of the heart. Globally, the expected prevalence of arrhythmia reaches up to 5% of the general population [1]. According to Centers for Disease Control and Prevention (CDC) [2] in 2030, an estimate of 12.1 million people in the United States are expected to develop atrial fibrillation, which is the most common type of arrhythmia. Hence, there is an urgent need for early detection of arrhythmia as a first step of prevention.
Electrocardiography (ECG) signals are biomedical signals that reflect the electric activity of the heart [3]. ECG signals have been widely used for arrhythmia detection and classification. Many research work focuses on the binary classification of ECG signals to predict whether a signal is arrhythmic or not [4,5,6]. Although such classification provides an essential first step towards diagnosis, it remains limited. Hence, more recent research work addresses Multi-Class Arrhythmia Detection (MCAD). These classification methods consider varying numbers of arrhythmia classes, for instance, 5 classes [7,8,9,10], 7 classes [11], 9 classes [12, 13], 17 classes [14], or even up to 29 classes [15].
Arrhythmia multi-class prediction is a challenging research problem due to the data imbalance. The scarcity of rare arrhythmia samples hampers minority class classification performance and exacerbates resource-intensive hyperparameter optimization to address the data imbalance. While several training and network structure hyperparameters require fine-tuning during optimization, the majority of the recent studies only optimize a limited set of training parameters while manually optimizing the network architecture [7, 8, 11, 13,14,15] with a few studies considering optimizing resource consumption [10, 11]. However, resource efficiency, encompassing training, inference time, and memory usage optimization, is crucial for maintaining high performance. This is especially vital for real-time applications and resource-constrained devices like those deployed in Internet of Medical Things (IoMT) systems. To this end, we aim to propose a more comprehensive approach for resources-aware optimized multi-class arrhythmia detection considering minority classes. The contributions of the proposed RL-ECGNet framework can be summarized as follows:
-
Deployment of Reinforcement Learning (RL) for optimizing both training and network architecture hyperparameters.
-
Resource consumption optimization considering training time and memory usage while maintaining high classification performance.
-
Evaluation and comparison of four Deep Learning (DL) models to achieve the highest classification performance, considering the arrhythmia minority classes.
This paper is structured as follows: Section 2 thoroughly discusses the literature related to deep learning models for arrhythmia prediction from ECG signals as well as automated fine-tuning. Section 3 discusses the proposed method for automated reinforcement learning-based prediction of arrhythmia on an imbalanced ECG dataset. Finally, Section 4 discusses the experimental setup and testing of the proposed framework and analyzes the obtained results.
2 Related work
Deep Learning (DL) models have been increasingly used for Arrhythmia detection [16]. Despite the efficiency of DL models in feature extraction and classification problems, optimized multi-class arrhythmia ECG-detection remains an on-going research problem.
Arrhythmia classification datasets are highly imbalanced, leading to classification difficulty and low classification performance of the minority classes. In addition, model optimization is a critical process for performance improvement. This process becomes more challenging due to data scarcity. Some approaches optimize the feature selection to improve the classification performance. For instance, Monarch Butterfly [4, 17], Manta Ray [17, 18], Emperor Penguin [17], Bat-Rider [19], as well as Coyote Grey Wolf [20] optimization methods have been deployed with ML models [17, 18], or DL models such as ANN [17], and CNN [4, 19, 20]. Although feature selection methods improve classification performance, this approach does not fully address the data imbalance issue as well as the network and training optimization, which are vital for performance optimization.
While some research drops the minority classes during evaluation [11], other proposed methods focus on improving the classification performance of the minority classes. For instance, Luo et al. [12] proposed a Hybrid Convolutional Recurrent Neural Network (HCRNet) for automatic classification of 9 arrhythmia classes including the imbalanced data samples. The time-series extracted heartbeat signals are fed into the hybrid CNN and RNN layers, namely, GRU and LSTM layers. The hybrid model consists of 4 blocks, a convolutional block for feature extraction, a separable convolutional block with a GRU layer for faster feature extraction and knowledge retention, a separable convolutional block with an LSTM layer, and a final convolutional block. The outputs of these blocks are concatenated and fed into fully connected layers that produce the final classifications. The model’s architecture was manually optimized, while the training optimization method was not stated.
To address the training optimization, Adam, Stochastic Gradient Descent (SGD), as well as Stochastic Gradient Descent with Momentum (SGDM) optimizers are commonly used in the reviewed literature. For example, Kanani et al. [7] proposed a modified ResNet for automatic classification of 5 arrhythmia classes. The proposed method uses an augmented dataset to improve the classification performance. This method deploys Adam optimizer for the training of the model; however, the model’s architecture is optimized manually. Although this method achieved a high classification accuracy, it is limited to inputs of certain amplitudes and frequencies.
To provide more wholesome results, Yang et al. [13] proposed a 12-lead ECG classifier considering the multi-directional heart activity. The proposed method deployed a Cascaded Convolutional Neural Network (CCNN) arrhythmia multi-label classifier. Features are extracted separately from the single leads using a 1-D CNN and are then cascaded based on the spatio-temporal correlation between the leads. A ResNet is then used to classify the 2-D feature maps data into 9 arrhythmia classes producing the first classification probability. This classification probability is combined with the classification probabilities obtained from training a random forest on expert features. The expert features include morphological, HRV, statistical, frequency, and AF features extracted after denoising the ECG signals. Finally, this approach deploys Adam optimizer for automatic fine tuning of the hyperparameters. Despite the representativeness of the multi-lead data and comprehensiveness of this approach, the reported classification accuracy was 86.5%, which can be improved further.
Although multiple lead-based classifications models provide more comprehensive and representative results, such high computational resource consumption is impractical for resource-limited systems such as those intended for use within Internet of Medical Things (IoMT) systems [11]. Therefore, Sepahvand et al. [11] proposed a teacher-student model for arrhythmia classification that uses single-lead ECG data with minimal performance loss. The teacher model is a CNN with an advanced architecture trained on multi-lead data, while the student model is a simpler and a lightweight CNN trained on single-lead data under the supervision of the teacher network. By deploying this teacher-student knowledge distillation (KD) approach, the simple and more compressed student model receives the knowledge of the multi-lead signals through imitating the teacher model during training. The internal feature map and output response knowledge are passed to the student model, balancing out the data representativeness and resource consumption. Due to the imbalance of the arrhythmia ECG data, the 4 minority classes (out of 11 classes) were crossed out during evaluation. Moreover, the Adam optimizer was deployed for fine-tuning the weights of the models, but the network architecture was manually optimized.
Pal et al. [15] proposed CardioNet, a Transfer learning-based model for automatic arrhythmia classification addressing data imbalance. The proposed method deploys a pre-trained DenseNet to optimize the feature extraction and classification results. This transfer learning approach speeds up the training process while fine-tuning its learning on the ECG arrhythmia data through using the DenseNet weights. Moreover, using transfer learning improves the classification performance of minority classes since it does not require balanced or identically distributed data. The model uses an augmented dataset of 2D images of ECG signals, and the model classifies these data samples into 29 arrhythmia classes. This method deployed SGD for optimizing the model’s training.
Other proposed approaches deploy Focal Loss to address the sample scarcity. Focal Loss drives the model’s attention towards the minority samples that are difficult to classify through weight reduction of the majority samples that are easily classified [8]. For instance, Li et al. [8] proposed an improved deep residual convolutional neural network ResNet with the focal loss function. The proposed method applies overlapping segmentation to the MIT-BIH data samples to produce 5 s segments and re-labels them to increase the number of data samples of the minority classes. The improved ResNet is then used to classify the ECG segments into 5 classes. The use of identity mapping within the ResNet does not add extra parameters, hence, contributes to reducing the computational complexity. Furthermore, the proposed method deploys Focal Loss by introducing the Cross-Entropy (CE) to the loss function. The network’s architecture was manually optimized, and training was optimized using SGD. Although this method generalizes better, the performance metrics of the minority classes are still evidently lower than the common classes.
To improve the classification performance of imbalanced datasets with the minority classes, Lu et al. [14] proposed a Depthwise Separable with Focal Loss Convolutional Neural Network (DSC-FL-CNN). The proposed network automatically classifies the data samples into 17 arrhythmia classes. The converted 2D ECG images are fed into the network, whose architecture was manually optimized, and training was optimized using SGD. The proposed method demonstrated slightly improved performance compared to CNNs, FL-CNNs, and DSC-CNNs.
Mohonta et al. [9] proposed deploying CNNs on spectrum images for automatic arrhythmia classification. Four spectral analysis techniques, namely, FFT, STFT, CWT-Grayscale, CWT-RGB, have been fed as the CNN inputs. The CNN’s architecture is manually optimized, while the training was optimized using SGDM. Although the CWT-RGB spectral analysis with the proposed CNN yielded high performance, such analysis adds overhead and computational complexity to the system.
While the aforementioned methods do not prioritize optimizing resource consumption and computational complexity, approaches designed for operation within IoMT, or embedded systems require such optimization. For instance, the teacher-student architecture proposed in [11] optimized the resource consumption. To achieve this, computational complexity, model size- indicated by the quantity of the parameters-, as well as the memory usage of the student network are analyzed against the teacher model. Although the student model’s performance is less than that of the teacher’s model, it consumes less resources and is 262.18 times more compressed. Moreover, despite the reduction in resource consumption, the model’s resource consumption is not optimal.
Falaschetti et al. [10] proposed an RNN LSTM-based architecture for automatic arrhythmia classification on embedded systems. Due to the limited resources available in embedded systems, this method focuses on optimizing memory, inference time, number of parameters, as well as accuracy. Throughout the evaluation of this approach, different RNN variants were tested. The RNN LSTM-based network’s performance was the best while optimizing the resources. However, the network’s architecture was still optimized manually. Moreover, although the results of this method outperform other architectures in terms of memory storage, number of parameters, and inference time, the reported accuracy needs further improvement.
To this end, in this research we aim to design and implement a comprehensive, resource-aware optimization framework for both network hyperparameters and architecture using Reinforcement Learning.
3 Materials and methods
To address the data scarcity and extreme class imbalance, we propose RL-ECGNet, a Reinforcement Learning (RL) -based architecture for classifying arrhythmia from raw ECG signals. Figure 1 depicts the proposed RL-ECGNet framework. The framework consists of different modules: ECG signal acquisition and preprocessing, as well as Reinforcement Learning module. The latter consists of three main components: (a) the Q-Table, (b) the Agent, and (c) the Environment. Detailed explanation of the working of each component is explained in subsequent sections.
3.1 Module 1: signal acquisition and processing
The raw ECG signals are fed to the signal acquisition and processing module for preprocessing, cleaning, and feature extraction. To obtain the digital signal from the raw analog ECG signal, the signal is first sampled. Next, several successive processing tasks are applied to denoise the signal and attenuate the artifacts. This step is crucial due to the susceptibility of the ECG signal to various types of noises that highly impact the accuracy of the readings. These noises include electromyogram noise, motion artifacts, powerline interference, as well as baseline wandering [21]. This implementation focuses on attenuating powerline interference and baseline wandering since their distortions are the most impactful. The ECG signal wave features are then extracted during the signal processing. In order to diagnose the ECG signal, the QRS complex is detected, which represents key morphological features of the cardiac electrical activity. The heartbeat signal is then segmented, identifying the heartbeats in a series of samples. Finally, for classification purposes, temporal features of the ECG signal including the P-, R-, and T- waves as well as the pre- and post- RR intervals are extracted.
3.2 Module 2: reinforcement learning
The processed ECG signals are then passed to the Reinforcement Learning module. The proposed framework deploys Markov Decision Processes (MDP) to fine-tune the hyperparameters of the multi DL predictive models. Markov Decision Processes are an off-policy value-based RL algorithm, where the agent updates the set of hyperparameter values based on the policy \((\pi )\). The targeted hyperparameters include the network architecture hyperparameters: activation function (af), hidden units (nhu), dropout (d), and weight initialization (nw), as well as the training hyperparameters: batch size (bsz), momentum (m), learning rate (lr), and epochs (nep). The Q table consists of an action list that the agent deploys to configure and set the parameters of the deep networks. The action list includes increasing or decreasing the target hyperparameter value based on predefined sets as illustrated in Table 1.
The agent consists of the MDP process, which iteratively updates the policy based on the feedback obtained from the environment. The MDP process also keeps a record of the states, including the preceding, current, and the succeeding states. The actions (at) are deployed to the environment, which includes the DL models that are trained to produce predictions. The proposed framework deploys multiple different predictive deep networks. The predictive models include Convolutional Neural Network (CNN), Gated Recurrent Unit (GRU), Long Short-Term Memory (LSTM), and Multilayer Perceptron (MLP).
Throughout the training and evaluation processes of the DL models, the environment returns observations, namely, the reward (Rt) associated with the action to the agent for feedback and policy update. The reward function in this architecture is a multi-objective fusion of MSE loss, accuracy, as well as time and memory. These networks are optimized, evaluated and compared to get the best performance with minimum resource consumption. This process is iterative until the models reach convergence. An optimum model achieves the best fusion of low loss, high accuracy, reduced time and memory consumption.
3.2.1 Training and network architecture hyperparameter sets
The proposed RL-ECGNet framework optimizes the network at two levels; training and network architecture, where the agent configures different combinations of the hyperparameter sets. The MDP algorithm reduces the search space dimensionality of the hyperparameters through sequential processes by updating the value of a single hyperparameter at a time. Table 1 lists the hyperparameters and their value ranges. The value ranges include the well-known and validated values that are commonly used during the optimization of these hyperparameters.
3.2.2 Deep learning network optimization through MDP
The proposed framework deploys Markov Decision Processes algorithm for automated reward-based optimization. Multiple Deep Learning models are trained and optimized throughout the process. Let \({D}_{\upmu }\) denote the Deep Learning algorithm with a set of training and network architecture hyperparameters \(\upmu\). The Deep learning model \({D}_{\upmu }\) is trained with the set of hyperparameters configured by the agent to reach state \({S}_{t}={\mathrm{Decision}}_{\mathrm{t}-1}\left({\upmu }_{\mathrm{t}-1}\right),\) where \({\mathrm{Decision}}_{\mathrm{t}-1}\) is the decision of the agent at time (t − 1). The agent chooses an action from a list of predefined actions as shown in Table 1. The agent’s action At follows a policy π that aims to maximize the expected reward associated with the action that lead to the current state \(\uppi \left(\mathrm{a},\mathrm{ s}\right).\) The reward function R is measured and re-evaluated after the state transition under action At at time t. Since this framework deploys multi-objective MDP, a fusion of performance measures is compromised as the weighted sum of the performance metrics: loss, accuracy, times, and memory.’
-
The model’s loss is determined using the Mean Square Error (MSE) metric shown in Eq. (1):
$$MSE=\frac{\sum_{t=1}^{n}{\left({y}_{pt}-{y}_{t}\right)}^{2}}{n},$$(1)where \({y}_{pt}\) is the predicted class at time t, \({y}_{t}\) is the actual value at time t, and n is the number of observations.
-
The accuracy is evaluated as shown in Eq. (2):
$$Accuracy= \frac{true\;positives}{total\;observations}$$(2) -
The time of the Deep Learning model training and testing is evaluated as shown in Eq. (3):
$$T= \sum \nolimits_{t=0}^{n}{\mathrm{e}}_{t}$$(3)where \({e}_{t}\) is the execution time at time step \(t\).
-
The memory resources consumption is the memory occupied during execution.
The proposed optimization method describes a multi-objective problem with four objectives. The ranges of the four objectives differ; hence, each objective is normalized using the z-score normalization technique, described in Eq. (4). Positive and negative values indicate improvement and decline compared to the mean, respectively.
where \(\mu\) and \(\upsigma\) are the mean and the standard deviation of each objective, respectively.
The reward Rt is the weighted and normalized value of the four objectives as shown in Eq. (5):
For reward update calculation, the reward difference is considered as stated in Eq. (6) [22]. To guide and balance the agent’s decision process, a discount factor \(\upgamma\) is introduced, where \(\mathrm{\gamma \epsilon }[\mathrm{0,1}]\). This factor assigns weights to the instant and long-term rewards. Values closer to 0 prioritize the instant reward, and ones closer to 1 prioritize long-term rewards.
3.2.3 Training with reinforcement learning
We define the state-value function \(Q\) as the expected value at a certain state ( \(S\)) after transition by action ( \(a\)) under the policy\((\uppi\)) as shown in Eq. (7) [22]. It provides an insight of the evaluation of the new state of the model.
The Q value is iteratively updated as described in Eq. (8) [22, 23]:
The agent selects the policy that maximizes the cumulative reward as follows:
\({\uppi }^{*}\left(s\right)\in {argmax}_{a}{Q}^{*}\left(s,a\right)\), where \({Q}^{*}\left(s, a\right)\) is the optimum state-action value.
In this proposed framework, the optimal policy by deploying Q-learning as the off-policy, value-based method obeying the Bellman Eq. (9):
For action priorities, in MDP Q-Learning, a transition probability matrix is defined, which determines the probability of transitioning from one state to another when the agent selects an action. This matrix affects the decision-making of the agent for the initial and the subsequent action selection process. In our proposed implementation, all actions are assigned equal transition probabilities. This uniform environment encourages the agent to explore all possible actions without bias, which ensures the robustness of the RL model.
Algorithm 1 describes the procedure carried out to optimize the deep learning networks using Reinforcement Learning (i.e., Procedure OptimizeModel()). First, the agent selects an action from the predefined action list based on a specific policy. The agent then takes the action and trains the DL model on the updated hyperparameters. The model is evaluated to obtain the performance metrics including the MSE loss and the accuracy. The training time and memory usage are also monitored. The reward is then calculated as the weighted sum of the aforementioned metrics. The state-value function \(Q\) is updated where the optimum value maximizes the reward. The state is then updated, and the agent selects the next action. The updated hyperparameters list of the closest optimum state is saved. This process iterates until it reaches an experimentally tested threshold (δ) that gives the desired performance. We ran multiple experiments with several iterations until a threshold (δ) was found to give the best desired performance. This iterative process updates the hyperparameters at each step that maximizes the reward and returns the optimum set at the end. This process is carried out to optimize each of the DL models sequentially and independently to reduce the resources consumption and computational complexity. The optimized DL models are then evaluated and compared, and the results are tabulated and analyzed.
Throughout the optimization process, the hyperparameters as well as the four objectives: loss, accuracy, time, and memory, are logged for each state. The four objectives are normalized, with positive and negative values indicating improvement or decline compared to the mean. The reward is a weighted sum of all the normalized objectives. During the process, the agent aims to select actions that increase the reward. Actions that move toward the optimum hyperparameter set are rewarded, while actions that maximize the distance between the current and the optimal hyperparameter sets, in other words, actions in the wrong direction, are penalized.
Figure 2 illustrates a sample hyperparameter log of the LSTM model during 13 consecutive states of optimization. Figure 2 demonstrates the hyperparameter configurations, their respective normalized objectives, and the reward. The agent updates one of the hyperparameters at a time in each state (highlighted in gray) following the actions’ definition presented earlier in Table 1. The reward is calculated and logged; rewarded (highlighted in green) and penalized (highlighted in red) actions have positive and negative values, respectively.
These reward values are logged in the Q-table in correspondence with the action taken in each state as shown in Fig. 3. For instance, as demonstrated in Fig. 3 (a), the algorithm runs an initial round in which it randomly selects a different action in every state covering all the predefined actions. After this round, the most rewarding action is selected for the succeeding states, as can be seen in Fig. 3 (b). Hence after, in every succeeding state the most rewarding action is selected and the Q table is updated accordingly. For instance, increasing the number of epochs (i.e., Action #13) is the most rewarding action in the initial round; hence, it is reselected in succeeding states, as demonstrated in Fig. 3 (b), until a penalty on the reward is encountered. Then, another action is selected following the probability matrix. The action is reselected as long as its reward continues to increase.
4 Experiments and results
4.1 Dataset
The proposed RL-ECGNet framework uses the MIT-BIH arrhythmia database [24], which includes 48 two-channel ambulatory ECG recordings from 47 subjects at the BIH Arrhythmia Laboratory between 1975 and 1979. Each instance in this dataset is a 30-min recording of the ECG signal sampled at a rate of 360 samples per second per channel. The data have been annotated by two or more specialized cardiologists, and the disagreements were resolved. Finally, computer-readable reference annotations were obtained for each beat [25] as WFDB format. The database includes nine classes of Arrhythmia:
-
1.
R: Right bundle branch block beats
-
2.
L: Left bundle branch block beats
-
3.
V: Premature ventricular contractions
-
4.
A: Atrial premature beats
-
5.
F: Fusions of ventricular
-
6.
N: Normal beats
-
7.
Q: Paced beats
-
8.
U (/): Unknown/unclassifiable beats
-
9.
S: Supraventricular premature beats
Figure 4 shows the distribution of the Arrhythmia classes in the dataset. As can be seen in Fig. 4, the dataset is extremely imbalanced. This imbalance makes it very challenging to train the predictive models on the minority classes, which represent rare arrhythmia occurrences. Additionally, extensive fine-tuning is required to reach the optimal configuration of predictive models when working with this type of data.
To alleviate the challenges related to manual fine-tuning and samples scarcity, four predictive models (i.e., MLP, GRU, LSTM, and CNN) are trained on all nine arrhythmia classes, and RL is used to automatically optimize the configuration of these models to reach the optimal configuration while minimizing memory and time consumption.
4.2 Environment setup
The predictive models were trained and tested on Google Colab with a Colab Pro + subscription with 51 GB System RAM, a 226 GB Disk space, and High-RAM runtime.
4.3 Baseline models
To evaluate the performance of the proposed ECG-RL Net model, several manually fine-tuned baseline models were used. These models have the same architecture as the models used with the proposed RL model. Table 2 summarizes the hyperparameter configuration of the baseline models.
4.4 Deep reinforcement learning with multi-models’ prediction results
In this section, the configuration of MLP, GRU, LSTM, and CNN was optimized using the proposed RL algorithm. A stratified hold-out evaluation approach was followed, with 20% of the dataset used for evaluation and 80% used for training. Each split of the dataset contains samples’ distribution proportional to the class distribution in the full dataset. Table 3 shows the initial and final hyperparameter sets for the MLP, CNN, LSTM, and GRU models, respectively.
Equation (10) is used to report the performance improvement of the proposed RL optimized models against the baseline models.
where \({\overline{o} }_{i}\) is the performance metric of the RL optimized model, and \({o}_{i}\) is the performance metric of the manually optimized baseline model.
Table 4 presents the overall prediction results for the four DL models with the proposed RL model against the manually optimized baseline models. The experimental results are reported in terms of accuracy, MSE loss, training time, and memory usage. The proposed RL optimization method is evaluated against manually optimized baseline DL models, as well as state-of-the-art methods with different number of arrhythmia classes [8, 10], as well as methods with similar number of classes [13].
As can be seen in Table 4, the proposed RL model outperformed manual optimization for all DL models with minimized resource consumption except for the LSTM. Among all DL models, LSTM with RL optimization achieved the highest prediction accuracy of 96.41%. GRU and CNN also achieved relatively high classification accuracies of 91.70%, and 95.82% with RL compared to manual optimization. The deep learning architectures, namely, LSTM, CNN, and GRU, are well-suited for processing sequential data like raw ECG data for arrhythmia classification. Moreover, these DL models capture the temporal aspects and learn the hierarchical features of the ECG data better; hence, they outperformed the simpler MLP architecture.
The proposed RL optimization method achieved a performance improvement (considering accuracy) of 1.28, 1.32, 1.39, and 1.33 times for MLP, CNN, LSTM, and GRU, respectively. In addition, the proposed RL method achieved a loss reduction of 2.45, 5.46, 7.61, and 3.39 times for MLP, CNN, LSTM, and GRU, respectively. Moreover, the proposed RL method achieved a reduction of training time of 1.925, 2.727, and 1.36 times for MLP, CNN, and GRU, respectively. Finally, the proposed RL method achieved a reduction of memory usage of 1.179, 1.815, and 1.359 times for MLP, CNN, and GRU, respectively. Hence, the proposed RL optimization method has proven efficient resource usage through reducing time, and memory of training. The only exception is the LSTM model, where the resource consumption of manual optimization is less than the proposed method. This is compensated by the noticeably higher performance improvement.
Finally, for evaluation against state-of-the-art research in ECG arrhythmia classification, the number of classes vary. For evaluation against methods with varying number of arrhythmia classes[8, 10], the proposed method outperformed all models, achieving a higher accuracy. For evaluation against a method with the same number of classes, the proposed method is compared to a 12-lead ECG 2D CCNN [13] that classifies arrhythmias into 9 classes, similar to our proposed method. Compared to [13], LSTM optimized with the proposed RL method outperformed the 2D CCNN considering accuracy. Although the performance of the proposed method is quite similar to the 2D CCNN, the computational complexity of the 2D CCNN is deemed to be high due to the high dimensionality of the data used, the intricate network architecture, and manual optimization of the network.
5 Discussion
In this paper, RL was deployed to optimize Deep Learning models for automatic multi-class arrhythmia prediction considering class imbalance. The proposed RL-ECGNet framework deploys an MDP process to optimize the deep networks on two levels: training and network architecture. The RL model uses a multi-objective signal that consists of accuracy, loss, time, and memory for optimization of performance and resource consumption. The MDP agent aims to iteratively select rewarded actions that maximize the multi-objective function.
The proposed RL-ECGNet framework was evaluated on four DL models, namely, MLP, CNN, LSTM, and GRU, which are commonly used in the recent research work, and proven to be efficient. Hence, the focus of the experiments carried out throughout this study was to automatically optimize the DL models while reducing the computational complexity, hence, the resource consumption.
The proposed RL optimization method achieved accuracies that range from 88.45% to 96.41% for the four DL models. In addition, the proposed RL method yielded performance improvements that range from 1.28 to 1.39 times for accuracy. Moreover, the proposed method achieved resource consumption reduction that range from 1.36 to 1.925, and from 1.179 to 1.815 times, for training time and memory usage, respectively. The performance of the proposed RL-ECGNet framework is comparable to the state-of-art methods evaluated on 9 arrhythmia classes [13]. The proposed method achieved slightly higher accuracy with efficient resource usage. Moreover, compared to methods that were evaluated on different numbers of arrhythmia classes, the proposed method still outperformed some of the reviewed research on the classification of minority classes [8, 10].
6 Conclusion
Cardiac arrhythmia is a fatal clinical condition caused by abnormal rhythms of the heart. ECG signals have long been utilized for detection and classification of such abnormalities for early diagnosis of arrhythmia. Arrhythmia ECG datasets are highly imbalanced, including minority classes that represent rare arrhythmia conditions. Although existing automated multi-class arrhythmia detection frameworks achieve high detection and classification performance, the classification performance of the minority classes remains low.
To this end, in this paper, we propose RL-ECGNet framework, a Reinforcement Learning-based optimization approach for multi-class arrhythmia detection using ECG signals. The proposed framework processes raw ECG signals for morphological feature extraction. Then, RL is utilized to automatically optimize the training and network hyperparameters of the Deep Learning models. Unlike the recent research work that fine-tunes a limited hyperparameter set, the proposed framework automatically optimizes the training and network architecture hyperparameters using Markov Decision Process (MDP).
The RL-ECGNet framework was validated on 9 arrhythmia classes from the well-known MIT-BIH dataset. For evaluation, the processed ECG signals are fed to four DL models, namely, MLP, CNN, LSTM, and GRU, are trained and optimized. Not only the training and networks are optimized, but the resource consumption, specifically, time and memory usage are minimized. The performance of the optimized DL models was tabulated and compared, concluding that with the described environment setup, the LSTM model outperformed the rest of the models, yielding an average accuracy of 96.41%. In addition, the proposed RL method yielded an average performance improvement of 1.33 times for accuracy. Moreover, the proposed method achieved an average resource consumption reduction of 1.671 and 1.451 times for training time and memory usage, respectively. In conclusion, optimizing DL models with RL has proven to perform well on multi-class arrhythmia classification with imbalanced datasets while efficiently reducing resource consumption.
Data availability
All the datasets used in this manuscript are published and publicly available for research. References to data sources are provided in the manuscript.
References
Crispi F, Martinez JM (2017) “Arrhythmias,” in Obstetric Imaging: Fetal Diagnosis and Care, 2nd Edition, StatPearls Publishing pp. 418–425.e1
Centers for Disease Control and Prevention (2020) “Atrial Fibrillation | cdc.gov”. https://www.cdc.gov/heartdisease/atrial_fibrillation.htm . Accessed 10 Jul 2022
Gacek A (2014) “An introduction to ECG signal processing and analysis,” in ECG Signal Processing, Classification and Interpretation: A Comprehensive Framework of Computational Intelligence, vol. 9780857298, Springer-Verlag London Ltd, pp. 21–46
Nainwal A, Kumar Y, Jha B (2022) Arrhythmia classification based on improved monarch butterfly optimization algorithm. J King Saud Univ-Comput Inf Sci 34(8):5100–5109. https://doi.org/10.1016/j.jksuci.2022.01.002
Sowmya S, Jose D (2022) Contemplate on ECG signals and classification of arrhythmia signals using CNN-LSTM deep learning model. Meas Sensors 24:100558. https://doi.org/10.1016/j.measen.2022.100558
Chumrit N, Weangwan C, Aunsri N (2020) “ECG-based Arrhythmia Detection using Average Energy and Zero-crossing Features with Support Vector Machine,” in InCIT 2020 - 5th International Conference on Information Technology, pp. 282–287. https://doi.org/10.1109/InCIT50588.2020.9310931
Kanani P, Padole M (2020) ECG heartbeat arrhythmia classification using time-series augmented signals and deep learning approach. Procedia Comput Sci 171:524–531. https://doi.org/10.1016/j.procs.2020.04.056
Li Y, Qian R, Li K (2022) Inter-patient arrhythmia classification with improved deep residual convolutional neural network. Comput Methods Programs Biomed 214:106582. https://doi.org/10.1016/j.cmpb.2021.106582
Mohonta SC, Motin MA, Kumar DK (2022) Electrocardiogram based arrhythmia classification using wavelet transform with deep learning model. Sens Bio-Sensing Res 37:100502. https://doi.org/10.1016/j.sbsr.2022.100502
Falaschetti L, Alessandrini M, Biagetti G, Crippa P, Turchetti C (2022) ECG-based arrhythmia classification using recurrent neural networks in embedded systems. Procedia Comput Sci 207:3473–3481. https://doi.org/10.1016/j.procs.2022.09.406
Sepahvand M, Abdali-Mohammadi F (2022) A novel method for reducing arrhythmia classification from 12-lead ECG signals to single-lead ECG with minimal loss of accuracy through teacher-student knowledge distillation. Inf Sci (NY) 593:64–77. https://doi.org/10.1016/j.ins.2022.01.030
Luo X, Yang L, Cai H, Tang R, Chen Y, Li W (2021) Multi-classification of arrhythmias using a HCRNet on imbalanced ECG datasets. Comput Methods Programs Biomed 208:106258. https://doi.org/10.1016/j.cmpb.2021.106258
Yang X, Zhang X, Yang M, Zhang L (2021) 12-Lead ECG arrhythmia classification using cascaded convolutional neural network and expert feature. J Electrocardiol 67:56–62. https://doi.org/10.1016/j.jelectrocard.2021.04.016
Lu Y et al (2021) Automated arrhythmia classification using depthwise separable convolutional neural network with focal loss. Biomed Signal Process Control 69:102843. https://doi.org/10.1016/j.bspc.2021.102843
Pal A, Srivastva R, Singh YN (2021) CardioNet: an efficient ecg arrhythmia classification system using transfer learning. Big Data Res 26:100271. https://doi.org/10.1016/j.bdr.2021.100271
Ebrahimi Z, Loni M, Daneshtalab M, Gharehbaghi A (2020) A review on deep learning methods for ECG arrhythmia classification. Exp Syst Appl: X 7:100033. https://doi.org/10.1016/j.eswax.2020.100033. (Elsevier)
MianQaisar S, Khan SI, Srinivasan K, Krichen M (2023) Arrhythmia classification using multirate processing metaheuristic optimization and variational mode decomposition. J King Saud Univ-Comput Inf Sci 35(1):26–37. https://doi.org/10.1016/j.jksuci.2022.05.009
Houssein EH, Ibrahim IE, Neggaz N, Hassaballah M, Wazery YM (2021) An efficient ECG arrhythmia classification method based on Manta ray foraging optimization. Expert Syst Appl 181:115131. https://doi.org/10.1016/j.eswa.2021.115131
Atal DK, Singh M (2020) Arrhythmia classification with ecg signals based on the optimization-enabled deep convolutional neural network. Comput Methods Programs Biomed 196:105607. https://doi.org/10.1016/j.cmpb.2020.105607
Kumar A, Kumar SA, Dutt V, Dubey AK, García-Díaz V (2022) “IoT-based ECG monitoring for arrhythmia classification using Coyote Grey Wolf optimization-based deep learning CNN classifier,” Biomed Signal Process Control 76. https://doi.org/10.1016/j.bspc.2022.103638
Madan P, Singh V, Singh DP, Diwakar M, Kishor A (2022) Denoising of ECG signals using weighted stationary wavelet total variation. Biomed Signal Process Control 73:103478. https://doi.org/10.1016/j.bspc.2021.103478
Hu Z, Wan K, Gao X, Zhai Y (2019) “A dynamic adjusting reward function method for deep reinforcement learning with adjustable parameters.” Math Probl Eng 2019. https://doi.org/10.1155/2019/7619483
Cordwell SAW “mdptoolbox.mdp — Python Markov Decision Process Toolbox 4.0-b4 documentation.” https://pymdptoolbox.readthedocs.io/en/latest/_modules/mdptoolbox/mdp.html#QLearning . Accessed 26 Jul 2023
Moody GB, Mark RG (2001) The impact of the MIT-BIH arrhythmia database. IEEE Eng Med Biol Mag 20(3):45–50. https://doi.org/10.1109/51.932724
Goldberger AL et al (2000) “PhysioBank, PhysioToolkit, and PhysioNet: components of a new research resource for complex physiologic signals.” Circulation https://doi.org/10.1161/01.cir.101.23.e215
Acknowledgements
This research is sponsored by the Office of Research and Sponsored Programs at Abu Dhabi University.
Author information
Authors and Affiliations
Contributions
Heba Ismail: Conceptualization, Methodology, Software, Validation, Formal analysis, Investigation, Resources, Data curation, Writing – original draft, Writing – review & editing, Visualization, Supervision, Project administration, Funding acquisition.
M. Adel Serhani: Conceptualization, Methodology, Formal analysis, Writing – review & editing.
Nada Hussein: Implementation, Writing – original draft, Writing – review & editing, Visualization.
Mourad Elhadef: Writing – review & editing.
Corresponding author
Additional information
Publisher's Note
Springer Nature remains neutral with regard to jurisdictional claims in published maps and institutional affiliations.
Rights and permissions
Open Access This article is licensed under a Creative Commons Attribution 4.0 International License, which permits use, sharing, adaptation, distribution and reproduction in any medium or format, as long as you give appropriate credit to the original author(s) and the source, provide a link to the Creative Commons licence, and indicate if changes were made. The images or other third party material in this article are included in the article's Creative Commons licence, unless indicated otherwise in a credit line to the material. If material is not included in the article's Creative Commons licence and your intended use is not permitted by statutory regulation or exceeds the permitted use, you will need to obtain permission directly from the copyright holder. To view a copy of this licence, visit http://creativecommons.org/licenses/by/4.0/.
About this article
Cite this article
Ismail, H., Serhani, M.A., Hussein, N.M. et al. RL-ECGNet: resource-aware multi-class detection of arrhythmia through reinforcement learning. Appl Intell 53, 30927–30939 (2023). https://doi.org/10.1007/s10489-023-05147-6
Accepted:
Published:
Issue Date:
DOI: https://doi.org/10.1007/s10489-023-05147-6