zhouzaida / channel-distillation

PyTorch implementation for Channel Distillation

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Channel Distillation

PyTorch implement of Channel Distillation: Channel-Wise Attention for Knowledge Distillation


  1. Channel Distillation (CD)
  2. Guided Knowledge Distillation (GKD)
  3. Early Decay Teacher (EDT)


In our code, kdv2 means GKD and lrd2 means EDT.

Structure of Repository

├── cifar_config.py  # Hyperparameters
├── cifar_train.py
├── data
│   └── directory_of_data.md
├── imagenet_config.py  # Hyperparameters
├── imagenet_train.py
├── losses
│   ├── cd_loss.py  # CD Loss
│   ├── ce_loss.py
│   ├── __init__.py
│   └── kd_loss.py  # GKD Loss
├── models
│   ├── channel_distillation.py  # Distillation Network
│   ├── __init__.py
│   └── resnet.py
├── pretrain
│   └── path_of_teacher_checkpoint.md
├── README.md
└── utils
    ├── average_meter.py
    ├── data_prefetcher.py
    ├── __init__.py
    ├── logutil.py
    ├── metric.py
    └── util.py  # Early Decay Teacher


python >= 3.7
torch >= 1.4.0
torchvision >= 0.5.0



Prepare Dataset

images should be arranged in this way




Teacher checkpoint will be downloaded automatically.

Running the following command and experiment will be launched.

CUDA_VISIBLE_DEVICES=0 python3 ./imagenet_train.py

If you want to run other experiments, you just need modify following losses in imagenet_config.py

  • s_resnet18.t_resnet34.cd.ce
loss_list = [
    {"loss_name": "CELoss", "loss_rate": 1, "factor": 1, "loss_type": "ce_family", "loss_rate_decay": "lrdv1"},
    {"loss_name": "CDLoss", "loss_rate": 6, "factor": 1, "loss_type": "fd_family", "loss_rate_decay": "lrdv1"},
  • s_resnet18.t_resnet34.cd.ce.kdv2
loss_list = [
    {"loss_name": "CELoss", "loss_rate": 1, "factor": 1, "loss_type": "ce_family", "loss_rate_decay": "lrdv1"},
    {"loss_name": "KDLossv2", "T": 1, "loss_rate": 1, "factor": 1, "loss_type": "kdv2_family", "loss_rate_decay": "lrdv1"},
    {"loss_name": "CDLoss", "loss_rate": 6, "factor": 0.9, "loss_type": "fd_family", "loss_rate_decay": "lrdv1"},
  • s_resnet18.t_resnet34.cd.kdv2.lrdv2
loss_list = [
    {"loss_name": "CELoss", "loss_rate": 1, "factor": 1, "loss_type": "ce_family", "loss_rate_decay": "lrdv2"},
    {"loss_name": "KDLossv2", "T": 1, "loss_rate": 1, "factor": 1, "loss_type": "kdv2_family", "loss_rate_decay": "lrdv2"},
    {"loss_name": "CDLoss", "loss_rate": 6, "factor": 0.9, "loss_type": "fd_family", "loss_rate_decay": "lrdv2"},


  • Ablation Study
Method Model Top-1 error(%) Top-5 error(%)
teacher ResNet34 26.73 8.74
student ResNet18 30.43 10.76
KD ResNet34-ResNet18 29.50 9.52
CD(our) ResNet34-ResNet18 28.53 9.56
CD+GKD(our) ResNet34-ResNet18 28.26 9.41
CD+GKD+EDT(our) ResNet34-ResNet18 27.61 9.2
  • Comparion result with other methods
Method Model Top-1 error(%) Top-5 error(%)
teacher ResNet34 26.73 8.74
student ResNet18 30.43 10.76
KD ResNet34-ResNet18 29.50 9.52
FitNets ResNet34-ResNet18 29.34 10.77
AT ResNet34-ResNet18 29.30 10.00
RKD ResNet34-ResNet18 28.46 9.74
CD+GKD+EDT(our) ResNet34-ResNet18 27.61 9.2


Prepare Dataset

CIFAR100 dataset will be downloaded automatically.



Download the teacher checkpoint from here
Then, put the checkpoint in the pretrain directory

Running the following command and experiment will be launched.

CUDA_VISIBLE_DEVICES=0 python3 ./cifar_train.py

If you want to run other experiments, you just need modify following losses in cifar_config.py

  • s_resnet18.t_resnet34.cd.ce
loss_list = [
    {"loss_name": "CELoss", "loss_rate": 1, "factor": 1, "loss_type": "ce_family", "loss_rate_decay": "lrdv1"},
    {"loss_name": "CDLoss", "loss_rate": 6, "factor": 1, "loss_type": "fd_family", "loss_rate_decay": "lrdv1"},
  • s_resnet18.t_resnet34.cd.ce.kdv2
loss_list = [
    {"loss_name": "CELoss", "loss_rate": 1, "factor": 1, "loss_type": "ce_family", "loss_rate_decay": "lrdv1"},
    {"loss_name": "KDLossv2", "T": 1, "loss_rate": 0.1, "factor": 1, "loss_type": "kdv2_family", "loss_rate_decay": "lrdv1"},
    {"loss_name": "CDLoss", "loss_rate": 6, "factor": 0.9, "loss_type": "fd_family", "loss_rate_decay": "lrdv1"},
  • s_resnet18.t_resnet34.cd.kdv2.lrdv2
loss_list = [
    {"loss_name": "CELoss", "loss_rate": 1, "factor": 1, "loss_type": "ce_family", "loss_rate_decay": "lrdv2"},
    {"loss_name": "KDLossv2", "T": 1, "loss_rate": 0.1, "factor": 1, "loss_type": "kdv2_family","loss_rate_decay": "lrdv2"},
    {"loss_name": "CDLoss", "loss_rate": 6, "factor": 0.9, "loss_type": "fd_family", "loss_rate_decay": "lrdv2"},


  • Ablation Study
Method Model Top-1 error(%) Top-5 error(%)
teacher ResNet152 19.09 4.45
student ResNet50 22.02 5.74
KD ResNet152-ResNet50 20.36 4.94
CD(our) ResNet152-ResNet50 20.08 4.78
CD+GKD(our) ResNet152-ResNet50 19.49 4.85
CD+GKD+EDT(our) ResNet152-ResNet50 18.63 4.29


PyTorch implementation for Channel Distillation


Language:Python 100.0%