Kthyeon / FedSup

[Official] ICML 2022 workshop: Dynamic Neural Networks, Oral Paper, "SUPERNET TRAINING FOR FEDERATED IMAGE CLASSIFICATION UNDER SYSTEM HETEROGENEITY"

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Official] Supernet Training for Federated Image Classification under System Heterogeneity

This repository is the official implementation of "Supernet Training for Federated Image Classification under System Heterogeneity".

  • Paper, ICML 2022 Workshop: Dynamic Neural Networks, Oral Paper (link) This repository contains the code to design the federated supernet training mentioned in the paper (link).

Reference Codes

We refer to some official implementation codes

💻 Overview

This paper proposes the federation of supernet training (FedSup) framework to tackle both data heterogeneity and system heterogeneity issues simultaneously, i.e., where clients send and receive a supernet that contains all possible architectures sampled from itself. The approach is inspired by observing that averaging parameters during model aggregation for FL is similar to weight-sharing in supernet training. Thus, the proposed FedSup framework combines a weightsharing approach widely used for training single shot models with FL averaging (FedAvg). Furthermore, we develop an efficient algorithm (E-FedSup) by sending the sub-model to clients on the broadcast stage to reduce communication costs and training overhead, including several strategies to enhance supernet training in the FL environment. We verify the proposed approach with extensive empirical evaluations. The resulting framework also ensures data and model heterogeneity robustness on several standard benchmarks.

overview

  • (a) Standard FL: Despite the FL's popularity, significant issues remain regarding delivering compact models specialized for edge devices with widely diverse hardware platforms and efficiency constraints.
  • (b) Standard Supernet Training: This approach has been rarely considered under data heterogeneity scenarios that can provoke the training instability.
  • (c) FedSup: It uses weight sharing in supernet training to forward supernets to each local client and ensembles the sub-model training sampled from itself at each client.
  • (d) E-FedSup: As an efficient version of FedSup, it broadcasts sub-models to local clients in lieu of full supernets.

🤔 Getting Started

Requirements

  • This codebase is written for python3 (used python 3.7.6 while implementing).
  • To install necessary python packages, run pip install -r requirements.txt.

⚠️ Setup Local Package (IMPORTANT)

Register the project directory as a Python package to allow for absolute imports.

python3 setup.py develop

Dataset

CIFAR10, CIFAR100, PathMNIST, FashionMNIST

You don't have to take care about these dataset. Download options of these datasets are included in the codes.

⭐️ Supernet Training

We provide the neural network used in the paper, MobileNet style.

👊 Dynammic Neural Network with MobileNet Style

Search space details are presented by referring to the previous NAS and FL approaches (Cai et al., 2019; Oh et al., 2021). Our network architecture comprises a stack with MobileNet V1 blocks (Howard et al., 2017), and the detailed search space is summarized in Appendix. Arbitrary numbers of layers, channels, and kernel sizes can be sampled from the proposed network. Following previous settings (Yu et al., 2018; 2020), lower-index layers in each network stage are always kept. Both kernel size and channel numbers can be adjusted in a layer-wise manner.

You can handle the search space with the config file ./config/efficient_fedsup.yml. In this file, supernet_config presents the search space as follows:

supernet_config:
    resolutions: [32]
    first_conv: 
        c: [32]
        act_func: 'relu'
        s: 2
    mb1:
        c: [32,64]
        d: [1]
        k: [3,5,7]
        t: [1]
        s: 1
        act_func: 'relu'
        se: False
    mb2:
        c: [64,128]
        d: [1,2]
        k: [3,5,7]
        t: [1]
        s: 2
        act_func: 'relu'
        se: False
    mb3:
        c: [128,256] 
        d: [1,2]
        k: [3,5,7]
        t: [1]
        s: 2
        act_func: 'relu'
        se: False
    mb4:
        c: [512,1024] 
        d: [1,2]
        k: [3,5,7]
        t: [1]
        s: 2
        act_func: 'relu'
        se: False

🎓 Parametric Normalization

PN

Although batch normalization (BN) is an ubiquitous technique for most recent deep learning approaches, its running statistics can disturb learning for the sub-model alignment in supernet optimization because different lengths of representations can have heterogeneous statistics, and hence their moving average diverges for a shared BN layer (Yu & Huang, 2019; Yu et al., 2018) (Figure (a)). Furthermore, it is well-known that these BN statistics can violate the data privacy in FL (Li et al., 2021c). To alleviate these issues, we develop new normalization technique, termed as parameteric normalization (PN), as follows:

PN_formula

where x is an input batch data, RMS is the root mean square operator, γ, β are learnable parameters used in general batch normalization techniques, and α is a learnable parameter for approximating batch mean and variance. The key difference between our PN and previous normalization techniques is the existence of batch statistics parameters. In the previous works (Yu et al., 2018; Diao et al., 2021), since the running statistics of batch normalization layers can not be accumulated during local training in FL owing to the violence of data privacy as well as different model size (Huang et al., 2021), running statistics are not tracked. An adhoc solution employs static batch normalization (Diao et al., 2021) for model heterogeneity in FL, with the running statistics updated as each test data is sequentially queried for evaluation (Figure (b)). However, operational performance is highly dependent on such query data and can be easily degraded by its running statistics. In contrast, the proposed PN method simply uses an RMS norm and learnable parameter α rather than batch-wise mean and variance statistics, and thus does not require any running statistics. Hence more robust architectures can be trained towards batch data (Figure (c)). In a nutshell, our PN methods not only eliminates privacy infringement concerns due to running statistics, but also enables slimmable normalization techniques robust towards query data.

Here is a snippet of the parametric normalization.

class PNorm(nn.Module):
    def __init__(self, d, p=-1., eps=1e-8, bias_switch=True):
        """
        Parametric Normalization
        :param d: model size
        :param p: partial PNorm, valid value [0, 1], default -1.0 (disabled)
        :param eps:  epsilon value, default 1e-8
        :param bias: whether use bias term for PNorm, disabled by
            default because PNorm doesn't enforce re-centering invariance.
        """
        super(PNorm, self).__init__()

        self.eps = eps
        self.d = d
        self.p = p
        self.bias_switch = bias_switch

        self.weight = nn.Parameter(torch.ones(d))
        self.register_parameter("weight", self.weight)

        if self.bias_switch:
            self.bias = nn.Parameter(torch.zeros(2, d))
            self.register_parameter("bias", self.bias)

    def forward(self, x, feature_dim=None):
        if feature_dim == None:
            feature_dim = self.d
        if self.p < 0. or self.p > 1.:
            norm_x = ((x  - self.bias[0,:feature_dim].unsqueeze(0).unsqueeze(2).unsqueeze(3))).norm(
                2, dim=(-4, -2,-1), keepdim=True)
            d_x = feature_dim
        else:
            partial_size = int(feature_dim * self.p)
            partial_x, _ = torch.split(x, [partial_size, feature_dim - partial_size], dim=(-2,-1))

            norm_x = partial_x.norm(2, dim=(-2,-1), keepdim=True)
            d_x = partial_size

        rms_x = norm_x * ((x.shape[0]* x.shape[-1]* x.shape[-2]) ** (-1/2))
        x_normed = (x-self.bias[0,:feature_dim].unsqueeze(0).unsqueeze(2).unsqueeze(3)) / (rms_x + self.eps)

        if self.bias_switch:
            return self.weight[:feature_dim].unsqueeze(0).unsqueeze(2).unsqueeze(3) * x_normed + self.bias[1,:feature_dim].unsqueeze(0).unsqueeze(2).unsqueeze(3)
        return self.weight[:feature_dim].unsqueeze(0).unsqueeze(2).unsqueeze(3) * x_normed

💻 Code Implementation

You can get the results of E-Fedsup with the following line:

I. Shard per users

python main.py --shard_per_user 10 --dataset cifar100 --label_smoothing 0.0 \
               --gpu 0 --frac 0.1 --local_ep 5 --momentum 0.9 \
               --inplace_distill False --num_arch_training 1 \
               --drop_connect 0.0 --dropout 0.0 --results_save "vanilla" \
               --diri "False" --config-file config/efficient_fedsup.yml

II. Dirichlet distribution

python main.py --beta 0.1 --dataset cifar100 --label_smoothing 0.0 \
               --gpu 0 --frac 0.1 --local_ep 5 --momentum 0.9 \
               --inplace_distill False --num_arch_training 1 \
               --drop_connect 0.0 --dropout 0.0 --results_save "vanilla" \
               --diri "True" --config-file config/efficient_fedsup.yml

Script for running

For much easier implementation, refer to ./scripts folder.

Results

You can reproduce all results in the paper with our code. All results have been described in our paper including Appendix. The results of our experiments are so numerous that it is difficult to post everything here. However, if you experiment several times by modifying the hyperparameter value in the .sh file, you will be able to reproduce all of our analysis.

About

[Official] ICML 2022 workshop: Dynamic Neural Networks, Oral Paper, "SUPERNET TRAINING FOR FEDERATED IMAGE CLASSIFICATION UNDER SYSTEM HETEROGENEITY"

License:MIT License


Languages

Language:Python 97.0%Language:Shell 3.0%