Introduction

People’s lives are increasingly dependent on smart devices, such as mobile phones. They generate a large amount of data every second. Uploading data to the cloud for processing will cause bigger delay and greater pressure on the cloud with continuous expansion of the equipment scale. To better protect the user’s data privacy and alleviate the traffic congestion on the backbone network, it is reasonable to provide services close to the user or process data at the network edge.

Edge devices are closer to the users and can collect huge amounts of data. Meanwhile, machine learning has developed rapidly, and many research breakthroughs have occurred. The theory and technology of machine learning can be used in edge computing to exploit the potential of data and improve quality of service (QoS) for users. The combination of edge computing and machine learning has emerged a new research direction: edge intelligence [1]. Especially, as one of the important branch, federated learning deployed on the edge network can combine clients and the server to train a neural network model under the premise of protecting data security. Clients collect considerable data, while federated learning can extract useful information from data. As shown in Fig. 1, each client uses local data to train the model downloaded from the server, and uploads the trained local model to the server for aggregation.

Fig. 1
figure 1

Federal learning framework

The neural network model is generally large in scale and requires multiple iterations during training. However, the memory and computing capacity of edge devices are very limited and fail to meet the users’ requirement. This contradiction makes the large model difficult to be deployed in resource-constrained edge devices. Therefore, pruning strategies have been usually used to lighten the large model. Nevertheless, many of pruning strategies in existing research are heuristic without considering the dynamics of weights in training, resulting in a large loss of accuracy. Thus, additional retraining is required to recover accuracy, which increases the computational burden on clients [2].

In addition, the data generated according to users’ habits by each client are often non-independent and identically distributed, for example, different clients contain data with different labels; hence, the trained models are heterogeneous. Model’s performance may degrade after aggregation in the server, which increases the number of model communication iterations. In this case, a unified global model does not meet the needs of different users. Personalized federated learning is an approach to solve heterogeneous data, but many of the related works require additional data representation when conducting personalized training, which increases the cost of training models.

Taking the aforementioned problem into account, this study trains lightweight and personalized models for clients. We take the dynamic weight pruning method to achieve a lightweight model. The loss function is defined as a constrained optimization problem, which is solved by the alternating direction method of multipliers (ADMM). Compared with the previous personalized federation learning which only consider data representation, this study uses mask under the premise of pruning to reduce the training cost. The dependence of model on hardware is thus reduced. Simultaneously, the response time and communication cost between the server and clients are reduced. Personalization is realized with the mask in the training process of federated learning. After each round of local training, the client generates a personalized mask to record the pruning of the model, and the mask is transmitted to server. The aggregated personalized model is returned to clients, which makes the personalized model deployed in each client more accurate compared with the universal global model. Our main contributions are as follows:

(1) Weight pruning based on ADMM is used to solve the problem of limited client resources. The local models are trained under sparse constraints, and pruned to make them lightweight. The relationship between model accuracy and communication cost under different target pruning rates is investigated through experiments. The appropriate target pruning rate can be selected according to the performance of edge equipment.

(2) Masking is used to build a personalized model for each client. The server downloads subnetworks of the global model to clients, and only aggregates the unpruned parts, reducing the impact of data heterogeneity on the model performance.

(3) Taking a flower identification system as an example, models are deployed in Android Studio to illustrate the practicality of the proposed method.

The rest of this study is organized as follows: Section  2 reviews the related works. Section  3 analyzes the system architecture and problem formulation of this study. In Sect.  4, the working process of lightweight and personalized federated learning is described in detail. Section  5 is the experiments and performance evaluation. Section  6 describes a case study of flowers. In Sect.  7, the system is summarized and some suggestions for future studies are given.

Related work

In this section, we review the related work that has been done from the perspectives of improving the lightweight and personalization characteristics.

In the lightweight model. Authors of [3,4,5,6] have theoretically proved the feasibility of the lightweight model. Bau et al. concluded that most key features of data could be detected only with some neurons by visualizing all neurons [3]. Denil et al. proposed that a small amount of weight could predict the remaining weight [4]. Furthermore, Shi et al. studied the energy consumption model of Graphic Processing Unit (GPU) computing and wireless transmission [5]. Wongso et al. illustrated the working principle of lightweight in deep neural networks [6]. Authors of [7, 8] used methods to implement lightweight model. Liang et al. compressed the model to reduce communication through local representation learning [7]. Abdellatif et al. reduce communication overhead by reducing the communication rounds [8]. Notably, these studies ignore the fact that data do not always follow the independent and identically distributed principle among clients.

The data between clients are heterogeneous, and different users generally have different demands; therefore, it is difficult to train a universally global model in all devices. Zhao et al. conducted a large number of experiments and proposed weight differences to prove that the heterogeneity of data degraded the model performance [9]. Personalized federated learning is the desired solution for data heterogeneity. Common methods to construct personalized federated learning include meta-learning, transfer learning, and adaptive adjustment. Fallah et al. proposed a personalized federated learning model based on meta-learning, which was a variant of federated averaging (FedAvg), and aimed to find an initial model that could quickly adapt to local data of each client in the training stage [10]. Jiang et al. proposed a two-step personalized federated learning model [11]. Schneider et al. and Zhuang et al. employed transfer learning to portrait personalized federated learning [12, 13]. Duan et al. proposed the self-balancing federated learning (FL) framework Astraea for unbalanced datasets [14]. Liao et al. solved the challenges of heterogeneity in edge computing system by determine the appropriate local updating frequencies and network topology [15]. However, these methods do not take into account the limitations of computing and communication capabilities of clients, which restrict the application scenarios of FL especially in the edge computing paradigm.

From the aforementioned studies, few of them keep a balance between a lightweight and personalized model. This study uses weight pruning based on ADMM to simultaneously solve these two issues.

System architecture and problem presentation

System architecture

The proposed system architecture is shown in Fig. 2. The left part includes a server and N clients. The server acts as the transfer station of models, which is used to transmit subnetworks and aggregate models. Clients are responsible for the training task of local models. In the right part, each client uses the loss function L to train the model, and prunes the model under the iterative pruning rate \(p\%\). By cutting the unimportant elements, the lightweight feature of the model is assured. In addition, by using the mask generated by pruning, each client can obtain a subnetwork from the server that is more representative of the local data characteristics, and then trains the personalized model.

Fig. 2
figure 2

System architecture

Problem presentation

In this study, the key operation to achieve lightweight and personalization characteristics is pruning. To obtain better accuracy and convergence time, the loss function is formulated as an optimization problem with sparsity requirements, which is expressed as Eq. (1).

$$\begin{aligned} \begin{array}{l} \mathop {\min }\limits _{{W_i}} {L_i} = f({W_i}),\\ s.t.{W_i} \in {S_i},i = 1,2,\ldots ,N. \end{array} \end{aligned}$$
(1)

The above equation indicates that the ith client aims to minimize the loss function \({L_i}\) under the sparse constraint of the number of model weights. \({W_i}\) represents the neural network weights of the ith client, which is obtained by \({W_g} \odot {M_i}\); \({W_g}\) represents the weights of the global model in the server, and \({M_i}\) represents the mask of the client. \({S_i} = \left\{ {{W_i} \vert card\left( {{W_i}} \right) \le {N_i}} \right\} \) is a non-convex set that requires the number of non-zero elements retained is less than or equal to \({N_i}\); \(card({W_i})\) represents the number of non-zero elements of the neural network in the client, and \({N_i}\) is the number of expected non-zero elements.

In the process of model training, the sparse index of the model is introduced into the original general loss function \(f({W_i})\). Therefore, the overall loss function is the sum of the general loss function and regularization function, which is expressed in Eq. (2).

$$\begin{aligned} \begin{array}{l} {L_i} = f({W_i}) + \mathrm{{regularization}}\\ \quad = - \frac{1}{K}\sum \limits _k {\sum \limits _{r = 1}^R {{y_{kr}}\log ({p_{kr}})} } + \lambda \vert \vert {W_i} \vert \vert _2^2. \end{array} \end{aligned}$$
(2)

The first term \(f({W_i})\) can be cross entropy loss or negative log-likelihood loss function; in this study, cross entropy loss is used. K and R represent the number of samples and label categories, respectively. For \({y_{kr}}\), if the model prediction is correct, i.e., the true category of sample k is r, \({y_{kr}}\) is 1; otherwise, it is 0. \({p_{kr}}\) is the probability that sample k equals true class r. \(\lambda \) in the second term is the \(L_{2}\) regularization coefficient. The details are presented in Sect.  4.

The traditional pruning concept fails to consider constraints in the training process, which requires more iterative training time to achieve higher accuracy. Our aim is to train a lightweight model that can be used in edge devices. Therefore, we set sparse conditions, which are determined by the pruning rate and evaluated in Sect.  5. We set greater penalties to smaller weights, and make them close to zero, so as to reduce the loss of model accuracy after pruning and training time to speed up convergence.

In FL, clients train models locally and do not share the original data with the server, providing a higher security protection mechanism [16]. However, the optimization goals among clients are often inconsistent, and it is obviously unreasonable to use a unified model to meet all user needs. Given this background, personalized FL is developed.

In comparison with general FL, personalized FL pays more attention to the needs of clients and trains different models for them. In this study, the value of each mask is initialized as 1. If any parameter is set to 0 in the client’s neural network model, the value of the corresponding position in the mask is also set to 0. After the local models of clients are uploaded to the server, only the weights that have not been pruned are aggregated to retain the personalization of clients as much as possible. The server dispatches the subnetworks of the global model to the clients; each client owns a personalized model.

The main components of the proposed model are shown in Fig. 3.

Fig. 3
figure 3

System flowchart

In Fig. 3, \({W_\mathrm{{g}}}\) represents the weights of the global model, and \({C_i}\) represents the ith client in the set \(\{ {C_1},{C_2},\ldots ,{C_n}\}\). \(M_i^t\) and \(W_i^t\) represent the mask and subnetwork weights of client \({C_i}\) in round t, respectively. The communication process between the server and client \({C_i}\) in round t consists of the following steps:

\(\textcircled {1}\) The weights of the global model \(W_g^t\) and client mask \(M_i^t\) are matched to obtain subnetwork weights \(W_i^t\);

\(\textcircled {2}\) The subnetwork weights \(W_i^t\) are dispatched to client \({C_i}\);

\(\textcircled {3}\) Client \({C_i}\) uses local data to perform ADMM weight pruning on \(W_i^t\);

\(\textcircled {4}\)-\(\textcircled {5}\) The new mask \(M_i^{t + 1}\) and local model weights \(W_i^{t + 1}\) of client \({C_i}\) obtained after the pruning in step \(\textcircled {6}\) are uploaded to the server and used in the following round of communication (the steps are carried out simultaneously);

\(\textcircled {7}\) After the server obtains all new local models, only the non-zero values are aggregated to obtain the new global model;

\(\textcircled {8}\) The global model is updated.

LPFed: Lightweight and personalized FL

The method proposed in this study is mainly divided into three aspects: (1) ADMM weight pruning, (2) local models are aggregated and (3) subnetworks are obtained.

Weight pruning

The neural network model has obvious redundancy, and many parameters contribute little or even negligible to the result. Removing the unimportant parameters can reduce the size of the model. Pruning is a very cost-effective operation without affecting model performance, and accelerates training and inference of the model. In this study, the ADMM weight pruning is used. It is divided into two steps: (1) locally sparse training and (2) pruning.

The loss function is expressed as a non-convex optimization problem under sparse constraints. ADMM is used to solve the optimization objective in Eq. (1). ADMM combines the advantages of dual ascent and augmented Lagrangian methods, decomposes the original problem into several subproblems, and solves variables alternately. It has been widely used in large-scale distributed learning, and obtains good results [17]. Applying the ADMM algorithm to FL is equivalent to adding a dynamic regularization to the loss function, giving greater penalties to smaller weights, and removing the weights below the threshold value in the network. Compared with one-step pruning without sparse training [18], higher accuracy and less iterations are achieved.

To minimize the objective function in Eq. (1), the indicator function g(.) of \({S_i}\) is first introduced.

$$\begin{aligned} g\left( {{W_i}} \right) = \left\{ \begin{array}{l} 0 \qquad if \; \; card\left( {{W_i}} \right) \le {N_i}\\ \infty \qquad otherwise. \end{array} \right. \end{aligned}$$
(3)

We add the indicator function to Eq. (1), and convert the original inequality constraint into equality constraint in Eq. (4).

$$\begin{aligned} \begin{array}{l} \mathop {\min }\limits _{{W_i},{Z_i}} f({W_i}) + g({Z_i})\\ s.t. \, \,{W_i} = {Z_i}. \end{array} \end{aligned}$$
(4)

Equation (4) is consistent with the standard form of ADMM, which consists of the equality constraint and two optimization variables \(W_i\) and \(Z_i\). By augmented Lagrangian, the above optimization problem can be decomposed into two subproblems \(W_i\) and \(Z_i\). The subproblem \(W_i\) can be solved using stochastic gradient descent, and the subproblem \(Z_i\) can be solved analytically [19]. The augmented Lagrangian function is:

$$\begin{aligned} {L_\rho }({W_i},{Z_i},{Y_i}){} & {} = f({W_i}) + g({Z_i}) + {Y^T}({W_i} - {Z_i}) \nonumber \\{} & {} \quad + \frac{\rho }{2} \vert \vert {W_i} - {Z_i} \vert \vert _\mathrm{{2}}^2, \end{aligned}$$
(5)

where \({Y^T}\) is the Lagrange multiplier which is used to constrain \({W_i}\)=\({Z_i}\). \(\rho \) is the penalty parameter.

To facilitate the solution, the scaling parameter \(U = \frac{Y}{\rho }\) is defined. Equation (5) is simplified into Eq. (6) by combining the linear and quadratic terms in the function.

$$\begin{aligned} {L_\rho }({W_i},{Z_i},{Y_i}){} & {} = f({W_i}) + g({Z_i}) + \frac{\rho }{2} \vert \vert {W_i} - {Z_i} + {U_i} \vert \vert _2^2 \nonumber \\{} & {} \quad - \frac{\rho }{2}\vert \vert {U_i} \vert \vert _2^2. \end{aligned}$$
(6)

We apply ADMM to Eq. (6) and decompose it into three subproblems: Eqs. (7), (8), and (9). By ADMM algorithm, in Eqs. (7) and (8), \(W_i^{k + 1}\) and \(Z_i^{k + 1}\) are solved alternately and iteratively. At each step, one variable is just updated and the other two are kept stable. In Eq. (7), the original parameter \(W_i^{k + 1}\) is minimized.

$$\begin{aligned} W_i^{k + 1} = \mathop {\arg \min }\limits _{{W_i}} f({W_i}) + \frac{\rho }{2} \vert \vert {W_i} - Z_i^k + U_i^k \vert \vert _2^2. \end{aligned}$$
(7)

In Eq. (6), sparse constraints are incorporated into the optimization problem. Equation (7) is the solution formula for the parameter \(W_i^{k + 1}\) that needs to be optimized. Focusing on the auxiliary parameter \({Z_i}\) in Eq. (6), \({Z_i}\) can be minimized as Eq. (8).

$$\begin{aligned} Z_i^{k + 1} = \mathop {\arg \min }\limits _{{Z_i}} \mathrm{{ g}}({Z_i}) + \frac{\rho }{2} \vert \vert W_i^{k + 1}- {Z_i} + U_i^k \vert \vert _2^2. \end{aligned}$$
(8)

According to the rules of ADMM, the scaling parameters \(U_i^{k + 1}\) can be updated by adding model’s useless weights \((W_i^{k + 1} - Z_i^{k + 1})\) to the value of last round of \(U_i^{k}\).

$$\begin{aligned} U_i^{k + 1} = U_i^k + W_i^{k + 1} - Z_i^{k + 1}. \end{aligned}$$
(9)

In Eq. (7), the first term is the loss function, and the second term is L2 regularization, both of which are differentiable. The stochastic gradient descent method is thus used to solve \({W_i}\):

$$\begin{aligned} \frac{{\partial {L_\rho }({W_i},Z_i^k,U_i^k)}}{{\partial {W_i}}} = \frac{{\partial f({W_i})}}{{\partial {W_i}}} + \rho ({W_i} - Z_i^k + U_i^k). \end{aligned}$$
(10)

In Eq. (3), g(.) is the indicator function of \(S_i\) and the minimum value is 0. \({S_i}\) is a non-convex set, which makes \(Z_i^{k + 1}\) difficult to compute. In [19], it is pointed out that \(Z_i^{k + 1}\) can be carried out analytically when the special case of \({S_i} = \left\{ {{W_i} \vert card\left( {{W_i}} \right) \le {N_i}} \right\} \). Moreover, it can be proved that the analytical solution of \(Z_i^{k + 1}\) is the Euclidean projection of \((W_i^{k + 1} + U_i^k)\) on \({S_i}\). Specifically, \({N_i}\) is the number of weights expected to be retained after pruning. \(Z_i^{k + 1}\) represents the retained parts of the weights. The Euclidean projection keeps the \({N_i}\) elements of \((W_i^{k + 1} + U_i^k)\) with the largest magnitudes and setting the rest to 0. Therefore, by Euclidean projection, \(Z_i^{k + 1}\) can be solved. That is to say, the solution of Eq. (8) can, thus, be written as Eq. (11). \(\prod \nolimits _{{S_i}} {(.)}\) is the Euclidean projection on \({S_i}\).

$$\begin{aligned} Z_i^{k + 1} = \prod \nolimits _{{S_i}} {(W_i^{k + 1} + U_i^k)}, \end{aligned}$$
(11)

\({U_i}\) is updated with Eq. (9); this completes an iteration of the ADMM.

Algorithm 1
figure a

Distributed ADMM

As shown in Algorithm.1, W, Z, and U are dynamically updated in the sparse training of each client; W is the weight of the model, while Z and U represent the retained and pruned parts of the model, respectively.

Boyd et al. pointed out that ADMM algorithm could solve the non-convex optimization problem of deep learning; however, it may not converge to a global optimal solution, and in many cases, the suboptimal solution could meet the requirements [19]. In the sparse training of the first step, a large number of iterations are required to fully meet our sparsity requirements, which do not meet the lightweight objective proposed in this study. Therefore, when each client reaches the predetermined number of iterations, the training will stop. At this time, the weight of pruning may not be completely zero, and many of them are close to zero. To solve this problem, we prunes them in the second step. Specifically, only the \(N_i\) elements with large magnitudes are retained and the remaining elements are pruned (set to 0), i.e., the parameter threshold’ index \(N_i\) is obtained according to the pre-defined iterative pruning rate, and the parameter below the threshold is set to 0. In later federated training, when the server selects the previously pruned client model again, the client will no longer update the pruned weights in local training. Because these weights are very small after the sparse training in the first step, in comparison with the method without sparse training, the method proposed in this study little impact on the accuracy of the model.

Algorithm 2
figure b

Client Pruning

Algorithm 2 describes the weight pruning process. First, according to Algorithm 1, ADMM sparsity training is performed on each subnetwork to obtain the local model (line 1). Then, the accuracy of the local model obtained by Algorithm 1 is tested (line 2). The pruned weights in the model are frozen and will not be updated in the next training. It is necessary to evaluate whether the current accuracy and pruning rate meet the requirements (line 3), to prevent permanent damage to the model accuracy. If the conditions are met, pruning, and obtaining the new mask and local model. Specifically, calculate the pruning threshold, and the mask is set to 0 if its corresponding parameter is lower than the threshold. The pruning threshold’ index \(N_i\) is calculated by Eq. (12). Note that the weights need to be sorted in descending order before the index of the threshold is calculated (line 4). A new local model is obtained by matching the mask with the model (line 5). Finally, upload the new mask and local model to the server (line 10).

$$\begin{aligned}{} & {} \mathrm{{threshold's \hspace{5.0pt}index }}\nonumber \\{} & {} \quad \leftarrow \mathrm{{ total \hspace{5.0pt}number \hspace{5.0pt}of \hspace{5.0pt}weights * (1 - iterative \hspace{5.0pt}pruning \hspace{5.0pt}rate)}}.\nonumber \\ \end{aligned}$$
(12)

Aggregation of local models

Fig. 4
figure 4

Aggregate local models

In each round of communication, after each client carries out ADMM weight pruning, it uploads the model with personalized features to the server. The models trained by datasets with similar labels also have similarities. In this study, the server aggregates only the unpruned parts, i.e., some or all overlapping parameters in models are averaged. If a parameter is pruned in all other clients except for one client, the parameter does not participate in aggregation, which can reduce the impact of heterogeneous data on accuracy. The aggregation method in this article improves model performance by finding "partners" for each client, preserving the personalization properties that are embedded into the model. As shown in Fig. 4, we plot the weight of neural network and the corresponding mask in the form of "weight: mask". In the three client models, the weights of the first neuron \({A_1}\) in the first layer to the first neuron \({B_1}\) in the second layer are 3, 0, and 2, and the corresponding masks are 1, 0, and 1. Thus, in global model, the average value obtained after aggregation is 5/2.

Derivation of subnetworks

After local training, each client uses the mask to mark the pruning condition of the model, and uploads the mask to the server. The mask is embedded with the information of local data, which is used to distinguish the models of clients in the server. As shown in \(\textcircled {1}\) in Fig. 3, the server multiplies the corresponding positions of the global model and masks to obtain the subnetworks; then, the subnetworks are offloaded to the clients. Because each client may have different masks, the subnetworks of clients may be different too, thus realizing the personalized FL. In short, deriving subnetworks means that the server selects a part of the network from the global model according to the characteristics of the clients’ local data.

As shown in Algorithm 3, after the client set S is randomly selected, the server downloads different subnetworks to clients in the set.

Algorithm 3
figure c

Server Derives Subnetworks

Fig. 5
figure 5

Obtain a subnetwork

Fig. 5 is an example of a subnetwork derived by the server. The weights of the first neuron \({A_1}\) in the first layer to the first neuron \({B_1}\) in the second layer are 5/2 and 1, respectively. So, corresponding weight in subnetwork calculated by \({W_g} \odot {M_i}\) is 5/2. Similarly, the weight of \({B_1}\) to \({C_1}\) in subnetwork is 0.

Theoretical analysis

Now, we show that our method can achieve lightweight and personalized features by the next two theorems.

Theorem 1

The communication cost of the proposed method is less than that of traditional FL model when the pruning rate meets \(p\% > B/A\).

Proof

Let \({V_1}\) and \({V_2}\) are the communication costs of traditional and our proposed methods for the server and clients, respectively. Model parameters are transmitted in the original federated learning, and pruned model parameters and masks are transmitted in the proposed method in this study. Suppose A and B are the bits of parameter and mask respectively. \(\sum \nolimits _{i = 1}^N {\left| {{W_i}} \right| } \) is the total number of parameters in the neural networks. We get the following equations.

$$\begin{aligned} {V_1}= & {} A*\sum \limits _{i = 1}^N {\left| {{W_i}} \right| }, \end{aligned}$$
(13)
$$\begin{aligned} {V_2}= & {} A*\sum \limits _{i = 1}^N {\left| {{W_i}} \right| } *(1 - p\% ) + B*\sum \limits _{i = 1}^N {\left| {{W_i}} \right| }, \end{aligned}$$
(14)
$$\begin{aligned}{} & {} {V_1}> {V_2} \Leftrightarrow A*\sum \limits _{i = 1}^N {\left| {{W_i}} \right| } \nonumber \\{} & {} \qquad> A*\sum \limits _{i = 1}^N {\left| {{W_i}} \right| } *(1 - p\% ) + B*\sum \limits _{i = 1}^N {\left| {{W_i}} \right| } \nonumber \\{} & {} \qquad \; \; \; \; \; \; \; \; \; \; \Leftrightarrow A> A*(1 - p\% ) + B \nonumber \\{} & {} \qquad \; \; \; \; \Leftrightarrow p\% > B/A. \end{aligned}$$
(15)

Obviously, when the pruning rate meets \(p\% > B/A\), \({V_1} > {V_2}\) can be got. That is to say, as long as the pruning rate of the model is greater than B/A, a lightweight model with less communication costs can be obtained. In this study, each parameter of neural network is 32 bits and the mask is 1 bit. The lightweight is, thus, realized if the pruning rate \(p\% > 1/32\). \(\square \)

Theorem 2

The proposed model assures that each client has different models if they have different masks.

Proof

For random i-th client, let \({W_i}\) be its model obtained from the server. From the aforementioned analysis, we have the following equation.

$$\begin{aligned} \begin{array}{l} \forall i \in n,{W_i} = {W_g} \odot {M_i}, \end{array} \end{aligned}$$
(16)

where \({W_g}\) is the global model, and \({M_i}\) is mask of i-th client.

Similarly, for the j-th client, we have,

$$\begin{aligned} \begin{array}{l} \forall j \in n,{W_j} = {W_g} \odot {M_j}. \end{array} \end{aligned}$$
(17)

Note that \({M_i} \ne {M_j}\), so we get \({W_i} \ne {W_j}\), thus realizing the personalized FL.

\(\square \)

Experiments and performance evaluation

Experimental setup

Dataset and set of parameters

Two datasets are used in this study. The first is CIFAR-10 which is a colorful dataset including 10 categories of images, with a total of 60,000 images; among them, 50,000 are training images and 10,000 are test images. The second is MNIST, which is a handwritten grayscale image dataset of 0–9, consisting of 60,000 training images and 10,000 test images.

In addition, PyTorch is used for simulation experiments. The number of total clients is 200, and frac represents the proportion of participating clients in each round of communication, i.e., \(200*frac\) clients are randomly selected for communication in each round. The total communication rounds of CIFAR-10 and MNIST are 500 and 300, respectively.

Heterogeneity of data

In simulation, each client uses local data to train the personalized model with the help of the central server. It is necessary to allocate non-independent and identically distributed data for each client before simulating. We assign \(n\_class\) different labels of data to each client. The specific operation of data partitioning is described as follows. First, the dataset is sorted according to labels; then, the dataset is sliced, and each fragment contains \(num\_samples\) data. Finally, each client is assigned \(n\_class\) fragments, i.e., \(n\_class\) different labels of data. In this way, the data between the clients are heterogeneous.

Neural network model

The neural network architecture of CIFAR-10 is six layers, which includes five hidden layers and one output layer. The five hidden layers can be divided into two convolutional layers and three fully connected layers. First, there are two \(5\times 5\) convolutional layers with 6 and 16 channels, respectively. And each convolutional layer is followed by the \(2\times 2\) max pooling layer. Next, there is three fully connected layers with 120, 84 and 10 channels, respectively.

The neural network architecture of MNIST is five layers, which includes four hidden layers and one output layer. The four hidden layers can be divided into two convolutional layers and two fully connected layers. First, there are two \(5 \times 5\) convolutional layers with 10 and 20 channels, respectively. And each convolutional layer is followed by the \(2 \times 2\) max pooling layer. Next, there is two fully connected layers with 50 and 10 channels, respectively.

Experimental results

Accuracy under different target pruning rates

In the experiment, the iterative pruning rate refers to the ratio of the pruned amount to the total amount in each round of federated training. The target pruning rate is the ratio during all rounds. In the training of the t-th generation, the pruning parameter of the \((t-1)\)-th generation is always equal to zero. This is because the parameters are frozen after being pruned and do not participate in the updating process. During the communication between the server and clients, 10% of the parameters are pruned iteratively, and the target pruning rate is set as 10%, 20%, 30%, 40%, 50%, 60%, 70%, 80%, and 90%, respectively. We select 10 clients randomly for communication in each round. In the client, each label contains 50 data.

To verify the relationship of accuracy with the change of the target pruning rate, we conduct ablation experiments in six models. The detailed neural network architectures of six models are as follows.

\(\textcircled {1}\):

An original six-layer model which is the same as used in CIFAR-10;

\(\textcircled {2}\):

A five-layer model obtained by removing a fully connected layer with 120 channels from the original network structure of CIFAR-10;

\(\textcircled {3}\):

A four-layer model obtained by removing two fully connected layers with 120 and 84 channels from the original network structure of CIFAR-10;

\(\textcircled {4}\):

An original five-layer model which is the same as used in MNIST;

\(\textcircled {5}\):

A four-layer model obtained by removing a fully connected layer with 50 channels from the original network structure of MNIST;

\(\textcircled {6}\):

A three-layer model obtained by removing a 5x5 convolutional layer with 20 channels and a fully connected layer with 50 channels from the original network structure of MNIST;

Fig. 6
figure 6

Accuracies of different datasets and original complete models under target pruning rates

Fig. 7
figure 7

Accuracies of CIFAR-10 and MNIST in shallowed models under target pruning rates (Left: CIFAR-10, Right: MNIST)

Figures 6 and  7 show the accuracies of two datasets CIFAR-10 and MNIST in different models under different target pruning rates. In Fig. 6, the models used in these two datasets are the original complete six-layer models (\(\textcircled {1}\)) and five-layer models (\(\textcircled {4}\)). The accuracy of CIFAR-10 is between 0.84 and 0.9, and the accuracy of MNIST is between 0.994 and 1.0. The accuracy difference of MNIST is relatively small, making it difficult to read. Therefore, we have enlarged the accuracy of MNIST in the subgraph. Figure 6 shows that accuracy trend is the same for complete models with the same dataset. Figure 7 is the accuracy of CIFAR-10 and MNIST. In Fig. 7, the models used are obtained by removing one layer (\(\textcircled {2}\) and \(\textcircled {5}\)) and two layers (\(\textcircled {3}\) and \(\textcircled {6}\)) from the original complete model. Figure 7 shows that accuracy trend is the same for shallowed models with the same dataset. From Figs. 6 and 7, it can be seen that the trend of accuracy remains consistent within the same dataset, regardless of whether the model is a full-fledged model or a shallowed model. The accuracies of CIFAR-10 increase to the maximum value at 30% and then decrease. On the contrary, the accuracies of MNIST keep dropping from the target pruning rate of 10%. Therefore, it can be speculated that the variation in accuracy with the target pruning rate is not related to the model.

To further verify the accuracy with the change of p-values, we introduce another dataset EMNIST to obtain the variation of accuracy with target pruning rate in the original complete six-layer and five-layer models, where the six-layer and five-layer models are the same as those used in CIFAR-10 and MNIST. The complexity of the EMNIST is higher than that of MNIST and lower than that of CIFAR-10. As shown in Fig. 8, the accuracies increase to a maximum when the target pruning rate is 20% and then decrease.

In the CIFAR-10, the accuracy is highest when the target pruning rate is 30%; in the EMNIST, the accuracy is highest when the target pruning rate is 20%; in the MNIST, the accuracy is highest when the target pruning rate is 10%. After reaching the highest accuracy, the model shows a downward trend. Therefore, it can be derived that the accuracy with the change of pruning rate maybe closely related to the dataset. The more complex the dataset, the higher the target pruning rate required to achieve the highest accuracy. As the pruning rate increases, the model retains more personalization, and a relatively high pruning rate will remove important parts of the model, resulting in reduced accuracy. More complex datasets are more difficult to identify and require more personalized data, so a higher target pruning rate is required to achieve the highest accuracy.

The communication costs of the model are different under different target pruning rates. In CIFAR-10, the accuracy is higher than 90% when the target pruning rate is 30%, but correspondingly more communication cost is needed; the same is true for EMNIST and MNIST. Pruning has a higher degree of freedom in the parameter number of the model, and different target pruning rates can be selected according to the actual situation.

Model parameter distribution under different target pruning rates

As the target pruning rate increases, more parameters become 0, the parameter matrix become sparse, and the model become sparser. To make a more intuitive comparison of sparse changes in the models, we visualize the model parameter distribution of CIFAR-10 and MNIST in unpruned and pruned (pruning rate with the highest accuracy) conditions. In Figs. 9 and  10, the horizontal coordinate is the weight value, and the vertical coordinate is the frequency of the weight value (how many values each weight has).

Fig. 8
figure 8

Accuracies of EMNIST in original complete models under different target pruning rates

Fig. 9
figure 9

Parameter distribution of CIFAR-10 under unpruned and target pruning rate of 30% (Left: unpruned. Right: target pruning rate of 30%. From top to bottom are convolutional layer 1, convolutional layer 2, fully connected layer 1, fully connected layer 2, and fully connected layer 3 of neural network, respectively)

Fig. 10
figure 10

Parameter distribution of MNIST under unpruned and target pruning rate of 10% (Left: unpruned. Right: target pruning rate of 10%. From top to bottom are convolutional layer 1, convolutional layer 2, fully connected layer 1 and fully connected layer 2 of neural network respectively)

Figure 9 shows the parameter changes of CIFAR-10 in five hidden layers under the unpruned condition and target pruning rate of 30%.

Figure 10 shows the parameter changes of the four hidden layers of MNIST under the unpruned condition and target pruning rate of 10%.

Accuracy and communication costs under different algorithms

Six experiments are conducted on CIFAR-10 and MNIST, including standalone, FedAvg, local global federated averaging (LG-FedAvg) [7] and our lightweight personalized federated learning (LPFed) under three different target pruning rates. The models are described as follows.

(1) Standalone: the client uses the data to train the model on the local device without server involvement.

(2) FedAvg: this is the classic federated averaging algorithm. The server transfers the global model to each client, and the clients upload the models to the server for aggregation after local training.

(3) LG-FedAvg: the useful high-dimensional representation from the local data of each client is learned. The global model acts in this high-dimensional representation, thus reducing the communication costs and making the models lightweight.

(4) Proposed lightweight personalized federated learning (LPFed): ADMM weight pruning is carried out in each client under different target pruning rates. According to the experimental results in Fig. 6, the target pruning rates of CIFAR-10 are set to 30%, 50% and 70%; the target pruning rates of MNIST are set to 10%, 50% and 90%.

Next, we explore the impact of the number of participating clients and the amount of each class on model performance.

Fig. 11
figure 11

Accuracies of CIFAR-10 and MNIST under different number of participating clients

Fig. 12
figure 12

Communication costs of CIFAR-10 and MNIST under different number of participating clients

Fig. 13
figure 13

Accuracies of CIFAR-10 and MNIST under different amount of each class

Fig. 14
figure 14

Communication costs of CIFAR-10 and MNIST under different amount of each class

We randomly select 10, 15, and 20 clients in each round of federation training, respectively. In each client, each class contains 50 pictures. For explore the impact of the amount of each class, we randomly select 10 clients in each round of federation training, and in each client, each class contains 30, 40, and 50 pictures, respectively. The experimental results show that the proposed model is better than Standalone, FedAvg and LG-FedAvg, in terms of both accuracy and communication cost.

Figures 11 and 12 show the accuracy and communication cost when selecting different participating clients. Figures 13 and  14 show the accuracy and communication cost when each client contains different amounts of data. Standalone only used the client to train the model, without the participation of the server; thus, there is no communication cost. For CIFAR-10, the highest accuracy is 90.14% and the communication cost is 775.95 MB when the target pruning rate is 30% (we here randomly select 10 clients and each client contains 50 pictures as an example). The accuracy of LG-FedAvg is 80.78%, and the communication cost is 1127.89 MB. In comparison with LG-FedAvg, our method improves the accuracy by 9.36%, while the communication cost is reduced by 1.45 times. For MNIST, the highest accuracy is 99.96% and the communication cost is 212.45 MB when the target pruning rate is 10%. The accuracy of LG-FedAvg is 97.03%, and the communication cost is 244.82 MB. The accuracy is increased by 2.93%, and the communication cost is reduced by 1.15 times in the proposed model.

In Fig. 12, the communication cost also increases as the number of clients increases. However, as the amount of data contained by each client increases, the communication cost actually decreases as shown in Fig. 14. This is because when the number of selected clients increases in each round, the number of clients communicating with the server also increases. However, as the amount of data contained by each client increases, it speeds up the convergence of the client, and the client prunes more weights in local training.

Case study

This section illustrates the practicality of the proposed model by implementing a flower identification application in mobile phones. Flower identification is a common application in users’ everyday life. The flower data in each client reveal users’ preferences, and these data are often heterogeneous. This study uses the flower dataset created by Oxford University and MobileNetV2 [20] for training. The dataset includes 17 categories of flowers; each category includes 80 pictures, with a total of 1360 pictures. MobileNetV2 is a lightweight convolutional neural network based on an inverted residual structure, which is proposed by the Google team in 2018.

Fig. 15
figure 15

System platform (Left: main interface. Right: flower identification function interface)

The number of communication rounds is set to 300; we randomly selects 5 from 40 clients to participate in training for each round. Among them, each client contains two types of data with 15 samples for each type, and the iterative and target pruning rates are 20% and 50%, respectively. Finally, the accuracy of the model is 92.47%, and the communication cost of each client is 0.8783 GB.

Fig. 16
figure 16

Left: select the model. Right: select the photo

Fig. 17
figure 17

Server model vs. client model image recognition accuracy(Left: server model. Right: client model)

The trained model is deployed to the local identification system based on [21] through the following three steps described below, and the experimental platform is Android Studio.

  1. (1)

    The.pth model file obtained in PyCharm is converted to the.pt file supported in Android Studio.

  2. (2)

    The PyTorch dependency is added to Android Studio’s build.Gradle, and the converted model file and tag information are placed into Android Studio.

  3. (3)

    The model is invoked for inference.

As shown in Fig. 15, the left figure is the main interface of the system. To obtain the functional interface of flower identification, "PT IDENTIFICATION" should be clicked, as shown in the right figure.

In this study, the global model of the server and the personalized model of a client, which includes Daisy and Bluebell, are deployed to the system, respectively. As shown the left view of Fig. 16, it is possible to choose from the server and client models. To select the photo to be recognized in the mobile device, "PHOTO" should be clicked, as shown the right view of Fig. 16.

Figure 17 shows the flower recognition results of the server and client models in the same picture. It is obvious that the personalized model of the client is more accurate than the global model of the server.

Conclusion

This study proposes an ADMM-based weight pruning approach to address the challenges of high model cost and heterogeneous client data in federated learning. The appropriate target pruning rate can be selected according to the performance of the edge equipment. In each round of communication, personalized models are processed by the server through masks, while ADMM weight pruning is carried out by each client on their subnetwork using local data. Finally, the proposed weight pruning approach is applied to a flower recognition system. The proposed algorithm is proved to be effective. However, there is still some specific issues warrant further discussion. For example, to enhance the efficiency of the training model in future endeavors, it is crucial to seriously consider client selection and structured pruning.