shashankag14 / Federated-Learning

A PyTorch implementation of Federated Learning from scratch based on the paper "Communication-Efficient Learning of Deep Networks from Decentralized Data"

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Federated Learning - Your Data Stays With You !

A PyTorch implementation of Federated Learning from scratch partially based on the paper Communication-Efficient Learning of Deep Networks from Decentralized Data. Ithas been implemented using MNIST dataset.

Federated learning (FL) is an approach that downloads the current model and computes an updated model at the device itself (ala edge computing) using local data. These locally trained models are then sent from the devices back to the central server where they are aggregated, i.e. averaging weights, and then a single consolidated and improved global model is sent back to the devices.

Getting started

  • To install the required libraries, run the following script :

sh requirements.sh

  • Run the following command to train using Federated Learning :
python3 run_federated.py [-h] [--data_dir DATA_DIR] [--batch_size BATCH_SIZE]
                        [--epoch EPOCH] [--global_epoch GLOBAL_EPOCH]
                        [--local_epoch LOCAL_EPOCH] [--init_lr INIT_LR]
                        [--num_clients NUM_CLIENTS]
                        [--num_select_clients NUM_SELECT_CLIENTS]

  • Run the following command to train without Federated Learning (for reference):
python3 run_baseline.py [-h] [--data_dir DATA_DIR] [--batch_size BATCH_SIZE]
                        [--epoch EPOCH] [--global_epoch GLOBAL_EPOCH]
                        [--local_epoch LOCAL_EPOCH] [--init_lr INIT_LR]
                        [--num_clients NUM_CLIENTS]
                        [--num_select_clients NUM_SELECT_CLIENTS]

Note : Test phase is automatically started immediately after training is finished.

Hyperparams

Parameters Description Value used
--epoch Number of epochs for baseline training 15
--batch_size Batch size 100
--global_epoch [ONLY FOR FED_LEARNING] Number of global epochs (updates to server) 5
--local_epoch [ONLY FOR FED_LEARNING] Number of epochs for clients to train per global epoch 5
--init_lr Initial learning rate 5e-5
--num_clients Total number of clients 8
--num_select_clients Number of randomly selected clients for local training 4

Results of Federated Learning

  • Test Accuracy = 98.5%
  • Test Loss = 0.048

About

A PyTorch implementation of Federated Learning from scratch based on the paper "Communication-Efficient Learning of Deep Networks from Decentralized Data"


Languages

Language:Python 98.7%Language:Shell 1.3%