1 Introduction

Arrhythmia is a cardiac disorder that results in irregular heart rhythm due to abnormal electrical activity of the heart. Globally, the expected prevalence of arrhythmia reaches up to 5% of the general population [1]. According to Centers for Disease Control and Prevention (CDC) [2] in 2030, an estimate of 12.1 million people in the United States are expected to develop atrial fibrillation, which is the most common type of arrhythmia. Hence, there is an urgent need for early detection of arrhythmia as a first step of prevention.

Electrocardiography (ECG) signals are biomedical signals that reflect the electric activity of the heart [3]. ECG signals have been widely used for arrhythmia detection and classification. Many research work focuses on the binary classification of ECG signals to predict whether a signal is arrhythmic or not [4,5,6]. Although such classification provides an essential first step towards diagnosis, it remains limited. Hence, more recent research work addresses Multi-Class Arrhythmia Detection (MCAD). These classification methods consider varying numbers of arrhythmia classes, for instance, 5 classes [7,8,9,10], 7 classes [11], 9 classes [12, 13], 17 classes [14], or even up to 29 classes [15].

Arrhythmia multi-class prediction is a challenging research problem due to the data imbalance. The scarcity of rare arrhythmia samples hampers minority class classification performance and exacerbates resource-intensive hyperparameter optimization to address the data imbalance. While several training and network structure hyperparameters require fine-tuning during optimization, the majority of the recent studies only optimize a limited set of training parameters while manually optimizing the network architecture [7, 8, 11, 13,14,15] with a few studies considering optimizing resource consumption [10, 11]. However, resource efficiency, encompassing training, inference time, and memory usage optimization, is crucial for maintaining high performance. This is especially vital for real-time applications and resource-constrained devices like those deployed in Internet of Medical Things (IoMT) systems. To this end, we aim to propose a more comprehensive approach for resources-aware optimized multi-class arrhythmia detection considering minority classes. The contributions of the proposed RL-ECGNet framework can be summarized as follows:

  • Deployment of Reinforcement Learning (RL) for optimizing both training and network architecture hyperparameters.

  • Resource consumption optimization considering training time and memory usage while maintaining high classification performance.

  • Evaluation and comparison of four Deep Learning (DL) models to achieve the highest classification performance, considering the arrhythmia minority classes.

This paper is structured as follows: Section 2 thoroughly discusses the literature related to deep learning models for arrhythmia prediction from ECG signals as well as automated fine-tuning. Section 3 discusses the proposed method for automated reinforcement learning-based prediction of arrhythmia on an imbalanced ECG dataset. Finally, Section 4 discusses the experimental setup and testing of the proposed framework and analyzes the obtained results.

2 Related work

Deep Learning (DL) models have been increasingly used for Arrhythmia detection [16]. Despite the efficiency of DL models in feature extraction and classification problems, optimized multi-class arrhythmia ECG-detection remains an on-going research problem.

Arrhythmia classification datasets are highly imbalanced, leading to classification difficulty and low classification performance of the minority classes. In addition, model optimization is a critical process for performance improvement. This process becomes more challenging due to data scarcity. Some approaches optimize the feature selection to improve the classification performance. For instance, Monarch Butterfly [4, 17], Manta Ray [17, 18], Emperor Penguin [17], Bat-Rider [19], as well as Coyote Grey Wolf [20] optimization methods have been deployed with ML models [17, 18], or DL models such as ANN [17], and CNN [4, 19, 20]. Although feature selection methods improve classification performance, this approach does not fully address the data imbalance issue as well as the network and training optimization, which are vital for performance optimization.

While some research drops the minority classes during evaluation [11], other proposed methods focus on improving the classification performance of the minority classes. For instance, Luo et al. [12] proposed a Hybrid Convolutional Recurrent Neural Network (HCRNet) for automatic classification of 9 arrhythmia classes including the imbalanced data samples. The time-series extracted heartbeat signals are fed into the hybrid CNN and RNN layers, namely, GRU and LSTM layers. The hybrid model consists of 4 blocks, a convolutional block for feature extraction, a separable convolutional block with a GRU layer for faster feature extraction and knowledge retention, a separable convolutional block with an LSTM layer, and a final convolutional block. The outputs of these blocks are concatenated and fed into fully connected layers that produce the final classifications. The model’s architecture was manually optimized, while the training optimization method was not stated.

To address the training optimization, Adam, Stochastic Gradient Descent (SGD), as well as Stochastic Gradient Descent with Momentum (SGDM) optimizers are commonly used in the reviewed literature. For example, Kanani et al. [7] proposed a modified ResNet for automatic classification of 5 arrhythmia classes. The proposed method uses an augmented dataset to improve the classification performance. This method deploys Adam optimizer for the training of the model; however, the model’s architecture is optimized manually. Although this method achieved a high classification accuracy, it is limited to inputs of certain amplitudes and frequencies.

To provide more wholesome results, Yang et al. [13] proposed a 12-lead ECG classifier considering the multi-directional heart activity. The proposed method deployed a Cascaded Convolutional Neural Network (CCNN) arrhythmia multi-label classifier. Features are extracted separately from the single leads using a 1-D CNN and are then cascaded based on the spatio-temporal correlation between the leads. A ResNet is then used to classify the 2-D feature maps data into 9 arrhythmia classes producing the first classification probability. This classification probability is combined with the classification probabilities obtained from training a random forest on expert features. The expert features include morphological, HRV, statistical, frequency, and AF features extracted after denoising the ECG signals. Finally, this approach deploys Adam optimizer for automatic fine tuning of the hyperparameters. Despite the representativeness of the multi-lead data and comprehensiveness of this approach, the reported classification accuracy was 86.5%, which can be improved further.

Although multiple lead-based classifications models provide more comprehensive and representative results, such high computational resource consumption is impractical for resource-limited systems such as those intended for use within Internet of Medical Things (IoMT) systems [11]. Therefore, Sepahvand et al. [11] proposed a teacher-student model for arrhythmia classification that uses single-lead ECG data with minimal performance loss. The teacher model is a CNN with an advanced architecture trained on multi-lead data, while the student model is a simpler and a lightweight CNN trained on single-lead data under the supervision of the teacher network. By deploying this teacher-student knowledge distillation (KD) approach, the simple and more compressed student model receives the knowledge of the multi-lead signals through imitating the teacher model during training. The internal feature map and output response knowledge are passed to the student model, balancing out the data representativeness and resource consumption. Due to the imbalance of the arrhythmia ECG data, the 4 minority classes (out of 11 classes) were crossed out during evaluation. Moreover, the Adam optimizer was deployed for fine-tuning the weights of the models, but the network architecture was manually optimized.

Pal et al. [15] proposed CardioNet, a Transfer learning-based model for automatic arrhythmia classification addressing data imbalance. The proposed method deploys a pre-trained DenseNet to optimize the feature extraction and classification results. This transfer learning approach speeds up the training process while fine-tuning its learning on the ECG arrhythmia data through using the DenseNet weights. Moreover, using transfer learning improves the classification performance of minority classes since it does not require balanced or identically distributed data. The model uses an augmented dataset of 2D images of ECG signals, and the model classifies these data samples into 29 arrhythmia classes. This method deployed SGD for optimizing the model’s training.

Other proposed approaches deploy Focal Loss to address the sample scarcity. Focal Loss drives the model’s attention towards the minority samples that are difficult to classify through weight reduction of the majority samples that are easily classified [8]. For instance, Li et al. [8] proposed an improved deep residual convolutional neural network ResNet with the focal loss function. The proposed method applies overlapping segmentation to the MIT-BIH data samples to produce 5 s segments and re-labels them to increase the number of data samples of the minority classes. The improved ResNet is then used to classify the ECG segments into 5 classes. The use of identity mapping within the ResNet does not add extra parameters, hence, contributes to reducing the computational complexity. Furthermore, the proposed method deploys Focal Loss by introducing the Cross-Entropy (CE) to the loss function. The network’s architecture was manually optimized, and training was optimized using SGD. Although this method generalizes better, the performance metrics of the minority classes are still evidently lower than the common classes.

To improve the classification performance of imbalanced datasets with the minority classes, Lu et al. [14] proposed a Depthwise Separable with Focal Loss Convolutional Neural Network (DSC-FL-CNN). The proposed network automatically classifies the data samples into 17 arrhythmia classes. The converted 2D ECG images are fed into the network, whose architecture was manually optimized, and training was optimized using SGD. The proposed method demonstrated slightly improved performance compared to CNNs, FL-CNNs, and DSC-CNNs.

Mohonta et al. [9] proposed deploying CNNs on spectrum images for automatic arrhythmia classification. Four spectral analysis techniques, namely, FFT, STFT, CWT-Grayscale, CWT-RGB, have been fed as the CNN inputs. The CNN’s architecture is manually optimized, while the training was optimized using SGDM. Although the CWT-RGB spectral analysis with the proposed CNN yielded high performance, such analysis adds overhead and computational complexity to the system.

While the aforementioned methods do not prioritize optimizing resource consumption and computational complexity, approaches designed for operation within IoMT, or embedded systems require such optimization. For instance, the teacher-student architecture proposed in [11] optimized the resource consumption. To achieve this, computational complexity, model size- indicated by the quantity of the parameters-, as well as the memory usage of the student network are analyzed against the teacher model. Although the student model’s performance is less than that of the teacher’s model, it consumes less resources and is 262.18 times more compressed. Moreover, despite the reduction in resource consumption, the model’s resource consumption is not optimal.

Falaschetti et al. [10] proposed an RNN LSTM-based architecture for automatic arrhythmia classification on embedded systems. Due to the limited resources available in embedded systems, this method focuses on optimizing memory, inference time, number of parameters, as well as accuracy. Throughout the evaluation of this approach, different RNN variants were tested. The RNN LSTM-based network’s performance was the best while optimizing the resources. However, the network’s architecture was still optimized manually. Moreover, although the results of this method outperform other architectures in terms of memory storage, number of parameters, and inference time, the reported accuracy needs further improvement.

To this end, in this research we aim to design and implement a comprehensive, resource-aware optimization framework for both network hyperparameters and architecture using Reinforcement Learning.

3 Materials and methods

To address the data scarcity and extreme class imbalance, we propose RL-ECGNet, a Reinforcement Learning (RL) -based architecture for classifying arrhythmia from raw ECG signals. Figure 1 depicts the proposed RL-ECGNet framework. The framework consists of different modules: ECG signal acquisition and preprocessing, as well as Reinforcement Learning module. The latter consists of three main components: (a) the Q-Table, (b) the Agent, and (c) the Environment. Detailed explanation of the working of each component is explained in subsequent sections.

Fig. 1
figure 1

RL-ECGNet: Reinforcement learning-based arrhythmia classification framework

3.1 Module 1: signal acquisition and processing

The raw ECG signals are fed to the signal acquisition and processing module for preprocessing, cleaning, and feature extraction. To obtain the digital signal from the raw analog ECG signal, the signal is first sampled. Next, several successive processing tasks are applied to denoise the signal and attenuate the artifacts. This step is crucial due to the susceptibility of the ECG signal to various types of noises that highly impact the accuracy of the readings. These noises include electromyogram noise, motion artifacts, powerline interference, as well as baseline wandering [21]. This implementation focuses on attenuating powerline interference and baseline wandering since their distortions are the most impactful. The ECG signal wave features are then extracted during the signal processing. In order to diagnose the ECG signal, the QRS complex is detected, which represents key morphological features of the cardiac electrical activity. The heartbeat signal is then segmented, identifying the heartbeats in a series of samples. Finally, for classification purposes, temporal features of the ECG signal including the P-, R-, and T- waves as well as the pre- and post- RR intervals are extracted.

3.2 Module 2: reinforcement learning

The processed ECG signals are then passed to the Reinforcement Learning module. The proposed framework deploys Markov Decision Processes (MDP) to fine-tune the hyperparameters of the multi DL predictive models. Markov Decision Processes are an off-policy value-based RL algorithm, where the agent updates the set of hyperparameter values based on the policy \((\pi )\). The targeted hyperparameters include the network architecture hyperparameters: activation function (af), hidden units (nhu), dropout (d), and weight initialization (nw), as well as the training hyperparameters: batch size (bsz), momentum (m), learning rate (lr), and epochs (nep). The Q table consists of an action list that the agent deploys to configure and set the parameters of the deep networks. The action list includes increasing or decreasing the target hyperparameter value based on predefined sets as illustrated in Table 1.

Table 1 Training and network architecture hyperparameter sets

The agent consists of the MDP process, which iteratively updates the policy based on the feedback obtained from the environment. The MDP process also keeps a record of the states, including the preceding, current, and the succeeding states. The actions (at) are deployed to the environment, which includes the DL models that are trained to produce predictions. The proposed framework deploys multiple different predictive deep networks. The predictive models include Convolutional Neural Network (CNN), Gated Recurrent Unit (GRU), Long Short-Term Memory (LSTM), and Multilayer Perceptron (MLP).

Throughout the training and evaluation processes of the DL models, the environment returns observations, namely, the reward (Rt) associated with the action to the agent for feedback and policy update. The reward function in this architecture is a multi-objective fusion of MSE loss, accuracy, as well as time and memory. These networks are optimized, evaluated and compared to get the best performance with minimum resource consumption. This process is iterative until the models reach convergence. An optimum model achieves the best fusion of low loss, high accuracy, reduced time and memory consumption.

3.2.1 Training and network architecture hyperparameter sets

The proposed RL-ECGNet framework optimizes the network at two levels; training and network architecture, where the agent configures different combinations of the hyperparameter sets. The MDP algorithm reduces the search space dimensionality of the hyperparameters through sequential processes by updating the value of a single hyperparameter at a time. Table 1 lists the hyperparameters and their value ranges. The value ranges include the well-known and validated values that are commonly used during the optimization of these hyperparameters.

3.2.2 Deep learning network optimization through MDP

The proposed framework deploys Markov Decision Processes algorithm for automated reward-based optimization. Multiple Deep Learning models are trained and optimized throughout the process. Let \({D}_{\upmu }\) denote the Deep Learning algorithm with a set of training and network architecture hyperparameters \(\upmu\). The Deep learning model \({D}_{\upmu }\) is trained with the set of hyperparameters configured by the agent to reach state \({S}_{t}={\mathrm{Decision}}_{\mathrm{t}-1}\left({\upmu }_{\mathrm{t}-1}\right),\) where \({\mathrm{Decision}}_{\mathrm{t}-1}\) is the decision of the agent at time (t − 1). The agent chooses an action from a list of predefined actions as shown in Table 1. The agent’s action At follows a policy π that aims to maximize the expected reward associated with the action that lead to the current state \(\uppi \left(\mathrm{a},\mathrm{ s}\right).\) The reward function R is measured and re-evaluated after the state transition under action At at time t. Since this framework deploys multi-objective MDP, a fusion of performance measures is compromised as the weighted sum of the performance metrics: loss, accuracy, times, and memory.’

  • The model’s loss is determined using the Mean Square Error (MSE) metric shown in Eq. (1):

    $$MSE=\frac{\sum_{t=1}^{n}{\left({y}_{pt}-{y}_{t}\right)}^{2}}{n},$$
    (1)

    where \({y}_{pt}\) is the predicted class at time t, \({y}_{t}\) is the actual value at time t, and n is the number of observations.

  • The accuracy is evaluated as shown in Eq. (2):

    $$Accuracy= \frac{true\;positives}{total\;observations}$$
    (2)
  • The time of the Deep Learning model training and testing is evaluated as shown in Eq. (3):

    $$T= \sum \nolimits_{t=0}^{n}{\mathrm{e}}_{t}$$
    (3)

    where \({e}_{t}\) is the execution time at time step \(t\).

  • The memory resources consumption is the memory occupied during execution.

The proposed optimization method describes a multi-objective problem with four objectives. The ranges of the four objectives differ; hence, each objective is normalized using the z-score normalization technique, described in Eq. (4). Positive and negative values indicate improvement and decline compared to the mean, respectively.

$$\begin{array}{cc}{o}_{i}{\prime}=\frac{{o}_{i}- \mu }{\upsigma }& {o}_{i} \epsilon \left\{MSE, Accuracy,Time, Memory\right\}\end{array}$$
(4)

where \(\mu\) and \(\upsigma\) are the mean and the standard deviation of each objective, respectively.

The reward Rt is the weighted and normalized value of the four objectives as shown in Eq. (5):

$${R}_{t}= {\sum }_{\mathrm{i}=1}^{n}{\mathrm{w}}_{i}* {o}_{i}{\prime}, \mathrm{s}.\mathrm{t}.{\sum }_{\mathrm{j}=1}^{n}{\mathrm{w}}_{\mathrm{j}}= 1$$
(5)

For reward update calculation, the reward difference is considered as stated in Eq. (6) [22]. To guide and balance the agent’s decision process, a discount factor \(\upgamma\) is introduced, where \(\mathrm{\gamma \epsilon }[\mathrm{0,1}]\). This factor assigns weights to the instant and long-term rewards. Values closer to 0 prioritize the instant reward, and ones closer to 1 prioritize long-term rewards.

$${R}_{t}{\prime}={\upgamma }^{t{\prime} - t }{R}_{t}$$
(6)

3.2.3 Training with reinforcement learning

We define the state-value function \(Q\) as the expected value at a certain state ( \(S\)) after transition by action ( \(a\)) under the policy\((\uppi\)) as shown in Eq. (7) [22]. It provides an insight of the evaluation of the new state of the model.

$${Q}_{\uppi }\left(s,a\right)= {E}_{\uppi }\left[{R}_{t} | {s}_{\mathrm{t}}=s, {a}_{\mathrm{t}}=a\right]$$
(7)

The Q value is iteratively updated as described in Eq. (8) [22, 23]:

$$\mathrm{Q}\left(s,a\right)\leftarrow {Q}_{\uppi }\left(s,a\right) + \alpha [r+\upgamma {max}_{a{\prime} }\mathrm{Q}\left(s{\prime},a{\prime}\right)-\mathrm{Q}\left(s,a\right)]$$
(8)

The agent selects the policy that maximizes the cumulative reward as follows:

\({\uppi }^{*}\left(s\right)\in {argmax}_{a}{Q}^{*}\left(s,a\right)\), where \({Q}^{*}\left(s, a\right)\) is the optimum state-action value.

In this proposed framework, the optimal policy by deploying Q-learning as the off-policy, value-based method obeying the Bellman Eq. (9):

$${Q}^{*}\left(s,a\right)= {E}_{\uppi }\left[\sum \nolimits_{t=0}^{\infty }\mathrm{R}+\mathrm{ \gamma max}\left({Q}^{*}\left({s}^{\mathrm{^{\prime}}},{a}^{\mathrm{^{\prime}}}\right)\right)| {a}^{\mathrm{^{\prime}}}=\uppi \left({s}^{\mathrm{^{\prime}}}\right)\right]$$
(9)

For action priorities, in MDP Q-Learning, a transition probability matrix is defined, which determines the probability of transitioning from one state to another when the agent selects an action. This matrix affects the decision-making of the agent for the initial and the subsequent action selection process. In our proposed implementation, all actions are assigned equal transition probabilities. This uniform environment encourages the agent to explore all possible actions without bias, which ensures the robustness of the RL model.

Algorithm 1 describes the procedure carried out to optimize the deep learning networks using Reinforcement Learning (i.e., Procedure OptimizeModel()). First, the agent selects an action from the predefined action list based on a specific policy. The agent then takes the action and trains the DL model on the updated hyperparameters. The model is evaluated to obtain the performance metrics including the MSE loss and the accuracy. The training time and memory usage are also monitored. The reward is then calculated as the weighted sum of the aforementioned metrics. The state-value function \(Q\) is updated where the optimum value maximizes the reward. The state is then updated, and the agent selects the next action. The updated hyperparameters list of the closest optimum state is saved. This process iterates until it reaches an experimentally tested threshold (δ) that gives the desired performance. We ran multiple experiments with several iterations until a threshold (δ) was found to give the best desired performance. This iterative process updates the hyperparameters at each step that maximizes the reward and returns the optimum set at the end. This process is carried out to optimize each of the DL models sequentially and independently to reduce the resources consumption and computational complexity. The optimized DL models are then evaluated and compared, and the results are tabulated and analyzed.

Algorithm 1
figure e

Multi-Objective Reinforcement Learning for Arrhythmia Classification

Throughout the optimization process, the hyperparameters as well as the four objectives: loss, accuracy, time, and memory, are logged for each state. The four objectives are normalized, with positive and negative values indicating improvement or decline compared to the mean. The reward is a weighted sum of all the normalized objectives. During the process, the agent aims to select actions that increase the reward. Actions that move toward the optimum hyperparameter set are rewarded, while actions that maximize the distance between the current and the optimal hyperparameter sets, in other words, actions in the wrong direction, are penalized.

Figure 2 illustrates a sample hyperparameter log of the LSTM model during 13 consecutive states of optimization. Figure 2 demonstrates the hyperparameter configurations, their respective normalized objectives, and the reward. The agent updates one of the hyperparameters at a time in each state (highlighted in gray) following the actions’ definition presented earlier in Table 1. The reward is calculated and logged; rewarded (highlighted in green) and penalized (highlighted in red) actions have positive and negative values, respectively.

Fig. 2
figure 2

LSTM sample hyperparameter and reward log. Configurations of new, rewarded, and penalized actions are highlighted in gray, green, and red, respectively

These reward values are logged in the Q-table in correspondence with the action taken in each state as shown in Fig. 3. For instance, as demonstrated in Fig. 3 (a), the algorithm runs an initial round in which it randomly selects a different action in every state covering all the predefined actions. After this round, the most rewarding action is selected for the succeeding states, as can be seen in Fig. 3 (b). Hence after, in every succeeding state the most rewarding action is selected and the Q table is updated accordingly. For instance, increasing the number of epochs (i.e., Action #13) is the most rewarding action in the initial round; hence, it is reselected in succeeding states, as demonstrated in Fig. 3 (b), until a penalty on the reward is encountered. Then, another action is selected following the probability matrix. The action is reselected as long as its reward continues to increase.

Fig. 3
figure 3

LSTM Sample Q-Table: initial and succeeding action selection rounds

4 Experiments and results

4.1 Dataset

The proposed RL-ECGNet framework uses the MIT-BIH arrhythmia database [24], which includes 48 two-channel ambulatory ECG recordings from 47 subjects at the BIH Arrhythmia Laboratory between 1975 and 1979. Each instance in this dataset is a 30-min recording of the ECG signal sampled at a rate of 360 samples per second per channel. The data have been annotated by two or more specialized cardiologists, and the disagreements were resolved. Finally, computer-readable reference annotations were obtained for each beat [25] as WFDB format. The database includes nine classes of Arrhythmia:

  1. 1.

    R: Right bundle branch block beats

  2. 2.

    L: Left bundle branch block beats

  3. 3.

    V: Premature ventricular contractions

  4. 4.

    A: Atrial premature beats

  5. 5.

    F: Fusions of ventricular

  6. 6.

    N: Normal beats

  7. 7.

    Q: Paced beats

  8. 8.

    U (/): Unknown/unclassifiable beats

  9. 9.

    S: Supraventricular premature beats

Figure 4 shows the distribution of the Arrhythmia classes in the dataset. As can be seen in Fig. 4, the dataset is extremely imbalanced. This imbalance makes it very challenging to train the predictive models on the minority classes, which represent rare arrhythmia occurrences. Additionally, extensive fine-tuning is required to reach the optimal configuration of predictive models when working with this type of data.

Fig. 4
figure 4

Arrhythmia class distribution

To alleviate the challenges related to manual fine-tuning and samples scarcity, four predictive models (i.e., MLP, GRU, LSTM, and CNN) are trained on all nine arrhythmia classes, and RL is used to automatically optimize the configuration of these models to reach the optimal configuration while minimizing memory and time consumption.

4.2 Environment setup

The predictive models were trained and tested on Google Colab with a Colab Pro + subscription with 51 GB System RAM, a 226 GB Disk space, and High-RAM runtime.

4.3 Baseline models

To evaluate the performance of the proposed ECG-RL Net model, several manually fine-tuned baseline models were used. These models have the same architecture as the models used with the proposed RL model. Table 2 summarizes the hyperparameter configuration of the baseline models.

Table 2 Baseline hyperparameter configuration

4.4 Deep reinforcement learning with multi-models’ prediction results

In this section, the configuration of MLP, GRU, LSTM, and CNN was optimized using the proposed RL algorithm. A stratified hold-out evaluation approach was followed, with 20% of the dataset used for evaluation and 80% used for training. Each split of the dataset contains samples’ distribution proportional to the class distribution in the full dataset. Table 3 shows the initial and final hyperparameter sets for the MLP, CNN, LSTM, and GRU models, respectively.

Table 3 Initial vs. optimized hyperparameter set per model

Equation (10) is used to report the performance improvement of the proposed RL optimized models against the baseline models.

$$\begin{array}{cc}Performance\;Improvement=\frac{{\overline{o} }_{i}}{{o}_{i}},& {\overline{o} }_{i}, {o}_{i} \epsilon \end{array}\left\{MSE, Accuracy,Time, Memory\right\},$$
(10)

where \({\overline{o} }_{i}\) is the performance metric of the RL optimized model, and \({o}_{i}\) is the performance metric of the manually optimized baseline model.

Table 4 presents the overall prediction results for the four DL models with the proposed RL model against the manually optimized baseline models. The experimental results are reported in terms of accuracy, MSE loss, training time, and memory usage. The proposed RL optimization method is evaluated against manually optimized baseline DL models, as well as state-of-the-art methods with different number of arrhythmia classes [8, 10], as well as methods with similar number of classes [13].

Table 4 Deep reinforcement learning arrhythmia classification overall results

As can be seen in Table 4, the proposed RL model outperformed manual optimization for all DL models with minimized resource consumption except for the LSTM. Among all DL models, LSTM with RL optimization achieved the highest prediction accuracy of 96.41%. GRU and CNN also achieved relatively high classification accuracies of 91.70%, and 95.82% with RL compared to manual optimization. The deep learning architectures, namely, LSTM, CNN, and GRU, are well-suited for processing sequential data like raw ECG data for arrhythmia classification. Moreover, these DL models capture the temporal aspects and learn the hierarchical features of the ECG data better; hence, they outperformed the simpler MLP architecture.

The proposed RL optimization method achieved a performance improvement (considering accuracy) of 1.28, 1.32, 1.39, and 1.33 times for MLP, CNN, LSTM, and GRU, respectively. In addition, the proposed RL method achieved a loss reduction of 2.45, 5.46, 7.61, and 3.39 times for MLP, CNN, LSTM, and GRU, respectively. Moreover, the proposed RL method achieved a reduction of training time of 1.925, 2.727, and 1.36 times for MLP, CNN, and GRU, respectively. Finally, the proposed RL method achieved a reduction of memory usage of 1.179, 1.815, and 1.359 times for MLP, CNN, and GRU, respectively. Hence, the proposed RL optimization method has proven efficient resource usage through reducing time, and memory of training. The only exception is the LSTM model, where the resource consumption of manual optimization is less than the proposed method. This is compensated by the noticeably higher performance improvement.

Finally, for evaluation against state-of-the-art research in ECG arrhythmia classification, the number of classes vary. For evaluation against methods with varying number of arrhythmia classes[8, 10], the proposed method outperformed all models, achieving a higher accuracy. For evaluation against a method with the same number of classes, the proposed method is compared to a 12-lead ECG 2D CCNN [13] that classifies arrhythmias into 9 classes, similar to our proposed method. Compared to [13], LSTM optimized with the proposed RL method outperformed the 2D CCNN considering accuracy. Although the performance of the proposed method is quite similar to the 2D CCNN, the computational complexity of the 2D CCNN is deemed to be high due to the high dimensionality of the data used, the intricate network architecture, and manual optimization of the network.

5 Discussion

In this paper, RL was deployed to optimize Deep Learning models for automatic multi-class arrhythmia prediction considering class imbalance. The proposed RL-ECGNet framework deploys an MDP process to optimize the deep networks on two levels: training and network architecture. The RL model uses a multi-objective signal that consists of accuracy, loss, time, and memory for optimization of performance and resource consumption. The MDP agent aims to iteratively select rewarded actions that maximize the multi-objective function.

The proposed RL-ECGNet framework was evaluated on four DL models, namely, MLP, CNN, LSTM, and GRU, which are commonly used in the recent research work, and proven to be efficient. Hence, the focus of the experiments carried out throughout this study was to automatically optimize the DL models while reducing the computational complexity, hence, the resource consumption.

The proposed RL optimization method achieved accuracies that range from 88.45% to 96.41% for the four DL models. In addition, the proposed RL method yielded performance improvements that range from 1.28 to 1.39 times for accuracy. Moreover, the proposed method achieved resource consumption reduction that range from 1.36 to 1.925, and from 1.179 to 1.815 times, for training time and memory usage, respectively. The performance of the proposed RL-ECGNet framework is comparable to the state-of-art methods evaluated on 9 arrhythmia classes [13]. The proposed method achieved slightly higher accuracy with efficient resource usage. Moreover, compared to methods that were evaluated on different numbers of arrhythmia classes, the proposed method still outperformed some of the reviewed research on the classification of minority classes [8, 10].

6 Conclusion

Cardiac arrhythmia is a fatal clinical condition caused by abnormal rhythms of the heart. ECG signals have long been utilized for detection and classification of such abnormalities for early diagnosis of arrhythmia. Arrhythmia ECG datasets are highly imbalanced, including minority classes that represent rare arrhythmia conditions. Although existing automated multi-class arrhythmia detection frameworks achieve high detection and classification performance, the classification performance of the minority classes remains low.

To this end, in this paper, we propose RL-ECGNet framework, a Reinforcement Learning-based optimization approach for multi-class arrhythmia detection using ECG signals. The proposed framework processes raw ECG signals for morphological feature extraction. Then, RL is utilized to automatically optimize the training and network hyperparameters of the Deep Learning models. Unlike the recent research work that fine-tunes a limited hyperparameter set, the proposed framework automatically optimizes the training and network architecture hyperparameters using Markov Decision Process (MDP).

The RL-ECGNet framework was validated on 9 arrhythmia classes from the well-known MIT-BIH dataset. For evaluation, the processed ECG signals are fed to four DL models, namely, MLP, CNN, LSTM, and GRU, are trained and optimized. Not only the training and networks are optimized, but the resource consumption, specifically, time and memory usage are minimized. The performance of the optimized DL models was tabulated and compared, concluding that with the described environment setup, the LSTM model outperformed the rest of the models, yielding an average accuracy of 96.41%. In addition, the proposed RL method yielded an average performance improvement of 1.33 times for accuracy. Moreover, the proposed method achieved an average resource consumption reduction of 1.671 and 1.451 times for training time and memory usage, respectively. In conclusion, optimizing DL models with RL has proven to perform well on multi-class arrhythmia classification with imbalanced datasets while efficiently reducing resource consumption.