Abstract
Federated learning (FL) is a new distributed learning framework that is different from traditional distributed machine learning: (1) differences in communication, computing, and storage performance among devices (device heterogeneity), (2) differences in data distribution and data volume (data heterogeneity), and (3) high communication consumption. Under heterogeneous conditions, the data distribution of clients varies greatly, which leads to the problem that the convergence speed of the training model decreases and the training model cannot converge to the global optimal solution. In this work, an FL algorithm based on stratified sampling and regularization (FedSSAR) is proposed. In FedSSAR, a density-based clustering method is used to divide the overall client into different clusters, then, some available clients are proportionally extracted from different clusters to participate in training which realizes unbiased sampling for the overall client and reduces the aggregation weight variance of the client. At the same time, when calculating the model local loss function, we limit the update direction of the model by a regular term, so that heterogeneous clients are optimized in the globally optimal direction. We prove the convergence of FedSSAR theoretically and experimentally, and demonstrate the superiority of FedSSAR by comparing it with other FL algorithms on public datasets.
Similar content being viewed by others
Avoid common mistakes on your manuscript.
Introduction
Federated learning is a new distributed machine learning paradigm [1, 2] that allows multiple devices (called clients) to collaboratively train a global model without uploading their local data [3, 4]. Compared to traditional distributed machine learning, the main differences are as follows: (1) clients have independent control over local devices and data; (2) clients are often unreliable (edge nodes are often disconnected due to equipment and communication problems); (3) communication consumption is higher than computing consumption; (4) the data distribution in FL is not independent identically distributed(non-IID); (5) the distribution of local data is uneven [5]. These new features challenge the design and analysis of FL algorithms.
One of the major challenges is client heterogeneity, including data and device heterogeneity. Heterogeneity of clients for FL exists extensively under real-world conditions [6], such as (1) heterogeneous data distribution, data on each client are generated locally, so the sample generation mechanism among clients may be different (like different countries or regions); (2) the covariate shift, for example, in handwriting recognition, different people write the same word differently; (3) label distribution skew (Prior probability drift), such as the use of the Chinese people in China, mainly in foreign people use less; (4) Quantity tilt or imbalance, etc. In real life, all kinds of a situation may lead to the occurrence of non-IID data. Traditional machine learning is based on the assumption of IID data, but FL differs from centralized machine learning in that the data on each node are non-IID in the absence of centralized data.
Consider a real situation, when we use the FL method to train a model of mobile phone input method [7, 8], different mobile phones have different operating speeds, internal data, network conditions, etc. The latest mobile phones have faster operation speed and transmission speed than the old ones. Mobile phones in a better signal location such as town have better signal location than those in rural areas. Mobile communication transmission is more stable in areas with signal interference. In model training, the old mobile phone training is slow and often does not complete the training task on time. Mobile phones with poor network alleviation are also more prone to signal loss when transmitting the model, resulting in the difference between the data distribution received by the parameter server and the actual distribution. Because of client heterogeneity, some classes of data participate more frequently in the training process, which introduces errors to the training data.
To reduce the impact of client heterogeneity, we propose FedSSAR algorithm, which consists of two parts: Client selection and regular term restriction. Compared with the traditional method of randomly selecting the clients participating in each round of training, we divide the overall clients into multiple clusters according to their local data similarity without requiring the local data of the clients and then extracts a certain number of available clients from different clusters for aggregation. At the same time, we add a regularizer to the client, modifying based on the global model, which reduces updates far from the global model when the client performs local model updates. By setting the appropriate parameters, we prove the \(O(1/\sqrt T )\) convergence speed of the algorithm. We performed experiments on benchmark datasets, including MNIST, EMNIST, Cifar-10, and Cifar-100 datasets, to explore the performance of algorithms under different levels of data heterogeneity. We compared FedSSAR with FedAvg [2], FedProx [6], and SCAFFOLD [9]. Experiments show that our proposed algorithm has a faster convergence speed and a higher training accuracy. The main contributions of this paper are as follows:
-
We demonstrate that, for a convex objective function, the traditional FL algorithm cannot converge to the global optimum due to client heterogeneity even if a precise (rather than random) gradient descent is used. In particular, when data are highly heterogeneous, traditional FL algorithm may diverge;
-
We propose the FedSSAR algorithm. The main idea of FedSSAR is to use stratified sampling for client selection and the regular term to limit the update of the model. Clients sampled by clustering can better represent the overall data distribution and reduce the variance of the aggregation model. At the same time, the regularization method is used to change the optimization objective of each device, making it close to the optimization objective of the global loss function;
-
We theoretically identify the reason for the divergence of traditional FL algorithm, and prove convergence results for FedSSAR. FedSSAR has a \(O(1/\sqrt T )\) convergence rate by our theoretical analysis;
-
We used the public MNIST, EMNIST-L, Cifar-10 and Cifar-100 datasets to evaluate the FedSSAR and compared it with FedAvg, FedProx and SCAFFOLD. The evaluation results validate the superiority of our FedSSAR in the aspects of a smoother training process, faster convergence and less training loss.
The remainder of this paper is organized as follows. In section "Related works", we introduce the background research of FL and an overview of related studies on heterogeneous federated learning; in section "Algorithm design", we theoretically analyse the effect of client heterogeneity on model convergence and propose the framework for FedSSAR; in section "Convergence analysis", we discuss the improvements provided by stratified sampling and regularization and give a theoretical proof of the convergence of FedSSAR; in section "Experiments", we comprehensively evaluate FedSSAR using different neural network models on multiple datasets.
Related works
With the improvement of storage capacity and computing capability of devices, in big data regime, there is an ever-increasing trend for distributed optimization [10,11,12,13], Traditional distributed machine learning concentrates training data in one machine or one data center and then updates the model. However, with the increasing local computing power of mobile phones, smart wearable devices, sensors, and other devices, as well as recent restrictions on user data privacy protection [14], it is more efficient for distributed clients to train models locally and then upload model parameters to parameter servers for aggregation rather than directly transferring client data. This research is known as FL and needs to address challenges such as large-scale training data, privacy protection, heterogeneous data, and devices [1, 15,16,17].
McMahan et al. [2] proposed a distributed learning method based on an iterative average of the training model and developed the Federated Averaging (FedAvg) algorithm, the learning task is jointly trained and solved by the participating devices and coordinated by the central server, its organization form is like a loose federation. Compared to data-centric distributed machine learning, one of the main advantages of FL is that it separates model training from the need for direct access to raw data, which is very important when data privacy is strictly required or it is difficult to share data centrally. Meanwhile, the FedAvg algorithm accelerates learning efficiency by using multiple local iterations. This is very helpful in reducing communication costs.
Kairouz et al. [18] discussed open issues and challenges in FL: (1) Data heterogeneity (non-IID data) in federal learning; (2) Data privacy of the client; (3) Communication restrictions; (4) Robustness of the algorithm; (5) Fairness and efficiency. One of the most fundamental challenges is the heterogeneity of client data.
To solve the problem of different data distribution in FL, Zhao et al. [19] improved the FedAvg algorithm. They proposed that when the client data is distributed heterogeneous, the application of FedAvg has a high precision loss, which can be explained by weight divergence. They proposed an improved strategy to create a small part of globally shared data among all devices. Although this method can reduce the impact of data skew, it is equivalent to adding errors artificially. In addition, this data-sharing method essentially violates the data privacy protection principle of federal learning. Yan et al. [20], considering the intermittent availability of clients. The article holds that under heterogeneous client conditions, different clients participate at different training times, which leads the training model to skew towards training-intensive client data. Therefore, when selecting clients, clients with fewer training sessions are preferred for training to ensure that each client participates in training as many times as possible. However, the assumptions in this article are too strong and the way in which the average number of client training sessions is enforced is susceptible to the influence of slow nodes, resulting in a much longer training time.
Li et al. [6] proposed FedProx, starts with the object function and reduces the impact of data heterogeneity by adding a restriction to the object function of the local model, so that each client does not deviate too much from the global model when using local data updates, this method can alleviate the impact of heterogeneity by increasing acceptable computing overhead on the client. Similar to this idea, Acar et al. [21], proposed FedDyn, when updating the local model, add regularization items based on the global model and the models in previous rounds of communication, it uses a penalty term to dynamically modify the equipment targets, so that when the model parameters converge under the limit, they converge to the stationary point of global empirical loss. SCAFFOLD [9] introduces variance reduction technology to deal with the problem of data heterogeneity. In the implementation, it introduces the control variables of the server and a local client to estimate the updated direction of the server model and each client. Then, the difference between the two update directions is used to approximate the offset of local training. Adding an optimizer to the update function of the client and global model can significantly improve the training performance of the model, but this method is very sensitive to parameters. Tuning model parameters is a difficult problem when faced with different datasets.
Huang et al. [22], starting from another direction, they believe that the existence of data heterogeneity does not allow us to obtain a global model with sufficient precision, so that they can personalize the model, use local data to further train the global model additional, and obtain a higher quality personalized model. Similarly to this idea, Ghosh et al. [23], Sattler et al. [24], proposed a method of dividing clients into different clusters, then trained a single global model in each cluster, which used different clustering methods to group the local empirical loss function of the client or the node gradient. Due to the high similarity of the clients in each cluster, the global model trained by this method has a higher accuracy. These methods are contrary to the traditional federated learning algorithm, they will train multiple global models. Each global model has high accuracy in one cluster but low accuracy in other clusters. This results in poor generalization performance of the model, however, it provided an idea that we could cluster clients with similar data distributions. Fraboni et al. [25] demonstrated that cluster sampling better represents customers and proposed two client aggregation methods based on sample size and model similarity. In addition, different sampling methods are used to construct a homogeneous distribution of client data, and model training is carried out, greatly improving the model performance in heterogeneous datasets. However, the computational complexity in the article is too high, and sampling of client-side local datasets seems to violate the rule that local data cannot flow out of federal learning.
Algorithm design
In FL, we consider the following standard optimization model:
where \(S_{t} \subseteq \{ 1,2, \ldots ,N\}\),\(S_{t}\) represent the set of clients participating in model parameter aggregation in round t, N is the total number of devices and \(\rho_{k}\) is the weight of the k-th device so that \(\rho_{k} \ge 0,\sum\nolimits_{{k \in S_{t} }}^{{}} {\rho_{k} } = 1\). In common, we define \(\rho_{k} = n_{k} /M\), where \(M = \sum\nolimits_{{k \in S_{t} }}^{{}} {n_{k} }\) represent the total number of client samples participating in the training in round t, \(n_{k}\) indicates the number of samples in client k. Suppose the k-th device holds the training data \(D^{k}\), \(\xi_{t}^{k}\) is a sample uniformly selected from the local data, \(\xi_{t}^{k} \sim D_{t}^{k}\).\(F_{k} ({\mathbf{w}})\) is the lost function of client k: \(F_{k} ({\mathbf{w}}) = \ell_{k} ({\mathbf{w}},\xi^{k} )\). Here, we describe a standard FedAvg algorithm. First, the central server broadcasts the latest model wt, to all the devices. Then, every device lets \({\mathbf{w}}_{t}^{k} = {\mathbf{w}}_{t}\) and performs local updates:
where \(\eta_{t}\) is the learning rate, assuming that K clients (\(1 \le K \le N\)) are selected to participate in the training process. The aggregation steps are as follows:
The global data are a mixture of all local data: \(D = \sum\nolimits_{k = 1}^{N} {\rho_{k} D_{k} }\). When the client data is independent and identically distributed, for all \(k \in N,D_{k} = D\). However, in real life, the data distribution of different clients is different, so our theoretical analysis is based on the assumption that the data is not independent and identically distributed.
Effects of data heterogeneity
Example 1.
A distributed optimization problem with N clients and convex objective functions is considered. Our goal is to learn the mean of these clients' one-dimensional data. The sample of local data is \(\xi_{t}^{k} \sim D_{t}^{k}\) and the mean \(e_{k} = E[\xi_{k} ]\). We can express this learning problem by minimizing the mean square error (MSE):
For the convenience of calculation, we assume that each client contains the same amount of data, and we can get the optimal solution as follows:
Assuming that \(\tau_{k}\) is the weight offset caused by communication loss, client device difference, etc., use \(\hat{x}^{*}\) to represent the convergence value in reality. The objective function will converge to:
Proof of Example 1
Since the objective function is convex, if we take a small learning rate, the function will converge to the optimal solution. Take the derivative of the objective function and make the derivative 0:
According to the assumption that each client contains the same amount of data, for any \(\rho_{k} (k \in N),\rho_{k} = 1/N\), therefore:
When calculating the convergence value of the objective function under real conditions, \(\rho_{k} = 1/N + \tau_{k}\) (\(\tau_{k}\) is the weight offset for client k), we can get:
Now that \(x^{*} = \hat{x}^{*}\) only when e1 = e2 = … = en (data distributions are IID) or for all \(k \in \{ 1,2, \ldots ,N\}\),\(\tau_{k} = 0\). Therefore, the traditional federated learning algorithms can lead to poor performance in the case of heterogeneous clients.
FedSSAR architecture
As shown in section "Effects of data heterogeneity", data heterogeneity and device heterogeneity significantly reduce the performance of the traditional FL algorithm. In FL, the overall data distribution is a mixture of local data distributions from each client, and in the FedAvg algorithm, this aggregation weight is the sample weight. This setting only considers the data volume differences of each client, not the hardware and communication differences of the client, such as in the classic example of federal learning. When training the mobile input method model, the latest mobile phone runs faster and transmits faster than the older one. Mobile phones in better signal locations such as cities and towns transmit more stability than mobile phones in rural areas and areas with signal interference, which results in differences between the distribution of data received by parameter servers and the actual distribution of data. Because of client heterogeneity, some classes of data participate more frequently in the training process, which introduces errors to the training data. To alleviate this problem, we consider using all types of data during each training cycle to ensure that the probability of each type of data participating in the training is basically the same so that the training data distribution is an unbiased mix of the sample distribution of each client. At the same time, regular terms are added to limit the updating of the model. In this way, we can eliminate bias in the training data and establish a convergence result.
Let us recall the training steps of federal learning, the parameter server first initializes the global model, and then broadcasts the model w0 to all clients. The client trains a sample of local data \(\xi^{k}\) based on the received model to get the local model parameter \({\mathbf{w}}_{1}^{k}\), the expression is as follows:
We can see that when calculating parameter \({\mathbf{w}}_{1}^{k}\), parameters \({\mathbf{w}}_{0}\) and \(\eta_{1}\) are the same for all clients, so parameter \({\mathbf{w}}_{1}^{k}\) only related to \(\xi_{1}^{k}\), that is, parameter \({\mathbf{w}}_{1}^{k}\) contains the data distribution information of the model. This indicates that we can group the model parameters to divide clients by their local data similarity.
![figure a](http://media.springernature.com/lw685/springer-static/image/art%3A10.1007%2Fs40747-022-00895-3/MediaObjects/40747_2022_895_Figa_HTML.png)
The detailed process of the FedSSAR is given in Algorithm 1. The client selection principle of the FedSSAR is to select available clients from different clusters (lines 2–8). After training, the parameter server collects information about the local model parameter for each client and divides the clients into groups using the OPTICS clustering method [26] (Ordering Points to Identify the Clustering Structure, OPTICS).Footnote 1 In each training round, available clients are drawn from each cluster proportionally to participate in the training, ensuring that all types of data participate in each training round and reduce the impact of client heterogeneity. After receiving the latest global model parameters from the parameter server, as shown in line 12, when calculating the client's minimizing local loss function, a regular term is added to the loss function so that the new local model parameters do not deviate from the previous global model parameters, and the degree of correction is controlled by parameter \(\alpha\). Then the latest parameters are sent back to the parameter server, the parameters returned are weighted and averaged by the parameter server (lines 10–15).
Convergence analysis
Notation and assumptions
Assumption 1.
F1, F2, …, FN are all L-smooth: for all v and w,\( \, f({\mathbf{v}}) \le f({\mathbf{w}}) + ({\mathbf{v}} - {\mathbf{w}})^{T} \nabla f({\mathbf{w}}) + L/2 \cdot \left\| {{\mathbf{v}} - {\mathbf{w}}} \right\|^{2}\).
Assumption 2.
Bounded dissimilarity:
We define \(F_{k}\) is B-local dissimilar at w if \({\rm E}_{k} [\left\| {\nabla F_{k} ({\mathbf{w}})} \right\|^{2} ] \le \left\| {\nabla F({\mathbf{w}})} \right\|^{2} \cdot B^{2}\) . Then, for all \(\left\| {\nabla F({\mathbf{w}})} \right\| \ne 0\) , \(B({\mathbf{w}}) = \sqrt {\tfrac{{{\rm E}_{k} [\left\| {\nabla F_{k} ({\mathbf{w}})} \right\|^{2} ]}}{{\left\| {\nabla F({\mathbf{w}})} \right\|^{2} }}}\) . This definition quantifies the difference between the local model loss function and the global loss function. We assume that, for some \(\varepsilon > 0\) , there exist a \(B_{\varepsilon }\) , for all \({\mathbf{w}} \in S_{\varepsilon }^{c} = \{ w|\left\| {\nabla F({\mathbf{w}})} \right\| > \varepsilon \} ,B({\mathbf{w}}) \le B_{\varepsilon }\) .
Assumption 3.
Bounded variance:
There exist a \(\sigma\) , \({\rm E}_{k} [\left\| {\nabla F_{k} ({\mathbf{w}}) - \nabla F({\mathbf{w}})} \right\|] \le \sigma^{2}\) .
Improvements provided by stratified sampling and regularization
Stratified sampling
As we know, the model parameters contain the data distribution information of the client. We assume that there are m different data distributions in the overall data, similarly, for the convenience of analysis, we assume that each client has the same amount of local data. Through clustering, the client data distribution in each cluster is the same, therefore, the trained neural network model has the same parameter information, namely:
According to Eq. 5, we use \({\mathbf{w}}_{t}^{{c_{i} }}\) to represent the model parameters in cluster i after the t rounds training, \(n_{{c_{i} }}\) indicates the number of clients in cluster i. We have:
We first prove that the sampling method of FedSSAR is unbiased, that is:
where \(w_{k} (S_{t} )\) represents the aggregate weight of client k in subset \(S_{t}\).
Proof.
We can sum the model parameters of each round of training according to different clusters,
From Eq. (7), we have
We prove that the stratified sampling method can realize the unbiased sampling of all data. At the same time, we will prove that the use of stratified sampling reduces the aggregate weight variance of the client and makes the model update more stable. For the traditional random sampling method, the aggregation weight of each client is equal to its sample weight, \(\rho_{k} \leftarrow n_{k} /M\), where \(n_{k}\) is the sample data of client k, and M is the overall sample data. Based on the assumption in this section that the client has the same amount of local data, it can be simplified to \(\rho_{k} \leftarrow 1/N\). According to the assumption that the overall data has different distributions in m, we select the number of clients participating in each round of training as m. In the random sampling method, m clients are randomly selected according to the Bernoulli distribution:\(B(\rho_{k} )\). For client k, the aggregate weight variance is:
In stratified sampling, we first divide the overall client into m cluster sets and then select one client from each set to participate in training, we use \({\text{Var}}_{c} ({\mathbf{w}}_{k} (S_{t} ))\) to represent the aggregate weight variance of client k in stratified sampling. It can be seen that:
We have that:\({\text{Var}}({\mathbf{w}}_{k} (S_{t} )) \ge {\text{Var}}_{c} ({\mathbf{w}}_{k} (S_{t} ))\), only when \(m = 1\)(the data distribution in each client is basically the same), they are equal.
Regularization
In general, for federated learning under heterogeneous conditions, the client's local optimum does not conform to the global optimum. In fact, the global optimum W* should satisfy:
The local optimum for each client \({\mathbf{w}}_{k}^{*}\) satisfies, \(\nabla F({\mathbf{w}}_{k}^{*} ) = 0\). However, due to client heterogeneity, the local and global data distributions are not the same, so the global and local optimum are often different, which means, \(\nabla F_{k} ({\mathbf{w}}_{{}}^{*} ) \ne 0\). This means that updating the model by optimizing the local empirical loss cannot make the model converge to the global optimal solution.
As shown in Algorithm 1. we added a regular term when optimizing the client’s local loss function, Therefore, the client's local model optimal solution will approach the global model optimal solution. To be more specific, when calculating the local empirical loss of the client, we add a restriction based on calculating the empirical loss, \(\tfrac{\alpha }{2}\left\| {{\mathbf{w}} - {\mathbf{w}}_{t} } \right\|^{2}\),we have \(\tfrac{\alpha }{2}\left\| {{\mathbf{w}} - {\mathbf{w}}_{t} } \right\|^{2} = 0\), if \({\mathbf{w}} - {\mathbf{w}}_{t}\), we modify the optimal point of each client to make it close to the best global point.
Convergence result
In this section, we discuss the convergence of FedSSAR with the participation of all devices in the aggregation step. Assume that FedSSAR terminates after T round iteration and returns wT as the output. Theoretical analysis proves the convergence of FedSSAR under non-convex setting.
Theorem 1
(Non-convex FedSSAR convergence: All devices participate).
Let Assumptions 1 to 3 hold and \(L,\alpha ,\sigma ,\varepsilon\) be defined. Assume the functions \(F_{k}\) are non-convex and there exist a \(\overline{L} > 0\), such that \(\nabla^{2} F_{k} \ge - \overline{L} \_I\), with \(\overline{\alpha } = \alpha - \overline{L} > 0\). We have
after T times of iterations
In full client participation mode, FL are seriously affected by the "straggler's effect" (where all nodes wait for the slowest node), so partial client participation FL have more practical applications. Suppose that there are m different data distributions in the overall dataset, \(S_{t} \subseteq \{ 1,2, \ldots ,m\}\) is a subset contains m indices in the k-th iteration, and St is randomly selected from each cluster. We can get the convergence of FedSSAR with the participation of partial clients.
Theorem 2
(Non-convex FedSSAR convergence: Partial devices participate).
Let Assumptions 1 to 3 hold and \(L,\alpha ,\sigma ,\varepsilon\) be defined. Assume the functions \(F_{k}\) are non-convex and there exist a \(\overline{L} > 0\), such that \(\left\| {\nabla^{2} F_{k} } \right\| \ge - \overline{L}\), with \(\overline{\alpha } = \alpha - \overline{L} > 0\). FedSSAR with partial device participation satisfies:
where
Experiments
This section, we evaluate the performance of FedSSAR across various datasets, models, availability settings and compare it with FedAvg, FedProx and SCAFFOLD algorithms.
Experimental details
Datasets
We experimented with different public datasets. These datasets are the benchmark datasets derived from previous related work (MNIST, Cifar10, Cifar-100 and EMNIST-L). Take MNIST dataset as an example, we distributed 60,000 training data to 100 clients, with 600 data (1%) in each client. The IID split adopts independent and random division, as shown in Fig. 1a, the data distribution in each client is basically the same; To simulate different degrees of data heterogeneous, the non-IID and non-IID2 split use biased sampling, the overall data are sorted according to the labels and then divided into different slices so that each slice contains only one kind of label data, and then the slice data are randomly divided into different clients. The size of the slice affects the number of data types contained in the client, each client in non-IID contains approximately two types of data (Fig. 1b), and each client in non-IID2 contains approximately only one type of data (Fig. 1c), which simulates the extreme heterogeneity of the data. Figure 1 shows the data distribution of the top 20 clients in the MNIST dataset.
We use a similar division method in Cifar-10 and EMNIST-L dataset (details are shown in "Appendix B.1"). Specifically, we conducted a large-scale client experiment with 1000 clients on MNIST, EMNIS, Cifar10 and Cifar-100 dataset. Figure 2 shows the data distribution of the top 20 clients in the Cifar-100 dataset.
Implementation
We select the FedAvg, FedProx and SCAFFOLD algorithms as baselines. The value of parameter \(\alpha\) in FedSSAR is selected from set {1, 0.5, 0.1, 0.05, 0.01, 0.005, 0.001}, in "Appendix C.1" we test the sensitivity of \(\alpha\). The original proportion of clients selected per round is 10%; the initial learning rate is 0.1, and the learning rate decay \(\eta_{t} = 0.1/(1 + t)\); we set the parameter weight decay as \(10^{ - 3}\); we set \(\mu = 10^{ - 3}\) for FedProx; K of 12 and 1 is selected for SCAFFOLD under moderate and massive device number. In order to ensure that the samples extracted each time are unbiased estimates of the population samples, we randomly extract the target number of samples from the population samples (FedSSAR algorithm is from each cluster) in each round.
Models
We use fully connected multilayer neural network architectures for MNIST and EMNIST-L with 2 hidden layers (the number of neurons in the hidden layer are 200 and 100). For Cifar-10 and Cifar-100 datasets, we use a CNN model (2 convolutional layers with 64 5 × 5 filters, 2 fully connected hidden layers contains 394 and 192 neurons followed by a softmax layer).
Clustering results of model parameters
According to algorithm 1, taking MNIST (100 clients) dataset as an example, under different heterogeneous data conditions, the clustering results of model parameters are shown in Fig. 3.
As shown in Fig. 3a, for the IID split, all clients belong to the same cluster, and the local data distribution among clients has high similarity, so it can be considered that the local data distribution of clients is the same as the overall data distribution of datasets; Under different heterogeneous settings (as shown in Fig. 3b, c), clients are divided into different clusters. As shown in Fig. 3b, in the first type of heterogeneous data settings, the total number of clients is divided into 29 groups, and the similarity of the model parameters in each set is relatively high; Fig. 3c shows the results of the heterogeneous setting of the second type of data. It can be seen that the overall client is divided into 10 cluster sets. Compared to the results of the heterogeneous setting of the first type of data, the model parameters between different clusters differ greatly. See "Appendix B.2" for the clustering results of other datasets.
Experimental results
We first tested the performance of the FedSSAR on different datasets with the above baselines: FedAvg, FedProx and SCAFFOLD under different conditions. The verification performance of each task is shown in Figs. 4 and 5. Tables 1 and 2 summarizes the validation performance of MNIST and EMNIST-L after 50 rounds of training and the validation performance of Cifar-10, Cifar-100 after 500 rounds of training. Due to space constrains, plots of model training losses will be shown in "Appendix C.2".
Non-IID split the non-IID split simulates the medium degree of data heterogeneity. From Table 1, it can be seen that FedSSAR achieves the best results in the specified training rounds in all task settings. In terms of dataset structure, MNIST and EMNIST are handwritten character image recognition, and Cifar-10 and Cifar-100 are color object image recognition in reality, therefore, the latter has higher data distribution differences on different types of pictures and higher heterogeneity of the entire dataset, experiments on Cifar datasets require more rounds of training to achieve convergence than on MNIST and EMNIST datasets. At the same time, FedSSAR performs better than competing methods in more heterogeneous datasets and in a larger number of device scenarios that fit larger-scale scenarios, such as Cifar-10 and Cifar-100 for a large number of devices.
In addition to faster training speed and higher training accuracy, as shown in Fig. 4, the training process of FedSSAR is more stable. As we have shown in section "Improvements provided by stratified sampling and regularization", the stratified sampling method reduces the client's aggregation weight bias, while the regular term on the client optimization function limits the direction of updating the model parameters, making the updating of the model more stable. Especially on highly heterogeneous datasets, such as CIFAR datasets, the traditional algorithm training curve fluctuates sharply, while the FedSSAR maintains a stable growth curve.
Non-IID2 split the non-IID2 split simulates the extreme heterogeneity of client data, as shown in Fig. 1 and "Appendix B.1", different types of clients have completely different local data distribution. According to the results of the clustering ("Appendix B.2"), for clients with different clustering results, their training model parameters differ greatly. When the model parameters are aggregated, the aggregation variance of the model parameters participating in the training increases, which greatly affects the training efficiency of the model. As can be seen from Table 2, the model performance of all training tasks is reduced to varying degrees compared to the non-IID split. However, under extreme heterogeneity conditions, FedSSAR shows better performance. As the degree of data heterogeneity increases, FedSSAR improves more than traditional federated learning algorithms.
Conclusions
In our work, we propose the FedSSAR algorithm, an FL algorithm optimized by stratified sampling and regularization. FedSSAR uses the density-based clustering method to divide the heterogeneous clients into different clusters, then use the stratified sampling method to extract the client from each cluster to participate in training, and a regular item is used to restrict the local model updates of heterogeneous clients. Under non-convex conditions, we provide the convergence proof of FedSSAR algorithm, and prove that FedSSAR has a convergence rate of \(O(1/\sqrt T )\). We verify the performance of FedSSAR on different training tasks and parameter settings, our experiments prove that FedSSAR has a great improvement over the FedAvg, FedProx and SCAFFOLD on heterogeneous conditions.
Notes
OPTICS is a density-based clustering algorithm. It defines the cluster as the maximum set of density connected points and divides the region with sufficient density into clusters. OPTICS was able to detect clusters of any shape in noisy spatial data compared to K-means and BIRCH. Compared to the DBSCAN method, the OPTICS method is less sensitive to input parameters and improves cluster stability. Thus, OPTICS clustering has several advantages over other clustering methods: (1) OPTICS does not need to know the number of clusters in advance; (2) OPTICS can find cluster classes of any shape; (3) OPTICS can detect noise points and eliminate the influence of malicious attack nodes; (4) OPTICS is insensitive to input parameters.
References
Li T, Sahu AK, Talwalkar A, Smith V (2020) Federated learning: challenges, methods, and future directions. IEEE Signal Process Mag 37(3):50–60
Mcmahan HB, Moore E, Ramage D, Hampson S, Arcas B (2016) Communication-efficient learning of deep networks from decentralized data. arXiv:1602.05629
Mcmahan HB, Ramage D, Talwar K, Zhang L (2017) Learning differentially private recurrent language models. arXiv:1710.06963
Yang Q, Liu Y, Chen T, Tong Y (2019) Federated machine learning. ACM Trans Intell Syst Technol 10(2):1–19
Hsieh K, Phanishayee A, Mutlu O, Gibbons PB (2019) The non-IID data quagmire of decentralized machine learning. arXiv:1910.00189
Li T, Sahu AK, Zaheer M, Sanjabi M, Talwalkar A, Smith V (2018) Federated optimization in heterogeneous networks. arXiv:1812.06127
Hard A, Rao K, Mathews R, Beaufays F, Ramage D (2018) Federated learning for mobile keyboard prediction. arXiv:1811.03604
Yang T, Andrew G, Eichner H, Sun H, Li W, Kong N, Ramage D, Franoise B (2018) Applied federated learning: improving google keyboard query suggestions. arXiv:1812.02903
Karimireddy SP, Kale S, Mohri M et al (2020) Scaffold: stochastic controlled averaging for federated learning. In: International Conference on Machine Learning. PMLR, pp 5132–5143
Boyd S, Parikh N, Hu EC, Peleato B, EcKstein J (2010) Distributed Optimization and Statistical Learning via the Alternating Direction Method of Multipliers. Found Trends Machine Learn 3(1):1–122
Dekel O, Gilad-Bachrach R, Shamir O, Xiao L (2012) Optimal distributed online prediction using mini-batches. J Machine Learn Res 13(1):165–202
Richtárik P, Takáč M (2016) Distributed coordinate descent method for learning with big data. J Machine Learn Res 17(1):2657–2681
Zhang S, Choromanska A, Lecun Y (2014) Deep learning with Elastic Averaging SGD. MIT Press, Cambridge
Bonawitz K, Ivanov V, Kreuter B et al (2017) Practical secure aggregation for privacy-preserving machine learning. In: Proceedings of the 2017 ACM SIGSAC Conference on Computer and Communications Security, pp 1175–1191
Bonawitz K et al (2021) Towards federated learning at scale: system design. arXiv:1912.04977
Mohri M, Sivek G, Suresh AT (2019) Agnostic federated learning. arXiv:1902.01046
Hu H, Wang D, Wu C (2020) Distributed machine learning through heterogeneous edge systems. In: Proceedings of the AAAI Conference on Artificial Intelligence, Vol 34, no. 05, pp 7179–7186
Peter Kairouz H, McMahan B et al (2021) Advances and open problems in federated learning. arXiv:1912.04977
Zhao Y, Li M, Lai L, Suda N, Civin D, Chandra V (2018) Federated learning with non-IID data. arXiv:1806.00582
Yan Y, Niu C, Ding Y, Zheng Z, Wu F, Chen G, Tang S, Wu Z (2020) Distributed non-convex optimization with sublinear speedup under intermittent client availability. arXiv:2002.07399
Acar DAE, Zhao Y, Navarro RM et al (2021) Federated learning based on dynamic regularization. arXiv:2111.04263
Huang Y, Chu L, Zhou Z et al (2020) Personalized cross-silo federated learning on non-IID data. arXiv:2007.03797
Ghosh A, Chung J, Dong Y, Ramchandran K (2020) An efficient framework for clustered federated learning. arXiv:2006.04088
Sattler F, Muller KR, Samek W (2021) Clustered federated learning: model-agnostic distributed multitask optimization under privacy constraints. IEEE Trans Neural Netw Learn Syst 32(8):3710–3722
Fraboni Y, Vidal R, Kameni L et al (2021) Clustered sampling: low-variance and improved representativity for clients selection in federated learning. In: International Conference on Machine Learning. PMLR, pp 3407–3416
Ankerst MB, Kriegel MM, Sander HPJ (1999) Optics: ordering points to identify the clustering structure. SIGMOD Record Spec Interest Group Manage Data 28:49–60
Funding
This work was supported by the General Program of National Natural Science Foundation of China (61871388).
Author information
Authors and Affiliations
Corresponding author
Ethics declarations
Conflict of interest
The authors declare that they have no known competing financial interests or personal relationships that could have appeared to influence the work reported in this paper.
Additional information
Publisher's Note
Springer Nature remains neutral with regard to jurisdictional claims in published maps and institutional affiliations.
Appendices
A: Complete proofs
A.1 Proof of Theorem 1
Let \({\mathbf{w}}_{t}^{k}\) represent the model parameters of the k-th client in training round t, the local update of clients can be expressed as:
Then, we have
we define \(\overline{{\mathbf{w}}}_{t + 1} = {\rm E}_{k} [{\mathbf{w}}_{t + 1}^{k} ] = \sum\nolimits_{k = 1}^{N} {\rho_{k} {\mathbf{w}}_{t + 1}^{k} }\) such that:
According to Assumption 1, the function \(F_{k}\) is L-Smooth and non-convex, there exist a \(\overline{L} > 0\), such that \(\left\| {\nabla^{2} F_{k} } \right\| \ge - \overline{L}\), with \(\overline{\alpha } = \alpha - \overline{L} > 0\), we have \(h_{k}\) is \(\overline{\alpha }\)-strong convexity. We know that:
Form Asuumption2 and 3, we have,
We define \(\mathop {M_{t + 1} }\limits_{{}} = {\rm E}_{k} (\nabla F_{k} ({\mathbf{w}}_{t + 1}^{k} ) - \nabla F_{k} ({\mathbf{w}}_{t} ))\), such that \(\mathop {\overline{{\mathbf{w}}}_{t + 1} - {\mathbf{w}}_{t} = - \tfrac{1}{\alpha }(\nabla F({\mathbf{w}}_{t} ) + M_{t + 1} )}\limits_{{}}\), we can bound \(\mathop {\left\| {M_{t + 1} } \right\|}\limits_{{}}\):
Based on Assumption 1, we have
Let \(\rho = \left[ {\tfrac{1}{\alpha } - \tfrac{{L\sqrt {1 + \tfrac{{\sigma^{2} }}{\varepsilon }} }}{{\alpha \overline{\alpha } }} - \tfrac{{L\sqrt {1 + \tfrac{{\sigma^{2} }}{\varepsilon }} }}{{2\overline{\alpha }^{2} }}} \right] > 0\),we have
Then
We have
A.2 Proof of Theorem 2
Participation of all devices in training is not how FedSSAR works. Based on the results of clustering, it is assumed that there are m data distribution modes in all devices, and we select m clients to participate in training each time. We define the client collection selected for each round as \(S_{t}\). We define \({\rm E}_{{S_{t} }} [F({\mathbf{w}}_{t + 1} )] = \sum\nolimits_{{k \in S_{t} }} {\rho_{k} {\mathbf{w}}_{t + 1}^{k} }\), obviously, \({\rm E}_{{S_{t} }} [F({\mathbf{w}}_{t + 1} )] \ne F({\overline{\mathbf{w}}}_{t + 1} )\), we need to bound \({\rm E}_{{S_{t} }} [F({\mathbf{w}}_{t + 1} )]\). According to Assumption 1, we have
We take expectations on subset \(S_{t}\)
We have that
Combining Eqs. (27), (28) and (29), we get
From Eq. (21), we have
Let \(\rho = \tfrac{1}{\alpha } - \tfrac{{L\sqrt {1 + \tfrac{{\sigma^{2} }}{\varepsilon }} }}{{\alpha \overline{\alpha } }} - \tfrac{{L\sqrt {1 + \tfrac{{\sigma^{2} }}{\varepsilon }} }}{{2\overline{\alpha }^{2} }} - \tfrac{\sqrt 2 }{{\sqrt m \cdot \overline{\alpha } }}\sqrt {1 + \tfrac{{\sigma^{2} }}{\varepsilon }} - \tfrac{L}{{m \cdot \overline{\alpha }^{2} }}(1 + \tfrac{{\sigma^{2} }}{\varepsilon }) > 0\), the same proof process as Theorem 1, we can get
B Datasets and clustering results
B.1 Supplementary data distribution map
In MNIST dataset, for massive device number, we distributed 60000 training data to 1000 clients, with approximately 60 data (0. 1%) in each client. In EMNIST-L dataset, for moderate device number, we distributed 48000 training data to 100 clients, with 480 data (1%) in each client; for massive device number, we distributed 48000 training data to 1000 clients, with 48 data (0.1%) in each client. In Cifar-10 dataset, for the moderate device number, we distributed 50000 training data to 100 clients, with 500 data (1%) in each client; for the massive device number, we distributed 50000 training data to 1000 clients, with 50 data (0.1%) in each client. In Cifar-100 dataset, we distributed 50000 training data to 100 clients, with 500 data (1%) in each client; for massive device number, we distributed 50000 training data to 1000 clients, with 50 data (0.1%) in each client (Figs.
6,
7 and
8).
B.2 Supplementary clustering results
See Figs.
Clustering results of EMNIST-L and Cifar-10 under the condition of moderate client scale and different data heterogeneity: For the IID split, all clients belong to the same set in the above three datasets. For non-IID split, in EMNIST-L and Cifar-10, clients are divided into 33 clusters. For non-IID2 split, in EMNIST-L and Cifar-10, clients are divided into 10 clusters
9 and
Clustering results of MNIST, EMNIST-L, Cifar-10 and Cifar-100 under the condition of massive client scale and different data heterogeneity: For the IID split, all clients have the same data distribution and all clients belong to the same set. For the non-IID split, clients are divided into 131 clusters in MNIST,133 in EMNIST-L, 120 in Cifar-10, 342 in Cifar-100. For non-IID2 split, there are 10 different data distributions for the overall data in MNIST, EMNIST-L and Cifar-10 (100 different data distributions for Cifar-100), and all clients are accurately partitioned into 10 (100 for Cifar-100) different cluster collections
10.
C Additional experimental results
C.1 \(\alpha\) sensitivity analysis of FedSSAR.
In FedSSAR, \(\alpha\) is an important parameter. \(\alpha\) changes the update direction of the local model, making the heterogeneous clients update in the direction of global optimization. As shown in Theorems 1 and 2, the value of \(\alpha\) affects the convergence speed of the model, so it is necessary to explore the sensitivity of \(\alpha\).
To explore the sensitivity of \(\alpha\), we consider MNIST and EMNIST-L under non-IID2 split, 100 clients, 10% participation setting. Figure 10 shows the best-achieved test accuracy that can be achieved under different \(\alpha\) while other parameters remain the same. We can see that the best test accuracy can be obtained when \(\alpha = 10^{ - 3}\). (Fig.
11).
C.2 Experimental training loss
See Figs.
12 and
13.
Rights and permissions
Open Access This article is licensed under a Creative Commons Attribution 4.0 International License, which permits use, sharing, adaptation, distribution and reproduction in any medium or format, as long as you give appropriate credit to the original author(s) and the source, provide a link to the Creative Commons licence, and indicate if changes were made. The images or other third party material in this article are included in the article's Creative Commons licence, unless indicated otherwise in a credit line to the material. If material is not included in the article's Creative Commons licence and your intended use is not permitted by statutory regulation or exceeds the permitted use, you will need to obtain permission directly from the copyright holder. To view a copy of this licence, visit http://creativecommons.org/licenses/by/4.0/.
About this article
Cite this article
Lu, C., Ma, W., Wang, R. et al. Federated learning based on stratified sampling and regularization. Complex Intell. Syst. 9, 2081–2099 (2023). https://doi.org/10.1007/s40747-022-00895-3
Received:
Accepted:
Published:
Issue Date:
DOI: https://doi.org/10.1007/s40747-022-00895-3