Introduction

In recent years, numerous deep neural networks such as AlexNet [1], VGG [2], GoogleNet [3], ResNet [4], DenseNet [5] and Transformer [6] have achieved state-of-the-art performances in the field of computer vision. While deeper neural networks often yield higher accuracies, they also have more parameters and greater storage requirements, resulting in increased computational costs. To mitigate this challenge, Refs. [7,8,9] use network pruning to remove redundant parameters, Refs. [10,11,12,13,14] utilize knowledge distillation to compress model sizes, Refs. [15,16,17] apply weight quantization to reduce parameter bit precision. However, these methods typically rely on static inference, where all the samples must traverse the entire network to obtain prediction results. In contrast, Adaptive inference [18,19,20,21,22,23,24] dynamically adjusts the computational resources based on the input samples, thereby significantly improving the inference efficiency. Combining adaptive inference with static inference can further improve the network efficiency.

The most common approach for implementing adaptive inference is to construct a dynamic multi-exit network, which adjusts its structure based on input sample difficulty [24]. A multi-exit network incorporates multiple intermediate classifiers (ICs) at varying depths within the backbone network, enabling a quick classification of simple samples by shallow classifiers and handling more difficult samples with deep classifiers. This adaptive strategy not only enhances the network’s inference performance but also boosts its overall efficiency during inference.

To improve the adaptive inference efficiency of multi-exit networks, the self-distillation technique is commonly employed to enhance the classification accuracy of shallow classifiers. Given that the final classifier typically possesses deeper layers, more parameters, and stronger feature fitting capabilities, methods such as IMTA [19] and BYOL [21] often treat it as the teacher, transferring valuable knowledge to the shallow classifiers. However, through experimentation, we have observed instances where certain samples are accurately identified by the intermediate classifiers but misclassified by the final classifier. For example, when testing a ResNet18-based multi-exit network on the CIFAR100 dataset, approximately 5% of the test samples are correctly classified by the first classifier but are misclassified by the final classifier. This experiment demonstrated that different classifiers can extract varied knowledge and features due to discrepancies in a network’s structure and parameters. Relying solely on the final classifier as a teacher to transfer knowledge risks overlooking the distinctive knowledge offered by the intermediate classifiers.

To comprehensively extract and transfer effective knowledge from all the classifiers of a multi-exit network, we propose a novel multi-level collaborative self-distillation learning strategy (MLCSD). Initially, MLCSD aggregates logit results and feature maps from all the classifiers into logit-based and feature-based pools. Subsequently, an attention mechanism computes weight coefficients matching each classifier, and the pools are multiplied by these coefficients to obtain logit-based and feature-based teachers corresponding to each classifier. In contrast to traditional self-distillation, which only utilizes the final classifier as the teacher, MLCSD explores distinctive knowledge from all the classifiers to construct more comprehensive and effective teachers. These new teachers then transfer the knowledge back to each classifier. This collaborative learning approach promotes knowledge transfer between classifiers and enhances the overall classification performance of the network.

By conducting experiments on various multi-exit networks and diverse datasets, we validate the effectiveness and generality of the MLCSD strategy. The results consistently show that MLCSD outperforms the traditional self-distillation strategy.

Our main contributions can be summarized as follows:

  1. 1.

    We propose the MLCSD, which constructs more comprehensive and effective teachers by extracting knowledge from all the classifiers. It can improve the inference efficiency of the multi-exit network without increasing computational costs.

  2. 2.

    We employ various backbone networks and intermediate classifiers to construct multi-exit networks, thereby validating the effectiveness and generality of the MLCSD strategy on these networks.

  3. 3.

    We conduct experiments on three datasets, and the experimental results show the effectiveness of the MLCSD in two typical adaptive inference applications:anytime prediction and budgeted batch classification.

Related work

Computationally efficient deep networks

Lightweight network model

In general, networks with more parameters and higher computational costs tend to outperform those with fewer parameters and lower computational costs. However, the excessive computations associated with deep networks pose deployment challenges in practical applications, particularly in time-sensitive and resource-limited scenarios. A direct and effective way to improve network computing efficiency is through the design of lightweight models. SqueezeNet [25] reduces the necessary number of channels and parameters by using compression and expansion layers. MobileNet [26] combines depthwise and pointwise convolution to form a depthwise separable convolution to replace general convolutions, reducing the computational costs. ShuffleNet [27] further reduces the computational costs by proposing pointwise group convolution and channel shuffling techniques. EspNet [28] combines pointwise convolution with spatial pyramid dilated convolution, reducing the number of parameters and computations while increasing the receptive field. GhostNet [29] first employs traditional convolution to generate feature maps with fewer channels, then further uses a depthwise convolution to reduce the computational costs, and finally integrates two groups of feature maps for inference. These methods mainly enhance standard convolutions in terms of the channel numbers and sparse connections between the convolution channels, aiming to reduce the network parameter sizes and improve the network inference speeds without sacrificing network performance.

Model compression and acceleration

Compressing the number of parameters or the scale of an existing network is also an effective way to improve network computational efficiency. Reference [7] leverages network pruning to remove parameters that contribute less to the network, thereby reducing the total number of parameters and accelerating the inference speed. Reference [10] utilizes knowledge distillation technology to promote student networks to simulate teacher networks, enabling the student networks to achieve superior generalization performances and enhance inference accuracies. Reference [15] quantifies network weights and uses low-precision bits to store the weights and activation outputs, which can markedly compress the network. Reference [28] decomposes the convolution kernel of the network in low rank and rearranges the order of parameters to reduce memory consumption. These methods compress the existing models based on the redundancies of the neural networks in different aspects, effectively reducing network parameters and computational costs to improve computational efficiencies.

Adaptive inference networks

Adaptive inference [30] offers an effective mechanism for dynamically balancing accuracies and computational costs. One intuitive approach involves cascading multiple models with different complexities [31, 32]. When the inference confidence of one model meets the preset threshold, the sample exits; otherwise, it proceeds to the subsequent, more complex network for inference. However, these cascade networks are independent and entail higher training times and storage costs. During testing, difficult samples undergo successive processing by several models, leading to increased computational costs.

A more efficient approach involves adding multiple output branches to one backbone network and dynamically adjusting the network structure through a width or depth adaptation, striking a balance between accuracy and computational costs. References [33, 34] achieve adaptive inference in the width direction by dynamically activating neurons with auxiliary branch structures. Moe [35] considers multiple branches of a network built in parallel as experts, and selectively activates them to complete a width adaptation. HydraNet [36] replaces the last-stage convolution blocks of a convolution network with multiple branches, and selectively executes these branches during testing to achieve an adaptive width inference. In addition to width adaptation, the MSDNET [18] adds multiple intermediate classifiers at different depths to DenseNet [5]. Based on the MSDNET, RANNET [20] incorporates the sample spatial resolution adaptation module to further improve network efficiency. These studies on adaptive inference primarily focus on designing more sophisticated network architectures to improve network inference efficiency.

Table 1 Preliminary experiments of the multi-exit network

Knowledge distillation for adaptive inference

Knowledge Distillation [37] serves as a crucial technique for training efficient network models, typically transferring knowledge from a well-trained large-scale teacher model to a small-scale student model. KD [10] regards the softening result of the teacher network prediction logits as the knowledge containing the understanding of the class distribution information among the data. FitNet [11] considers the intermediate hidden layer features of the network as knowledge and uses the Euclidean distance loss to minimize the distance between the teacher and student features. AT [38] extends the constraints of FitNet and uses the teacher network’s attention map as distillation knowledge, achieving better results than FitNet. FSP [39] defines the relationship between the network feature layers as distillation knowledge and takes the inner product between the two layers as the learning objective. In addition to learning the knowledge within the instances, CC [40], SP [41], and RKD [42] also take the correlation between instances as transferable knowledge, enhancing the student models by calculating the characteristic matrix between the instances. These methods improve the performance of the student models by defining and transferring different types of knowledge.

Multi-exit networks integrate these distillation techniques for self-distillation. BYOL [21] and IMTA [19] first add multiple intermediate classifiers to the network, and then use the deepest classifier to distill the shallow classifier to improve the network’s performance. Reference [43] trains a more efficient multi-exit network by encouraging shallow classifiers to simulate a deeper classifier. Reference [44] proposes a one-stage online distillation framework to enhance the target network using multi-exit network integration instead of a complex two-stage training program. However, these methods use only the deepest layer classifier in the multi-exit network as the teacher, neglecting the effective knowledge contained in the other classifiers. This paper addresses this limitation by aggregating effective knowledge from each classifier to construct collaborative teachers. This collaborative learning approach enhances the classification accuracy of each classifier, thereby improving the inference efficiency of multi-exit networks.

Method

Motivation

Traditional knowledge distillation has conventionally focused on transferring knowledge from deeper and larger networks to shallower and smaller networks. Researchers have assumed that only the deepest classifier in a multi-exit network can transfer knowledge to the shallow classifier [21]. However, our experiments reveal that the shallow classifiers in the multi-exit network also possess transferable features and knowledge. To quantify the differences in the feature expressions and the knowledge between the shallow and deep classifiers in a multi-exit network, we conduct comparative experiments on the CIFAR100 dataset.

Fig. 1
figure 1

The network architecture of MLCSD-Net. a We divide the backbone network into four feature extraction stages based on depth. b Add a feature reduction layer, maxpooling layer, and fully connected layer to each feature extraction stage to form a multi-exit network. c The weight module calculates the contribution weight coefficients of each classifier. d Logit-based teachers transfer logit-based knowledge to the logits of each classifier. e The feature-based teacher transfers feature-based knowledge to the last feature layer of each classifier

We use ResNet18 and ResNet34 as the backbone for the multi-exit network and add seven additional intermediate classifiers to the backbone network. In Table 1, “Ex1–Ex8” denotes eight exits, “SE” represents the original single-exit network, “1*1-conv” signifies a direct predicted intermediate classifier structure with a 1*1 convolution, and “Interpolate-conv” represents an intermediate classifier structure with multiple feature reduction layers. Table 1 shows the classification accuracy of each classifier in the multi-exit network, as well as the proportion of Ex1–Ex7 correctly classified but misclassified by Ex8 (the values in parentheses).

When employing the 1*1-conv module as the intermediate classifier, the classification accuracy of the final classifier, Ex8, is lower than that of the single-exit model. Conversely, utilizing the interpolate-conv module results in the classification accuracy of Ex8 exceeding that of the single-exit model. The results indicate that adding appropriate intermediate classifiers to the backbone network can help improve the predictive performance of the final classifier. Despite sharing a backbone network, each classifier extracts distinctive features due to the differences in depth, parameter quantities, parameter weights, and feature scales. Consequently, the multi-exit network exhibits two key characteristics: (1) The shallow classifiers possess unique features and knowledge that are absent in the deepest classifiers and can transfer knowledge to other classifiers. (2) An appropriate intermediate classifier structure can enhance the final classifier’s performance. Leveraging these insights, we propose a novel multi-Level collaborative self-distillation strategy (MLCSD), which extracts effective knowledge from all the classifiers to construct teachers, and it enhances each classifier’s classification accuracy. Additionally, we design an appropriate intermediate classifier structure for MLCSD-Net, facilitating the construction of an efficient multi-exit network and improving the overall inference efficiency.

Overall framework

This paper proposes a novel multi-level collaborative self-distillation network (MLCSD-Net) based on a multi-exit network, as depicted in Fig. 1. MLCSD-Net consists of two main components, a multi-exit network module and a multi-level collaborative self-distillation module. The multi-exit network module is the basic module, consisting of a backbone network and eight classifiers. The backbone network is organized into four feature extraction stages, denoted as Stages 1–4, each is equipped with two classifiers. Each classifier contains four feature extraction layers, one Maxpooling layer, and one FC layer. The output of each classifier yields logits prediction value,with cross-entropy loss \(L_{CLS}\). For more details, please refer to section “Multi-exit network architectures”.

The multi-level collaborative self-distillation module is the key component of the MLCSD-Net, designed to enhance classification performance. This module consists of a classifier weight encoding module (CWEM), a multi-level collaborative logit-based self-distillation module (MCLSM), and a multi-level collaborative feature-based self-distillation module (MCFSM). CWEM inputs the results from the Maxpooling layer of each classifier into the weight encoding module, generating attention weight coefficients that correspond to each classifier (as denoted by the weight module in Fig. 1). MCLSM initially aggregates the logits from all the classifiers to formulate a logit-based pool (depicted as the orange rectangular block in Fig. 1). The logit-based pool is then multiplied by the weight coefficients generated in the CWEM to construct logit-based teachers tailored to each classifier. Ultimately, the logit-based teachers transfer logit-based knowledge to the corresponding classifier after softmax softening. The loss function for the MCLSM is computed using the Kullback–Leibler (KL) divergence loss, denoted as \(L_{KD}\). Further details are available in section “Multi-level collaborative logit-based self-distillation”.

The MCFSM builds a feature-based knowledge pool by collecting all the classifiers’ last feature maps (the red feature block in Fig. 1). This feature-based knowledge pool multiplied by the weight coefficients generated in the CWEM to dynamically construct feature-based teachers. The feature-based teacher guides the feature alignment of the classifier through the L2 (MSE) loss. The loss function for the MCFSM is denoted as \(L_{HT}\). Additional insights are provided in section “Multi-level collaborative feature-based self-distillation”. The comprehensive loss function of MLCSD-Net encompasses \(L_{CLS}\), \(L_{KD}\), and \(L_{HT}\). Please refer to section “Loss function” for more details.

Multi-exit network architectures

Standard convolutional networks make predictions only at the final stage, necessitating computation over the entire network for outputting the results. In contrast, multi-exit networks have several classifiers at different depths. When the predictions from a shallow classifier surpass the predefined threshold, the prediction process can be terminated, avoiding the need for further computations in subsequent deeper networks, and conserving the computational resources. Suppose we add \(K-1\) classifiers at different depths of the backbone network to form a multi-exit network with a total of \(K\) classifiers. The predictive logits \(z_{i}\) for the output of the \(i_{th}\) classifier can be defined as \(z_{i}=E_{i}\left( x, \theta _{i}\right) \), where \(x\) is the input image, \(\theta _{i}\) represents all the parameters contained in the \(i_{th}\) classifier, and \(E_{i}\) represents a fully connected layer. The final predicted probability value can be calculated by softmax:

$$\begin{aligned} \textbf{p}_{\textbf{i}}={\text {softmax}}\left( z_{i}\right) \end{aligned}$$
(1)

The total classification loss for all branch classifiers can be expressed as:

$$\begin{aligned} L_{CLS}=\sum _{\textrm{i}=1}^{\textrm{k}} \lambda _{i} L_{C E}\left( p_{i}, y\right) \end{aligned}$$
(2)

where \(L_{CE}\) is the cross-entropy loss, \(y\) is the classification label of input sample \(x\), \(\lambda _{i}\) is the weight coefficient of the \(i_{th}\) classifier. Consistent with the setting in Ref. [18], we set all the weight coefficients of the classifiers to 1.

Convolutional networks learn fine-grained and local features in the shallow layers, while gradually discerning coarse-grained and global features in the deeper layers through stacked convolution modules. Integrating a fully connected prediction layer directly into the shallow layers of the backbone network poses challenges in achieving high classification accuracy due to insufficient depth and receptive field. Moreover, the spatial noise and erroneous predictions generated during inference may propagate to the deep classifiers through the backbone network, adversely affecting the subsequent classifiers and compromising the overall network performance.

To compensate for the coarse-grained feature that shallow classifiers lack, we designed a simple and low-computational intermediate classifier (SLIC). SLIC comprises feature reduction layers and a fully connected layer. The feature reduction layer employs an improved bottleneck structure consisting of a 3*3 convolution with a stride of 2 and a 1*1 convolution. The 3*3 convolution is designed to extract features and reduce the size of the feature map, while the 1*1 convolution adjusts the number of channels. Compared to BYOL [21], the improved SLIC not only extracts sufficient coarse-grained features but also reduces the number of parameters, achieving a balance between model complexity and performance.

The number of feature reduction layers required by the intermediate classifiers at different depths varies based on the current classifier’s feature dimension \(F_{i}\) and the final classifier’s feature dimension \(F_{k}\). Given \(F_{k}/F_{i}=2^{n}\), n represents the number of reduction layers that the current classifier will stack. For example, in the ResNet18 network with the CIFAR100 dataset, the feature map size is 32 in the first stage and 4 in the final stage. Thus, the first-stage intermediate classifier needs to stack three feature reduction layers. In the second stage, only two feature reduction layers need to be stacked. All the classifiers undergo four downsampling operations, depicted by the four color blocks in Fig. 1. This approach ensures uniform feature dimensions across all classifier outputs, facilitating the extraction of the coarse-grained features and subsequent distillation operations. Due to its lightweight nature, SLIC allows the multi-exit network to extract sufficiently coarse-grained features at a minimal computational cost, thereby enhancing classification performance.

Multi-level collaborative self-distillation module

The initial self-distillation network uses the deepest classifier to help the shallow classifiers improve their classification accuracy. Since the deepest classifier generally achieves the best classification accuracy, it is believed to contain the most transferable knowledge. We name this knowledge transfer mode “One for all”. However, while “One for all” can enhance the performance of shallow classifiers, its benefits to deeper classifiers are limited and can sometimes even degrade their performance. For example, as shown in Table 2, the accuracy of ResNet18 (One for all) decreases from 78.63% to 78.54% for Ex7 and from 78.70% to 78.53% for Ex8 after distillation. To address this issue, we propose a multi-level collaborative self-distillation (MLCSD) learning strategy that enables multiple classifiers to learn from each other. MLCSD extracts effective knowledge from all the classifiers to construct teacher models, which then transfer knowledge to each classifier to enhance their classification accuracy. We name this knowledge transfer method “All for all”. The MLCSD module comprises three components, classifier weight encoding module, multi-level collaborative logit-based self-distillation and multi-level collaborative feature-based self-distillation, which will be detailed below.

Classifier weight encoding module

Since each classifier possesses distinctive knowledge, it may contribute differently during knowledge distillation. Therefore, we design a classifier weight encoding module (CWEM) to dynamically generate the importance weight coefficients for each classifier. The module consists of a 1*1 convolution layer, a Batch Normalization layer, and a softmax activation layer. The max-pooling results of each classifier are fed into this module, which outputs a weight matrix \(W_{e c}^{i}\) dimension of \(K*B*1\), where i represents the \(i_{th}\) classifier.

Multi-level collaborative logit-based self-distillation

In the multi-level collaborative logit-based self-distillation module (MCLSM), the logits \(p_{1}\) ...\(p_{k}\) of all the classifiers are fed into the logit-based pool \(L^{pool}\) to form a knowledge base \(L^{pool}\in {\mathbb {R}}^{KBC}\), where \(K\) is the number of classifiers, \(B\) is the batch size, and \(C\) is the number of categories of datasets. The weight matrix \(W_{e c}^{i}\) is multiplied by the logit-based pool \(L^{pool}\) to obtain the logit-based teachers, with a dimension of \(KBC\). This process is described by Formula (3):

$$\begin{aligned} \textrm{pt}_{i}=W_{e c}^{i} \odot L^{\text{ pool } } \end{aligned}$$
(3)

where \( W_{ec}^{i}\) is a learning matrix, \(\odot \) represents the multiplication of corresponding elements, \({pt}_{i}\) represents the logit-based teacher corresponding to the \({i}_{th}\) classifier. We mine the effective knowledge of each classifier and dynamically combine this knowledge through the contribution matrix to construct new logit-based teachers.

After constructing the logit-based teachers, each classifier’s logits can simulate the teacher by using \(KL\) divergence to learn the teacher’s understanding of the abstract relationship between the sample categories. The distillation loss of all the classifiers can be expressed as:

$$\begin{aligned} L_{K D}=\sum _{i=1}^{K} \theta _{i} \tau ^{2} K L\left( p_{i}^{\tau }, p t_{i}^{\tau }\right) \end{aligned}$$
(4)

where \(p_{i}^{\tau }={\text {softmax}}\left( \frac{z_{i}}{\tau }\right) \), \({z}_{i}\) is the logits result of the \({i}_{th}\) classifier, \({\tau }\) is the temperature coefficient in the knowledge distillation, and \({\theta }_{i}\) is the weight factor of the \({i}_{th}\) classifier. According to Ref. [18], we set the weight factors of all the classifiers to 1.

Multi-level collaborative feature-based self-distillation

The results of FitNet [11] emphasize that the hidden feature layers of the network also contain valuable knowledge, and each classifier can also benefit from feature-based knowledge distillation. However, our experiments reveal that distilling features from the final classifier to the shallow classifiers results in a performance decline rather than an improvement (as shown in Table 14). For instance, the average accuracy using final classifier logits is 76.35% (Interpolate-conv+FLSM), whereas incorporating feature-based self-distillation (Interpolate-conv+FLSM+FFSM) decreases the average accuracy to 76.19%. This decrease occurs because the number of convolutional layers stacked by each classifier varies, leading to differences in feature scales and effective receptive fields. Forcing alignment between the shallow and final classifier feature layers may transfer hidden uncertainties from the final layer to the shallow layer, adversely affecting the performance of all the classifiers.

To solve the above problem, we design a multi-level collaborative feature-based self-distillation module (MCFSM) that enhances the representation ability of the intermediate feature layer. In the MCFSM, the last feature maps before the fully connected layer of all the classifiers are fed into the feature-based pool \({F}^{pool}\). \(F^{\text{ pool } } \in {\mathbb {R}}^{K B H W}\), where \(K\) is the number of classifiers, \(B\) is the batch size, and \(H*W\) is the resolution of the feature map. We take the contribution coefficient matrix \(W_{e c}^{i}\) in the MCLSM as the weight, and multiply it by the feature-based knowledge pool \({F}^{pool}\) to construct feature-based teachers. The feature-based teachers are represented as follows:

$$\begin{aligned} \textrm{ft}_{i}=W_{e c}^{i} \odot F^{\text{ pool } } \end{aligned}$$
(5)

where \(W_{e c}^{i}\) is the weight matrix calculated by MCLSM, \(\odot \) is the multiplication of corresponding elements, \({ft}_{i}\) represents the feature-based teacher corresponding to the \({i}_{th}\) classifier, and the dimension of the feature-based teacher is \(KHW\). We mine the effective features from each classifier and dynamically combine these features through the contribution matrix to generate new logit-based teachers. The adaptive feature-based teacher of each classifier can help the student learn more structural knowledge and improve each classifier’s accuracy.

Table 2 Classification accuracy of MLCSD-Net on the CIFAR100 dataset

We utilize the L2 loss to constrain the distance between student and teacher features. The feature-based distillation loss for all the classifiers is shown as follows:

$$\begin{aligned} L_{H T}=\sum _{i=1}^{K} \sigma _{i} L_{2}\left( f_{i}, f t_{i}\right) \end{aligned}$$
(6)

\(f_{i}\) is the feature maps of the \({i}_{th}\) classifier, \(\sigma _{i}\) is the weight factor of \({i}_{th}\) classifier. According to Ref. [18], we set the weight factors of all the classifiers to 1.

Loss function

The total training loss of the MLCSD-Net consists of the cross-entropy classification loss \({L}_{CLS}\) of the multi-exit module, the multi-level collaborative logit-based self-distillation loss \({L}_{KD}\) of the MCLSM, and the multi-level collaborative feature-based self-distillation loss \({L}_{HT}\) of the MCFSM, which can be expressed as:

$$\begin{aligned} L=\gamma L_{C L S}+\alpha L_{K D}+\beta L_{H T} \end{aligned}$$
(7)

where \(L\) is the total loss of training, \(\gamma \alpha \beta \) is the three hyper-parameters used to balance the accuracy of MLCSD-Net.

Experiments

We select six backbone networks for MLCSD-Net: ResNet18, ResNet34, ResNet50, ResNet101, MobileNetV2, and ShuffleNetV2. These six networks include four commonly used residual networks and two mainstream lightweight networks. To evaluate the effectiveness of MLCSD-Net, we conduct extensive experiments on three classic classification datasets: CIFAR10, CIFAR100, and Tiny-ImageNet. All the experimental results are averaged over three runs.

Datasets

CIFAR10 and CIFAR100 datasets

The CIFAR10 and CIFAR100 [45] datasets comprise 50,000 training and 10,000 test images with a fixed spatial resolution of 32\(\times \)32. These images are equally distributed over 10 and 100 classes, respectively. Following the settings in Ref. [18], we select 5000 images from the training set as the validation set for confidence threshold selection in adaptive inference. We refer to the standard data-augmentation techniques in [4], which include randomly cropping the images to 32 \(\times \) 32 pixels after adding 4 pixels to the image boundary, horizontally flipping, and then normalizing using the channel mean and standard deviation.

Tiny-ImageNet dataset

The Tiny-ImageNet dataset is derived from ImageNet [46] and contains 200 image categories. It has 500 training examples, 50 validation examples, and 50 test examples per class. All the images in the dataset are preprocessed and downsampled to 64 \(\times \) 64 pixels, while their original size is 256 \(\times \) 256. The data augmentation method is the same as that used for the CIFAR dataset.

Table 3 Classification accuracy of MLCSD-Net on the CIFAR10 dataset
Table 4 Classification accuracy of MLCSD-Net on the Tiny-ImageNet dataset

Implementation details

In this paper, all the experiments are designed in Pytorch 1.7.0 in the Python 3.7.0 environment and performed on one NVIDIA Tesla V100 GPU with 32 GB memory. The proposed MLCSD-Net and compared methods are trained end-to-end using the stochastic gradient descent (SGD) optimizer, with an initial learning rate of 0.05 and a mini-batch of 64. The model is trained for 240 epochs and decays by 0.1 every 30 epochs after 150 epochs. Top-1 accuracy is the performance evaluation standard for the model.

Adaptive inference setting

MSDNET [18] is the first paper to apply multi-exit networks to the field of adaptive inference. Consistent with the settings of MSDNET, this paper conducts adaptive inference experiments in two distinct scenarios, anytime prediction and budgeted batch classification.

In the anytime prediction scenario, the network can produce predictions at any given point in time, making a series of predictions between the initial and final classifications. Test samples pass through all the classifiers sequentially until the time budget is exhausted, the latest inference results are returned, or the output results meet a predefined threshold. When the time budget is limited, samples are mainly predicted by shallow classifiers. Conversely, when the time budget is sufficient, samples are predicted by a deeper classifier, which offers better performance.

In the budgeted batch classification scenario, samples are sequentially fed into each classifier for prediction. If the confidence of the predicted result of the ith classifier exceeds its threshold, the current result is considered final and the sample does not proceed to subsequent classifiers. Given the computational budget, the prediction threshold for each classifier can be calculated from the validation set. During testing, the predictive confidence of each sample is compared to the threshold. If the predictive confidence exceeds the predefined threshold of the current classifier, the prediction is concluded. If none of the classifiers’ confidence surpasses the threshold, the final classifier provides the prediction. The average classification accuracy of all the test samples is then calculated to determine the inference accuracy under the current budget. For more detailed information about budgeted batch classification, please refer to Ref. [18].

Multi-level collaborative self-distillation network

To validate the effectiveness of MLCSD-Net, we performed experiments on three datasets: CIFAR10, CIFAR100, and Tiny-ImageNet. The experimental results are presented in Tables 2, 3 and 4. In these tables, “Backbone” refers to the selected backbone network, and “SE" is the result of formal training for the corresponding backbone network. “Method” indicates the training method used, including “MENM”, “One for all”, and “All for all”. “MENM” represents the basic multi-exit network module based on the SLIC structure proposed in this paper. “One for all’ represents the distillation technology using the deepest classifier to distill the shallow classifier, and “All for all” represents the distillation technology using MLCSD proposed in this paper. “F(G)” represents computational costs. “Ex1–Ex8” represents the classification results of eight classifiers of a multi-exit network. To facilitate network performance comparison, we use “Avg” to denote the average classification accuracy of the eight classifiers.

From Tables 2, 3 and 4, we can draw the following conclusions: (1) The accuracy of the final classifier exceeds that of the corresponding single-exit network after adding the proposed intermediate classifier, which indicates the effectiveness of our design. (2) “One for all” distillation technology exhibits a more obvious improvement on the shallow classifier (Ex1–Ex4), while deep classifiers (Ex5–Ex8) have no marked improvement or a slight decline. (3) “All for all” enhances the classification performance of each classifier, including the deep classifiers. (4) Computing the average performance of six networks on the CIFAR100 dataset, “All for all” achieves 1.36% higher than that of “MENM", and “All for all" surpasses “One for all" by 0.87%.

Fig. 2
figure 2

Experimental results of anytime prediction mode on the CIFAR100 dataset (ResNet18 and ResNet34 as the backbone network)

Fig. 3
figure 3

Experimental results of the anytime prediction mode on the Tiny-ImageNet dataset (ResNet18 and ResNet34 as the backbone network)

We calculate and compare the time complexity of different models. The self-distillation strategy is only applied during the training phase, the construction of teachers and the knowledge transfer processes increase the network’s training cost. However, during the inference phase, the network does not generate additional parameters or incur extra computational costs. Consequently, both self-distillation networks and MLCSD-Nets have the same time complexity as the baseline multi-exit networks. Therefore, MLCSD can further improve the overall performance of adaptive inference networks without increasing computational costs. Detailed experimental results are shown in Table 2. We also calculate the p-values for MLCSD-Net (ResNet18) compared to the self-distillation method and the baseline multi-exit network using a two-tailed Welch’s t-test [47]. The p-values are \( 2.58 \times 10^{-5} \) and \( 8.31 \times 10^{-8} \), respectively. Both values are below the 0.05 threshold, indicating that the multi-level self-distillation learning strategy is statistically significant.

Anytime prediction results

In the anytime prediction scenario, we perform experiments on the ResNet18 and ResNet34 networks using the CIFAR100 and Tiny-ImageNet datasets. The experimental results are illustrated in Figs. 2 and 3. The horizontal axis represents the network’s computational costs, and the vertical axis represents the classification accuracy. The blue line represents the basic multi-exit network, the black line represents the “One for all” distillation technology, and the red line represents the “All for all” distillation technology.

From the figures, it is evident that the accuracy of the “All for all” approach surpasses that of the “One for all” approach across all budgets, particularly at the lowest and highest budgets. For example, on the CIFAR100 dataset, when the budget is between 0.05 and 0.3 GFLOPs, the accuracy of ResNet18 (All for all) improves by 1.96–3.43% compared to that of ResNet18 (MENM). On the Tiny-ImageNet dataset, when the budget is 0.61 GFLOPs, ResNet34 (One for all) exhibits a 0.90% increase in accuracy compared to ResNet34 (MENM), while ResNet34 (All for all) achieves an even improvement in accuracy of 1.32%. We also perform experiments on the ResNet50, ResNet101, MobileNetV2, ShuffleNetV2 networks, and CIFAR10 datasets. The detailed experimental results are shown in the appendix.

Budgeted batch classification results

In the budgeted batch classification scenario, we first train the best network model on the training set, and then use the given budget to select the best performance classification result on the validation set. This enables us to determine the threshold for each branch classifier based on the minimum predictive confidence of the correctly classified samples. For instance, when evaluating the “All for all” network based on ResNet18 with 0.1 GFLOPs, if Ex1’s output threshold \(\theta _{1}=0.93\), only samples with predicted results greater than 0.93 are retained for an accuracy calculation during testing. The remaining samples continue to be processed by subsequent classifiers.

Fig. 4
figure 4

Experimental results of the budgeted batch model of CIFAR100 and Tiny-ImageNet datasets (ResNet18 and ResNet34 as the backbone network). The left figure is the result of CIFAR100 dataset, the right figure is the result of Tiny-ImageNet dataset

Figure 4 illustrates the budgeted batch classification experiment results on the CIFAR100 and Tiny-ImageNet datasets with ResNet18 and ResNet34 as the backbone networks. In the figure, the blue and orange solid lines represent the results of the “One for all” and “All for all” networks based on ResNet18, while the red and black dotted lines represent the results of the “One for all” and “All for all” networks based on ResNet34. As shown in Fig. 4, the “All for all” method achieves the highest accuracy under identical budget conditions. Specifically, on the CIFAR100 dataset, when the budget is 0.3 GFLOPs, the classification accuracy of ResNet18 (All for all) is 2.3% higher than that of ResNet18 (One for all). When the budget is 0.5 GFLOPs, the classification accuracy of ResNet18 (All for all) is 2.8% higher than that of ResNet18 (One for all). Similar results are observed on the Tiny-ImageNet dataset.

Ablation study

We evaluate the performance of MLCSD-Net’s MENM, MCLSM, and MCFSM on the CIFAR100 and Tiny-ImageNet datasets using ResNet18 as the backbone network. The results are shown in Tables 5 and 6. In the tables, “MENM” represents the multi-exit network module that uses the SLIC structure, “MCLSM” represents the multi-level collaborative logit-based self-distillation module, and “MCFSM” represents the multi-level collaborative feature-based self-distillation module.

Table 5 Results of the ablation experiment on the CIFAR100 dataset (ResNet18 as the backbone network)
Table 6 Results of ablation experiment on Tiny-ImageNet dataset (ResNet18 as the backbone network)

We set the multi-exit network using a 1*1 convolution as the intermediate classifier to the baseline. Compared with the baseline, MENM enhances the classification accuracy by 5.7% and 5.8% on the two datasets, respectively. Adding MCLSM to MENM further enhances the accuracy by 1.75% and 2.1%, respectively. Incorporating MCFSM on top of MENM and MCLSM leads to additional increases of 0.25% and 0.5%, respectively. These results demonstrate that MLCSD-Net achieves average accuracies of 8.07% and 7.18% over those of the baseline network on both datasets. The ablation experiments confirm the effectiveness of each module in enhancing the classification accuracy of the multi-exit networks.

Comparative experiments of different intermediate classifier structures

To investigate the impact of different intermediate classifiers on network performance, we conduct several comparative experiments on CIFAR100 using ResNet18 as the backbone network. The experimental results are shown in Table 7. In addition to the classification accuracy, we also focus on the changes in computational costs and the depth of the classifiers. In Table 7, “1*1-conv” is the 1*1 convolutional direct prediction structure, “Interpolate-conv” is the SLIC structure designed in this paper, and “Bottleneck-conv” is a multilayer bottleneck structure [21]. “F(G)” represents the computational costs, “P(M)” represents the number of parameters, and “Depth” represents the average depth of all the branch networks.

Table 7 Results of intermediate classifiers with different structures
Table 8 Results of different teacher integration strategies

When using the 1*1-conv structure, the classification accuracy of Ex1–Ex4 is suboptimal, and the accuracy of Ex8 is decreased by 1.0% compared to that of the original ResNet18. When using the Interpolate-conv structure proposed in this paper, the classification accuracy of the shallow layers is markedly improved. Ex1 and Ex2 show an average increase of 16.28%, reducing the performance gap with deep classifiers. Although the FLOPs and depth of the network increase by 25% and 14%, respectively, the average accuracy of all the classifiers improves by 6.06%. When using a more complex Bottleneck structure for prediction, FLOPs and depth increased by 48% and 39%, respectively, but the average accuracy of all the classifiers increases by only 6.71%. Moreover, the Bottleneck structure results in all the classifiers having the same depth as the backbone network, preventing the formation of a dynamic depth neural network, and requiring more storage space and higher computational costs.

To balance computational efficiency and classification accuracy, we chose the SLIC structure as the intermediate classifier. Additionally, we combine the three intermediate classifier structures with MLCSD, which improved the average accuracy of the network by 1.73%, 2.01%, and 2.21%, respectively. These experimental results reflect the generality and effectiveness of MLCSD.

Comparison with the other teacher integration strategy

To verify the effectiveness of the MLCSD, we perform experiments with different teacher integration strategies on the CIFAR100 using ResNet18 as the backbone network. The results are presented in Table 8. In this table, “MENM” represents the multi-exit network employing the SLIC structure, which can be regarded as the baseline method in this comparative experiment. “FLSM” represents the traditional self-distillation strategy, where the deepest classifier acts as the teacher. “Minimum” represents selecting the classifier with the minimum loss among all the classifiers as the teacher. Compared to the baseline, “FLSM” enhances the average classification accuracy by 0.96%, and “Minimum” enhances it by 0.19%. The above methods attempt to select a single classifier from among all the available classifiers to transfer knowledge and improve network performance.

In the field of knowledge distillation, combining knowledge from multiple networks to achieve multi-teacher distillation often yields better performance. When knowledge distillation is performed on a multi-exit network, the distilled classifier can treat the remaining classifiers as teachers. Therefore, we apply the multi-teacher integration strategy [48,49,50] to multi-exit networks. “Average” represents averaging the results of all the classifiers [48], “Deeper” represents that all the distilled classifiers only learn from the classifiers that are deeper than themselves, “Attention” represents that each teacher’s weight is calculated using the entropy value of the predicted results [49], “Confidence” represents that the weight of each teacher is calculated based on the magnitude of the loss [50], and “MCLSM” represents the multi-level collaborative logit-based self-distillation proposed in this paper.

For a fair and objective comparison, we only distill the logits values of the classifiers. Compared to the baseline, “Average” enhances the average classification accuracy by 1.10%, “Deeper” enhances it by 0.81%, “Attention” reduces it by 1.22%, “Confidence” increases it by 1.01%, and “MCLSM” increases it by 1.75%. The experimental results demonstrate that the MLCSD is more effective for multi-exit network distillation than the traditional multi-teacher integration strategies.

Results of multi-level collaborative self-distillation strategy on other adaptive inference networks

To further validate the effectiveness of the MLCSD, we apply it to the classical adaptive inference networks MSDNET and RANNET. Both networks set the network structure as base = 2, step = 3, block = 8. Experiments are performed on the CIFAR100 dataset. As shown in Table 9, the average accuracies of MSDNET and RANNET improved by 2.0% and 1.67%, without increasing the network inference cost. These results demonstrate the effectiveness of the MLCSD, highlighting its potential for deployment in other adaptive inference networks.

Table 9 Results of the integration of multi-level collaborative self-distillation learning strategy on other adaptive inference networks
Table 10 Results of the the multi-level collaborative self-distillation learning strategy with other distillation methods

Results of multi-level collaborative self-distillation strategy on other distillation method

With the development of distillation techniques, many new distillation methods have been designed, such as the correlation consistent distillation method CC [40], the contrastive representation distillation method CRD [51], and the decoupled knowledge distillation method DKD [13]. We replace the logit-based distillation with these new knowledge distillation methods. Experiments are conducted on the CIFAR100 dataset using ResNet18 as the backbone network, and the results are shown in Table 10. In this table, MENM represents a multi-exit network without using a distillation technique, while CC, CRD, and DKD represent self-distillation experiments based on the three new distillation methods. CC-MLCSD, CRD-MLCSD, and DKD-MLCSD denote experiments integrating MLCSD with the three new distillation methods.

The experimental results show that compared to MENM, the CC, CRD, and DKD methods achieve average accuracy improvements of 0.52%, 0.73%, and 1.06%, respectively. When applying the MLCSD to these three new distillation methods, the average classification accuracies of CC-MLCSD, CRD-MLCSD, and DKD-MLCSD increased by 0.27%, 0.31%, and 0.83%, respectively. These results demonstrate that MLCSD can be applied to other distillation techniques with good versatility.

Table 11 Results of MLCSD-Net with different number of intermediate classifiers
Table 12 Results of MLCSD-Net (ResNet-18) with different hyper-parameters

Results of different number of intermediate classifiers

In the design of MLCSD-Net, we adhere to the configurations established in MSDNET [18] by incorporating seven intermediate classifiers into the backbone network. Consequently, each stage of the backbone network contains two classifiers. A multi-exit network with eight classifiers provides more exit options, enhancing the network’s flexibility and making it better suited for adaptive inference tasks such as anytime prediction and budgeted batch classification. To verify the impact of the number of intermediate classifiers on network performance, we add three intermediate classifiers to the backbone network, forming a four-exit network. The classifiers Ex1, Ex2, Ex3, and Ex4 in the four-exit network have the same network structure and size as Ex2, Ex4, Ex6, and Ex8 in the eight-exit network. We conduct experiments on CIFAR100 using ResNet18 and ResNet50 as the backbone networks, and the experimental results are shown in Table 11. The results indicate that, despite varying numbers of intermediate classifiers within the same backbone network, the classification accuracy of the classifiers at the corresponding positions remains consistent. Furthermore, the improvements observed after applying the distillation techniques are comparable. The experimental results of ResNet18 (All for all) and ResNet50 (All for all) confirm that MLCSD improves the inference efficiency of the network.

Results of different hyper-parameters

We also evaluate the hyperparameters in MLCSD-Net (ResNet-18) on the CIFAR-100 dataset, and the results are shown in Table 12. Where “r” is the coefficient of classification loss, “a” is the weight of the logit-based self-distillation loss, “T” is the temperature coefficient of distillation, and “b” is the coefficient of feature-based self-distillation loss. By comparing and analyzing the results, we can see that the coefficient “r” greatly influences network performance. When \(r=0.1\), the average accuracy of MLCSD-Net is 71.99%. When \(r=1\), the network is fully trained, and the average accuracy increases to 75.39%, a 3.40% improvement over \(r=0.1\). When “a” ranges from 0.2 to 10, the average accuracy is between 76.25% and 77.14%. Notably, when \(a = 5\), the network achieves the best performance. With fixed values of \(r=1\) and \(a=5\), MLCSD-Net achieves the best accuracy of 77.40% when \(T=4\) and \(b=10\). These results show that hyperparameter selection substantially affects network performance. For a fair comparison, all experiments are performed under consistent settings, and the most appropriate hyper-parameters are obtained via grid search [52].

Discussion

Traditional self-distillation and multi-level collaborative self-distillation

Knowledge distillation typically involves selecting a large-scale, pretrained model as the teacher to impart its knowledge to a student network without further training the teacher. In self-distillation, the deepest classifier generally serves as the teacher to provide knowledge, while the shallow classifier serves as the student to receive knowledge. However, since all the classifiers share a common backbone network, the gradients and the errors generated during training will be propagated through the backbone network, affecting each classifier’s classification accuracy. This observation motivates an exploration of the impact of the multi-exit network structure on distillation.

We perform two groups of self-distillation comparison experiments on different intermediate classifier structures using ResNet18 as the backbone network. The experimental results are presented in Table 13, where “1*1-conv” represents direct prediction using a 1*1 convolution, “Interpolate-conv” represents prediction using the SLIC structure designed in this paper, and the coefficient “a” represents the weight coefficient of the logit-based self-distillation loss. The results reveal that while the accuracy of the shallow classifiers improves markedly with the 1*1-conv intermediate classifier structure, the accuracy of Ex8 decreases. This is because the deepest classifier encourages the shallow classifier to learn more coarse-grained features through distillation, markedly enhancing the shallow classifier’s performance. However, the backbone part of the shallow classifier is mainly responsible for extracting fine-grained features and transferring them forward to the deeper convolution layer. After distillation, the backbone network struggles to extract both the coarse-grained and fine-grained features simultaneously, causing feature conflicts that propagate to the deeper layers and leading to a performance decline in Ex8.

Table 13 Results of self-distillation with different intermediate classifier structures
Table 14 Results of various self-distillation strategies

When the weight of the distillation loss is greater, the accuracy of Ex8 decreases more significantly. Although the middle layer classifiers are negatively affected by this conflict, they still benefit from knowledge distillation, resulting in minor performance changes. With the SLIC structure, the backbone network of the shallow classifier continues to extract fine-grained features, while its feature reduction layer extracts coarse-grained features. In this scenario, self-distillation does not disrupt the backbone network’s feature extraction, enhancing the accuracy of both Ex8 and the shallow classifiers. These results suggest that self-distillation is effective when conducted within the same hierarchical level but has limitations otherwise. Thus, in designing the SLIC, we use sufficient downsampling operations to balance the predictive performance and efficiency.

In contrast, MLCSD dynamically constructs suitable teachers for all the classifiers, even if the classifiers are not at the same level, which can improve the classification performance of each classifier. Applying MLCSD to multi-exit networks with three different intermediate classifier structures shows that it can enhance the classification performance of each classifier. The experimental results in Table 7 demonstrate the superior applicability of the MLCSD.

Ensemble self-distillation and multi-level collaborative delf-distillation

In self-distillation experiments, the outputs of all the classifiers are generally averaged as a teacher to enhance the network’s performance through ensemble self-distillation [21]. To compare the performance differences between ensemble self-distillation and multi-level collaborative self-distillation (MLCSD), we conducted experiments on the CIFAR100 dataset using ResNet18 as the backbone network. The results are shown in Table 14. In the table, “IC” represents the structure of the intermediate classifier, “1*1-conv” is a direct prediction using a 1* 1 convolution, and “Interpolate-conv” is prediction using the SLIC structure designed in this paper. “KD” represents the distillation method, “FLSM” is the traditional logit-based self-distillation, “Ensemble” is the self-distillation that integrates the logits of all the classifiers, and “MCLSM” is the multi-level collaborative logit-based self-distillation proposed in this paper.

The experimental results show that ensemble self-distillation achieves high performances only when the intermediate classifier structure is “Interpolate-conv”, whereas collaborative self-distillation is not limited by this constraint. For example, compared with the 1*1-conv (FLSM) mode, the average accuracy of the 1*1-conv (Ensemble) mode decreases by 1.13%, and the average accuracy of the 1*1-conv (MCLSM) mode increases by 0.44%. In contrast, compared with Interpolate-conv (FLSM) mode, the average accuracy of Interpolate-conv (Ensemble) improves by 0.22%, and that of Interpolate-conv (MCLSM) improves by 0.79%. This discrepancy is attributed to the shallow classifiers under the 1*1-conv structure lacking coarse-grained features and not achieving a good classification performance. When all the classifiers are combined, the performance of the ensemble teacher diminishes. Unlike ensemble self-distillation, collaborative self-distillation does not simply average the output of all the classifiers but dynamically calculates more appropriate teachers by mining each classifier’s effective features and knowledge. Compared to ensemble distillation, collaborative distillation can construct more effective teachers and has broader applicability.

Table 15 Contribution coefficients of each classifier in the construction of teachers with MLCSD

Table 15 shows the weight matrix values for constructing teachers with multi-level collaborative self-distillation learning strategy. “Teacher-Ex1” to “Teacher-Ex8” represent the corresponding teachers for the eight classifiers, and “Ex1–Ex8” represent the average contribution coefficients of the eight classifiers. Compared with the ensemble self-distillation by averaging all the classifier results (0.125), the table shows that each classifier has different coefficients when constructing the corresponding teachers.

Effect of feature-based self-distillation

To verify the impact of feature-based self-distillation on network performance, we conduct comparative experiments on the CIFAR100 dataset using ResNet18 as the backbone network. The results are presented in Table 14. In this table, “FLSM+FFSM" denotes the combination of final-classifier logit-based self-distillation and final-classifier feature-based self-distillation, while “MCLSM+MCFSM" represents the combination of multi-level collaborative logit-based self-distillation and multi-level collaborative feature-based self-distillation proposed in this paper. For example, compared with the 1*1-conv (FLSM) mode, the average accuracy of the 1*1-conv (FLSM + FFSM) mode is reduced by 0.22%. In contrast, compared with the 1*1-conv (MCLSM) mode, the average accuracy of the 1*1-conv (MCLSM + MCFSM) mode is improved by 0.22%.

The experimental results indicate that the addition of feature-based self-distillation does not enhance performance when combined with traditional logit-based methods. However, it enhances accuracy when integrated within the MLCSD strategy. This underscores the efficacy of our proposed multi-level collaborative approach in leveraging feature-based self-distillation to boost classifier performance across the network.

Limitations

This paper extends the self-distillation strategy to improve the classification accuracy of all the classifiers from the perspective of network training. However, this method does not truly establish connections between the branch networks, nor does it enable deep classifiers to correct the errors of shallow classifiers. Thus, developing a multi-exit network that can gradually correct the errors of shallow classifiers is a future research direction. Furthermore, the current implementation has focused solely on image classification tasks. To establish the method’s versatility and robustness, future work will involve extending our approach to other computer vision tasks.

Fig. 5
figure 5

Experimental results of the anytime prediction mode on the CIFAR100 dataset (ResNet50 and ResNet101 as the backbone network)

Fig. 6
figure 6

Experimental results of the anytime prediction mode on the Tiny-ImageNet dataset ( MobileNetV2 and ShuffleNetV2 as the backbone network)

Fig. 7
figure 7

Experimental results of the anytime prediction mode on the CIFAR10 dataset (ResNet18 and ResNet34 as the backbone network)

Conclusion

This paper proposes an adaptive inference network called MLCSD-Net based on a multi-exit network structure. MLCSD-Net achieved excellent performance in both anytime prediction and budgeted batch classification. To build an efficient network architecture, we designed an intermediate classifier using an improved bottleneck structure to balance the model’s accuracy with the model’s computational cost. To train MLCSD-Net efficiently, we employed multi-level collaborative self-distillation (MLCSD) to mine and transfer effective knowledge among the classifiers, thereby enhancing the classification performance of each classifier. MLCSD exhibited high generality and adaptability across the multi-exit networks with diverse intermediate classifier structures. Additionally, we applied MLCSD to classic adaptive inference networks such as MSDNET and RANNET, further improving their classification performance. On the CIFAR10, CIFAR100, and Tiny-ImageNet datasets, compared with traditional self-distillation, MLCSD-Net based on ResNet18 as the backbone network achieved improvements in classification accuracies of 0.29%, 1.18%, and 0.84%, respectively.