Introduction

Machine learning techniques have achieved significant applications in many real-world problems, such as speech recognition [1], autonomous driving [2], and disease detection [3]. Traditional machine learning techniques normally require gathering all data into a centralized server to extract beneficial information hidden in data. However, a great amount of data is normally generated by various distributed edge devices or organizations, resulting in great difficulties in collecting all data from multiple entities to the centralized server due to privacy protection and communication efficiency. The emergence of federated learning (FL) [4, 5] gives a chance to collaboratively train a centralized model without sharing the local data with the centralized server [6].

It has been shown that federated learning can reduce resource consumption on the server and simultaneously ensure the privacy protection of client data [7]. Thus, federated learning has become a hot research topic in artificial intelligence after being presented [8]. Common federated learning approaches share parameters or gradients of local models with the centralized server after each round of local training on local data. The central server then aggregates the parameters of local models using aggregation techniques such as federated averaging (FedAvg) [6] and its variants [9,10,11,12]. The aggregated global model will be distributed to clients for the next round of model optimization.

Till now, researchers have mainly focused on three concerns, including the model performance, communication traffic, data privacy and security, to improve the performance of federated learning [13]. Many FedAvg variants have been proposed to improve the performance of global models on datasets either with identical and independent distribution (IID) or non-IID. Li et al. [14] proposed a broader framework, called FedProx, in which a proximal term is added to the client cost functions to stabilize the method. In [15], Mohri et al. proposed an agnostic federated learning approach in which the centralized model is optimized for any possible target distribution formed by a mixture of client distributions. For non-IID problems [16], Zhao et al. proposed adding the regular term to the global loss function. Acar et al. attempted to improve the regular terms and proposed an algorithm called FedDyn in [17]. Some works try to optimize convergence to improve the performance of federated learning [18]. A kind of work combines traditional optimization algorithms such as momentum [19], Adam [20] with federated learning. Liu et al. [21] proposed a momentum federated learning to accelerate the convergence speed. Reddi et al. proposed a general framework to incorporate adaptivity in federated learning [22], in which clients run multiple epochs to minimize the loss of local data and the server updates its global model by applying a gradient-based server optimizer to the average of the clients’ model updates. Nguyenc [23] proposed a fast-convergent federated learning algorithm, in which an accurate and communication-efficient approximation of a near-optimal distribution of device selection is employed to accelerate the convergence. In [24], Briggs et al. presented federated learning with hierarchical clustering, in which clients are clustered by similarity based on their updates to the model of a centralized server, to speed up the convergence of model training. Some methods are proposed to construct a personalized model for each client for solving problems with non-IID data. In [25], Chen et al. constructed a two-loss,two-predictor personal federated learning framework by adding a predictor to the training model. Huang et al. proposed a personal federated learning algorithm by personal aggregation, called FedAMP [26]. In FedAMP, every client executes aggregation independently based on the similarity to other models distributed by the server. Some researchers have also proposed some federated ensemble learning approaches to improve generalization. In [27], Liu et al. considered the process of federated ensemble learning as a kind of evolutionary algorithm and defined crossover and mutation operators for neural networks. In [28], Thonglek et al. proposed using a weighted average ensemble to combine the outputs from each model. The weight for the ensemble model is optimized using black box optimization methods. Stephanie applied federated ensemble learning into blockchain [29], using a secure multiparty computation-based ensemble federated learning with blockchain that enables the heterogeneous model to train under the federated environment.

Communication efficiency is one of the critical concerns in federated learning. Normally, the communication is saved by reducing the total number of communication rounds or the size of the transmitted message at each round [30]. Sattler et al. [31] proposed a sparse ternary compression to compress both the upstream and downstream communications by sparsification, internalization, error accumulation, and optimal Golomb encoding. Reisizadeh et al. [32] proposed a communication-efficient federated learning algorithm with periodic averaging and quantization for addressing the communication and scalability challenges of federated learning. In [11], Rothchild et al. proposed using a data structure, called a count sketch, to compress the gradient of each client before it is sent to the central server. Yang et al. [33] applied binary neural networks to improve communication efficiency in federated learning. Besides the compression of information, there are also other ways to improve communication efficiency. Chen et al. [10] suggested an asynchronous model update strategy and a temporally weighted aggregation method for reducing communication costs and improving the learning performance of federated learning. Lin et al. proved that communication delay would improve communication efficiency to some extent in [34]. Shi et al. explored the improvement of communication delay further. They proposed a joint device scheduling and resource allocation scheme to maximize the model accuracy under a restricted environment, fixed total training delay and communication budgets [35].

Federated learning is proposed to prevent data leakage, corresponding to preserving data privacy and security to some extent. However, private information is still possible to be extracted by some attacks against the federated architecture. Thus, researchers have presented some strategies for privacy and security protection for federated learning. Truex et al. [36] proposed using both differential privacy and secure multiparty computation to balance the privacy preservation and the model performance. In [37], Xu et al. presented an approach for privacy-preserving federated learning by employing a secure multiparty computation protocol based on functional encryption. Xiong et al. [38] proposed a dual differentially private federated learning to defend against privacy inference attacks. Zhang et al. [39] suggested tackling the encryption and communication bottlenecks created by homomorphic encryption with a simple batch encryption technique. In [40], Ma et al. proposed a multi-key homomorphic encryption protocol for designing a novel privacy-preserving federated learning scheme.

Although federated learning has achieved rapid development, it is still in its early stages. In this paper, a composition–decomposition based federated learning, denoted as CD-FL, is proposed. In CD-FL, each client will randomly choose one sub-model to train so as to reduce the communication traffic on the one hand. On the other hand, the model in the central server will be composed of the sub-models after clustering on the sub-models of the previous and current rounds, so as to improve the generalization capability. The main contributions of CD-FL are summarized as follows:

  1. 1.

    The global model in the central server is composed of a number of sub-models with the same architecture. The global model will send to all clients, while each client selects only one sub-model randomly and update its parameters.

  2. 2.

    The communication traffic of the federated learning will be reduced, because each client only needs to upload the parameter of the selected sub-model.

  3. 3.

    All uploaded sub-models at the current round, together with those at the previous round, will be clustered using the K-means approach. The center of each cluster will be used to compose the global model. Thus, different characteristics implied in the client dataset can be kept, accordingly improving the generalization of the global model.

The rest of the paper is organized as follows. “Related works” gives a simple explanation of convolutional neural network and federated learning that are related to the proposed method. In “The CD-FL framework ”, details of the proposed CD-FL method are presented. The experimental results and discussions are given in “Experimental results and discussions”. Finally, “Conclusions and future work” gives the conclusions and future work.

Related works

Convolutional neural network

In the federated learning framework, any machine learning technique can be used to train a model on its local data. In our method, the convolutional neural network (CNN) is adopted. Generally, a convolutional neural network is composed of an input layer, convolution layers, pooling layers, and fully connected layers. Each convolution layer performs a dot product of the convolution kernel with the input matrix of the current layer. Frobenius inner product is normally adopted as the dot product, and the ReLU function is often used as the activation function. A feature map will be generated by the convolution operation and will contribute to the input of the next layer which may be pooling layers, fully connected layers, or normalization layers. Pooling layers are used to reduce the dimensions of data by extracting dominant features which are rotational and positional invariant, thus maintaining the process of effectively training the model. Max pooling and average pooling, returning the maximum and mean value, respectively, from the image covered by the filter, are the two main types of pooling. The fully connected layer is the same as the traditional multilayer perceptron neural network, connecting every neuron in one layer to every neuron in another layer. In our method, a convolutional neural network consisting of eight convolutional layers and three fully connected layers is adopted.

After the convolutional neural network framework is determined, the weight vector of CNN needs to be optimized to achieve an optimal performance of the model. Many algorithms can be used to minimize the loss function. In our method, the batch gradient descent (BGD) approach is used for training the model. Training data are divided into a number of batches. Each batch of data is used to calculate the gradient of CNN.

Federated learning

Federated learning (FL) is a new field that focuses on solving data island problems in machine learning. Algorithm 1 gives the pseudocode of FedAvg [6], a basic framework of federated learning.

Algorithm 1
figure a

FedAvg

Fig. 1
figure 1

An example to show the overall framework of our proposed method

In Algorithm 1, a global model \({\textbf{W}}^t\) is initialized by generating parameters randomly in the range of [0, 1] at first. Then, at each round \(t, t=1,2,\ldots ,T\), the global model will be broadcast to all clients. Each client will update the model based on its own dataset. The updated model \({\textbf{w}}^{t}_i, i=1,2,\ldots ,N\) will be uploaded together with the total number of data in the client, \(n_i, i=1,2,\ldots ,N\), to the central server. The global model \({\textbf{W}}\) will be updated using the following equation:

$$\begin{aligned} {\textbf{W}}^{t+1}=\frac{\sum _{k=1}^{N}n_k {\textbf{w}}^{t}_k}{\sum _{k=1}^{N}n_k}, \end{aligned}$$
(1)

where \(n_k\) represents the number of data in the k-th client, i.e., \(n_k = |{\textbf{D}}_k|, k=1,2,\ldots ,N\). The model after training with T rounds, \({\textbf{W}}^{T+1}\), will be output and used as the global model.

The CD-FL approach

In this section, the overall framework of our proposed CD-FL approach will be given first. Then the details of clustering and aggregation approaches are described. Finally, a simple analysis of the security and privacy will be given as well as the communication traffic in the last part of this section.

The CD-FL framework

Figure 1 gives a simple example to show the idea of our proposed CD-FL method. In Fig. 1, suppose there are three sub-models with the same architecture and five clients participating in the model training. There are totally seven main steps in CD-FL. The global model will be decomposed into K sub-problems which have the same architecture (Step 1). In this example, the global model will be decomposed into three sub-models. All sub-models will be broadcast to each client (Step 2). Each client will randomly choose one sub-model to update using its own dataset (Step 3). Note that different clients may choose the same sub-model to update, and there exists a probability that some sub-models will not be updated in the current round. After each client finishes updating its sub-model, the sub-model will be uploaded to the central server. All sub-models, including the updated sub-models and the decomposed sub-models in the server, will be clustered into K clusters (Step 4). The center of each cluster will be composed of a global model (Step 5). Note that the clustering of all sub-models can also be conducted at a trusted third party to further improve the security.

Algorithm 2 gives the pseudocode of our proposed CD-FL method. The notation of each symbol used in the description of the proposed method is given in Table 1. In Algorithm 2, a global model, \({\textbf{W}}^t\), is initialized at first, which is composed of K sub-models with the same architectures. Then, \({\textbf{W}}^t\) will be decomposed into K sub-models (lines 3–4), which will be saved to an archive \({\textbf{S}}\) (line 5). After that, the following procedure will be repeated for T rounds. All sub-models in \({\textbf{S}}\) will be broadcast to all clients. When each client i receives the sub-models, it will randomly select one sub-model to be updated using the data in client i (lines 9–10). The process of model_updating is given in detail in Algorithm 3. Note that the sub-model will be updated for \(T_0\) generation at the first round to speed up the updating. The updated sub-model of each client and only this sub-model will be uploaded to the central server and saved to the archive \({\textbf{S}}\) (lines 11–12), thus being able to save communication time. After all of the clients upload the updated sub-models to \({\textbf{S}}\), all sub-models in \({\textbf{S}}\) will be clustered into K clusters using K-means approach. The center of each cluster will compose the global model and will be broadcast to the clients in the next round (lines 14–20). After the maximum number of communication rounds T is met, output the global model, which is composed of all the current sub-models in parallel, only except that the output layer is the result of the softmax function on the output of previous layer of all sub-models.

Table 1 Description of notations
Algorithm 2
figure b

The composition–decomposition based federated learning

Algorithm 3
figure c

\(model\_updating(\textbf{w}_k^t,\textbf{D}_i)\)

Clustering and aggregation

The parameters of each sub-model are regarded as a point in a solution space. All points will be clustered into K sets based on their distance, such as Euclidean distance and cosine similarity, from each other. Any clustering technique can be used in the proposed method. In our method, the K-means approach is adopted. Note that in the proposed CD-FL method, the sub-models decomposed from the global model at the previous round are used as the initial center points of the K-means clustering in order to reduce the influence of outliers (e.g., a poisoning sub-model) as much as possible. After clustering, there are two cases to get a sub-model based on each cluster. When all sub-models in a cluster are from the previous round, then the mean of their parameters will be a sub-model for the next generation. When there is at least one sub-model in a cluster comes from the client, all sub-models except those from the previous round in each cluster will be aggregated using Eq. (1) to get a sub-model of the global model for the next ground.

Fig. 2
figure 2

The performance of the method with a different number of sub-models when there are 20 and 50 clients

Table 2 The accuracy (%) of CD-FL with a different number of sub-models
Fig. 3
figure 3

Relationship between the hyperparameter \(T_0\) and the final performance on different datasets and different clients

Performance analysis

In this section, we briefly analyse two performances for our proposed method. One is on the performance of security and privacy, and the other is on communication traffic.

Security and privacy

Data privacy and security have been of critical concern in recent years. Thus, federated learning, developed as a new machine learning methodology, is paid more and more attention to addressing these concerns due to privacy-guarantee advantages compared to traditional machine learning techniques. However, the parameters of the local model still contain sensitive information because some features of the data samples are inherently encoded into local models. Thus, information can be stealthily extracted from local training parameters, which are called inference attacks. In this paper, the sub-models will be clustered into a number of clusters and then composed into a global model in the central server. Thus, when the global model is broadcast to all clients, it is difficult to infer the data hidden in the model, ensuring security to some extent.

Poisoning attacks are another kind of attack, which aims to break the robustness of the system by sending poisonous data during local data collection or making models worse during the local model training process. In our proposed method, the global model is composed of a number of cluster centers, which are clustered on the parameters of sub-models. Thus, there will be two situations of poisoning attacks. One is that the attacker from a client is completely far away from other sub-models. That is, there will be a cluster only containing this attacker, which will definitely be one of the sub-component of the global model. As the global model is composed of a number of sub-components, the attacker will not affect the overall performance of the global model. The other situation is that the attacker is close to some sub-models, which will be clustered to a cluster together with other sub-models. Thus, it will clearly not affect the performance of the global model.

Communication traffic

To give a fair comparison between the FedAvg method and our proposed CD-FL algorithm on the total communication traffic, we suppose the global model is also composed of K sub-models with the same architecture. Then the total amount of bits in FedAvg algorithm, including the bit uploading from all clients and that downloading to all clients, can be calculated as follows.

$$\begin{aligned} b^{\text {up}}_{\text {FedAvg}}=b^{\text {down}}_{\text {FedAvg}} \in {\mathcal {O}}( T_{\text {fedAvg}} \times f \times N \times K \times |w| \times \eta ),\nonumber \\ \end{aligned}$$
(2)

where \(b^{\text {up}}_{\text {FedAvg}}\) and \(b^{\text {down}}_{\text {FedAvg}}\) represent the total amount of bits for uploading from and downloading to all clients, respectively. \(T_{\text {FedAvg}}\) and f represent the number of communication rounds and frequency of communication, respectively. \(|\omega |\) is the size of each sub-model, and \(\eta \) represents the efficiency of the encoding.

In our proposed CD-FL method, the downloading amount of the communication is the same as that of FedAvg, that is,

$$\begin{aligned} b^{\text {down}}_{\text {CD-FL}} \in {\mathcal {O}}( T_{\text {CD-FL}} \times f \times N \times K \times |w| \times \eta ), \end{aligned}$$
(3)

where \(T_{\text {CD-FL}}=T_{\text {fedAvg}}\) represents the number of communication rounds for downloading the global model. However, the uploading amount of the communication is different from that of FedAvg, which is given as follows:

$$\begin{aligned} b^{up}_{\text {CD-FL}} \in {\mathcal {O}}( T_{\text {CD-FL}} \times f \times N \times |w| \times \eta ). \end{aligned}$$
(4)

Compared to FedAvg, in our proposed CD-FL approach, only parameters of a sub-component of the global architecture are needed to upload to the server for each client. Thus, the uploading communication traffic can be 1/K of the amount uploaded in FedAvg, thus saving the total communication traffic.

Experimental results and discussions

To evaluate the performance of our proposed method, we conduct a number of experiments on Fashion-MNIST, CIFAR-10, EMNIST and Tiny-ImageNet benchmark datasets. The sensitivity of hyperparameters is analyzed. Finally, the performance of our proposed method is validated by comparing the results to those obtained by global federated learning methods and personal federated learning methods.

Experimental setup

Datasets and the model architecture

To evaluate the performance of our proposed method, all experiments are conducted on IID and non-IID Fashion-MNIST [41], CIFAR-10 [42], EMNIST [43], and Tiny-IMAGENET [44] datasets. In all experiments, the training data are split evenly among the clients according to the Dirichlet distribution following the procedure outlined in [45]. To be specific, in the IID cases, we take \(\alpha =100.0,\) while in non-IID cases, \(\alpha =0.5\).

Each sub-model is a standard convolutional neural network consisting of eight convolutional layers and three fully connected layers. The global model trained in our experiment can be seen as an ensemble model of several sub-models.

Table 3 Classification accuracy (%) and convergence speed (# of epochs) obtained by CD-FL with different (\(T_0\) )

Comparsion methods

FedAvg [6], FedAdam [22], Fedprox [14], Moon [46], FedAvg with TPE [28], FedAMP [26], and FedRod [25] are adopted for empirical comparisons with our proposed CD-FL approach. In all comparison approaches, FedAvg, FedAdam, Fedprox, Moon, and FedAvg with TPE are all global federated learning methods, while FedAMP and FedRod are personal federated learning approaches. FedAdam is under the FedAvg framework by using the Adam optimizer in the server. In Fedprox, a regular term is added into the loss function of FedAvg to have a faster and better converged performance. In Moon, the contrastive learning method is added to local training to utilize the similarity between model representations. In [28], Thonglek et al. proposes a skill that uses a weighted average ensemble to combine the outputs from each model. The weight for the ensemble model is optimized using black-box optimization methods. They use TPE optimizer [47] to optimize the FedAvg algorithm. This method in experiments is called FedAvg with TPE. In FedAMP, every client executes aggregation independently based on the similarity between models. In FedRod, a predictor head and its loss item are added to the personal model. In PFL methods, the accuracy of the ensemble model for all personal models will be recorded, while the accuracy of the global model in global FL methods will be recorded.

Fig. 4
figure 4

Classification accuracy on CIFAR-10 under a number of clients 5, 20 and 50

Fig. 5
figure 5

Classification accuracy on Fashion-MNIST under 5, 20 and 50 clients, respectively

Fig. 6
figure 6

Classification accuracy on EMNIST under 5, 20 and 50 clients, respectively

Fig. 7
figure 7

Classification accuracy on Tiny-IMAGENET under 5, 20 and 50 clients, respectively

Fig. 8
figure 8

Classification accuracy on non-IID datasets under 20 and 50 clients, respectively

Running environment

All experiments are implemented in the TensorFlow2.3.0 framework and run on the server with 6 GTX 1080 Ti GPUs and 16 2.50GHz CPUs. The server has 8.0 GBytes video memory and 128.0 GBytes RAM memory.

Hyperparameters

The number of cluster centers, K, is set to 5 in our experiments by default. The number of epochs \(T_0\) is set to 20 at the first communication round and 1 at other rounds by default.

Optimization

In CD-FL, two strategies are used for updating the sub-models in clients. In the first communication round of CD-FL, the SGD optimizer with a fixed learning rate of \(\eta = 0.01\) is used to update the sub-model in each client and 20 epochs are conducted. In the remainder of communication rounds of CD-FL, the Adam optimizer [20] with a fixed learning rate of \(\eta = 0.001\) is used to update the sub-model in each client. The batch size is set to a fixed number, 32, in the optimization of models for clients.

Stopping criterion

All experiment data are recorded with top-1 accuracy. Note that the results are obtained when there are no changes for ten continuous rounds.

Sensitivity analysis on hyperparameters

There are two important hyperparameters in CD-FL, i.e., the number of sub-models K and the epochs \(T_0\) used for training the local models at the first communication round. In this section, empirical studies are performed to determine the value of these hyperparameters.

Table 4 Classification accuracy (%) obtained by CD-FL and comparison approaches when there are 20, and 50 workers
Table 5 Communication traffic (# of GB) obtained by CD-FL and comparison approaches on when there are 20, and 50 workers
Table 6 Classification accuracy (%) obtained by CD-FL and comparison approaches when there are 20 and 50 workers under poison attack

The number of sub-models K

The global model is composed of a number of sub-models with the same structure. In this section, we will explore the influence of the number of sub-models, the hyperparameter K, on the performance of the proposed method. Figure 2 and Table 2 give the performance of the method with a different number of sub-models when there are 20 and 50 clients. From Fig. 2 and Table 2, we can see that when the global model is composed of five sub-models, i.e., \(K=5\), the proposed method achieves the best performance on both CIFAR-10, Fashion-MNIST, EMNIST, and Tiny-IMAGENET dataset when there are 20 and 50 clients.

The epochs \(T_0\) used for training the local models at the first communication round

To balance the convergence and accuracy for training the model, \(T_0>1\) epochs are conducted at the first communication round and \(T_0 = 1\) is used in other rounds. Figure 3 and Table 3 show the performance of the method with 1, 20, 40, 60, 80 and 100 epochs at the first communication round. From Fig. 3 and Table 3, we can see that the proposed CD-FL achieves better performance when \(T_0\) is set to 20 for the first communication round.

Performance comparison

Figures 4, 5, 6 and 7 plot the trends of accuracy achieved by seven comparison algorithms and CD-FL on CIFAR-10, Fashion-MNIST, EMNIST and Tiny-IMAGENET datasets, respectively, when a different number of clients are considered. From Fig. 4, we can see that our proposed CD-FL approach achieves better performance on CIFAR-10 dataset when a different number of clients participates in the training. The performance of CD-FL is not better than other approaches under five clients on Fashion-MINIST, EMNIST, and Tiny-IMAGENET datasets. However, when the number of clients increases, the performance of CD-FL increases and is better than those of others. Figure 8 presents the tendency of CD-FL and seven other approaches on non-IID datasets. From Fig. 8, we can see that except on CIFAR-10 data set, our proposed CD-FL method can get competitive results with the other seven methods on non-IID datasets.

The experimental results obtained by seven compared methods and CD-FL are also summarized in Table 4. Note that the results are obtained when there are no changes for 10 continuous rounds. From Table 4, we can see that our proposed method achieves the best performance compared to the other seven approaches. Furthermore, the corresponding communication traffic of each approach is also given in Table 5. From Table 5, we can clearly see that the communication traffic of CD-FL is much less than those of other approaches.

Analysis on the security

In this section, we will validate the security against poison attacks. Experiments are conducted on cases where there are 10%, 50%, and 90% poison attacks on clients. Experimental results are shown in Table 6. From Table 6, we can see that when there are 90% attackers, no algorithm is able to converge. When there are 10% and 50% attackers, the performance of our proposed CD-FL method is much better than other algorithms on almost all datasets.

Conclusions and future work

A composition–decomposition based federated learning is proposed in this paper to train a good global model with less communication traffic. The privacy of data in clients can be protected implicitly because the parameters of the global model are not directly updated from the models uploaded by the clients. The experimental results on Fashion-MNIST, CIFAR-10, EMNIST, and Tiny-IMAGENET datasets show the good performance of our proposed method. However, the model performance on the non-IID datasets needs to be further improved. Furthermore, heterogeneous data sets are not considered in this paper. Therefore, in the future, vertical federated learning will be studied.