Pytorch implementation of various Knowledge Distillation (KD) methods.
Name | Method | Paper Link |
---|---|---|
Baseline | basic model with softmax loss | — |
ST | soft target | paper |
Fitnet | hints for thin deep nets | paper |
RKD | relational knowledge distillation | paper |
- CIFAR100
- Resnet-20
- Resnet-110
The networks are same with Tabel 6 in paper.
- Creating
./dataset
directory and downloading CIFAR10 in it. - Using the script
example_train_script.sh
to train various KD methods. You can simply specify the hyper-parameters listed intrain_xxx.py
or manually change them. - Some Notes:
- We assume the size (C, H, W) of features between teacher and student are the same. If not, you could employ 1*1 conv, linear or pooling to rectify them.
- The trained baseline models are used as teachers. For fair comparison, all the student nets have same initialization with the baseline models.
- The initial models, trained models and training logs are uploaded here.
- The trade-off parameter
--lambda_kd
and other hyper-parameters are not chosen carefully. Thus the following results do not reflect which method is better than the others.
Teacher | Student | Name | CIFAR10 |
resnet-110 | resnet-20 | Baseline | 93% |
resnet-110 | resnet-20 | ST | 85.34% |
resnet-110 | resnet-20 | Fitnet | 85.57% |
resnet-110 | resnet-20 | RKD | 87% |
- python 3.7
- pytorch 1.3.1
- torchvision 0.4.2