1 Introduction

Continual learning is a rapidly evolving field of machine-learning research, with several applications such as lifelong learning for robotic vision [1] and multi-task model learning [2, 3].

Continual learning methods aim to mitigate catastrophic forgetting [4] which occurs when training data drawn from different distributions are incrementally presented to the model , essentially violating the i.i.d. assumption [5]. Its immediate effect is the decrease of model effectiveness on previously learned tasks, as the model learns with the newly presented training data. The same issue has also been known as the “stability-plasticity dilemma” [6] for machine learning models [7], which refers to the trade-off between adapting a model’s parameters to new information and maintaining its effectiveness on previously seen samples.

Continual learning remains a challenging problem for deep neural networks and has attracted significant research interest recently [7], with methods addressing several variations of learning scenarios, depending on the information available to the learning algorithm. Some of the methods have been proposed to address these issues under the assumption that (i) the model learns a sequence of well-defined, disjoint tasks and (ii) the task boundaries are known [8]. This is usually simulated by splitting a dataset into smaller, disjoint subsets (e.g., corresponding to different classes, or different appearances of the same classes) [3]. Other methods aim to address the more challenging online continual learning problem [9], where training data arrive as a stream, without well-defined tasks or with unknown task boundaries.

Online continual learning introduces additional challenges. For a specific task, the model does not have access to all the task data before training and cannot plan the training process. In the case of a continuum of data, the types of tasks and the task boundaries are unknown. Furthermore, when operating in constrained resource setups (e.g. in edge computing and critical application environments) there is a need to keep the computational complexity low, which in turn imposes significant memory and computational constraints [10].

In this work, as a part of recent advances such as [11], we propose a set of training strategies for rehearsal-based online continual learning, aiming to improve computational efficiency without compromising model effectiveness. These strategies involve decisions on when to train, as stream training data become available, as well as on how to train, in terms of the number of required training iterations and learning rate.

This work expands significantly our initial work in the field [11] by testing all baselines and method combinations with additional datasets such as the permutations of the MNIST digits dataset [12]. In this work, also the CIFAR10 dataset [13] is used with different buffer sizes as an extension of the experiments. Additionally, extended hyperparameter and heuristic rules testing are included in the Ablation Study (Section 6). These additions provide a more transparent view of each component and its contribution to the final results.

Experiments on image classification tasks demonstrate the importance of selecting an appropriate training strategy in online continual learning scenarios. Results show that the proposed methods achieve higher classification accuracy with lower computational complexity compared to simpler baseline rehearsal strategies, such as fixed-iteration training.

The structure of the paper is as follows: Section 2 describes the online continual learning setting and the implications that may arise when samples arrive at a stream with unknown boundaries between tasks. It surveys the main rehearsal methods and their computational limitations and summarizes the contributions of our work. Section 3 describes the online continual learning setting used in this work and gives an algorithmic description of our suggested framework. Section 4 details the heuristics that help our method to determine how to train the model when rehearsal is needed. Section 5 illustrates the experimental setup and our results, whereas Section 6 performs an ablation study demonstrating the value of individual heuristics and strategies described in Section 4. Finally, Section 7 summarizes the main findings and discusses the future steps of our work.

2 Related work

Continual learning approaches can be broadly categorized into regularization, rehearsal (or replay), and parameter isolation methods [7]. Regularization methods, use a custom loss function with a regularization term that helps avoid catastrophic forgetting when learning with new data. Replay methods store examples or generate synthetic ones in order to train the model using a mixture of old and new samples, or use them to constrain the optimization. Finally, parameter isolation methods dedicate different parts of the neural network to different tasks. Consequently, they try to support all tasks by either dynamically adjusting the architecture or by re-adjusting the per-task parameters [2].

Rehearsal methods might require memory for buffer storage unlike other methods, such as Elastic Weight Consolidation (EWC) [14] or Learning-Without-Forgetting (LwF). The work of [15] summarizes and tests all known Rehearsal-Free methods in that area. An issue with EWC is that recomputing the Fisher Information Matrix can become computationally demanding. Calculation of the Fisher Diagonal requires all the model weights and in an online training setup this can become prohibitive. Another important issue is that the model distillation proposed in the baselines of [15] requires additional methods for providing a decent result [7]. In most works such as [16] Rehearsal-Free methods are best used as additional components for increasing classification accuracy. These additions often result in more computations and memory. In the case of model distillation, an extra copy of the model will be stored and used for additional model inferences in each training step.

Rehearsal methods are promising candidates for real-life incremental learning scenarios, such as online continual learning, since they have relatively simple implementations [7, 17, 18] which can be adapted in terms of computational and memory demands. Their effectiveness, however, is largely affected by two distinguishing factors: i) the size of the buffer that stores training samples from previous tasks, and ii) the strategy of mixing old and new task samples during rehearsal.

One of the most popular replay-based methods is iCaRL by Rebuffi et al. [19]. iCaRL is a class incremental method for image classification without forgetting. It is based on the use of a set of image samples that is dynamically updated to include the samples nearest to the class mean in the learned representation space. This set is used both for rehearsal and classification via the Nearest Class Mean (NCM) method. This dynamic buffer is updated via a combination of herding selection of new exemplars and priority-list-based removal of less representative exemplars. iCaRL has been created for solving the class incremental learning problem, where the task identities are unknown, but is not specifically designed for online learning.

Experience Replay (ER) [20] is another approach that is based on the concept of mixing old and new task samples in order to mitigate catastrophic forgetting. ER, which has been successfully employed in reinforcement learning and supervised learning tasks [21], maintains a dynamic buffer that is updated at every time step. Samples are randomly picked by the buffer and mixed with the new samples that arrive in a stream to create a synthetic batch, which is then used for training in the next iteration. Experiments show that even a small number of samples from the buffer can make a significant difference to mitigate forgetting.

The Greedy Sampler and Dumb Learner GDUMB [17] method is the online equivalent of experience replay, meaning that it can be adapted on a more stream-based approach a simple rehearsal strategy and buffer update that focuses on the greedy sampling of new knowledge and is a proof of the simplicity of memory-based approaches. The methods proposed in this paper, follow a similar principle in terms of adapting to stream-based learning of new data.

Similarly, the Gradient Episodic Memory (GEM) method [22] regularizes the gradients employed for back-propagation with the use of buffer samples, which are called episodic memory. A common setting for testing the rehearsal approaches is by streaming samples in small batches, each one containing complete sets of samples for new tasks, with known tasks and clear task boundaries [7, 19]. This task-incremental setting is a simplification of the general continual online learning setting, in which the tasks and their boundaries are completely unknown.

Between the two ends, the relaxation approach proposed in [23] employs a Bayesian method to infer the task context. Similarly, an algorithm that uses the Shannon entropy as a measure to select representative samples from previously seen classes, without being affected by the fact that task boundaries are unknown is presented in [24]. However, the applicability of these solutions in resource-constrained applications may be limited by their computational complexity.

Another prominent approach for continual online learning is based on a combination of model distillation and iCaRL [9]. This method proposes the use of a custom distillation loss with an offline baseline model for retraining and the customization of the iCaRL method for online learning. This combination of proven methods shows promising results although its implementation is multi-faceted and can be resource-intensive. In its base, the Incremental Learning Online Scenario adopts the iCaRL principles in terms of exemplar use and NCM classification, and LwF-based distillation. However, it improves some intermediate steps and categorizes the training itself into various phases in a combination of online and offline training.

In all the above-mentioned methods, little emphasis is placed on the evaluation of their performance in terms of memory requirements and timeliness (i.e. the time to embody new knowledge once it arrives) [25]. A rather small body of literature is dealing with these issues. A representative case is Latent Replay [10], which emphasizes storing and using the output of intermediate layers for optimizing computational complexity and memory usage. As in the previously reported cases, this approach also assumes a priori known task boundaries.

Another case that investigates the computational cost of the training procedure is [26]. In this work, the authors propose the replacement of the Softmax activation function with a Balanced Softmax, which reduces the need for an additional fine-tuning step during training. Although this approach results in a bias towards new tasks, it tackles the important, real-world, problem of memory size limitations.

Recent works on rehearsal-based methods such as [27], suggest a prototype-based selection of samples to update the buffer. The buffer update is based on the cosine distance of the prototypes to achieve a balance between samples that are easier to classify correctly before training and samples that are further from a model’s knowledge. This method is best used in the scenario where a model is updated frequently to incorporate new knowledge, which is not always optimal due to possible hardware restrictions. Also, on some occasions, the data stream might contain samples that can be classified correctly without training, therefore some training cycles might be unnecessary.

Another recent work is [28] where repeated data augmentation is used, for rehearsal training to maintain variety and bias-resistance during training. This method makes use of data augmentation which adds extra memory requirements for the new samples. Also, the constant requirement of repeating this process adds significant memory and computational overhead, thus requiring more resources.

The work of [16] is a stronger alternative to experience replay tested on task-incremental and domain-incremental experimental conditions. This alternative maintains two models for model distillation. This work is mostly focused on incremental learning in the same manner as the majority of the continual learning literature, excluding the stream-based, task-agnostic scenarios that often occur in real events.

Fig. 1
figure 1

Online continual learning scenario. Each batch \(B_1 \ldots B_t, B_{t+1}, \ldots \) belongs to a specific classification task \(T_1, T_2, \ldots \) and \(H_1 \ldots H_t, H_{t+1} \ldots \) is the sequence of produced models after each step

All the aforementioned methods do not consider a model’s ability to classify without errors at any given time step. This means that a model is unnecessarily over-trained thus adding to the time complexity, which can be proven critical in a resource-constrained setup. Also, in practice, the changes between tasks and the distribution of data within the stream may be unknown. Most of the aforementioned methods assume that we are able to have a beforehand knowledge of when a task will be altered. This task agnosticism can be an additional cause of the degradation of the model performance, due to possible i.i.d corruption.

This work is targeting the general online continual learning setting, where input samples are provided by a stream and correspond to tasks with unknown boundaries, while timeliness and memory efficiency are of equal importance to complexity, plasticity, scalability and accuracy. Those characteristics allow our work to move ahead of previous research and deliver the following important contributions:

  • The introduction of a Drift Decision Mechanism for determining when training is needed. This mechanism offers a significant advantage since it avoids unnecessary training and reduces the overall training time. Another advantage is that this mechanism is suitable for detecting any transitions between classification tasks (i.e. task boundaries), and therefore dealing with the issue of task agnosticism.

  • The proposition of an Adaptive Rehearsal Tuning mechanism as a solution for customizing the rehearsal correction to the current model status, in a setup on which the model status is ever-changing in unforeseen ways. Our approach solves the problem of fine-tuning the rehearsal training at very precise moments in time, in an online data stream, where the distribution of data in terms of classes and tasks is unknown. Another important feature is the additional dynamic heuristics which allow for higher-level adaptation in different datasets and resources.

3 Online continual learning

3.1 Scenario

The working scenario considers a constant stream S of annotated data that is used for model training in an online fashion (Fig. 1). The stream of data is not i.i.d. (independent and identically distributed), but it is sampled in each period from different tasks. The task boundaries and the task identities are not known in advance. This adds difficulty compared to a scenario where batch training takes place with all data available from each task. We focus on classification problems and therefore each task is a sequence of annotated samples from a set of classes.

For simplicity, we assume that the data stream comes in batches of constant size (e.g., 32) comprising samples of any task. This setup is general enough to accommodate various scenarios, such as a stream composed of large and few tasks, or a stream with a fast succession of multiple small tasks. In all cases the stream is seen as a sequence of batches, each one comprising a set of samples \(B_t=(X_t, Y_t)\), where \(t \ge 1\) represents the batch at time step t. The task corresponds to a subset of batches that are observed sequentially. When the task changes, batches of samples that belong to new classes appear in the stream.

Our objective in this scenario is to continuously train a model, where the model snapshot \(h_{t+1}\) results from training \(h_{t}\) with the corresponding batch \(B_t\) of time step t. As the model learns using data from new tasks through this procedure, the final goal is to mitigate forgetting of old tasks, while at the same avoiding unnecessary training iterations (i.e., rehearsals), to keep computational complexity minimal.

3.2 An online rehearsal framework

In this section, we describe a general rehearsal framework, which is shown in Fig. 2 and is more formally defined in Algorithm 1.

Assume we are at time step t. Let \(B_t = (X_t, Y_t)\) be the batch of newly acquired samples from the stream and \(H_t\) be the model from the previous time step. We denote as \(S_t\) a possible collection of local state variables that the algorithm needs to maintain during its execution (e.g. counters, etc.) in order to make informed decisions.

Fig. 2
figure 2

Proposed strategy outline. If there is no change between tasks (Left), the drift detection mechanism stays idle and the new batch \(B_t\) is added to \(P_t, R_t\). When training is needed (Right) the next batch \(B_{t+1}\) is used together with both the new rehearsal buffer \(R_{t+1}\) and the new postponed buffer \(P_{t+1}\). All the aforementioned data are mixed for rehearsal training

During training, we maintain two buffers, a postponed buffer \(P_t\) that keeps samples that have arrived since the last rehearsal, and a rehearsal buffer \(R_t\) that stores a mix of samples used for rehearsal. The postponed buffer is initially empty and the rehearsal buffer initially contains a user-selected fixed number of q samples per class where \(q = \mid R_t \mid /n\) with n the number of classes. The samples of \(R_t\) are initially selected uniformly at random, while the size remains constant.

At each time step, the algorithm first applies a change detection mechanism to decide whether to train with the new batch of samples or not. If not, the new batch is simply appended to the postponed buffer \(P_t\) and the algorithm returns. Assume now that a change is detected and the algorithm decides to train the model at the current time step. The postponed buffer may be empty if training took place in the previous time step, or it may contain multiple batches in case the last training took place in previous time steps. The algorithm then employs a rehearsal strategy (see Section 4 for the supported options). Part of the strategy is to create training batches by combining the newly collected samples \(P_t \cup B_t\) with samples from the rehearsal buffer \(R_t\).

After each training step, the rehearsal buffer \(R_t\) is updated by deciding which samples to replace with samples from \(P_t\) or \(B_t\), while the whole postponed buffer is cleared. In the proposed framework the update of the rehearsal buffer is performed using the method proposed by He et al. [9]. This method is a simplified variant of the prioritized example selection algorithm [19] that is based on herding ,also known as Herding Selection. The main difference between the two, which is in line with our performance requirements, is that the latter maintains a running average estimation for each class, instead of recalculating the sample average at each time step.

Herding Selection updates a buffer based on a model’s representation average to ensure a more consistent set of representatives from each class. An extra feature is the ability to limit the number of the aforementioned representatives, thus controlling the buffer size whilst maintaining class balance during training cycles.

Based on the above, at each time t the rehearsal buffer contains a set of q samples \((E_1^y, \ldots , E_q^y)\), \(y \in [1, \ldots , n]\), where n is the number of classes seen so far. The set is updated at each time step to always contain \(q = |R|/n\) samples per class. When classes are encountered for the first time, the number of samples per class is dynamically updated to keep the total memory requirements constant.

Algorithm 1
figure f

Model and buffer update mechanism (time step t).

Algorithm 2 describes a generic training and buffer update mechanism that is applied when training takes place. The number of training iterations (method iter) and the learning rate (method lrs) are being chosen using the methods described in Section 4. After redefining its hyper-parameters the model is ready for the required stochastic gradient descent steps (method SGD).

Algorithm 2
figure g

AdaptiveRehearsalTraining.

To make this abstract framework applicable in practice, we need methods that allow us to determine (a) if additional training is currently required, (b) how samples from the postponed and rehearsal buffers should be mixed for training, (c) how many iterations to use for training the model at the current time step, (d) what learning rate to use. We address these questions, in the section that follows.

4 Rehearsal strategies

This section presents the alternative ways to decide when (line 5, Alg. 1) and how (line 13, Alg. 1) to perform the rehearsal during training. Algorithm 2 gives a high-level overview of the online training process. Note that all suggested methods and heuristics can be tuned and switched in various independent setups offering an adequate palette of different online learning setups, thus creating a more flexible forgetting mitigation plan.

4.1 Continuous rehearsal and experience replay

The baseline strategy on which we make our comparative study is a variant of Experience Replay [18]. This method is based on the online application of experience replay similar to the GDUMB [17]. The only difference is the usage of herding selection instead of greedy sampling (as proposed by GDUMB). The herding selection buffer ensures class balance and is commonly used in the most well-established continual learning methods. Our baseline does consider recent works and additions to make an even more challenging testing setup for our methods.

This method assumes that training takes place in each time step. At each time step t, training happens for a fixed number of iterations (epochs) using the latest batch \(P_t \cup B_t = B_t\) (for this strategy the postponed buffer \(P_t\) is always empty) and batches from the rehearsal buffer \(R_t\). In each iteration, the latest batch \(B_t\) is combined with a different batch \(r_j \in R_t\) in order to produce two new batches containing 50% samples from \(B_t\) and 50% samples from \(r_j\) each. In order to make use of the entire rehearsal buffer, the position of the batch used in each time step is kept in a global variable and the rehearsal batches (i.e. \(R_t\)) are employed in a round-robin fashion. In the simplest case, only one iteration is used per batch [18]. In what follows we use ER-n to denote this baseline method with n iterations (ER-1 for a single iteration).

4.2 Drift activated rehearsal

In the general online continual learning scenario defined in Section 3.1, the task boundaries are unknown. A solution for deciding when to train (instead of applying a constant training schedule) is to use a concept drift detector [29]. From the concept Drift Detectors that are available in the literature, we choose the ECDD detector [30] which uses exponentially weighted moving average charts (EWMA) as an indicator of divergence between samples. It is a single pass method with \(\mathcal {O}(1)\) update in each time step, which makes it suitable for performance-critical, streaming applications. The ECDD detector is a simple and popular solution [31] for drift detection, so in the following, we describe the internal mechanism of this particular Drift Detector and leave as future work the task of studying different and more sophisticated drift detection mechanisms.

Samples in the stream are sequentially presented to the classifier, and at each time step we examine whether the predicted class label was correct (i.e. \(X_t = 0\)), or incorrect (i.e. \(X_t = 1\)). The ECDD detector perceives \(\{ X_t \}\) as a sequence of observations from a Bernoulli distribution. Detecting concept drift becomes the problem of detecting an increase in the Bernoulli parameter p corresponding to the probability of misclassification. The ECDD detector maintains an estimate \(Z_t = (1-\lambda ) Z_{t-1} + \lambda X_t\) of the current mean \(\mu _t\) along with a second estimator \(\hat{p}_{0,t} = \frac{1}{t}\sum _{i=1}^{t} X_i\), which changes more slowly and better estimates the probability before the change (event). Given estimates of \(\hat{\sigma }_{X_t}\) and \(\hat{\sigma }_{Z_t}\) ECDD first computes a control limit, \(L_t\) based on the methodology presented in [30]. If \(Z_t > \hat{p}_{0,t} + L \hat{\sigma }_{Z_t}\), then ECDD raises a concept drift flag.

Algorithm 3
figure h

detectChange.

During time step t all samples of batch \(B_t\) are given to the Drift Detector which updates the estimators. Since new samples arrive in a batch \(B_t\), \(\{ X_t \}\) can be viewed as a sequence of per-batch observations that contains information about how many samples from \(B_t\) when classified incorrectly. The following heuristic rules are used to decide whether to train: (i) \(Z_t > \hat{p}_{t,0} + L_t \hat{\sigma }_{Z_t}\), which is the original rule of ECDD, (ii) \(U_t > \hat{p}_{t,0} + 2 \hat{\sigma }_{U_t}\) i.e., the current batch error \(U_t\) must not be too high (two standard deviations above the estimated mean), (iii) \(Z_t > \epsilon \), the running average of the error must not exceed a user-defined limit, and (iv) no training during the last \(\mu \) time steps, where \(\epsilon \) and \(\mu \) are user-defined thresholds. Reasonable choices are \(\epsilon = 0.2\) for the error threshold, and \(\mu = 20\) for the no-training time steps threshold. Assuming a constant number of iterations for training, the main difference with the ER-n method is that, when it is triggered, the postponed buffer \(P_t\) will likely contain multiple batches. In each training iteration, we scan over all batches of \(P_t\). For each batch \(p_i \in P_t\) we read a batch \(r_j \in R_t\) and create two batches which contain 50% from each. The buffer \(R_t\) is used in a round-robin fashion using the updated position pointer after each time step. We use DRIFTA-n to name the subset of methods which performs a constant number of n iterations when the Drift Detector is triggered.

4.3 Drift detection with buffer samples (double drift detector strategy)

Depending on the size of the rehearsal buffer and the average length of the tasks, a possible adaptation to the previous strategy is to include a second Drift Detector. This additional detector monitors the classification failure rate from samples of the rehearsal buffer. Thus, the detection happens on an artificially created stream of samples, which are drawn uniformly at random from the rehearsal buffer, in each time step. Training takes place when any of the two detectors are triggered, based on the previously mentioned heuristics. The only difference is that two versions of each estimate are kept, for example, \(Z_t\) for the new samples and \(\hat{Z}_t\) for the rehearsal samples. The number of random elements which are used for drift detection can change as a hyper-parameter. This mode is more sensitive in terms of corruption detection at the cost of a slightly more computationally expensive implementation. We denote this strategy which uses a constant number of n iterations for training as 2DRIFTA-n.

4.4 Setting the number of iterations dynamically

Instead of a fixed number of training iterations in each step, we can dynamically adapt the number of iterations based on the misclassification rate \(Z_t\). We use the following heuristic rule where n is a hyperparameter.

$$\begin{aligned} n_t = \lceil 2 * n * \log _{2}(1 + Z_t) \rceil \end{aligned}$$
(1)

The heuristic rule behind dynamic iterations is a simple binary logarithm equation that shifts between zero and \(2*n\) iterations based on the misclassification rate \(Z_t\). We used the binary logarithm to achieve smoother adaptations when \(Z_t\) rate is close to zero or one. We denote this strategy that associates the estimator \(Z_t\) values with the number of iterations as DRIFTA-DYN-n. Similarly, 2DRIFTA-DYN-n do the same getting the output of both Drift Detectors and using \(max(Z_t,\hat{Z}_t)\) for the final decision.

4.5 Rehearsal based on convergence

A different approach is to determine the number of necessary rehearsal iterations based on the model convergence. In this approach, we monitor the model’s loss \(\mathcal {L}\) by keeping two exponential moving averages, one short where:

$$\begin{aligned} \mathcal {EMA}_{short} = (1-\alpha _{short}) \mathcal {EMA}_{short} + \alpha _{short} \mathcal {L} \end{aligned}$$
(2)

and one long, where:

$$\begin{aligned} \mathcal {EMA}_{long} = (1-\alpha _{l}) \mathcal {EMA}_{long} + \alpha _{long}\mathcal {L}, \end{aligned}$$
(3)

with \(\alpha _{short} =0.5\) and \(\alpha _{long} = 0.05\) respectively. We stop training when the two values converge, i.e. \(\Vert (\mathcal {EMA}_{long} -\mathcal {EMA}_{short})\Vert < \epsilon \) for some hyper-parameter value \(\epsilon \). Note that this approach can be used for any rehearsal strategy. We denote this strategy as ER-CONV, DRIFTA-CONV and 2DRIFTA-CONV for a constant, drift-activated and two-Drift Detector setup respectively.

4.6 Adjusting the learning rate

A last action that can affect the efficiency of the rehearsal strategy is a learning rate schedule across iterations based on a model’s state. This is more important in the drift-activated methods, where multiple batches are collected before initiating the training. The simplest approach is to keep the learning rate \(\eta \) constant, however as in regular training with i.i.d. data, this is often suboptimal. Another approach is to initialize the learning rate with a user-defined value at the beginning of each time step and use a decay mechanism that modifies the learning rate through training iterations. Finally, the drift-activated methods can also use \(Z_t\) and a predefined \(\eta \) value to dynamically adjust the initial learning rate. A combination of the aforementioned approaches can help to further improve the performance. \(Z_t\) is used as an exponent to a variation of the exponential function. The reason behind this is a fast adaptation of the learning rate if \(Z_t\) is high. For the initialization of the learning rate (LR) we use a simple rule that sets it to \(LR_{new} = LR_0 * \min (100, 5 * e^{3Z_t})\) where \(LR_0\) is a pre-defined learning rate. For the decay of LR a simple used-defined constant is enough as shown in our experiments.

5 Experiments

In this section, we define a robust experimental benchmark, for testing both our propositions alongside the standard rehearsal in an online setup based on Experience Replay. The purpose of the experiments is to provide insight into how ER-based methods can be improved and test the limitations and benefits of each algorithmic setup. The methods will be tested with an improved version of the Experience Replay algorithm based on GDUMB [17], which can be viewed as an online application of Experience Replay.

5.1 Experimental setup

5.1.1 Datasets

We use the CIFAR-10 image classification dataset [13] and the MNIST digits dataset [8, 32] to evaluate the online continual learning strategies of Section 4. CIFAR-10 consists of 60,000 images (50,000 training, 10,000 testing), sampled uniformly from ten classes \(\left\{ 0, \cdots , 9\right\} \) and MNIST consists of 70,000 images (60,000 training, 10,000 testing), also sampled uniformly from digits \(\left\{ 0, \cdots , 9\right\} \). Following the online continual learning scenario described in Section 3.2, we split the training data into five tasks, each one containing images from two classes, i.e.,

$$\begin{aligned} T_i = [(\textbf{x}, y), y \in [2i, 2i+1], i = 0, 1, 2, 3, 4] \end{aligned}$$
(4)

An online annotated image stream, S, is generated by first sampling multiple images from the first task, then the second task and so on

$$\begin{aligned} S = (T^{(0)}_{0}, T^{(0)}_{1}, T^{(0)}_{2}, T^{(0)}_{3}, T^{(0)}_{4}, T^{(1)}_{0}, T^{(1)}_{1}, \cdots ) \end{aligned}$$
(5)

Each set \(T^{(j)}_{i}\) has a fixed size of 3,200 images, grouped into batches of 32 images, for a total of 100 batches per task. Images in each batch are sampled uniformly at random from the classes of the current task, \(T_i\). For each experiment, each task appears three times (i.e., \(j = 0, 1, 2\)). All experiments use the same random seed for sampling and for model parameter initialization, to ensure that experiments are comparable.

5.1.2 Model and pre-training

The base model in our experiments is the adaptation of ResNet32 for the CIFAR dataset, as described in [33]. The model is first trained offline for 100 epochs with a subset of the CIFAR dataset consisting of 1500 images for each of the 10 classes. As a result, the model used in the experiments has gone through a "warm-up" training stage but has not been fully trained. It achieves accuracies of [0.805, 0.602, 0.732, 0.844, 0.854] for tasks \(T_0\)-\(T_4\) respectively. This model is used at the start of online training in all experiments. In most occasions, in the continual learning literature, a method keeps either one or two model weight copies (two, in the case of model distillation techniques). In this setup we only keep and update one copy of the model weights.

5.1.3 Metrics

Each training strategy is evaluated using the following metrics, computed at each time step t. We denote as \(H_{t}(x)\) the output of the model at time step t when the input is sample x. The ground truth is denoted as \(\hat{y}\)

  • Accuracy (\(A_t\)): Accuracy of the model evaluated in the held-out test set D, averaged across all tasks, this metric is measured on each time step For time step t the accuracy is:

    $$ A_t = \frac{1}{\Vert D\Vert }\sum _{\forall x \in D}^{} {\left\{ \begin{array}{ll} 0,&{} H_{t}(x)\ne \hat{y}\\ 1,&{} H_{t}(x) = \hat{y}\\ \end{array}\right. } $$
  • Current task accuracy (\(C_t\)): Accuracy of the model evaluated in the held-out test set, but only for the images belonging to classes of the current task, \(T_t\). For time step t the Current accuracy for \(T_t\) is:

    $$ C_t = \frac{1}{\Vert D_{T_t}\Vert }\sum _{\forall x \in D_{T_t}}^{} {\left\{ \begin{array}{ll} 0,&{} H_{t}(x)\ne \hat{y}\\ 1,&{} H_{t}(x) = \hat{y}\\ \end{array}\right. } $$
  • Online accuracy (\(O_t\)): Accuracy in each batch \(B_t\) of stream S at time step t, evaluated right before it is used for training. This approach was also used in [9]. For time step t the Online accuracy at time step t is:

    $$ O_t = \frac{1}{\Vert B_t\Vert }\sum _{\forall x \in B_{t}}^{} {\left\{ \begin{array}{ll} 0,&{} H_{t}(x)\ne \hat{y}\\ 1,&{} H_{t}(x) = \hat{y}\\ \end{array}\right. } $$
  • Training iterations (\(N_t\)): Cumulative number of iterations performed during stochastic gradient descent optimization for every time step t. Less training iterations show a more efficient approach, although classification accuracy needs to be taken into account as well. Given that the model and the batch size are the same across all experiments, this metric can be used to compare the computational complexity of different training strategies. The tables below will include this metric as a ratio \(N_{t}/\max (N_{t})\) where \(\max (N_{t})\) is the training iterations count of the most computationally intensive method. Unlike the rest of the aforementioned metrics, we aim to reduce the amount of training within all experiments.

Each metric evaluates a different aspect of the model’s effectiveness. \(C_t\) and \(O_t\) evaluate the effectiveness of the model in the current task and data, respectively, while \(A_t\) evaluates the overall effectiveness of the model (including previous tasks). In addition to metrics computed at each t, we also report averages across t, such as \(\bar{A}_t\).

5.2 Experiment 1: motivation

This experiment demonstrates the benefits of an adaptive training strategy in an online environment, as discussed in Section 4. We compare three training strategies, which include (i) online training without any further action, (ii) experience replay with 1 training iteration [20] (ER-1), and (iii) experience replay with 50 training iterations (ER-50). In all cases, the learning rate was fixed to \(\eta = 0.01\). Results are visually illustrated in Fig. 3 and show the limitations of training with continuous rehearsal methods in an online setting.

Fig. 3
figure 3

Results from the CIFAR-10 test on accuracy (\(A_t\)) and current task accuracy (\(C_t\)) with (a) 50 samples per class in the buffer, and (b) 500 samples per class. Continuous rehearsal with 1 iteration (ER-1) performs similarly to having no rehearsal at all. ER-50 iterations perform better for 500 samples per class but require 50 times higher computational cost. It also does not perform well for 50 samples per class. None of these methods seems to be satisfactory in a continuous online learning setting

It is clear that a static strategy of dealing with online learning is either subpar in terms of accuracy when iterations are a few and in the case of ( ER-50) the accuracy is achieved by adding a very high cost in terms of computation, therefore rendering this implementation insufficient

These results show that ER-1 is similar to not performing any training at all. Performing several iterations per time step ( ER-50 in this case) leads to a model that performs well only for the current task, at the expense of the overall accuracy (due to forgetting previous tasks).

It is apparent that between task changes, we often see sharp drops in the Current Accuracy metric, as shown in Fig. 3. These fluctuations are caused by overfitting the model with a specific task. The more per-step iterations a method uses the harder is to recover between task training, the ER-50 method is the more prone to this trend due to the fact that it trains the model at every step and every training step iterates constantly 50 times, causing overfitting as a result.

All the experiments conducted, actually introduce tasks that are occasionally radically different, in order to provide a better insight of the limitations per method.

5.3 Experiment 2: comparison of training strategies

The goal of the second experiment is to assess the improvement of the proposed training strategies in terms of both effectiveness and efficiency. Figures 4 and 5 illustrate the results for the MNIST and CIFAR-10 datasets respectively. One can observe that DRIFTA-DYN-50 manages to maintain high accuracy in the current task, \(C_t\), while requiring significantly lower computation compared to ER-50.

Fig. 4
figure 4

Experiment on the MNIST dataset measuring: (a) average task accuracy \(A_t\), (b) current task accuracy \(C_t\) on the range of \(t = 1000,\cdots ,1500\), and (c) number of batches \(N_t\) for \(t = 0,\cdots ,1500\) of the stream for a rehearsal buffer with 50 samples per class. DRIFTA-DYN-50 is as effective as ER-50 with a fraction of the computational requirements

Fig. 5
figure 5

Experiment on the CIFAR-10 dataset measuring: (a) average task accuracy \(A_t\), (b) current task accuracy \(C_t\) on the range of \(t = 1000,\cdots ,1500\), and (c) number of batches \(N_t\) for \(t = 0,\cdots ,1500\) of the stream for a rehearsal buffer with 500 samples per class. 2DRIFTA-CONV is consistently more effective and efficient in this experimental setup

Tables 1 and 2 show the average metrics and the 95% confidence intervals of various ER-n methods, for \(n=1,10,25,50\), the proposed methods as well as the computational requirements of each method (in terms of the number of iterations). All methods are compared in terms of the number of training iterations, \(N_t\), as well as in terms of the average accuracy metrics \(\bar{A}_t\), \(\bar{C}_t\) and \(\bar{O}_t\).

Results demonstrate the benefit of using the proposed drift-activated, dynamic and convergence-based rehearsal strategies especially 2DRIFTA-DYN-n and 2DRIFTA-CONV. In the MNIST dataset, DRIFTA-DYN-n and 2DRIFTA-DYN-n achieve approximately the same average accuracy metrics as the best experience replay methods, for a significantly smaller number of iterations (i.e., computational cost). Similar observations can be made for the CIFAR-10 dataset, and especially for convergence-based methods.

Table 1 Comparison of the different online training strategies for the MNIST digit dataset by using average values of the metrics are shown across the entire stream (bold indicates the best and underline indicates the second best result for each metric)
Table 2 Comparison of the different online training strategies in terms of average values of the metrics across the entire stream (bold indicates the best and underline indicates the second best result for each metric)

5.4 Evaluation of the permuted MNIST digits problem

In previous experiments, the experiments involved model training with a sequence of tasks that included disjoint classes. In this section, we evaluate our model on a different scenario, where each class may have different appearances which are presented to our model sequentially. In practice, we use the Permuted MNIST dataset [12]. Permuted MNIST expands over the MNIST by creating tasks via fixed pixel permutations. The first task is classifying all original digit images, and each new task is a pixel-wise permutation of the first task. Note that every permutation is constant for all images of a task. The model has to classify correctly each digit in the \(\left\{ 0, \cdots , 9\right\} \) classes without any known correspondence between images and tasks.

The ResNet32 model of previous experiments is also used in this case. The model is first trained offline for 50 epochs only with the original MNIST dataset, using only 500 samples per MNIST class. This model is then used for online training in all experiments. Note that the model is originally effective only for non-permuted digits. The buffer contains samples from the 10 classes and 50 images per class. Each task is created with 100 batches of 32 elements (i.e., not all permuted MNIST digits are presented to the model during a task).

The stream is built as:

$$\begin{aligned} S = (T_{0}, T_{1}, T_{2}, T_{3}, T_{4}, T_{5}, T_{6}, T_{7}, T_{8}, T_{9}, T_{0} ) \end{aligned}$$
(6)

with \(T_{0}\) corresponding to the classification of the original dataset and each \(T_{i}\), \(i = 1, \cdots , 9\), corresponding to a permutation (9 permutations of the original dataset in total).

Fig. 6
figure 6

(a) Average accuracy \(A_t\), and (b) current task accuracy \(C_t\) for the MNIST digits dataset and nine permutations, for ER-1 , ER-50 , DRIFTA-DYN-50 and 2DRIFTA-CONV rehearsal strategies

Figure 6 illustrates the results in terms of effectiveness, comparing the ER-1 , ER-50 , DRIFTA-DYN-50 and 2DRIFTA-CONV rehearsal strategies. Furthermore, Table 3 shows average accuracy values, as well as the number of iterations required by each method. Results indicate that for a stream constructed using the permuted MNIST scenario, the proposed methodology improves effectiveness in terms of \(A_t\), \(C_t\) as well as \(O_t\), while at the same time significantly reducing the required number of training iterations.

Table 3 Comparison of the different online training strategies in terms of average values of the metrics across the entire stream from the MNIST digits and their permutations
Table 4 Comparison of dynamic iteration strategies in terms of average values of the metrics across the entire stream

6 Ablation study

The proposed methodology addresses various aspects of online continual learning with rehearsal, including the use of a Drift Detector to determine when to train, as well as strategies for determining the number of training iterations and the learning rate. In this section, we progressively evaluate the added value of these different components of the proposed training framework. Experiments use the CIFAR-10 dataset and the ResNet32 pre-trained model, as described in Section 5.1.1.

6.1 Dynamic choice of iterations

In this experiment, we only use the option of dynamic iterations parameter, as described in Section 4.4, to assess the value of adaptive vs constant number of training iterations per batch. In all experiments, we use the same learning rate, and there is no learning rate decay. Results are shown in Table 4.

The dynamic choice of iterations reduces the required number of iterations and consequently decreases computational complexity. Based on the results, the convergence-based methods are the most suitable for this purpose.

6.2 Dynamic learning rate schedule

We further evaluate the added value of dynamically selecting the learning schedule based on the output of the Drift Detector (as described in Section 4), compared to a constant learning rate. The experimental results in Fig. 7 show the benefit of using a dynamic learning rate schedule.

Fig. 7
figure 7

Current task accuracy \(C_t\) for (a) constant learning rate and (b) dynamic learning rate schedule in convergence rehearsal strategies

Table 5 shows the average performance of both the static and the dynamic learning rate schedules.

Table 5 Comparison of dynamic versus constant iteration strategies in terms of average values of the metrics across the entire stream

Experiments show the value of adapting the learning rate using the strategies described in Section 4.6. They also demonstrate that convergence-based methods are more sensitive to changes in the learning rate. Furthermore, it seems that the 2DRIFTA-CONV and 2DRIFTA-DYN-50 methods require fewer iterations when the learning rate is tuned based on drift detector statistics. Note that there is both a drop in \(N_t\) and an increase in overall accuracy for both DRIFTA-CONV and 2DRIFTA-CONV methods.

7 Conclusions and future work

This study presented a framework for utilizing Rehearsal-based methods into stream-based scenarios in order to facilitate online continual learning. Our primary objective is to tackle the constraints of limited computational and memory resources. The key considerations revolve around optimizing training cycles and adapt the training procedure at each step.

To address these challenges, we propose the incorporation of a drift detection mechanism, which initiates model training in response to changes in data distribution. Additionally, we propose several strategies for determining the appropriate number of training iterations and creating an dynamic learning rate schedule, based on model misclassification statistics.

Our proposed approach, according to our experimental settings, demonstrates comparable and even increased effectiveness compared to the conventional rehearsal strategies while utilizing fewer training iterations and balancing between accuracy and overfitting. Baseline methods, such as GDUMB, are tested with various iterations per training step to fix the issues that arise with the original Experience Replay methods. Baseline methods with fewer iterations (such as ER-1, ER-10 (Original Experience Replay) do not overfit, but at the cost of decreased classification accuracy. More computationally intensive baseline methods such as ER-50 (50 iterations per batch) are more accurate per task, but between tasks, they tend to show decreased average accuracy and "sharp changes" in between tasks (ER-50 shows clear signs of overfitting). These observations validate that our proposed methods keep the balance between accuracy per task and overfitting. Also, our methods use only the least necessary training times to achieve these results. The final result shows a significant advantage in terms of computational complexity. Consequently, this approach emerges as a viable solution for continual learning in online applications dealing with stream-based data.

A reasonable future insight is the application of such methods in real-life scenarios where Edge devices are utilized, such as automotive solutions, search & rescue operations and the health industry. In all the aforementioned domains, smart devices have access to an infinite amount of data whilst having a limited amount of memory and processing power. Our methods could provide a decent balance of computational efficiency and classification accuracy.

Our work does come with some limitations. We have yet to outperform some of the methods which work with known task boundaries and some of the batch processing methods. The decrease of computations and memory usage do play a significant part in the classification accuracy. Another issue is the fact that some Edge devices cannot provide even a small amount of memory or support training computations.

Some practical hurdles might also occur in terms of training with questionable ground truth, such as driver input in the case of recommendation systems in driving. Finally safe-critical applications such as medical diagnosis do not allow the same margins of misclassification as of our methods, since medical accuracy is more important than hardware usage optimizations.