Federated learning is an emerging approach that becomes more and more important since it solves several issues many Machine Learning applications have nowadays. Most require a centralized dataset which is usually achieved by sending data created on a client to a remote server. This is critical in the context of data privacy as well as data handling since a big data infrastructure is needed to process huge amounts of data to train the models. Additionally, you have the potential risk of data breaches which might impact the user’s privacy.
Federated learning takes a different approach and can be used in conjunction with traditional AI applications such as image classification, recommender systems, or natural language processing. It brings the model to the data instead of sending the data to the centralized model. This helps to improve for example image recognition for a smartphone app identifying animals. This is done in the following way:
Using the local dataset a model is trained on the smartphone
The model is sent to the server
The server creates a global model by aggregating all local models
The new global model is sent back to all smartphones
Each smartphone receives the updated global model
The updated global model has better performance since it shares the learning from other smartphones. Therefore, the image recognition app becomes better with a growing number of users while maintaining users’ privacy. In addition, federated learning ensures data privacy by design. The data is never shared with a server or other devices. The data stays on the phone and does not leave it for the purpose of training a model.
Federated learning does not require a big data infrastructure since the data is not collected in a single place or server. Since only the ML model information is shared which is defined by its weights and bias. In the case of a Sequential model, the amount of data sent to the server is usually small allowing a small latency and bandwidth. The advantages of federated learning make the usage of AI use cases more suitable for the industry than traditional AI concepts where the data is collected in a single place like a cloud. federated learning suitable for the industry.
However, federated learning is a complex topic and the infrastructure is not trivial to build. Compared to traditional ML training additional hyperparameters are introduced and need to be carefully selected.
As of today, there are only a few federated learning frameworks available. The most well-known are Flower, PySyft, TensorFlow Federated, or Paddle FL.
To showcase how a federated learning system can easily build we will use the federated learning framework Flower. It is one of the more popular frameworks in this field and takes a very straightforward approach. The final result will consist of a single server as well as two clients which will do the training.
Let’s start coding. This example will use TensorFlow compiling a MobileNetV2 model and use the CIFAR-10 dataset. Before we can start coding we need to install all required libraries. You can do this with the following command: `pip install flwr==0.15.0 tensorflow==2.4.1`.
We will start by writing the client-specific code. Create a new file and call it `client.py`. Initially import the required libraries `flwr` and `tensorflow`. Afterward, you have to create and compile the model followed by the data loading code. Here is the code:
import flwr as fl import tensorflow as tf
# Load and compile Keras model model = tf.keras.applications.MobileNetV2((32, 32, 3), classes=10, weights=None) model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])# Load CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data() import tensorflow as tf # Load and compile Keras model model = tf.keras.applications.MobileNetV2((32, 32, 3), classes=10, weights=None)
This should look familiar to anyone who has prior experience with TensorFlow or Keras. Next, we build a Flower client called `CifarClient` which is derived from Flower’s convenience class `NumpyClient`. The abstract base class `NumpyClient` defines three methods that clients need to override. These methods allow Flower to trigger training and evaluation of the previously defined Keras model:
# Define Flower client
class CifarClient(fl.client.NumPyClient):
def get_parameters(self):
return model.get_weights()
def fit(self, parameters, config):
model.set_weights(parameters)
model.fit(x_train, y_train, epochs=1, batch_size=32, steps_per_epoch=3)
return model.get_weights(), len(x_train), {}
def evaluate(self, parameters, config):
model.set_weights(parameters)
loss, accuracy = model.evaluate(x_test, y_test)
return loss, len(x_test), {"accuracy": accuracy}
Flower’s `NumpyClient.fit` method receives weights from the server, updates the model with those weights, trains the model on the locally held dataset `(x_train/y_train)`, and then returns the updated weights (via `model.get_weights`). Note that you can do a quick “dry run” bypassing `steps_per_epoch=3` to `model.fit` – this will only process three batches per epoch instead of the entire dataset. Remove `steps_per_epoch=3` to train on the full dataset (this will take longer).
The evaluate method works similarly, but it uses the provided weights to evaluate the model on the locally held dataset `(x_test/y_test)`. The last step is to create an instance of CifarClient and run it:
# Start Flower client
fl.client.start_numpy_client(server_address="[::]:8080", client=CifarClient())
That’s it for the client. We create the model, load the data, implement a subclass `NumpyClient`, and start the client. Let’s build the server script next.
In a new script called server.py, we add the following two lines to start a Flower server that performs three rounds of Federated Averaging that simply takes a weighted model parameter and averages them:
import flwr as fl fl.server.start_server(config={"num_rounds": 3})
That’s it!
First, we start the server:
$ python server.py
Next, we open a new terminal and start the first client:
$ python client.py
Finally, we open another new terminal and start the second client:
$ python client.py
This should result in the following output in terminal 2 or 3 (one of those running `client.py`). We can see that three rounds of federated learning improve the accuracy to about 46% on the training set and 28% on the test set (if we train on the full dataset, so `no steps_per_epoch=3`). There’s obviously lots of room for improvement, for example, by doing more rounds of federated learning and by tuning hyperparameters.
Congratulations, you have built a running Federated Learning system in less than 20 lines of code!
The full source code can be found here.
Federated learning becomes more and more relevant when machine learning is used in the industry. It enables data privacy by design and offers various other advantages such as low bandwidth requirements, best availability, and reduced cloud resources. A few frameworks are already available to realize federated machine learning workloads. A very simple federated learning setup is realized with Flower and Keras and is also manageable for beginners.
The media shown in this article are not owned by Analytics Vidhya and is used at the Author’s discretion.