Federated Learning using IBMFL
A python framework for federated learning (FL) in an enterprise environment.
Modern mobile devices have access to a wealth of data suitable for learning models, which in turn can greatly improve the user experience on the device. For example, language models can improve speech recognition and text entry, and image models can automatically select good photos. Standard machine learning approaches require centralizing the training data on one machine or in a datacenter. This rich data is often privacy sensitive, large in quantity, or both. Due to these reasons it becomes difficult to access this data.
Federated Learning[1] enables the devices to collaboratively learn a shared prediction model while keeping all the training data on device, thus removing the need for centralizing the data. Federated Learning allows for smarter models, lower latency, and less power consumption, all while ensuring privacy.
Federated Learning solves a bunch of problems like:
- Training on real-world data from mobile devices provides a distinct advantage over training on proxy data that is generally available in the data center.
- Data which is privacy sensitive or large in size (compared to the size of the model), so it is preferable not to log it to the data center purely for the purpose of model training .
- For supervised tasks, labels on the data can be inferred naturally from user interaction.
IBM Federated Learning
IBM federated learning[2] is a Python framework for federated learning (FL) in an enterprise environment where each participant node (or party) retains data locally and interacts with the other participants via a learning protocol. IBM federated learning provides a basic fabric for FL, to which advanced features can be added. It is not dependent on any specific machine learning framework and supports different learning topologies, e.g., a shared aggregator, and protocols. It supports Deep Neural Networks (DNNs)[3] as well as classic machine learning techniques such as linear regression and k-means. This comprises supervised and unsupervised approaches as well as reinforcement learning.
Aggregator and Party Configuration
IBM federated learning offers as abstraction an Aggregator and Parties (as shown in the fig2) so that a group of parties can collaboratively learn a machine learning model (usually referred to as a global model) without revealing their raw data to each other. For most federated learning algorithms, only model updates, e.g., model weights, are shared with the aggregator. The aggregator then learns a global model by aggregating the information provided by parties. We now introduce these two abstractions in IBM FL.
- The aggregator is in charge of running the Fusion Algorithm. A fusion algorithm queries the registered parties to carry out the federated learning training process. The queries sent vary according to the model/algorithm type. In return, parties send their replies as a model update object, and these model updates are then aggregated according to the specified fusion algorithm, specified via a Fusion Handler class.
- Each party holds its own dataset that is kept to itself and used to answer queries received from the aggregator. Because each party may have stored data in different formats, IBM FL offers an abstraction called Data Handler. This module allows for custom implementations to retrieve the data from each of the participating parties. A Local Training Handler sits at each party to control the local training happening at the party side.
The aggregator and parties communicate through flask-based servers. All messages and communication are handled by protocol handler and connection modules. IBM FL is designed to allow different modules to be swapped in and out without interfering with the functionality of other components. Users can configure these components via configuration files (.yml files) for the aggregator and each party. The following is a list of building blocks that may appear in the configuration files.
- Connection: Connection types for communication between the aggregator and each registered party. It includes information needed to initiate the connection, in particular, flask server information, flask connection location, and a synchronization mode flag (sync) for training phase.
- Data: Data preparation and pre-processing. It includes information needed to initiate a data handler class, in particular, a given data path, a data handler class name and location.
- Fusion: Fusion algorithm to be performed at the aggregator. It includes information needed to initiate a fusion algorithm at the aggregator side, in particular, a fusion handler class name and location.
- Local_training: Local training algorithm happened at the party side. It includes information needed to initiate a local training handler at the party side, in particular, a local training handler class name and location.
- Hyperparams: It includes global training and local training hyperparameters.
- Model: Definition of the machine learning model to be trained. It includes information needed to initiate a machine learning model and its corresponding functionality, like train, predict, evaluate and save, etc. It includes a model class name and location, and a given model specification path.
- Protocol_handler: Handling the protocol for message exchange. It includes information needed to initiate a protocol, in particular, a protocol handler class name and location.
- Aggregator: aggregator server information (only for the parties’ configuration file).
We will use IBM FL to have multiple parties train a classifier[4] to recognize handwritten digits in the MNIST dataset[5]. We will be looking at problem to recognize digits from these tens of thousands of handwritten images.
Steps involved
- We begin by setting the number of parties that will participate in the federated learning run and splitting up the data among them.
- Prepare datasets for each participating parties. We are going to assume that the parties and the aggregator are run in the same machine. (if you want to try this in different machines, you can assign samples for each party locally.)
- Define model specification and create configuration files for the aggregator and parties.
The following YAML file(Fig3) is an example of the aggregator’s configuration file. In this example, flask is used as our connection type, and simple average fusion algorithm (select IterAvgFusionHandler in fusion section) to train the model. Moreover, the global training round is set as 3 and each global training round triggers parties to perform 3 local training epochs. Also the default protocol is used for the aggregator. Note that the aggregator does not specify a data file (data section) or maintain a global model (model section). Hence, during the federated learning process, it only keeps track of the current model parameters, i.e., current weights of the neural network. However, it is possible that the aggregator also has data for testing purposes and maintains a global model. When this is the case, one needs to add data and model sections in the configuration file.
The following YAML file(Fig4) is an example of the party’s configuration file. In this example, a flask connection is selected, therefore, the aggregator server information will be provided in the aggregator section. The party also specifies its data information in the data section, and in the model section, the Keras model definition is given as an .h5 file. Moreover, for simple average fusion algorithm, please select LocalTrainingHandler in local_training section. The default protocol is selected for parties' protocol.
4. Now pass the configuration parameters set to instantiate the Aggregator object and start the aggregator.
5. Start and register parties.
6. Initiate training from the aggregator. This will have 2 parties join the training and run 3 global rounds, each round with 3 local epochs.
7. Terminate the aggregator and parties processes. Remember to terminate the aggregator's and parties' processes and exit.
Summary
Applying Federated Learning requires machine learning practitioners to adopt new tools and a new way of thinking: model development, training, and evaluation with no direct access to or labeling of raw data, with communication cost as a limiting factor. The system needs to communicate and aggregate the model updates in a secure, efficient, scalable, and fault-tolerant way. It’s only the combination of research with this infrastructure that makes the benefits of Federated Learning possible.
This article was co-authored by Saichandra Pandraju and Sakthi Ganesh
References
[1] “Communication-Efficient Learning of Deep Networks from Decentralized Data”, https://arxiv.org/pdf/1602.05629.pdf
[2] “IBM Federated Learning”, https://github.com/IBM/federated-learning-lib
[3] “IBM FL supported functionalities”, https://github.com/IBM/federated-learning-lib/blob/main/README.md#supported-functionality
[4] “MNIST classifier in IBM FL”, https://github.com/IBM/federated-learning-lib/tree/main/Notebooks/keras_classifier
[5] “THE MNIST DATABASE of handwritten digits”, http://yann.lecun.com/exdb/mnist/