Codebase of our TaT on ImageNet. Refer to TaT-seg for the experiments on semantic segmentation.
Executable code can be found in examples/image_classification.py. The implementation of TaT is AttnEmbed. The loss function MaskedFM is decoupled with the model.
- This codebase currently do not support resume. However, it allows you to load a pre-trained model for specific purposes, i.e., distilling a contrastive learning model.
- The classification model is wrapped with the learnable KD parameters. Please be careful on the model parameters you want to save.
If you would like to customize your own model, please put all the learnable parameters on here. And you can set up the calculation of the loss funcion on here.
We use the Forward Hook to extract the intermediate representations. Just modify the yaml file to access the model layers of your interest. This example notebook will give you a better idea of the usage. You may refer to our config.
- Python 3.7
- pytorch 1.5
- einops
- ml-collection
Please modify the ImageNet path of the config.
We use 8 GPUs with 256 images per GPU.
sh ./train_local.sh
sh ./test_local.sh
Feel free to create an issue if you get a question or just email me ( sihao.lin@student.rmit.edu.au ).
This repo is built upon torchdistill. Thanks to Yoshitomo.