Abstract
As a new distributed machine learning paradigm, federated learning has gained increasing attention in the industry and research community. However, federated learning is challenging to implement on edge devices with limited resources and heterogeneous data. This study aims to realize a lightweight and personalized model through pruning and masking with insufficient resources and heterogeneous data. Particularly, the server first downloads the subnetwork to the client according to the mask, and client prunes the subnetwork with the alternating direction method of multipliers (ADMM), so as to remove the unimportant parameters and reduce the cost of training and communication. At the same time, mask is used to mark the pruning condition of the model. Then, the unpruned parts and masks of local models are transmitted to the server for aggregation. The experimental results showed that the accuracy of the proposed model was improved by 9.36%, and the communication cost was reduced by 1.45 times compared with state-of-the-art models. Last but not least, we deploy flower identification models in Android Studio to illustrate the practicality of the proposed method.
Similar content being viewed by others
Avoid common mistakes on your manuscript.
Introduction
People’s lives are increasingly dependent on smart devices, such as mobile phones. They generate a large amount of data every second. Uploading data to the cloud for processing will cause bigger delay and greater pressure on the cloud with continuous expansion of the equipment scale. To better protect the user’s data privacy and alleviate the traffic congestion on the backbone network, it is reasonable to provide services close to the user or process data at the network edge.
Edge devices are closer to the users and can collect huge amounts of data. Meanwhile, machine learning has developed rapidly, and many research breakthroughs have occurred. The theory and technology of machine learning can be used in edge computing to exploit the potential of data and improve quality of service (QoS) for users. The combination of edge computing and machine learning has emerged a new research direction: edge intelligence [1]. Especially, as one of the important branch, federated learning deployed on the edge network can combine clients and the server to train a neural network model under the premise of protecting data security. Clients collect considerable data, while federated learning can extract useful information from data. As shown in Fig. 1, each client uses local data to train the model downloaded from the server, and uploads the trained local model to the server for aggregation.
The neural network model is generally large in scale and requires multiple iterations during training. However, the memory and computing capacity of edge devices are very limited and fail to meet the users’ requirement. This contradiction makes the large model difficult to be deployed in resource-constrained edge devices. Therefore, pruning strategies have been usually used to lighten the large model. Nevertheless, many of pruning strategies in existing research are heuristic without considering the dynamics of weights in training, resulting in a large loss of accuracy. Thus, additional retraining is required to recover accuracy, which increases the computational burden on clients [2].
In addition, the data generated according to users’ habits by each client are often non-independent and identically distributed, for example, different clients contain data with different labels; hence, the trained models are heterogeneous. Model’s performance may degrade after aggregation in the server, which increases the number of model communication iterations. In this case, a unified global model does not meet the needs of different users. Personalized federated learning is an approach to solve heterogeneous data, but many of the related works require additional data representation when conducting personalized training, which increases the cost of training models.
Taking the aforementioned problem into account, this study trains lightweight and personalized models for clients. We take the dynamic weight pruning method to achieve a lightweight model. The loss function is defined as a constrained optimization problem, which is solved by the alternating direction method of multipliers (ADMM). Compared with the previous personalized federation learning which only consider data representation, this study uses mask under the premise of pruning to reduce the training cost. The dependence of model on hardware is thus reduced. Simultaneously, the response time and communication cost between the server and clients are reduced. Personalization is realized with the mask in the training process of federated learning. After each round of local training, the client generates a personalized mask to record the pruning of the model, and the mask is transmitted to server. The aggregated personalized model is returned to clients, which makes the personalized model deployed in each client more accurate compared with the universal global model. Our main contributions are as follows:
(1) Weight pruning based on ADMM is used to solve the problem of limited client resources. The local models are trained under sparse constraints, and pruned to make them lightweight. The relationship between model accuracy and communication cost under different target pruning rates is investigated through experiments. The appropriate target pruning rate can be selected according to the performance of edge equipment.
(2) Masking is used to build a personalized model for each client. The server downloads subnetworks of the global model to clients, and only aggregates the unpruned parts, reducing the impact of data heterogeneity on the model performance.
(3) Taking a flower identification system as an example, models are deployed in Android Studio to illustrate the practicality of the proposed method.
The rest of this study is organized as follows: Section 2 reviews the related works. Section 3 analyzes the system architecture and problem formulation of this study. In Sect. 4, the working process of lightweight and personalized federated learning is described in detail. Section 5 is the experiments and performance evaluation. Section 6 describes a case study of flowers. In Sect. 7, the system is summarized and some suggestions for future studies are given.
Related work
In this section, we review the related work that has been done from the perspectives of improving the lightweight and personalization characteristics.
In the lightweight model. Authors of [3,4,5,6] have theoretically proved the feasibility of the lightweight model. Bau et al. concluded that most key features of data could be detected only with some neurons by visualizing all neurons [3]. Denil et al. proposed that a small amount of weight could predict the remaining weight [4]. Furthermore, Shi et al. studied the energy consumption model of Graphic Processing Unit (GPU) computing and wireless transmission [5]. Wongso et al. illustrated the working principle of lightweight in deep neural networks [6]. Authors of [7, 8] used methods to implement lightweight model. Liang et al. compressed the model to reduce communication through local representation learning [7]. Abdellatif et al. reduce communication overhead by reducing the communication rounds [8]. Notably, these studies ignore the fact that data do not always follow the independent and identically distributed principle among clients.
The data between clients are heterogeneous, and different users generally have different demands; therefore, it is difficult to train a universally global model in all devices. Zhao et al. conducted a large number of experiments and proposed weight differences to prove that the heterogeneity of data degraded the model performance [9]. Personalized federated learning is the desired solution for data heterogeneity. Common methods to construct personalized federated learning include meta-learning, transfer learning, and adaptive adjustment. Fallah et al. proposed a personalized federated learning model based on meta-learning, which was a variant of federated averaging (FedAvg), and aimed to find an initial model that could quickly adapt to local data of each client in the training stage [10]. Jiang et al. proposed a two-step personalized federated learning model [11]. Schneider et al. and Zhuang et al. employed transfer learning to portrait personalized federated learning [12, 13]. Duan et al. proposed the self-balancing federated learning (FL) framework Astraea for unbalanced datasets [14]. Liao et al. solved the challenges of heterogeneity in edge computing system by determine the appropriate local updating frequencies and network topology [15]. However, these methods do not take into account the limitations of computing and communication capabilities of clients, which restrict the application scenarios of FL especially in the edge computing paradigm.
From the aforementioned studies, few of them keep a balance between a lightweight and personalized model. This study uses weight pruning based on ADMM to simultaneously solve these two issues.
System architecture and problem presentation
System architecture
The proposed system architecture is shown in Fig. 2. The left part includes a server and N clients. The server acts as the transfer station of models, which is used to transmit subnetworks and aggregate models. Clients are responsible for the training task of local models. In the right part, each client uses the loss function L to train the model, and prunes the model under the iterative pruning rate \(p\%\). By cutting the unimportant elements, the lightweight feature of the model is assured. In addition, by using the mask generated by pruning, each client can obtain a subnetwork from the server that is more representative of the local data characteristics, and then trains the personalized model.
Problem presentation
In this study, the key operation to achieve lightweight and personalization characteristics is pruning. To obtain better accuracy and convergence time, the loss function is formulated as an optimization problem with sparsity requirements, which is expressed as Eq. (1).
The above equation indicates that the ith client aims to minimize the loss function \({L_i}\) under the sparse constraint of the number of model weights. \({W_i}\) represents the neural network weights of the ith client, which is obtained by \({W_g} \odot {M_i}\); \({W_g}\) represents the weights of the global model in the server, and \({M_i}\) represents the mask of the client. \({S_i} = \left\{ {{W_i} \vert card\left( {{W_i}} \right) \le {N_i}} \right\} \) is a non-convex set that requires the number of non-zero elements retained is less than or equal to \({N_i}\); \(card({W_i})\) represents the number of non-zero elements of the neural network in the client, and \({N_i}\) is the number of expected non-zero elements.
In the process of model training, the sparse index of the model is introduced into the original general loss function \(f({W_i})\). Therefore, the overall loss function is the sum of the general loss function and regularization function, which is expressed in Eq. (2).
The first term \(f({W_i})\) can be cross entropy loss or negative log-likelihood loss function; in this study, cross entropy loss is used. K and R represent the number of samples and label categories, respectively. For \({y_{kr}}\), if the model prediction is correct, i.e., the true category of sample k is r, \({y_{kr}}\) is 1; otherwise, it is 0. \({p_{kr}}\) is the probability that sample k equals true class r. \(\lambda \) in the second term is the \(L_{2}\) regularization coefficient. The details are presented in Sect. 4.
The traditional pruning concept fails to consider constraints in the training process, which requires more iterative training time to achieve higher accuracy. Our aim is to train a lightweight model that can be used in edge devices. Therefore, we set sparse conditions, which are determined by the pruning rate and evaluated in Sect. 5. We set greater penalties to smaller weights, and make them close to zero, so as to reduce the loss of model accuracy after pruning and training time to speed up convergence.
In FL, clients train models locally and do not share the original data with the server, providing a higher security protection mechanism [16]. However, the optimization goals among clients are often inconsistent, and it is obviously unreasonable to use a unified model to meet all user needs. Given this background, personalized FL is developed.
In comparison with general FL, personalized FL pays more attention to the needs of clients and trains different models for them. In this study, the value of each mask is initialized as 1. If any parameter is set to 0 in the client’s neural network model, the value of the corresponding position in the mask is also set to 0. After the local models of clients are uploaded to the server, only the weights that have not been pruned are aggregated to retain the personalization of clients as much as possible. The server dispatches the subnetworks of the global model to the clients; each client owns a personalized model.
The main components of the proposed model are shown in Fig. 3.
In Fig. 3, \({W_\mathrm{{g}}}\) represents the weights of the global model, and \({C_i}\) represents the ith client in the set \(\{ {C_1},{C_2},\ldots ,{C_n}\}\). \(M_i^t\) and \(W_i^t\) represent the mask and subnetwork weights of client \({C_i}\) in round t, respectively. The communication process between the server and client \({C_i}\) in round t consists of the following steps:
\(\textcircled {1}\) The weights of the global model \(W_g^t\) and client mask \(M_i^t\) are matched to obtain subnetwork weights \(W_i^t\);
\(\textcircled {2}\) The subnetwork weights \(W_i^t\) are dispatched to client \({C_i}\);
\(\textcircled {3}\) Client \({C_i}\) uses local data to perform ADMM weight pruning on \(W_i^t\);
\(\textcircled {4}\)-\(\textcircled {5}\) The new mask \(M_i^{t + 1}\) and local model weights \(W_i^{t + 1}\) of client \({C_i}\) obtained after the pruning in step \(\textcircled {6}\) are uploaded to the server and used in the following round of communication (the steps are carried out simultaneously);
\(\textcircled {7}\) After the server obtains all new local models, only the non-zero values are aggregated to obtain the new global model;
\(\textcircled {8}\) The global model is updated.
LPFed: Lightweight and personalized FL
The method proposed in this study is mainly divided into three aspects: (1) ADMM weight pruning, (2) local models are aggregated and (3) subnetworks are obtained.
Weight pruning
The neural network model has obvious redundancy, and many parameters contribute little or even negligible to the result. Removing the unimportant parameters can reduce the size of the model. Pruning is a very cost-effective operation without affecting model performance, and accelerates training and inference of the model. In this study, the ADMM weight pruning is used. It is divided into two steps: (1) locally sparse training and (2) pruning.
The loss function is expressed as a non-convex optimization problem under sparse constraints. ADMM is used to solve the optimization objective in Eq. (1). ADMM combines the advantages of dual ascent and augmented Lagrangian methods, decomposes the original problem into several subproblems, and solves variables alternately. It has been widely used in large-scale distributed learning, and obtains good results [17]. Applying the ADMM algorithm to FL is equivalent to adding a dynamic regularization to the loss function, giving greater penalties to smaller weights, and removing the weights below the threshold value in the network. Compared with one-step pruning without sparse training [18], higher accuracy and less iterations are achieved.
To minimize the objective function in Eq. (1), the indicator function g(.) of \({S_i}\) is first introduced.
We add the indicator function to Eq. (1), and convert the original inequality constraint into equality constraint in Eq. (4).
Equation (4) is consistent with the standard form of ADMM, which consists of the equality constraint and two optimization variables \(W_i\) and \(Z_i\). By augmented Lagrangian, the above optimization problem can be decomposed into two subproblems \(W_i\) and \(Z_i\). The subproblem \(W_i\) can be solved using stochastic gradient descent, and the subproblem \(Z_i\) can be solved analytically [19]. The augmented Lagrangian function is:
where \({Y^T}\) is the Lagrange multiplier which is used to constrain \({W_i}\)=\({Z_i}\). \(\rho \) is the penalty parameter.
To facilitate the solution, the scaling parameter \(U = \frac{Y}{\rho }\) is defined. Equation (5) is simplified into Eq. (6) by combining the linear and quadratic terms in the function.
We apply ADMM to Eq. (6) and decompose it into three subproblems: Eqs. (7), (8), and (9). By ADMM algorithm, in Eqs. (7) and (8), \(W_i^{k + 1}\) and \(Z_i^{k + 1}\) are solved alternately and iteratively. At each step, one variable is just updated and the other two are kept stable. In Eq. (7), the original parameter \(W_i^{k + 1}\) is minimized.
In Eq. (6), sparse constraints are incorporated into the optimization problem. Equation (7) is the solution formula for the parameter \(W_i^{k + 1}\) that needs to be optimized. Focusing on the auxiliary parameter \({Z_i}\) in Eq. (6), \({Z_i}\) can be minimized as Eq. (8).
According to the rules of ADMM, the scaling parameters \(U_i^{k + 1}\) can be updated by adding model’s useless weights \((W_i^{k + 1} - Z_i^{k + 1})\) to the value of last round of \(U_i^{k}\).
In Eq. (7), the first term is the loss function, and the second term is L2 regularization, both of which are differentiable. The stochastic gradient descent method is thus used to solve \({W_i}\):
In Eq. (3), g(.) is the indicator function of \(S_i\) and the minimum value is 0. \({S_i}\) is a non-convex set, which makes \(Z_i^{k + 1}\) difficult to compute. In [19], it is pointed out that \(Z_i^{k + 1}\) can be carried out analytically when the special case of \({S_i} = \left\{ {{W_i} \vert card\left( {{W_i}} \right) \le {N_i}} \right\} \). Moreover, it can be proved that the analytical solution of \(Z_i^{k + 1}\) is the Euclidean projection of \((W_i^{k + 1} + U_i^k)\) on \({S_i}\). Specifically, \({N_i}\) is the number of weights expected to be retained after pruning. \(Z_i^{k + 1}\) represents the retained parts of the weights. The Euclidean projection keeps the \({N_i}\) elements of \((W_i^{k + 1} + U_i^k)\) with the largest magnitudes and setting the rest to 0. Therefore, by Euclidean projection, \(Z_i^{k + 1}\) can be solved. That is to say, the solution of Eq. (8) can, thus, be written as Eq. (11). \(\prod \nolimits _{{S_i}} {(.)}\) is the Euclidean projection on \({S_i}\).
\({U_i}\) is updated with Eq. (9); this completes an iteration of the ADMM.
As shown in Algorithm.1, W, Z, and U are dynamically updated in the sparse training of each client; W is the weight of the model, while Z and U represent the retained and pruned parts of the model, respectively.
Boyd et al. pointed out that ADMM algorithm could solve the non-convex optimization problem of deep learning; however, it may not converge to a global optimal solution, and in many cases, the suboptimal solution could meet the requirements [19]. In the sparse training of the first step, a large number of iterations are required to fully meet our sparsity requirements, which do not meet the lightweight objective proposed in this study. Therefore, when each client reaches the predetermined number of iterations, the training will stop. At this time, the weight of pruning may not be completely zero, and many of them are close to zero. To solve this problem, we prunes them in the second step. Specifically, only the \(N_i\) elements with large magnitudes are retained and the remaining elements are pruned (set to 0), i.e., the parameter threshold’ index \(N_i\) is obtained according to the pre-defined iterative pruning rate, and the parameter below the threshold is set to 0. In later federated training, when the server selects the previously pruned client model again, the client will no longer update the pruned weights in local training. Because these weights are very small after the sparse training in the first step, in comparison with the method without sparse training, the method proposed in this study little impact on the accuracy of the model.
Algorithm 2 describes the weight pruning process. First, according to Algorithm 1, ADMM sparsity training is performed on each subnetwork to obtain the local model (line 1). Then, the accuracy of the local model obtained by Algorithm 1 is tested (line 2). The pruned weights in the model are frozen and will not be updated in the next training. It is necessary to evaluate whether the current accuracy and pruning rate meet the requirements (line 3), to prevent permanent damage to the model accuracy. If the conditions are met, pruning, and obtaining the new mask and local model. Specifically, calculate the pruning threshold, and the mask is set to 0 if its corresponding parameter is lower than the threshold. The pruning threshold’ index \(N_i\) is calculated by Eq. (12). Note that the weights need to be sorted in descending order before the index of the threshold is calculated (line 4). A new local model is obtained by matching the mask with the model (line 5). Finally, upload the new mask and local model to the server (line 10).
Aggregation of local models
In each round of communication, after each client carries out ADMM weight pruning, it uploads the model with personalized features to the server. The models trained by datasets with similar labels also have similarities. In this study, the server aggregates only the unpruned parts, i.e., some or all overlapping parameters in models are averaged. If a parameter is pruned in all other clients except for one client, the parameter does not participate in aggregation, which can reduce the impact of heterogeneous data on accuracy. The aggregation method in this article improves model performance by finding "partners" for each client, preserving the personalization properties that are embedded into the model. As shown in Fig. 4, we plot the weight of neural network and the corresponding mask in the form of "weight: mask". In the three client models, the weights of the first neuron \({A_1}\) in the first layer to the first neuron \({B_1}\) in the second layer are 3, 0, and 2, and the corresponding masks are 1, 0, and 1. Thus, in global model, the average value obtained after aggregation is 5/2.
Derivation of subnetworks
After local training, each client uses the mask to mark the pruning condition of the model, and uploads the mask to the server. The mask is embedded with the information of local data, which is used to distinguish the models of clients in the server. As shown in \(\textcircled {1}\) in Fig. 3, the server multiplies the corresponding positions of the global model and masks to obtain the subnetworks; then, the subnetworks are offloaded to the clients. Because each client may have different masks, the subnetworks of clients may be different too, thus realizing the personalized FL. In short, deriving subnetworks means that the server selects a part of the network from the global model according to the characteristics of the clients’ local data.
As shown in Algorithm 3, after the client set S is randomly selected, the server downloads different subnetworks to clients in the set.
Fig. 5 is an example of a subnetwork derived by the server. The weights of the first neuron \({A_1}\) in the first layer to the first neuron \({B_1}\) in the second layer are 5/2 and 1, respectively. So, corresponding weight in subnetwork calculated by \({W_g} \odot {M_i}\) is 5/2. Similarly, the weight of \({B_1}\) to \({C_1}\) in subnetwork is 0.
Theoretical analysis
Now, we show that our method can achieve lightweight and personalized features by the next two theorems.
Theorem 1
The communication cost of the proposed method is less than that of traditional FL model when the pruning rate meets \(p\% > B/A\).
Proof
Let \({V_1}\) and \({V_2}\) are the communication costs of traditional and our proposed methods for the server and clients, respectively. Model parameters are transmitted in the original federated learning, and pruned model parameters and masks are transmitted in the proposed method in this study. Suppose A and B are the bits of parameter and mask respectively. \(\sum \nolimits _{i = 1}^N {\left| {{W_i}} \right| } \) is the total number of parameters in the neural networks. We get the following equations.
Obviously, when the pruning rate meets \(p\% > B/A\), \({V_1} > {V_2}\) can be got. That is to say, as long as the pruning rate of the model is greater than B/A, a lightweight model with less communication costs can be obtained. In this study, each parameter of neural network is 32 bits and the mask is 1 bit. The lightweight is, thus, realized if the pruning rate \(p\% > 1/32\). \(\square \)
Theorem 2
The proposed model assures that each client has different models if they have different masks.
Proof
For random i-th client, let \({W_i}\) be its model obtained from the server. From the aforementioned analysis, we have the following equation.
where \({W_g}\) is the global model, and \({M_i}\) is mask of i-th client.
Similarly, for the j-th client, we have,
Note that \({M_i} \ne {M_j}\), so we get \({W_i} \ne {W_j}\), thus realizing the personalized FL.
\(\square \)
Experiments and performance evaluation
Experimental setup
Dataset and set of parameters
Two datasets are used in this study. The first is CIFAR-10 which is a colorful dataset including 10 categories of images, with a total of 60,000 images; among them, 50,000 are training images and 10,000 are test images. The second is MNIST, which is a handwritten grayscale image dataset of 0–9, consisting of 60,000 training images and 10,000 test images.
In addition, PyTorch is used for simulation experiments. The number of total clients is 200, and frac represents the proportion of participating clients in each round of communication, i.e., \(200*frac\) clients are randomly selected for communication in each round. The total communication rounds of CIFAR-10 and MNIST are 500 and 300, respectively.
Heterogeneity of data
In simulation, each client uses local data to train the personalized model with the help of the central server. It is necessary to allocate non-independent and identically distributed data for each client before simulating. We assign \(n\_class\) different labels of data to each client. The specific operation of data partitioning is described as follows. First, the dataset is sorted according to labels; then, the dataset is sliced, and each fragment contains \(num\_samples\) data. Finally, each client is assigned \(n\_class\) fragments, i.e., \(n\_class\) different labels of data. In this way, the data between the clients are heterogeneous.
Neural network model
The neural network architecture of CIFAR-10 is six layers, which includes five hidden layers and one output layer. The five hidden layers can be divided into two convolutional layers and three fully connected layers. First, there are two \(5\times 5\) convolutional layers with 6 and 16 channels, respectively. And each convolutional layer is followed by the \(2\times 2\) max pooling layer. Next, there is three fully connected layers with 120, 84 and 10 channels, respectively.
The neural network architecture of MNIST is five layers, which includes four hidden layers and one output layer. The four hidden layers can be divided into two convolutional layers and two fully connected layers. First, there are two \(5 \times 5\) convolutional layers with 10 and 20 channels, respectively. And each convolutional layer is followed by the \(2 \times 2\) max pooling layer. Next, there is two fully connected layers with 50 and 10 channels, respectively.
Experimental results
Accuracy under different target pruning rates
In the experiment, the iterative pruning rate refers to the ratio of the pruned amount to the total amount in each round of federated training. The target pruning rate is the ratio during all rounds. In the training of the t-th generation, the pruning parameter of the \((t-1)\)-th generation is always equal to zero. This is because the parameters are frozen after being pruned and do not participate in the updating process. During the communication between the server and clients, 10% of the parameters are pruned iteratively, and the target pruning rate is set as 10%, 20%, 30%, 40%, 50%, 60%, 70%, 80%, and 90%, respectively. We select 10 clients randomly for communication in each round. In the client, each label contains 50 data.
To verify the relationship of accuracy with the change of the target pruning rate, we conduct ablation experiments in six models. The detailed neural network architectures of six models are as follows.
- \(\textcircled {1}\):
-
An original six-layer model which is the same as used in CIFAR-10;
- \(\textcircled {2}\):
-
A five-layer model obtained by removing a fully connected layer with 120 channels from the original network structure of CIFAR-10;
- \(\textcircled {3}\):
-
A four-layer model obtained by removing two fully connected layers with 120 and 84 channels from the original network structure of CIFAR-10;
- \(\textcircled {4}\):
-
An original five-layer model which is the same as used in MNIST;
- \(\textcircled {5}\):
-
A four-layer model obtained by removing a fully connected layer with 50 channels from the original network structure of MNIST;
- \(\textcircled {6}\):
-
A three-layer model obtained by removing a 5x5 convolutional layer with 20 channels and a fully connected layer with 50 channels from the original network structure of MNIST;
Figures 6 and 7 show the accuracies of two datasets CIFAR-10 and MNIST in different models under different target pruning rates. In Fig. 6, the models used in these two datasets are the original complete six-layer models (\(\textcircled {1}\)) and five-layer models (\(\textcircled {4}\)). The accuracy of CIFAR-10 is between 0.84 and 0.9, and the accuracy of MNIST is between 0.994 and 1.0. The accuracy difference of MNIST is relatively small, making it difficult to read. Therefore, we have enlarged the accuracy of MNIST in the subgraph. Figure 6 shows that accuracy trend is the same for complete models with the same dataset. Figure 7 is the accuracy of CIFAR-10 and MNIST. In Fig. 7, the models used are obtained by removing one layer (\(\textcircled {2}\) and \(\textcircled {5}\)) and two layers (\(\textcircled {3}\) and \(\textcircled {6}\)) from the original complete model. Figure 7 shows that accuracy trend is the same for shallowed models with the same dataset. From Figs. 6 and 7, it can be seen that the trend of accuracy remains consistent within the same dataset, regardless of whether the model is a full-fledged model or a shallowed model. The accuracies of CIFAR-10 increase to the maximum value at 30% and then decrease. On the contrary, the accuracies of MNIST keep dropping from the target pruning rate of 10%. Therefore, it can be speculated that the variation in accuracy with the target pruning rate is not related to the model.
To further verify the accuracy with the change of p-values, we introduce another dataset EMNIST to obtain the variation of accuracy with target pruning rate in the original complete six-layer and five-layer models, where the six-layer and five-layer models are the same as those used in CIFAR-10 and MNIST. The complexity of the EMNIST is higher than that of MNIST and lower than that of CIFAR-10. As shown in Fig. 8, the accuracies increase to a maximum when the target pruning rate is 20% and then decrease.
In the CIFAR-10, the accuracy is highest when the target pruning rate is 30%; in the EMNIST, the accuracy is highest when the target pruning rate is 20%; in the MNIST, the accuracy is highest when the target pruning rate is 10%. After reaching the highest accuracy, the model shows a downward trend. Therefore, it can be derived that the accuracy with the change of pruning rate maybe closely related to the dataset. The more complex the dataset, the higher the target pruning rate required to achieve the highest accuracy. As the pruning rate increases, the model retains more personalization, and a relatively high pruning rate will remove important parts of the model, resulting in reduced accuracy. More complex datasets are more difficult to identify and require more personalized data, so a higher target pruning rate is required to achieve the highest accuracy.
The communication costs of the model are different under different target pruning rates. In CIFAR-10, the accuracy is higher than 90% when the target pruning rate is 30%, but correspondingly more communication cost is needed; the same is true for EMNIST and MNIST. Pruning has a higher degree of freedom in the parameter number of the model, and different target pruning rates can be selected according to the actual situation.
Model parameter distribution under different target pruning rates
As the target pruning rate increases, more parameters become 0, the parameter matrix become sparse, and the model become sparser. To make a more intuitive comparison of sparse changes in the models, we visualize the model parameter distribution of CIFAR-10 and MNIST in unpruned and pruned (pruning rate with the highest accuracy) conditions. In Figs. 9 and 10, the horizontal coordinate is the weight value, and the vertical coordinate is the frequency of the weight value (how many values each weight has).
Figure 9 shows the parameter changes of CIFAR-10 in five hidden layers under the unpruned condition and target pruning rate of 30%.
Figure 10 shows the parameter changes of the four hidden layers of MNIST under the unpruned condition and target pruning rate of 10%.
Accuracy and communication costs under different algorithms
Six experiments are conducted on CIFAR-10 and MNIST, including standalone, FedAvg, local global federated averaging (LG-FedAvg) [7] and our lightweight personalized federated learning (LPFed) under three different target pruning rates. The models are described as follows.
(1) Standalone: the client uses the data to train the model on the local device without server involvement.
(2) FedAvg: this is the classic federated averaging algorithm. The server transfers the global model to each client, and the clients upload the models to the server for aggregation after local training.
(3) LG-FedAvg: the useful high-dimensional representation from the local data of each client is learned. The global model acts in this high-dimensional representation, thus reducing the communication costs and making the models lightweight.
(4) Proposed lightweight personalized federated learning (LPFed): ADMM weight pruning is carried out in each client under different target pruning rates. According to the experimental results in Fig. 6, the target pruning rates of CIFAR-10 are set to 30%, 50% and 70%; the target pruning rates of MNIST are set to 10%, 50% and 90%.
Next, we explore the impact of the number of participating clients and the amount of each class on model performance.
We randomly select 10, 15, and 20 clients in each round of federation training, respectively. In each client, each class contains 50 pictures. For explore the impact of the amount of each class, we randomly select 10 clients in each round of federation training, and in each client, each class contains 30, 40, and 50 pictures, respectively. The experimental results show that the proposed model is better than Standalone, FedAvg and LG-FedAvg, in terms of both accuracy and communication cost.
Figures 11 and 12 show the accuracy and communication cost when selecting different participating clients. Figures 13 and 14 show the accuracy and communication cost when each client contains different amounts of data. Standalone only used the client to train the model, without the participation of the server; thus, there is no communication cost. For CIFAR-10, the highest accuracy is 90.14% and the communication cost is 775.95 MB when the target pruning rate is 30% (we here randomly select 10 clients and each client contains 50 pictures as an example). The accuracy of LG-FedAvg is 80.78%, and the communication cost is 1127.89 MB. In comparison with LG-FedAvg, our method improves the accuracy by 9.36%, while the communication cost is reduced by 1.45 times. For MNIST, the highest accuracy is 99.96% and the communication cost is 212.45 MB when the target pruning rate is 10%. The accuracy of LG-FedAvg is 97.03%, and the communication cost is 244.82 MB. The accuracy is increased by 2.93%, and the communication cost is reduced by 1.15 times in the proposed model.
In Fig. 12, the communication cost also increases as the number of clients increases. However, as the amount of data contained by each client increases, the communication cost actually decreases as shown in Fig. 14. This is because when the number of selected clients increases in each round, the number of clients communicating with the server also increases. However, as the amount of data contained by each client increases, it speeds up the convergence of the client, and the client prunes more weights in local training.
Case study
This section illustrates the practicality of the proposed model by implementing a flower identification application in mobile phones. Flower identification is a common application in users’ everyday life. The flower data in each client reveal users’ preferences, and these data are often heterogeneous. This study uses the flower dataset created by Oxford University and MobileNetV2 [20] for training. The dataset includes 17 categories of flowers; each category includes 80 pictures, with a total of 1360 pictures. MobileNetV2 is a lightweight convolutional neural network based on an inverted residual structure, which is proposed by the Google team in 2018.
The number of communication rounds is set to 300; we randomly selects 5 from 40 clients to participate in training for each round. Among them, each client contains two types of data with 15 samples for each type, and the iterative and target pruning rates are 20% and 50%, respectively. Finally, the accuracy of the model is 92.47%, and the communication cost of each client is 0.8783 GB.
The trained model is deployed to the local identification system based on [21] through the following three steps described below, and the experimental platform is Android Studio.
-
(1)
The.pth model file obtained in PyCharm is converted to the.pt file supported in Android Studio.
-
(2)
The PyTorch dependency is added to Android Studio’s build.Gradle, and the converted model file and tag information are placed into Android Studio.
-
(3)
The model is invoked for inference.
As shown in Fig. 15, the left figure is the main interface of the system. To obtain the functional interface of flower identification, "PT IDENTIFICATION" should be clicked, as shown in the right figure.
In this study, the global model of the server and the personalized model of a client, which includes Daisy and Bluebell, are deployed to the system, respectively. As shown the left view of Fig. 16, it is possible to choose from the server and client models. To select the photo to be recognized in the mobile device, "PHOTO" should be clicked, as shown the right view of Fig. 16.
Figure 17 shows the flower recognition results of the server and client models in the same picture. It is obvious that the personalized model of the client is more accurate than the global model of the server.
Conclusion
This study proposes an ADMM-based weight pruning approach to address the challenges of high model cost and heterogeneous client data in federated learning. The appropriate target pruning rate can be selected according to the performance of the edge equipment. In each round of communication, personalized models are processed by the server through masks, while ADMM weight pruning is carried out by each client on their subnetwork using local data. Finally, the proposed weight pruning approach is applied to a flower recognition system. The proposed algorithm is proved to be effective. However, there is still some specific issues warrant further discussion. For example, to enhance the efficiency of the training model in future endeavors, it is crucial to seriously consider client selection and structured pruning.
Data availability
Not applicable.
Code availability
Not applicable.
References
Ben Mansour A, Carenini G, Duplessis A (2023) Tackling computational heterogeneity in fl: A few theoretical insights. arXiv e-prints, 2307
Choudhary T, Mishra V, Goswami A, Sarangapani J (2022) Heuristic-based automatic pruning of deep neural networks. Neural Comput Appl 34(6):4889–4903
Bau D, Zhu J-Y, Strobelt H, Lapedriza A, Zhou B, Torralba A (2020) Understanding the role of individual units in a deep neural network. Proc Natl Acad Sci 117(48):30071–30078
Denil M, Shakibi B, Dinh L, Ranzato M, de Freitas N (2013) Predicting parameters in deep learning. In: Proceedings of the 26th International Conference on Neural Information Processing Systems - Volume 2. NIPS’13, pp. 2148– 2156. Curran Associates Inc., Red Hook, NY, USA
Shi D, Li L, Chen R, Prakash P, Pan M, Fang Y (2022) Toward energy-efficient federated learning over 5g+ mobile devices. IEEE Wirel Commun 29(5):44–51
Wongso S, Ghosh R, Motani M ( 2022) Understanding deep neural networks using sliced mutual information. In: 2022 IEEE International Symposium on Information Theory (ISIT), pp. 133– 138 . IEEE
Liang PP, Liu T, Ziyin L, Salakhutdinov R, Morency L-P (2020) Think locally, act globally: Federated learning with local and global representations. arXiv:2001.01523
Abdellatif AA, Mhaisen N, Mohamed A, Erbad A, Guizani M, Dawy Z, Nasreddine W (2022)Communication-efficient hierarchical federated learning for iot heterogeneous systems with imbalanced data. Future Generation Computer Systems 128( C), 406– 419
Zhao Y, Li M, Lai L, Suda N, Civin D, Chandra V (2018) Federated learning with non-iid data. arXiv preprint arXiv:1806.00582
Fallah A, Mokhtari A, Ozdaglar A (2020) Personalized federated learning: A meta-learning approach. arXiv preprint arXiv:2002.07948
Jiang Y, Konečnỳ J, Rush K, Kannan S(2019) Improving federated learning personalization via model agnostic meta learning. arXiv preprint arXiv:1909.12488
Schneider J, Vlachos M (2019) Mass personalization of deep learning. arXiv preprint arXiv:1909.02803
Zhuang F, Qi Z, Duan K, Xi D, Zhu Y, Zhu H, Xiong H, He Q (2021) A comprehensive survey on transfer learning. Proc IEEE 109(1):43–76. https://doi.org/10.1109/JPROC.2020.3004555
Duan M, Liu D, Chen X, Liu R, Tan Y, Liang L (2020) Self-balancing federated learning with global imbalanced data in mobile systems. IEEE Trans Parallel Distrib Syst 32(1):59–71
Liao Y, Xu Y, Xu H, Wang L, Qian C (2022) Adaptive configuration for heterogeneous participants in decentralized federated learning. arXiv preprint arXiv:2212.02136
McMahan B, Moore E, Ramage D, Hampson S, y Arcas BA ( 2017) Communication-efficient learning of deep networks from decentralized data. In: Artificial Intelligence and Statistics, pp. 1273– 1282. PMLR
Wu L, Wang Y, Shi T (2021) A flexible stochastic multi-agent admm method for large-scale distributed optimization. IEEE Access 10:19045–19059
Han S, Pool J, Tran J, Dally W (2015) Learning both weights and connections for efficient neural network. Advances in neural information processing systems 28
Boyd S, Parikh N, Chu E, Peleato B, Eckstein J et al (2011) Distributed optimization and statistical learning via the alternating direction method of multipliers. Foundations and Trends® in Machine learning 3(1):1–122
Sandler M, Howard A, Zhu M, Zhmoginov A, Chen L-C ( 2018) Mobilenetv2: Inverted residuals and linear bottlenecks. In: 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 4510– 4520 . https://doi.org/10.1109/CVPR.2018.00474
Yuan P, Huang R, Zhang J, Zhang E, Zhao X (2022) Accuracy rate maximization in edge federated learning with delay and energy constraints. IEEE Systems Journal, 1–12 . https://doi.org/10.1109/JSYST.2022.3203727
Funding
The funding has been received from National Natural Science Foundation of China with Grant no. No.62072159, No.61902112; Henan Provincial Science and Technology Research Project with Granr no. No.222102210011, No.232102211061
Author information
Authors and Affiliations
Contributions
Not applicable
Corresponding author
Ethics declarations
Conflict of interest
On behalf of all authors, the corresponding author states that there is no conflict of interest.
Ethics approval
Not applicable.
Consent to participate
Not applicable.
Consent for publication
Not applicable.
Additional information
Publisher's Note
Springer Nature remains neutral with regard to jurisdictional claims in published maps and institutional affiliations.
Rights and permissions
Open Access This article is licensed under a Creative Commons Attribution 4.0 International License, which permits use, sharing, adaptation, distribution and reproduction in any medium or format, as long as you give appropriate credit to the original author(s) and the source, provide a link to the Creative Commons licence, and indicate if changes were made. The images or other third party material in this article are included in the article's Creative Commons licence, unless indicated otherwise in a credit line to the material. If material is not included in the article's Creative Commons licence and your intended use is not permitted by statutory regulation or exceeds the permitted use, you will need to obtain permission directly from the copyright holder. To view a copy of this licence, visit http://creativecommons.org/licenses/by/4.0/.
About this article
Cite this article
Yuan, P., Shi, L., Zhao, X. et al. A lightweight and personalized edge federated learning model. Complex Intell. Syst. 10, 3577–3592 (2024). https://doi.org/10.1007/s40747-023-01332-9
Received:
Accepted:
Published:
Issue Date:
DOI: https://doi.org/10.1007/s40747-023-01332-9