Introduction

With the rapid growth of the number of intelligent edge devices and the scale of data generated by the Internet of Things [1, 2], distributed training by uploading the data generated by edge devices to a data center for centralized learning will be limited by communication resources and delays. At the same time, the privacy of edge device data and the increased complexity of deep learning models make it difficult for traditional training to rationally use a large amount of edge data [3]. Edge training based on federated learning [4] provides a feasible solution for overcoming centralized learning. Federated learning cooperates with individual edge devices, trains the same models using local data, and aggregates the parameters of each device to update the global model [5, 6].

A federated learning system typically assumes that clients are equipped with fast processors and sufficient computation capability to perform calculations locally and update model parameters. However, most edge devices, such as mobile phones, wearables, sensors, etc., have limited computing and memory resources, which makes it difficult to support deep learning model training. With the large-scale production of image, video, voice and other data on edge devices, the low-density model can no longer meet the data volume requirements. The high performance and precision of deep neural networks always come at the cost of a larger model size and more computation. The transmission of CNN models with millions of parameters also brings great challenges to network communication. How to ensure efficient federated training of a model on edge devices with weak computation capability while communicating efficiently has become a research difficulty.

Model pruning has been widely studied as an important solution in edge inference deployment [7, 8], which indicates that there exists a subnet that can represent the performance of the entire model after training. While the above research prunes pre-trained models in centralized training, our goal is to design pruning strategies during federated training to meet the computational requirements of resource-constrained devices. There has been a lot of researches on reducing communication cost by pruning client models before uploading. However, they only considered the communication cost and did not change the model structure to reduce the computation on the clients. Recently, some research attempts to reduce the computational requirements on the client by modifying the model in training through model pruning. Liu et al. [9] and AdaptCL [10] improve learning efficiency by changing the size of the local training model on the client, but pruning on the client increases additional calculations. PruneFL [11] proposes to use unstructured pruning in training and support corresponding sparse matrix computation by extending the deep learning framework, which will be difficult to generalize. Federated Pruning [12] performs structural pruning on the client model, which is similar to our proposed work, but does not propose corresponding parameter aggregation scheme for lossy pruning.

This paper proposes an efficient federated training framework based on model structure pruning to overcome the above challenges. Different from pruning on the clients, we prune the global model on the computatively powerful server with negligible latency and reduced upstream and downstream communication. We perform structural pruning on the model to change the structure size of the delivered model without relying on dedicated acceleration. Most of the current federated systems adopt synchronization scheme, in which the global model will not converge if the pruned models are directly aggregated. Therefore, we developed a new parameter aggregation update scheme to reduce the error caused by pruning through model reconstruction and parameter reuse. Instead of changing the size of the final global model, the goal is to train the larger model on the server with the smaller models on the clients. Specifically, the framework dynamically selects neurons or convolution kernels for the model, prunes the current optimal sub-net before releasing the global model, and then distributes the compressed model to each client for training. We change the model structure by directly removing redundant convolution kernels or neurons to form a compact model. To summarize, the specific contributions of this article are as follows:

  • We propose an efficient federated training framework based on model structure pruning, which greatly reduces the demand for client computing and memory resources by dynamically selecting the optimal sub-model of the current global model for delivery.

  • We develop a new parameter aggregation update scheme, which provides training opportunities for global model parameters and maintains the complete model structure through model reconstruction and parameter reuse, reducing the error caused by pruning.

  • We conduct a large number of experiments on different data sets and data distributions to verify the effectiveness of the proposed framework, which reduces upstream and downstream communication while maintaining the accuracy of the global model and reducing client computing costs.

Related work

Edge inference/training based on model pruning

For resource-limited edge devices, model pruning is proposed to reduce the complexity of a neural network before it is deployed. The earliest research on model pruning was in 1988. Hanson and Pratt [13] proposed an amplitude-based pruning method for shallow fully connected networks. In recent years, [14] combined various methods such as pruning, quantization, and Huffman coding to compress CNNs. Li et al. [15] proposed summing the absolute value of the convolution kernel as a criteria to measure its importance. This was the first time that the convolution kernel was used as the pruning unit to achieve network compression by changing the model structure. The research of [8, 15,16,17,18,19,20,21,22] and others put forward various criteria to determine the importance of convolution kernels on structured pruning, and the redundant convolution kernels are deleted or set to zero. Regarding pruning methods, iterative pruning [23], soft pruning [24] and dynamic pruning [25] were proposed to identify redundant parameters during training. Thus far, how to judge the effectiveness of the parameters in the model and minimize the loss of accuracy is still an unsolved problem.

Efficient federated learning

To address the bottleneck of communication delay in federated learning, much research has been carried out on gradient compression, including gradient quantization and sparsification. Quantization compresses parameters by changing the number of bits of a floating-point [26]. Bernstein etal. [27] proposed SIGNSGD, which only transmits the symbols of each small batch of stochastic gradients, and uses majority voting to aggregate the gradient symbols of each client. Sattleret al. [28] and Xu et al. [29] proposed sparse ternary compression (STC) and ternary quantization (FTTQ), respectively, and expressed the model parameters as [− 1,0,1], which greatly reduced the communication overhead. Other variants of the quantization gradient scheme include three-value quantization [30], variance reduction quantization [31], error compensation [32] and gradient difference quantization [33, 34].

Sparsification is equivalent to the client discarding part of the parameters before communication. Strom [35] set a fixed threshold, and the parameters were allowed to upload when the gradient was greater than the threshold. Dryden et al. [36] sparsed the gradient using a fixed ratio, while [37] simplified the gradient update based on a single threshold of the absolute value, with the minimum gradient by removing the absolute value of the R% gradient. AdaComp [38] was based on localized selection of gradient residues and automatically tunes the compression rate depending on local activity. The framework proposed in this paper is completely orthogonal to the above research, greatly reducing edge-side computation while compressing both upstream and downstream communication. In addition, the combination of knowledge distillation and federated learning is gradually used to reduce the amount of computation on the client side. Xing et al. [39] proposed an efficient federated distillation learning system for multitask time series classification. If there is a huge capacity gap between the large teacher model and the small student model, it may be difficult for the student model to learn well, so the optimal design of the server and client models must be determined, and the researches [40,41,42] designed the model through an optimization problem.

Approach

As described in [6, 43], in a federated learning system, each edge local model \(W_{k}\) according to local data, and sends the trained model to the server:

$$\begin{aligned} W_{k}(t)=W_{k}(t-1)-\eta \nabla f_{k}\left( W_{k}(t-1)\right) . \end{aligned}$$
(1)

The server updates global model \(W_{g}\) by aggregating the model parameters of each client,

$$\begin{aligned} {W}_{g}=\sum _{k=1}^{K} p_{k} {W}_{k}, \end{aligned}$$
(2)

where \(p_{k} \ge 0, \sum _{k=1}^{K} p_{k}=1\). Therefore, there are two models in the FL system: one is local model, maintained by each edge device, and the other is global model updated by the central server. If pruning is applied to federated learning system, the first thing to consider is whether to prune the local model on the client or the global model on the server. In addition, how can the model be pruned and how can the parameters of the pruned model be updated?

Fig. 1
figure 1

Comparison of model pruning on the server and client. a Is the typical model pruning process. b Is local pruning, where the model uploads the pruned model after the client finishes pruning, and aggregates it on the server. c Is server pruning, in which pruning is performed before the model is issued, only subnet training is performed locally, all training parameters are uploaded, and the model is aggregated and updated on the server

Where to prune?

Typical model pruning includes three stages [44]: (1) training a large, over-parameterized model (sometimes a pre-trained model), (2) pruning the trained large model according to certain criteria, and (3) fine-tuning the pruned model to restore the lost performance, as shown in Fig. 1a. The processing object of edge inference based on network pruning is the trained model. However, in federated learning, the object of pruning is the model in the training process. Whether to prune the local model on the client or the global model on the server, and how to ensure the convergence of the final model while ensuring the effectiveness of the pruning is the first consideration.

Without changing the structure of the global model, two models are maintained on the server, the global model and the pruned model. Figure 1b shows pruning on the client, (1) global model \(W_{g}\) is issued, and the client trains several epochs based on local data, (2) pruned model \(W_{k}^{P}\) is obtained by pruning trained model \(W_{k}\), and (3) each client uploads pruned model \(W_{k}^{P}\) and they are aggregated on the server side. Fig. 1c shows pruning on the server, (1) the server prunes global model \(W_{g}\), and sends pruned model \(W_{g}^{P}\), 2) the client trains the pruned model for several epochs, and (3) the client uploads the trained model and the global model is aggregated and updated on the server.

All previous sparsity-based studies have trained the full model and then selected the larger or important parameters of the gradient update to upload. Based on current research, local pruning, in which parameters that do not need to be uploaded are directly pruned instead of sparse coding, is the most readily evaluated federated learning pruning scheme. However, we mainly focus on the problem of insufficient computing resources on edge devices, not just the communication problem. We discuss the pruning location from the aspects of fine-tuning method, computation, communication, and model structure:

  1. 1.

    Fine-tuning method: Fine-tuning is an essential step to restore network performance after pruning. In federated learning, the data are only on the client, and fine-tuning can only be performed on the client. As seen in Fig. 1, in local pruning, the client model is first aggregated after pruning, and then the aggregated model is trained (not the pruned model). If fine-tuning is done locally, the amount of computation is greatly increased, and if not, part of the training information is lost, causing an error in the federated learning. However, in server pruning, pruning is followed by model training. Local training on the client is equivalent to the fine-tuning in model pruning, that is, fine-tuning in the original pruning process is transferred to the client.

  2. 2.

    Computation: Model pruning is to solves the problem of insufficient local computation capability. Current pruning methods are mostly data-driven and need to traverse all the parameters of a model. Pruning on the client increases local computation requirements instead.

  3. 3.

    Communication: Server pruning can reduce upstream and downstream communication at the same time, while local pruning can only reduce upstream communication without changing the global model.

  4. 4.

    Model structure: Due to different local client data distributions, the trained models are different, and a model after pruning on a client is also different (whether the pruning rate is identical or not). For example, a convolution kernel of client k is pruned, but the convolution kernel of client j is not pruned. This leads to heterogeneous models during aggregation, which increases the difficulty of updating the parameters of the global model and leads to noise.

In summary, we choose to perform pruning on the server to reduce upstream and downstream communications without changing the original pruning process, reduce local computation, and avoid the problem of model heterogeneity.

Model pruning and mask

The basic idea of model pruning is to remove the unimportant parts. Assuming a pretrained CNN model has a set of L convolutional layers, then the parameters of the \(l_\textrm{th}\) layer can be represented as a set of 3D filters \({W}_{l}=\left\{ W_{1}^{i}, W_{2}^{i}, \ldots , W_{N_{l}}^{i}\right\} \in {\mathbb {R}}^{N_{l} \times N_{l-1} \times k_{l} \times k_{l}}\), where \(N_{l}\) represents the number of filters in the \(l_\textrm{th}\) layer and \(k_{l}\) denotes the kernel size. In model pruning, \({W}_{l}\) can be split into two groups, i.e., a subset to be kept \(S_{l}\) and a subset, with less importance, to be pruned \(P_{l}\), where \(m_{l}\) and \(n_{l}\) are the number of important and unimportant filters, respectively. Determining which filter needs to be pruned is a combinatorial optimization problem, that can be expressed as follows:

Fig. 2
figure 2

A round of the model structure pruning process in federated learning. Convolutional kernels with small importance and their corresponding feature maps are directly removed

$$\begin{aligned} \begin{aligned} \min _{\mathcal {M}_{i}} \sum _{l=1}^{L} \sum _{i=1}^{N_{l}} \mathcal {M}_{l i} \mathcal {L}\left( W_{i}^{l}\right) \\ \mathrm{s.t.} \sum _{l i}^{N_{l}} \mathcal {M}_{l i}=m_{l}, \end{aligned} \end{aligned}$$
(3)

where \(\mathcal {M}_{\varvec{li}} \in \{0,1\}\) is the mask of the filter or neuron which is 1 if \({W}_{i}^{l}\) is grouped in \(S_{l}\), or 0 if \({W}_{i}^{l}\) is grouped in \(P_{l}\). \(\mathcal {L}(\bullet )\) is the criterion for judging the importance of a filter. For the current large network structure, finding a subnet that can be pruned without performance degradation is an NP-hard problem, which is difficult to accurately solve by searching all possible subsets. The current popular pruning method is to determine the importance of parameters based on criteria and delete the parameters with low importance. For example, using the sum of the absolute values of a filter as a criterion, the importance of filter \(\mathcal {L}\left( W_{i}^{l}\right) \) is,

$$\begin{aligned} \mathcal {L}(W_{i}^{l})=\sum _{l=1}^{N_{l-1}} \sum |K_{l}|, \quad K_{l} \in {\mathbb {R}}^{k_{l} \times k_{l}}. \end{aligned}$$
(4)

Then, \(\mathcal {L}\left( W_{i}^{l}\right) \) is sorted, where the \(\textrm{Top}(m_{l})\) is reserved for high importance, and its corresponding \(\mathcal {M}_{li}=1\), \(W_{i}^{l} \in S_{l}\). The other filters \(P_{l}\) are pruned, and their corresponding feature maps are also removed at the same time, as shown in Fig. 2.

Model aggregation and updating

The process of parameter aggregation and updating in a federated learning edge training framework based on model pruning is shown in Fig. 3. The circle in Fig. 3 represents a convolution kernel in a convolutional neural network or a neuron in fully connected neural network. The server maintains two models: the global model \(W_{g}\) and the pruned model \(W_{P}\). The parameter aggregation and update of the \(\tau _\textrm{th}\) round is:

  1. 1.

    Before issuing the global model on the server, the importance of the model parameters is first judged, important neurons are identified, and update the corresponding mask \(\mathcal {M}^{\tau }\) is updated. For neurons whose \(\mathcal {M}^{\tau }\) is 0, the neuron is deleted and the corresponding weight and bias to obtain a compressed small model \(W_{P}^{\tau }\).

  2. 2.

    Pruned model \(W_{P}^{\tau }\) is sent to each client, and the client uploads all model parameters after training \(W_{k}^{\tau }\).

  3. 3.

    The server first aggregates the parameters uploaded by each client to obtain \(W_\textrm{FL}^{\tau }\). \(W_\textrm{FL}^{\tau }\) is fused with the model \(W_{g}^{\tau -1}\) parameters before compression. In \(W_{g}^{\tau }\), the original pruned neuron parameters are consistent with the previous round of \(W_{g}^{\tau -1}\), and the unpruned neuron parameters are correspondingly replaced and updated with the parameters in \(W_\textrm{FL}^{\tau }\) according to the index order of \(\mathcal {M}^{\tau -1}=1\). The specific process is shown in Fig. 3. The entire framework of federated learning based on model pruning is shown in Algorithm 1.

    $$\begin{aligned}{} & {} W_\textrm{F L}^{\tau } \leftarrow \sum _{k=1}^{K} p_{k} W_{k}^{\tau }, \end{aligned}$$
    (5)
    $$\begin{aligned}{} & {} W_{g}^{\tau }(l, i)=\left\{ \begin{array}{ll} W_\textrm{F L}^{\tau }(l, j), &{} \textrm{if }\quad W_{g}^{\tau -1}(l, i) \in S_{l}^{\tau -1} \\ W_{g}^{\tau -1}(l, i), &{} \textrm{if} \quad W_{g}^{\tau -1}(l, i) \in P_{l}^{\tau -1}. \end{array}\right. \end{aligned}$$
    (6)
Fig. 3
figure 3

The \(\tau \) and \(\tau +1\) rounds of the client model and global model aggregation update process

Convergence analysis

We assume that the pruned network parameters obtained from some current importance criteria can represent the performance of the original network, where the parameters retained after pruning contribute greatly to the network. The gradient of each neuron is denoted as \(P\left( g\left( W^{\tau }\right) \right) \), then the probability of each neuron being retained in each round is \(p_{i}\). Therefore, the gradient variance can be reformulated as:

$$\begin{aligned} P(g)=\left[ \mathcal {M}_{1} \frac{g_{1}}{p_{1}}, \mathcal {M}_{2} \frac{g_{2}}{p_{2}}, \ldots , \mathcal {M}_{d} \frac{g_{d}}{p_{d}}\right] , \end{aligned}$$
(7)

then the variance of \(P\left( g\left( W^{\tau }\right) \right) \) can be reformulated as:

$$\begin{aligned} \begin{aligned} \mathbb {E} \sum _{i=1}^{N}\left[ P(g)_{i}^{2}\right]&=\sum _{i=1}^{N}\left[ \frac{g_{i}^{2}}{p_{i}^{2}} \times p_{i}+0 \times \left( 1-p_{i}\right) \right] \\&=\sum _{i=1}^{N} \frac{g_{i}^{2}}{p_{i}}. \end{aligned} \end{aligned}$$
(8)

Therefore, the trade-off between \(p_{i}\) and the gradient variance can be formulated as the following optimization problem:

$$\begin{aligned} \min _{p} \sum _{i=1}^{N} p_{i} \quad \mathrm{s.t.} \quad \sum _{i=1}^{N} \frac{g_{i}^{2}}{p_{i}} \le (1+\epsilon ) \sum _{i=1}^{N} g_{i}^{2}, \end{aligned}$$
(9)

where \(0<p_{i} \le 1\) and \(\epsilon \) can control the variance increase of g. We can get the solution of Eq. (9) by introducing Lagrange multipliers \(\lambda \) and \(\mu _{i}\), as the following objective:

$$\begin{aligned} \begin{aligned}&\min _{p} \max _{\lambda } \max _{\mu } L\left( p_{i}, \lambda , \mu _{i}\right) \\&=\sum _{i=1}^{N} p_{i} +\lambda ^{2}\left( \sum _{i=1}^{N} \frac{g_{i}^{2}}{p_{i}}-(1+\epsilon ) \sum _{i=1}^{N} g_{i}^{2}\right) \\&\quad +\sum _{i=1}^{N} \mu _{i}\left( p_{i}-1\right) . \end{aligned} \end{aligned}$$
(10)

Consider the KKT conditions of the above formulation, we have:

$$\begin{aligned} 1-\lambda ^{2} \frac{g_{i}^{2}}{p_{i}^{2}}+\mu _{i}=0. \end{aligned}$$
(11)

We can get the following connections combined with the complementary relaxation condition of \(\mu _{i}\left( p_{i}-1\right) =0\):

$$\begin{aligned} p_{i}= {\left\{ \begin{array}{ll}1, &{} \textrm{if}\quad \mu _{i} \ne 0 \\ \lambda |g_{i}|, &{} \textrm{if}\quad \mu _{i}=0\end{array}\right. }. \end{aligned}$$
(12)

As can be seen that if \(|g_{i}|\ge |g_{j}|\) then \(|p_{i}|\ge |p_{j}|\). Therefore, there is a set S with \(p_{j}=1, \forall j \in S\), and its corresponding \(|g_{j}|\) has the largest absolute magnitude. Assuming that the size of the set is \(k(0 \le k \le N)\) and the elements are ordered by magnitudes, denoted as \(g_{(1)}, g_{(2)}, \ldots , g_{(N)}\), we have

$$\begin{aligned} p_{(i)}= {\left\{ \begin{array}{ll}1, &{} \textrm{if}\quad i \le k \\ \lambda |g_{(i)}|, &{} \textrm{if}\quad i>k\end{array}\right. }, \end{aligned}$$
(13)

which further implies

$$\begin{aligned} \lambda =\frac{\sum _{i=k+1}^{N}|g_{i}|}{\epsilon \sum _{i=1}^{N} g_{i}^{2}+\sum _{i=k+1}^{N} g_{i}^{2}}. \end{aligned}$$
(14)

And the probability vector p is

$$\begin{aligned} p_{i}= {\left\{ \begin{array}{ll}1, &{} \textrm{if}\,\, i \in S \\ \frac{|g_{i}|\left( \sum _{i=k+1}^{N}|g_{j}|\right) }{\epsilon \sum _{j=1}^{N} g_{j}^{2}+\sum _{j=k+1}^{N} g_{j}^{2}}, &{} \textrm{if}\,\, i \notin S\end{array}\right. }. \end{aligned}$$
(15)

We can get from \((\rho , s)\)-approximately sparsity[3] that if there exists a subset S such that \(|S|=s\) and

$$\begin{aligned} \left\| g_{S^{c}}\right\| _{1} \le \rho \left\| g_{S}\right\| _{1}, \end{aligned}$$
(16)

where \(S^{c}\) is the complement of S. Thus, the variance of P(g) can be bounded by

$$\begin{aligned} \begin{aligned} \mathbb {E}\left[ \Vert P(g)\Vert _{0}\right]&=\sum _{i=1}^{N} p_{i}=\sum _{i \in S} p_{i}+\sum _{i \notin S} p_{i} \\&=s+\sum _{i \notin S} \frac{|g_{i}|\left( \sum _{j=k+1}^{d}|g_{j}|\right) }{\epsilon \sum _{j=1}^{k} g_{j}^{2}+(1+\epsilon ) \sum _{j=k+1}^{N} g_{j}^{2}}\\&=s+\frac{\left\| g_{S^{c}}\right\| _{1}^{2}}{\rho \left\| g_{S} \right\| _{2}^{2}+(1+\rho )\left\| g_{S^{c}}\right\| _{2}^{2}}\\&\le s+\frac{\rho ^{2} s\left\| g_{S}\right\| _{2}^{2}}{\rho \left\| g_{S}\right\| _{2}^{2} +(1+\rho )\left\| g_{S^{c}}\right\| _{2}^{2}}\\&\le (1+\rho ) s. \end{aligned} \end{aligned}$$
(17)

Convolution kernels or neurons are dynamically selected in each round, and the gradient of the model is bounded by the above formula, thereby ensuring convergence of the model.

In each round, we directly delete all the unimportant neurons and their corresponding weights instead of setting them to 0 by soft training. The client only trains a subnet composed of important neurons, and only updates the selected neuronal parameters during this round. In the model pruning in edge inference, neurons will be permanently removed. However, in federated learning, a fixed subnet structure is determined before the model reaches the ideal performance, which will seriously affect the convergence effect. The proposed framework guarantees synergistic convergence in two aspects: (1) The server maintains two models, and when the global parameters are updated, the pruned neurons still retain the original parameters of the previous round, instead of directly discarding the previously trained parameters. (2) During each round, the neurons that contribute the most to the network performance are always selected for update. The client uploads all training update parameters without losing the learned information, maximizing the use of local training updates, thus ensuring the convergence of training.

figure a
Fig. 4
figure 4

Distribution among classes is represented with different colors. The populations in figure generated from Dirichlet distribution with \(\alpha =1\) and 0.001, respectively, 30 random clients each

Experiments

Performance indicators

Training computation: In model compression, the computational power required for forward propagation of a model is used to evaluate the complexity of the model. Model training includes forward propagation and back-propagation, where the forward propagation computation is mainly on the feature graph and weight matrix multiplication and the backpropagation computation is on the reverse gradient computation. The MAC operations required by the two are the same. Therefore, the computation required for training the \({l}_\textrm{th}\) convolutional layer in a convolutional neural network is expressed as:

$$\begin{aligned} \textrm{FLOPs}_{l}=2 \times 2 \times N_{b} m_{b} \times C_\textrm{i n} \times K^{2} \times H \times W \times C_\textrm{out},\nonumber \\ \end{aligned}$$
(18)

where \(m_{b}\) is the minibatch size and \(N_{b}\) is the total minibatch number. KHW are the size of the convolution kernel and the height and width of the feature map, respectively, and \(C_\textrm{in}, C_\textrm{out}\) are the number of input and output channels of the convolution layer, respectively. The computation required for training the \({l}_\textrm{th}\) layer in a fully connected neural network is:

$$\begin{aligned} \textrm{FLOPs}_{l}^\textrm{fc}=2 \times N_{b} m_{b} \times (2 \times I-1) \times O, \end{aligned}$$
(19)

where I is the input neuron number and O is the output neuron number. Therefore, the computation of a single training cycle of the network model can be expressed as: \(\textrm{FLOPs}=\sum _{l=1}^{L} \textrm{FLOPs}_{l}\).

Training memory usage: We simplify the memory required for training to calculate the weights, gradients and generated feature maps of the network in a single batch (activation for fully connected network). The gradient matrix and the weight matrix are the same size. Therefore, the memory required for the \({l}_\textrm{th}\) convolutional layer is:

$$\begin{aligned} \textrm{Mem}_{l}=2 \times \left( B_{f} K^{2} \times C_\textrm{out} \times C_{i n}\right) +B_{a} H \times W \times C_\textrm{out}. \end{aligned}$$
(20)

For the memory required for training the \({l}_\textrm{th}\) layer in fully connected neural network:

$$\begin{aligned} \textrm{Mem}_{l}^{f_{C}}=2 \times B_{f}(I \times O+1)+2 \times B_{a} O, \end{aligned}$$
(21)

where \(B_{f}\) and \(B_{a}\) are data bit values that are usually equal to 32 in an edge device. Therefore, the memory usage of a single training cycle of the model can be expressed as: \(\textrm{Mem}=\sum _{l=1}^{L} \textrm{Mem}_{l}\).

Model parameters: Since the communication time is affected by the bandwidth, we take the parameter of the model as an index to evaluate the communication efficiency. In a convolutional neural network, the parameter quantity of the \({l}_\textrm{th}\) convolutional layer is:

$$\begin{aligned} \textrm{Param}_{l}=K^{2} \times C_\textrm{i n} \times C_\textrm{out}. \end{aligned}$$
(22)

The parameter quantity of the \({l}_\textrm{th}\) layer in fully connected neural network is:

$$\begin{aligned} \textrm{Param}_{l}^\textrm{f c}=I \times (O+1). \end{aligned}$$
(23)

The parameters of the entire model are: \(\textrm{Param} =\sum _{l=1}^{L}\textrm{Param}_{l}\).

Models and datasets

We evaluate the effectiveness of the proposed framework on two classification tasks: (1) CIFAR10 on VGG16 and (2) MNIST on a 5-layer fully connected network where the number of neurons in each layer is [784,512,512,256, 100,10]. The two models represent the most widely used models at present and are typical tasks in FL. Both models are verified in IID and Non-IID data distribution scenarios. CIFAR-10 dataset contains 60,000 images (50,000 for training, 10,000 for testing) from 10 classes, and MNIST dataset contains 60,000 training and 10,000 test greyscale images of handwritten digits of size \(28 \times 28\). We assume that the datasets for each client follow a distribution over N classes parameterized by a vector q \(\left( q_{i} \ge 0, i \in [1, N]\right. \) and \(\left. \Vert \varvec{q}\Vert _{1}=1\right) \). To obtain a set of clients with different data distributions, we generate \(\varvec{q} \sim {\text {Dir}}(\alpha )\) from Dirichlet distribution, where \(\alpha >0\) is a concentration parameter controlling the identicalness among clients. For every client, given an \(\alpha \), we sample q and assign the client with the corresponding number of images from 10 classes. Fig. 4 illustrates populations drawn from the Dirichlet distribution with \(\alpha =1\) and 0.001, corresponding to the IID and Non-IID data distribution scenarios in the experiments in this paper, respectively.

On the CIFAR10 dataset, we set the number of clients to 15, and all clients participate in training. The batch size is set to 128, the learning rate is 0.1, SGD is used for training, and the weight decay is 5e-4. On the MNIST dataset, the number of clients is set to 30, the batch size is set to 64, and SGD is used for training one epoch per round with learning rate of 0.01. We use the global model test accuracy obtained by FedAvg [6] and the average loss of each client as the baseline.

Different pruning rates

We first evaluate the effectiveness of the proposed framework at different pruning rates and the convergence impact on the global model. We set the pruning ratio of the number of neurons in each layer from 30 to 80%, and the calculation amount of the corresponding model decreases by 40–80%. For the MNIST dataset, whether in the IID or Non-IID data distribution, when the pruning rate is within 70%, the increase of the pruning rate has little effect on the convergence speed of the global model, and the final accuracy is still comparable or even higher than full model training, as shown in Fig. 5a–d. However, when the pruning rate reaches 80%, the convergence speed of pruning training is significantly slower, and the shock is more significant. In addition, more rounds are required to achieve the same accuracy as the full model training. When the pruning rate exceeds 80%, the global model begins to diverge, which indicates that the compressed model is too small to fit the data fully.

In VGG, the pruning granularity is a convolution kernel, which is different from the fully connected neural network. We set the pruning ratio of the number of convolution kernels in each layer from 20 to 50%, and the calculation amount of the corresponding model decreases by 35–75%, as shown in Fig. 5e–h. It can be seen from the figure that the convergence speed of the global model slows down with the increasing pruning rate, but the final accuracy is still better than the full model training. However, when the pruning rate increases to 75%, the model convergence speed becomes very slow, and the model performance degrades. In conclusion, our proposed pruning training framework is effective, and the more complex the network is more sensitive it is to pruning training.

Fig. 5
figure 5

a and b Are the average loss and test accuracy of MNIST/IID under different pruning rates, c and d are of MNIST/Non-IID, e and f are of CIFAR10/Non-IID

Different parameter selection criteria

In model pruning, the criterion for parameter redundancy is the key to determining the pruning performance. In the proposed framework, the criterion is also an important factor in determining whether the subnetworks represent the current global model. We evaluate the effectiveness of the proposed framework under different criteria. These methods are briefly summarized as follows.

  • Random. Parameters are randomly discarded.

  • L1 [15]. Using the sum of the absolute values of a filter as a criterion: \(\mathcal {L}\left( W_{i}^{l}\right) =\sum \vert \mathcal {W}(i,:,:,:)\vert \).

  • L2 [15]. \(\mathcal {L}\left( W_{i}^{l}\right) =\sum \Vert \mathcal {W}(i,:,:,:)\Vert _{2}\).

  • BN mask [45]. The \(\gamma \) of \(\hat{z}=\frac{z_\textrm{in}-\mu _{\mathcal {B}}}{\sqrt{\sigma _{\mathcal {B}}^{2}+\epsilon }}; z_\textrm{out}=\gamma \hat{z}+\beta \) in a BN layer is calculated as the corresponding filter’s importance score, where \(z_\textrm{in }\) and \(z_\textrm{out}\) be the input and output, \(\mu _{\mathcal {B}}\) and \(\sigma _{\mathcal {B}}\) are the mean and standard deviation values of input activations over the current minibatch \(\mathcal {B}\).

  • Similarity. Compare the similarity between filters and remove one of them: \(D^{(l)}=\textrm{dist}\left( W_{j}^{l}, W_{k}^{l}\right) , 0 \le j \le N_{l}, j \le k \le N_{l}\)

Fig. 6
figure 6

a and e Are the average loss and test accuracy of different pruning criteria under a pruning rate of 40% on MNIST/IID, b and f are at a pruning rate of 62% on MNIST/Non-IID, c and g are at a pruning rate of 50% on CIFAR10/IID, d and h are at a pruning rate of 50% on CIFAR10/Non-IID

For the MNIST dataset, Fig. 6a and e are the average loss and test accuracy of different pruning criteria under a pruning rate of 40% on the IID data distribution; Fig. 6b and f are the average loss and test accuracy of different pruning criteria on the Non-IID data distribution with a pruning rate of 62%. As we can see, different pruning criteria have significant impacts on the convergence speed and the final obtained global model’s performance. The pruning training converges faster than the full model training, and the global model obtained in the same number of rounds has higher accuracy. Among them, random pruning can quickly converge in the proposed framework regardless of IID and Non-IID data distributions, far exceeding other heuristic pruning criteria. Since the above pruning algorithms are proposed based on CNN, the small number of weights of a single neuron in fully connected network can easily lead to partial neuron inactivation (discussed in detail later).

For the CIFAR10 dataset, Fig. 6c and g are the average loss and test accuracy of different pruning methods under 50% pruning of FLOPs and parameters on the IID data distribution; Fig. 6d and h are the average loss and test accuracy of different pruning methods when the FLOPs and parameter pruning rate are 50% on the Non-IID data distribution. The pruning granularity in VGG is a convolution kernel. Except for BN mask and random, different pruning criteria have little effect on the convergence speed and performance, proving that the proposed framework is effective for large network structures. Unlike the fully connected neural network, the global model performance obtained by randomly selected convolutional kernels is not ideal and even diverges when the pruning rate increases. To further explore the relationship between convergence speed and subnet selection, we use the above pruning methods to obtain different subnets at the same time in each round and select the subnet with the best performance for delivery. It is experimentally demonstrated that the adapted method further accelerates the convergence speed, which provides a new idea for us to further accelerate the convergence speed.

Table 1 Communication cost under different pruning rates
Fig. 7
figure 7

The computation, memory usage and communication of the client in a single round under different pruning rates. The line is the accuracy change of the global model. The baseline accuracy of VGG16 using FedAvg training on CIFAR10 is \(72.83\%\), FC/MNIST is \(93.39\%\)

The efficiency of computation and communication

Our proposed framework reduces computational and memory requirements for edge devices while improving communication efficiency. Finally, we compare the amount of computation and memory on the client under different pruning rates and the communication required to achieve a specific target accuracy. Similarly, to be closer to the actual scenario, we still choose to evaluate on the Non-IID data distribution. The computation, memory usage and communication of the client in a single round are shown in Fig. 7. The total communication amount required for different pruning rates to achieve the same target accuracy is shown in Table 1. Our proposed framework dramatically reduces the amount of computation and memory usage on the client at an accuracy exceeding that of full model training, and reduces the total amount of communication simultaneously. The larger network structure is more sensitive to pruning training, and computational reduction requires more communication rounds to compensate. The smaller network structure still maintains efficient communication while reducing the computation by 80%.

Table 2 Comparison of the time taken by different algorithms to achieve the same accuracy under the same pruning rate

In addition, we compare with the current state-of-the-art federated pruning algorithms such as Federated dropout [46], PruneFL [11], Federated Pruning [12], AdaptCL [10]. Due to the different implementation methods and pruning strategies of each algorithm, AdaptCL performs pruning on the clients, while PruneFL uses an extended framework to support sparse matrix acceleration, so the advantages and disadvantages of the algorithm cannot be measured by the single-round communication time and client model training time. We compare the time taken by different algorithms to achieve the same accuracy at the same pruning rate, and the results are shown in the Table 2. It can be seen from the table that although some methods reduce the resource requirements of the client, they increase the overall training time. Federated dropout has the slowest convergence speed because it randomly prunes the network. PruneFL requires extended library support for fine-grained pruning and has limited acceleration effect. Federated Pruning is unable to rapidly converge because no reasonable parameter aggregation scheme is proposed. AdaptCL performs pruning on the clients, which has a great acceleration effect, but brings additional calculations to the client and increases the delay. The proposed framework significantly reduces the training time of the client and the up-down communication time, and the proposed parameter aggregation scheme ensures the stable convergence of the model and the optimal performance.

Ablation study

Fig. 8
figure 8

Visualization of the convolution kernel in the pruned model, where white indicates pruned and blue indicates reserved. The top shows the pruning of the third layer on FC/MNIST. The bottom shows the pruning of the fifth layer on VGG/CIFAR10

Since neurons are selected according to their importance, a similar network structure will be selected for similar periods, resulting in the problem of inactivation of neurons that have not been selected. In the fully connected network experiment, the number of weights of each neuron is small, which leads to errors in judging the importance of neurons based on a data-driven pruning algorithm. Moreover, each round of local training is one epoch, and the parameters vary very little, which makes the pruned network structure similar over several rounds.

To further explore the reasons for neuron inactivation, we visualized the neuron index of each round of pruning when the pruning rate was 50% on MNIST/FC, as shown in Fig. 8. The experiment in Fig. 6 shows that the first 10 rounds converge the fastest, so we analyze the pruning situation of the third fully connected layer (256 neurons) of the first 10 rounds of the pruned model. White indicates the pruned neurons, and blue indicates reserved neurons. It can be seen from the figure that the model structure obtained by the pruning algorithm of calculating the L1 norm for the weight matrix of each neuron as its importance criterion is very similar. The error of the pruning criterion makes the federated learning fall into a local solution that trains only one subnet, resulting in the inactivation of other neurons in the network. However, random pruning jumps out of the local solution, so it showed better performance in the end.

We also visually analyze the pruning of the fifth layer of the convolution kernel (256 convolution kernels) in each round of pruning in the 30th through 40th rounds of VGG16/CIFAR10. From the experiment in Fig. 6, the L1-based pruning algorithm has the fastest convergence rate and higher accuracy, achieving the expected effect of pruning, which constantly seeks the optimal subnet in the adjustment and distribution of the subnet structure. However, random pruning can eventually reach convergence after more rounds of training, but the convergence speed is significantly slower. From the visual comparative analysis of the pruning, we can see that the pruning algorithm has a greater impact on the performance of federated learning, and it easily falls into the local solution of a single structure when the pruning standard is not effective. Therefore, identifying the effectiveness of neurons in the training process will be the focus of future research.

Conclusion

According to the problems of insufficient edge client computing resources and limited communication resources in the actual deployment of federated learning, this paper proposes an efficient federated training framework based on model pruning. We first discuss the problem of pruning position, then analyze the convergence of pruning-based federated learning, and finally explain the detailed process of the aggregation and update of the parameters of the entire framework. This paper applies model pruning to federated learning for the first time, and proposes the corresponding parameter update scheme to ensure the complete training of the model while maintaining the integrity of the learning information of each client. This framework greatly reduces the computational and memory requirements for local training while compressing uplink and downlink communication. Extensive experiments have verified the effectiveness of the framework.

Federated learning is currently in its infancy and many challenges remain. Although we greatly reduce the computing and memory requirements for resource-constrained devices in the federated system, in practical applications, the resource heterogeneity of devices, the withdrawal of participating devices at any time, and the dynamic unknown network environment will still bring about other delays and non-convergence of training. In the future work, we will continue to deeply combine model pruning with reinforcement learning to select reliable participating training devices in dynamic unknown network environment, and customize personalized models for devices with heterogeneous resources.