Introduction

Federated learning is a new distributed machine learning paradigm [1, 2] that allows multiple devices (called clients) to collaboratively train a global model without uploading their local data [3, 4]. Compared to traditional distributed machine learning, the main differences are as follows: (1) clients have independent control over local devices and data; (2) clients are often unreliable (edge nodes are often disconnected due to equipment and communication problems); (3) communication consumption is higher than computing consumption; (4) the data distribution in FL is not independent identically distributed(non-IID); (5) the distribution of local data is uneven [5]. These new features challenge the design and analysis of FL algorithms.

One of the major challenges is client heterogeneity, including data and device heterogeneity. Heterogeneity of clients for FL exists extensively under real-world conditions [6], such as (1) heterogeneous data distribution, data on each client are generated locally, so the sample generation mechanism among clients may be different (like different countries or regions); (2) the covariate shift, for example, in handwriting recognition, different people write the same word differently; (3) label distribution skew (Prior probability drift), such as the use of the Chinese people in China, mainly in foreign people use less; (4) Quantity tilt or imbalance, etc. In real life, all kinds of a situation may lead to the occurrence of non-IID data. Traditional machine learning is based on the assumption of IID data, but FL differs from centralized machine learning in that the data on each node are non-IID in the absence of centralized data.

Consider a real situation, when we use the FL method to train a model of mobile phone input method [7, 8], different mobile phones have different operating speeds, internal data, network conditions, etc. The latest mobile phones have faster operation speed and transmission speed than the old ones. Mobile phones in a better signal location such as town have better signal location than those in rural areas. Mobile communication transmission is more stable in areas with signal interference. In model training, the old mobile phone training is slow and often does not complete the training task on time. Mobile phones with poor network alleviation are also more prone to signal loss when transmitting the model, resulting in the difference between the data distribution received by the parameter server and the actual distribution. Because of client heterogeneity, some classes of data participate more frequently in the training process, which introduces errors to the training data.

To reduce the impact of client heterogeneity, we propose FedSSAR algorithm, which consists of two parts: Client selection and regular term restriction. Compared with the traditional method of randomly selecting the clients participating in each round of training, we divide the overall clients into multiple clusters according to their local data similarity without requiring the local data of the clients and then extracts a certain number of available clients from different clusters for aggregation. At the same time, we add a regularizer to the client, modifying based on the global model, which reduces updates far from the global model when the client performs local model updates. By setting the appropriate parameters, we prove the \(O(1/\sqrt T )\) convergence speed of the algorithm. We performed experiments on benchmark datasets, including MNIST, EMNIST, Cifar-10, and Cifar-100 datasets, to explore the performance of algorithms under different levels of data heterogeneity. We compared FedSSAR with FedAvg [2], FedProx [6], and SCAFFOLD [9]. Experiments show that our proposed algorithm has a faster convergence speed and a higher training accuracy. The main contributions of this paper are as follows:

  • We demonstrate that, for a convex objective function, the traditional FL algorithm cannot converge to the global optimum due to client heterogeneity even if a precise (rather than random) gradient descent is used. In particular, when data are highly heterogeneous, traditional FL algorithm may diverge;

  • We propose the FedSSAR algorithm. The main idea of FedSSAR is to use stratified sampling for client selection and the regular term to limit the update of the model. Clients sampled by clustering can better represent the overall data distribution and reduce the variance of the aggregation model. At the same time, the regularization method is used to change the optimization objective of each device, making it close to the optimization objective of the global loss function;

  • We theoretically identify the reason for the divergence of traditional FL algorithm, and prove convergence results for FedSSAR. FedSSAR has a \(O(1/\sqrt T )\) convergence rate by our theoretical analysis;

  • We used the public MNIST, EMNIST-L, Cifar-10 and Cifar-100 datasets to evaluate the FedSSAR and compared it with FedAvg, FedProx and SCAFFOLD. The evaluation results validate the superiority of our FedSSAR in the aspects of a smoother training process, faster convergence and less training loss.

The remainder of this paper is organized as follows. In section "Related works", we introduce the background research of FL and an overview of related studies on heterogeneous federated learning; in section "Algorithm design", we theoretically analyse the effect of client heterogeneity on model convergence and propose the framework for FedSSAR; in section "Convergence analysis", we discuss the improvements provided by stratified sampling and regularization and give a theoretical proof of the convergence of FedSSAR; in section "Experiments", we comprehensively evaluate FedSSAR using different neural network models on multiple datasets.

Related works

With the improvement of storage capacity and computing capability of devices, in big data regime, there is an ever-increasing trend for distributed optimization [10,11,12,13], Traditional distributed machine learning concentrates training data in one machine or one data center and then updates the model. However, with the increasing local computing power of mobile phones, smart wearable devices, sensors, and other devices, as well as recent restrictions on user data privacy protection [14], it is more efficient for distributed clients to train models locally and then upload model parameters to parameter servers for aggregation rather than directly transferring client data. This research is known as FL and needs to address challenges such as large-scale training data, privacy protection, heterogeneous data, and devices [1, 15,16,17].

McMahan et al. [2] proposed a distributed learning method based on an iterative average of the training model and developed the Federated Averaging (FedAvg) algorithm, the learning task is jointly trained and solved by the participating devices and coordinated by the central server, its organization form is like a loose federation. Compared to data-centric distributed machine learning, one of the main advantages of FL is that it separates model training from the need for direct access to raw data, which is very important when data privacy is strictly required or it is difficult to share data centrally. Meanwhile, the FedAvg algorithm accelerates learning efficiency by using multiple local iterations. This is very helpful in reducing communication costs.

Kairouz et al. [18] discussed open issues and challenges in FL: (1) Data heterogeneity (non-IID data) in federal learning; (2) Data privacy of the client; (3) Communication restrictions; (4) Robustness of the algorithm; (5) Fairness and efficiency. One of the most fundamental challenges is the heterogeneity of client data.

To solve the problem of different data distribution in FL, Zhao et al. [19] improved the FedAvg algorithm. They proposed that when the client data is distributed heterogeneous, the application of FedAvg has a high precision loss, which can be explained by weight divergence. They proposed an improved strategy to create a small part of globally shared data among all devices. Although this method can reduce the impact of data skew, it is equivalent to adding errors artificially. In addition, this data-sharing method essentially violates the data privacy protection principle of federal learning. Yan et al. [20], considering the intermittent availability of clients. The article holds that under heterogeneous client conditions, different clients participate at different training times, which leads the training model to skew towards training-intensive client data. Therefore, when selecting clients, clients with fewer training sessions are preferred for training to ensure that each client participates in training as many times as possible. However, the assumptions in this article are too strong and the way in which the average number of client training sessions is enforced is susceptible to the influence of slow nodes, resulting in a much longer training time.

Li et al. [6] proposed FedProx, starts with the object function and reduces the impact of data heterogeneity by adding a restriction to the object function of the local model, so that each client does not deviate too much from the global model when using local data updates, this method can alleviate the impact of heterogeneity by increasing acceptable computing overhead on the client. Similar to this idea, Acar et al. [21], proposed FedDyn, when updating the local model, add regularization items based on the global model and the models in previous rounds of communication, it uses a penalty term to dynamically modify the equipment targets, so that when the model parameters converge under the limit, they converge to the stationary point of global empirical loss. SCAFFOLD [9] introduces variance reduction technology to deal with the problem of data heterogeneity. In the implementation, it introduces the control variables of the server and a local client to estimate the updated direction of the server model and each client. Then, the difference between the two update directions is used to approximate the offset of local training. Adding an optimizer to the update function of the client and global model can significantly improve the training performance of the model, but this method is very sensitive to parameters. Tuning model parameters is a difficult problem when faced with different datasets.

Huang et al. [22], starting from another direction, they believe that the existence of data heterogeneity does not allow us to obtain a global model with sufficient precision, so that they can personalize the model, use local data to further train the global model additional, and obtain a higher quality personalized model. Similarly to this idea, Ghosh et al. [23], Sattler et al. [24], proposed a method of dividing clients into different clusters, then trained a single global model in each cluster, which used different clustering methods to group the local empirical loss function of the client or the node gradient. Due to the high similarity of the clients in each cluster, the global model trained by this method has a higher accuracy. These methods are contrary to the traditional federated learning algorithm, they will train multiple global models. Each global model has high accuracy in one cluster but low accuracy in other clusters. This results in poor generalization performance of the model, however, it provided an idea that we could cluster clients with similar data distributions. Fraboni et al. [25] demonstrated that cluster sampling better represents customers and proposed two client aggregation methods based on sample size and model similarity. In addition, different sampling methods are used to construct a homogeneous distribution of client data, and model training is carried out, greatly improving the model performance in heterogeneous datasets. However, the computational complexity in the article is too high, and sampling of client-side local datasets seems to violate the rule that local data cannot flow out of federal learning.

Algorithm design

In FL, we consider the following standard optimization model:

$$ \mathop {\min }\limits_w \left\{ {F({\varvec{w}}) = \sum_{k = S_t }^{} {\rho _k F_k ({\varvec{w}})} } \right\} $$
(1)

where \(S_{t} \subseteq \{ 1,2, \ldots ,N\}\),\(S_{t}\) represent the set of clients participating in model parameter aggregation in round t, N is the total number of devices and \(\rho_{k}\) is the weight of the k-th device so that \(\rho_{k} \ge 0,\sum\nolimits_{{k \in S_{t} }}^{{}} {\rho_{k} } = 1\). In common, we define \(\rho_{k} = n_{k} /M\), where \(M = \sum\nolimits_{{k \in S_{t} }}^{{}} {n_{k} }\) represent the total number of client samples participating in the training in round t, \(n_{k}\) indicates the number of samples in client k. Suppose the k-th device holds the training data \(D^{k}\), \(\xi_{t}^{k}\) is a sample uniformly selected from the local data, \(\xi_{t}^{k} \sim D_{t}^{k}\).\(F_{k} ({\mathbf{w}})\) is the lost function of client k: \(F_{k} ({\mathbf{w}}) = \ell_{k} ({\mathbf{w}},\xi^{k} )\). Here, we describe a standard FedAvg algorithm. First, the central server broadcasts the latest model wt, to all the devices. Then, every device lets \({\mathbf{w}}_{t}^{k} = {\mathbf{w}}_{t}\) and performs local updates:

$$ {\mathbf{w}}_{t + 1}^{k} = {\mathbf{w}}_{t}^{k} - \eta_{t} \nabla F_{k} ({\mathbf{w}}_{t}^{k} ) $$
(2)

where \(\eta_{t}\) is the learning rate, assuming that K clients (\(1 \le K \le N\)) are selected to participate in the training process. The aggregation steps are as follows:

$$ {\mathbf{w}}_{t + 1} \leftarrow \sum\limits_{{k \in S_{t} }}^{{}} {\tfrac{{n_{k} }}{M}{\mathbf{w}}_{t + 1}^{k} } $$
(3)

The global data are a mixture of all local data: \(D = \sum\nolimits_{k = 1}^{N} {\rho_{k} D_{k} }\). When the client data is independent and identically distributed, for all \(k \in N,D_{k} = D\). However, in real life, the data distribution of different clients is different, so our theoretical analysis is based on the assumption that the data is not independent and identically distributed.

Effects of data heterogeneity

Example 1.

A distributed optimization problem with N clients and convex objective functions is considered. Our goal is to learn the mean of these clients' one-dimensional data. The sample of local data is \(\xi_{t}^{k} \sim D_{t}^{k}\) and the mean \(e_{k} = E[\xi_{k} ]\). We can express this learning problem by minimizing the mean square error (MSE):

$$ \begin{gathered} f(x) = \sum\limits_{k = 1}^{N} {\rho_{k} f_{k} (x) = \sum\limits_{k = 1}^{N} {\rho_{k} E_{{\xi_{k} \sim D_{k} }} [(x - \xi_{k} )^{2} ]} } \hfill \\ { = }\sum\limits_{k = 1}^{N} {\rho_{k} (x - e_{k} )^{2} } + \sum\limits_{k = 1}^{N} {\rho_{k} E_{{\xi_{k} \sim D_{k} }} [(\xi_{k} - e_{k} )^{2} ]} \hfill \\ \end{gathered} $$

For the convenience of calculation, we assume that each client contains the same amount of data, and we can get the optimal solution as follows:

$$ x^{*} = 1/N\left( {\sum\limits_{k = 1}^{N} {e_{k} } } \right) $$

Assuming that \(\tau_{k}\) is the weight offset caused by communication loss, client device difference, etc., use \(\hat{x}^{*}\) to represent the convergence value in reality. The objective function will converge to:

$$ \hat{x}^{*} = 1/N\left( {\sum\limits_{k = 1}^{N} {e_{k} } } \right) + \sum\limits_{k = 1}^{N} {\tau_{k} e_{k} } $$

Proof of Example 1

Since the objective function is convex, if we take a small learning rate, the function will converge to the optimal solution. Take the derivative of the objective function and make the derivative 0:

$$ \nabla f(x) = 2\sum\limits_{k = 1}^{N} {\rho_{k} (x - e_{k} )} $$
$$ x = \sum\limits_{k = 1}^{N} {\rho_{k} e_{k} } $$

According to the assumption that each client contains the same amount of data, for any \(\rho_{k} (k \in N),\rho_{k} = 1/N\), therefore:

$$ x^{*} = 1/N\left( {\sum\limits_{k = 1}^{N} {e_{k} } } \right) $$

When calculating the convergence value of the objective function under real conditions, \(\rho_{k} = 1/N + \tau_{k}\) (\(\tau_{k}\) is the weight offset for client k), we can get:

$$ \hat{x}^{*} = 1/N\left( {\sum\limits_{k = 1}^{N} {e_{k} } } \right) + \sum\limits_{k = 1}^{N} {\tau_{k} e_{k} } $$

Now that \(x^{*} = \hat{x}^{*}\) only when e1 = e2 = … = en (data distributions are IID) or for all \(k \in \{ 1,2, \ldots ,N\}\),\(\tau_{k} = 0\). Therefore, the traditional federated learning algorithms can lead to poor performance in the case of heterogeneous clients.

FedSSAR architecture

As shown in section "Effects of data heterogeneity", data heterogeneity and device heterogeneity significantly reduce the performance of the traditional FL algorithm. In FL, the overall data distribution is a mixture of local data distributions from each client, and in the FedAvg algorithm, this aggregation weight is the sample weight. This setting only considers the data volume differences of each client, not the hardware and communication differences of the client, such as in the classic example of federal learning. When training the mobile input method model, the latest mobile phone runs faster and transmits faster than the older one. Mobile phones in better signal locations such as cities and towns transmit more stability than mobile phones in rural areas and areas with signal interference, which results in differences between the distribution of data received by parameter servers and the actual distribution of data. Because of client heterogeneity, some classes of data participate more frequently in the training process, which introduces errors to the training data. To alleviate this problem, we consider using all types of data during each training cycle to ensure that the probability of each type of data participating in the training is basically the same so that the training data distribution is an unbiased mix of the sample distribution of each client. At the same time, regular terms are added to limit the updating of the model. In this way, we can eliminate bias in the training data and establish a convergence result.

Let us recall the training steps of federal learning, the parameter server first initializes the global model, and then broadcasts the model w0 to all clients. The client trains a sample of local data \(\xi^{k}\) based on the received model to get the local model parameter \({\mathbf{w}}_{1}^{k}\), the expression is as follows:

$$ {\mathbf{w}}_{1}^{k} \leftarrow {\mathbf{w}}_{0} - \eta_{1} \nabla \ell_{k} ({\mathbf{w}}_{0} ,\xi_{1}^{k} ) $$
(4)

We can see that when calculating parameter \({\mathbf{w}}_{1}^{k}\), parameters \({\mathbf{w}}_{0}\) and \(\eta_{1}\) are the same for all clients, so parameter \({\mathbf{w}}_{1}^{k}\) only related to \(\xi_{1}^{k}\), that is, parameter \({\mathbf{w}}_{1}^{k}\) contains the data distribution information of the model. This indicates that we can group the model parameters to divide clients by their local data similarity.

figure a

The detailed process of the FedSSAR is given in Algorithm 1. The client selection principle of the FedSSAR is to select available clients from different clusters (lines 2–8). After training, the parameter server collects information about the local model parameter for each client and divides the clients into groups using the OPTICS clustering method [26] (Ordering Points to Identify the Clustering Structure, OPTICS).Footnote 1 In each training round, available clients are drawn from each cluster proportionally to participate in the training, ensuring that all types of data participate in each training round and reduce the impact of client heterogeneity. After receiving the latest global model parameters from the parameter server, as shown in line 12, when calculating the client's minimizing local loss function, a regular term is added to the loss function so that the new local model parameters do not deviate from the previous global model parameters, and the degree of correction is controlled by parameter \(\alpha\). Then the latest parameters are sent back to the parameter server, the parameters returned are weighted and averaged by the parameter server (lines 10–15).

Convergence analysis

Notation and assumptions

Assumption 1.

F1, F2, …, FN are all L-smooth: for all v and w,\( \, f({\mathbf{v}}) \le f({\mathbf{w}}) + ({\mathbf{v}} - {\mathbf{w}})^{T} \nabla f({\mathbf{w}}) + L/2 \cdot \left\| {{\mathbf{v}} - {\mathbf{w}}} \right\|^{2}\).

Assumption 2.

Bounded dissimilarity:

We define \(F_{k}\) is B-local dissimilar at w if \({\rm E}_{k} [\left\| {\nabla F_{k} ({\mathbf{w}})} \right\|^{2} ] \le \left\| {\nabla F({\mathbf{w}})} \right\|^{2} \cdot B^{2}\) . Then, for all \(\left\| {\nabla F({\mathbf{w}})} \right\| \ne 0\) , \(B({\mathbf{w}}) = \sqrt {\tfrac{{{\rm E}_{k} [\left\| {\nabla F_{k} ({\mathbf{w}})} \right\|^{2} ]}}{{\left\| {\nabla F({\mathbf{w}})} \right\|^{2} }}}\) . This definition quantifies the difference between the local model loss function and the global loss function. We assume that, for some \(\varepsilon > 0\) , there exist a \(B_{\varepsilon }\) , for all \({\mathbf{w}} \in S_{\varepsilon }^{c} = \{ w|\left\| {\nabla F({\mathbf{w}})} \right\| > \varepsilon \} ,B({\mathbf{w}}) \le B_{\varepsilon }\) .

Assumption 3.

Bounded variance:

There exist a \(\sigma\) , \({\rm E}_{k} [\left\| {\nabla F_{k} ({\mathbf{w}}) - \nabla F({\mathbf{w}})} \right\|] \le \sigma^{2}\) .

Improvements provided by stratified sampling and regularization

Stratified sampling

As we know, the model parameters contain the data distribution information of the client. We assume that there are m different data distributions in the overall data, similarly, for the convenience of analysis, we assume that each client has the same amount of local data. Through clustering, the client data distribution in each cluster is the same, therefore, the trained neural network model has the same parameter information, namely:

$$ \forall i \in [1,m],\forall j,k \in c_{i} ,{\mathbf{w}}^{j} = {\mathbf{w}}^{k} $$
(5)

According to Eq. 5, we use \({\mathbf{w}}_{t}^{{c_{i} }}\) to represent the model parameters in cluster i after the t rounds training, \(n_{{c_{i} }}\) indicates the number of clients in cluster i. We have:

$$ \forall i \in [1,m],\sum\limits_{{k \in c_{i} }} {{\mathbf{w}}_{t}^{k} } = n_{{c_{i} }} \cdot {\mathbf{w}}_{t}^{ci} $$
(6)

We first prove that the sampling method of FedSSAR is unbiased, that is:

$$ {\rm E}_{{S_{t} }} ({\mathbf{w}}_{t} ) = {\rm E}_{{S_{t} }} \left( {\sum\limits_{{k \in S_{t} }} {w_{k} (S_{t} )} \cdot {\mathbf{w}}_{t}^{k} } \right) = \sum\limits_{k = 1}^{N} {\rho_{k} } {\mathbf{w}}_{t}^{k} $$
(7)

where \(w_{k} (S_{t} )\) represents the aggregate weight of client k in subset \(S_{t}\).

Proof.

We can sum the model parameters of each round of training according to different clusters,

$$ {\rm E}_{{S_{t} }} (w_{t} ) = {\rm E}_{{S_{t} }} \left( {\sum\limits_{{k \in S_{t} }} {w_{k} (S_{t} )} \cdot {\mathbf{w}}_{t}^{k} } \right) = \sum\limits_{i = 1}^{m} {\tfrac{{n_{ci} }}{N}} {\mathbf{w}}_{t}^{ci} $$
(8)

From Eq. (7), we have

$$ \sum\limits_{i = 1}^{m} {\tfrac{{n_{ci} }}{N}} {\mathbf{w}}_{t}^{ci} = \tfrac{1}{N}\sum\limits_{i = 1}^{m} {n_{ci} } {\mathbf{w}}_{t}^{ci} = \tfrac{1}{N}\sum\limits_{k = 1}^{N} {{\mathbf{w}}_{t}^{k} } = \sum\limits_{k = 1}^{N} {\rho_{k} } {\mathbf{w}}_{t}^{k} $$
(9)

We prove that the stratified sampling method can realize the unbiased sampling of all data. At the same time, we will prove that the use of stratified sampling reduces the aggregate weight variance of the client and makes the model update more stable. For the traditional random sampling method, the aggregation weight of each client is equal to its sample weight, \(\rho_{k} \leftarrow n_{k} /M\), where \(n_{k}\) is the sample data of client k, and M is the overall sample data. Based on the assumption in this section that the client has the same amount of local data, it can be simplified to \(\rho_{k} \leftarrow 1/N\). According to the assumption that the overall data has different distributions in m, we select the number of clients participating in each round of training as m. In the random sampling method, m clients are randomly selected according to the Bernoulli distribution:\(B(\rho_{k} )\). For client k, the aggregate weight variance is:

$$ {\text{Var}}({\mathbf{w}}_{k} (S_{t} )) = \tfrac{1}{{m^{2} }}m{\text{Var}}(B(\rho_{k} )) = \tfrac{1}{m}\rho_{k} (1 - \rho_{k} ) = \tfrac{1}{mN}\left(1 - \tfrac{1}{N}\right) $$
(10)

In stratified sampling, we first divide the overall client into m cluster sets and then select one client from each set to participate in training, we use \({\text{Var}}_{c} ({\mathbf{w}}_{k} (S_{t} ))\) to represent the aggregate weight variance of client k in stratified sampling. It can be seen that:

$$ {\text{Var}}_{c} ({\mathbf{w}}_{k} (S_{t} )) = \tfrac{1}{{m^{2} }}{\text{Var}}(B(r_{t}^{k} )) = \tfrac{1}{{m^{2} }}m\rho_{k} (1 - m\rho_{k} ) = \tfrac{1}{mN}(1 - \tfrac{m}{N}) $$
(11)

We have that:\({\text{Var}}({\mathbf{w}}_{k} (S_{t} )) \ge {\text{Var}}_{c} ({\mathbf{w}}_{k} (S_{t} ))\), only when \(m = 1\)(the data distribution in each client is basically the same), they are equal.

Regularization

In general, for federated learning under heterogeneous conditions, the client's local optimum does not conform to the global optimum. In fact, the global optimum W* should satisfy:

$$ \nabla F({\mathbf{w}}_{t} ) = \sum\limits_{{k \in S_{t} }} {\rho_{k} } \nabla F_{k} ({\mathbf{w}}^{*} ) = \sum\limits_{{k \in S_{t} }} {{\rm E}_{{S_{t} }} } (\nabla \ell_{k} (w^{*} ,\xi^{k} )) = 0 $$
(12)

The local optimum for each client \({\mathbf{w}}_{k}^{*}\) satisfies, \(\nabla F({\mathbf{w}}_{k}^{*} ) = 0\). However, due to client heterogeneity, the local and global data distributions are not the same, so the global and local optimum are often different, which means, \(\nabla F_{k} ({\mathbf{w}}_{{}}^{*} ) \ne 0\). This means that updating the model by optimizing the local empirical loss cannot make the model converge to the global optimal solution.

As shown in Algorithm 1. we added a regular term when optimizing the client’s local loss function, Therefore, the client's local model optimal solution will approach the global model optimal solution. To be more specific, when calculating the local empirical loss of the client, we add a restriction based on calculating the empirical loss, \(\tfrac{\alpha }{2}\left\| {{\mathbf{w}} - {\mathbf{w}}_{t} } \right\|^{2}\),we have \(\tfrac{\alpha }{2}\left\| {{\mathbf{w}} - {\mathbf{w}}_{t} } \right\|^{2} = 0\), if \({\mathbf{w}} - {\mathbf{w}}_{t}\), we modify the optimal point of each client to make it close to the best global point.

Convergence result

In this section, we discuss the convergence of FedSSAR with the participation of all devices in the aggregation step. Assume that FedSSAR terminates after T round iteration and returns wT as the output. Theoretical analysis proves the convergence of FedSSAR under non-convex setting.

Theorem 1

(Non-convex FedSSAR convergence: All devices participate).

Let Assumptions 1 to 3 hold and \(L,\alpha ,\sigma ,\varepsilon\) be defined. Assume the functions \(F_{k}\) are non-convex and there exist a \(\overline{L} > 0\), such that \(\nabla^{2} F_{k} \ge - \overline{L} \_I\), with \(\overline{\alpha } = \alpha - \overline{L} > 0\). We have

$$ \rho = \left[ {\tfrac{1}{\alpha } - \tfrac{{L\sqrt {1 + \tfrac{{\sigma^{2} }}{\varepsilon }} }}{{\alpha \overline{\alpha } }} - \tfrac{{L\sqrt {1 + \tfrac{{\sigma^{2} }}{\varepsilon }} }}{{2\overline{\alpha }^{2} }}} \right] > 0 $$

after T times of iterations

$$ \left\| {\nabla F({\mathbf{w}}_{T} )} \right\| \le \sqrt {\tfrac{{F({\mathbf{w}}_{0} ) - F^{*} }}{(T + 1) \cdot \rho }} = O\left( {\tfrac{1}{\sqrt T }} \right) $$

In full client participation mode, FL are seriously affected by the "straggler's effect" (where all nodes wait for the slowest node), so partial client participation FL have more practical applications. Suppose that there are m different data distributions in the overall dataset, \(S_{t} \subseteq \{ 1,2, \ldots ,m\}\) is a subset contains m indices in the k-th iteration, and St is randomly selected from each cluster. We can get the convergence of FedSSAR with the participation of partial clients.

Theorem 2

(Non-convex FedSSAR convergence: Partial devices participate).

Let Assumptions 1 to 3 hold and \(L,\alpha ,\sigma ,\varepsilon\) be defined. Assume the functions \(F_{k}\) are non-convex and there exist a \(\overline{L} > 0\), such that \(\left\| {\nabla^{2} F_{k} } \right\| \ge - \overline{L}\), with \(\overline{\alpha } = \alpha - \overline{L} > 0\). FedSSAR with partial device participation satisfies:

$$ \left\| {\nabla F({\mathbf{w}}_{T} )} \right\| \le \sqrt {\tfrac{{F({\mathbf{w}}_{0} ) - F^{*} }}{(T + 1) \cdot \rho }} = O\left( {\tfrac{1}{\sqrt T }} \right) $$

where

$$ \rho = \tfrac{1}{\alpha } - \tfrac{{L\sqrt {1 + \tfrac{{\sigma^{2} }}{\varepsilon }} }}{{\alpha \overline{\alpha } }} - \tfrac{{L\sqrt {1 + \tfrac{{\sigma^{2} }}{\varepsilon }} }}{{2\overline{\alpha }^{2} }} - \tfrac{\sqrt 2 }{{\sqrt m \cdot \overline{\alpha } }}\sqrt {1 + \tfrac{{\sigma^{2} }}{\varepsilon }} - \tfrac{L}{{m \cdot \overline{\alpha }^{2} }}\left(1 + \tfrac{{\sigma^{2} }}{\varepsilon }\right) > 0 $$

Experiments

This section, we evaluate the performance of FedSSAR across various datasets, models, availability settings and compare it with FedAvg, FedProx and SCAFFOLD algorithms.

Experimental details

Datasets

We experimented with different public datasets. These datasets are the benchmark datasets derived from previous related work (MNIST, Cifar10, Cifar-100 and EMNIST-L). Take MNIST dataset as an example, we distributed 60,000 training data to 100 clients, with 600 data (1%) in each client. The IID split adopts independent and random division, as shown in Fig. 1a, the data distribution in each client is basically the same; To simulate different degrees of data heterogeneous, the non-IID and non-IID2 split use biased sampling, the overall data are sorted according to the labels and then divided into different slices so that each slice contains only one kind of label data, and then the slice data are randomly divided into different clients. The size of the slice affects the number of data types contained in the client, each client in non-IID contains approximately two types of data (Fig. 1b), and each client in non-IID2 contains approximately only one type of data (Fig. 1c), which simulates the extreme heterogeneity of the data. Figure 1 shows the data distribution of the top 20 clients in the MNIST dataset.

Fig. 1
figure 1

Data distribution of the top 20 clients in the MNISTdataset

We use a similar division method in Cifar-10 and EMNIST-L dataset (details are shown in "Appendix B.1"). Specifically, we conducted a large-scale client experiment with 1000 clients on MNIST, EMNIS, Cifar10 and Cifar-100 dataset. Figure 2 shows the data distribution of the top 20 clients in the Cifar-100 dataset.

Fig. 2
figure 2

Data distribution of the top 20 clients in the Cifar-100 dataset

Implementation

We select the FedAvg, FedProx and SCAFFOLD algorithms as baselines. The value of parameter \(\alpha\) in FedSSAR is selected from set {1, 0.5, 0.1, 0.05, 0.01, 0.005, 0.001}, in "Appendix C.1" we test the sensitivity of \(\alpha\). The original proportion of clients selected per round is 10%; the initial learning rate is 0.1, and the learning rate decay \(\eta_{t} = 0.1/(1 + t)\); we set the parameter weight decay as \(10^{ - 3}\); we set \(\mu = 10^{ - 3}\) for FedProx; K of 12 and 1 is selected for SCAFFOLD under moderate and massive device number. In order to ensure that the samples extracted each time are unbiased estimates of the population samples, we randomly extract the target number of samples from the population samples (FedSSAR algorithm is from each cluster) in each round.

Models

We use fully connected multilayer neural network architectures for MNIST and EMNIST-L with 2 hidden layers (the number of neurons in the hidden layer are 200 and 100). For Cifar-10 and Cifar-100 datasets, we use a CNN model (2 convolutional layers with 64 5 × 5 filters, 2 fully connected hidden layers contains 394 and 192 neurons followed by a softmax layer).

Clustering results of model parameters

According to algorithm 1, taking MNIST (100 clients) dataset as an example, under different heterogeneous data conditions, the clustering results of model parameters are shown in Fig. 3.

Fig. 3
figure 3

Clustering results of model parameters under different heterogeneous setting on MNIST dataset (moderate device number)

As shown in Fig. 3a, for the IID split, all clients belong to the same cluster, and the local data distribution among clients has high similarity, so it can be considered that the local data distribution of clients is the same as the overall data distribution of datasets; Under different heterogeneous settings (as shown in Fig. 3b, c), clients are divided into different clusters. As shown in Fig. 3b, in the first type of heterogeneous data settings, the total number of clients is divided into 29 groups, and the similarity of the model parameters in each set is relatively high; Fig. 3c shows the results of the heterogeneous setting of the second type of data. It can be seen that the overall client is divided into 10 cluster sets. Compared to the results of the heterogeneous setting of the first type of data, the model parameters between different clusters differ greatly. See "Appendix B.2" for the clustering results of other datasets.

Experimental results

We first tested the performance of the FedSSAR on different datasets with the above baselines: FedAvg, FedProx and SCAFFOLD under different conditions. The verification performance of each task is shown in Figs. 4 and 5. Tables 1 and 2 summarizes the validation performance of MNIST and EMNIST-L after 50 rounds of training and the validation performance of Cifar-10, Cifar-100 after 500 rounds of training. Due to space constrains, plots of model training losses will be shown in "Appendix C.2".

Fig. 4
figure 4

Validation accuracy of FedSSAR and other methods, using decreasing learning rate \(\eta_{t}\) and weight decay: \(10^{ - 3}\) to achieve the best training performance during the last 50 training rounds for MNIST and EMNIST-L, 500 training rounds for Cifar-10 and Cifar-100

Fig. 5
figure 5

Validation accuracy of FedSSAR and other methods, using decreasing learning rate \(\eta_{t}\) and weight decay: \(10^{ - 3}\) to achieve the best training performance during the last 50 training rounds for MNIST and EMNIST-L, 500 training rounds for Cifar-10 and Cifar-100

Table 1 Average validation performance over the last 50 rounds for MNIST and EMNIST-L, 500 rounds for Cifar-10 and Cifar-100; % accuracy; on non-IID split
Table 2 Average validation performance over the last 50 rounds for MNIST and EMNIST-L, 500 rounds for Cifar-10 and Cifar-100; % accuracy; on non-IID2 split

Non-IID split the non-IID split simulates the medium degree of data heterogeneity. From Table 1, it can be seen that FedSSAR achieves the best results in the specified training rounds in all task settings. In terms of dataset structure, MNIST and EMNIST are handwritten character image recognition, and Cifar-10 and Cifar-100 are color object image recognition in reality, therefore, the latter has higher data distribution differences on different types of pictures and higher heterogeneity of the entire dataset, experiments on Cifar datasets require more rounds of training to achieve convergence than on MNIST and EMNIST datasets. At the same time, FedSSAR performs better than competing methods in more heterogeneous datasets and in a larger number of device scenarios that fit larger-scale scenarios, such as Cifar-10 and Cifar-100 for a large number of devices.

In addition to faster training speed and higher training accuracy, as shown in Fig. 4, the training process of FedSSAR is more stable. As we have shown in section "Improvements provided by stratified sampling and regularization", the stratified sampling method reduces the client's aggregation weight bias, while the regular term on the client optimization function limits the direction of updating the model parameters, making the updating of the model more stable. Especially on highly heterogeneous datasets, such as CIFAR datasets, the traditional algorithm training curve fluctuates sharply, while the FedSSAR maintains a stable growth curve.

Non-IID2 split the non-IID2 split simulates the extreme heterogeneity of client data, as shown in Fig. 1 and "Appendix B.1", different types of clients have completely different local data distribution. According to the results of the clustering ("Appendix B.2"), for clients with different clustering results, their training model parameters differ greatly. When the model parameters are aggregated, the aggregation variance of the model parameters participating in the training increases, which greatly affects the training efficiency of the model. As can be seen from Table 2, the model performance of all training tasks is reduced to varying degrees compared to the non-IID split. However, under extreme heterogeneity conditions, FedSSAR shows better performance. As the degree of data heterogeneity increases, FedSSAR improves more than traditional federated learning algorithms.

Conclusions

In our work, we propose the FedSSAR algorithm, an FL algorithm optimized by stratified sampling and regularization. FedSSAR uses the density-based clustering method to divide the heterogeneous clients into different clusters, then use the stratified sampling method to extract the client from each cluster to participate in training, and a regular item is used to restrict the local model updates of heterogeneous clients. Under non-convex conditions, we provide the convergence proof of FedSSAR algorithm, and prove that FedSSAR has a convergence rate of \(O(1/\sqrt T )\). We verify the performance of FedSSAR on different training tasks and parameter settings, our experiments prove that FedSSAR has a great improvement over the FedAvg, FedProx and SCAFFOLD on heterogeneous conditions.