manastahir / Saliency-detection-using-knowledge-distillation

Exploring the concept of knowledge distillation on the task of Saliency detection.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Saliency-detection-using-knowledge-distillation

Exploring the concept of knowledge distillation on the task of Saliency detection.


Project on a really interesting topic of knowledge distillation. To know about the knowledge distillation concept read this article : Article.

About the project

Model

Teacher model

I applied the knowledge distillation to the task of saliency detection. The MSI-NET presented in the paper Contextual encoder–decoder network for visual saliency prediction is used as the teacher model. Teacher model has approximately 25 million parameters. This github repository provides the implementaion of the model in tensorflow, dataset links and related files.

Student Model

The student model simply consists of Resnet18 backbone, the output from the last residual block is passed to a series of UpScale2D and Convolution2D layers. This smaller model has approximately 12 million parameters (about half the size of the teacher model). For comparison of simple training with teacher-student training two variants of this model were implemnted.

Model A: simple model with unbranched decoder and produces just one output: a saliency map.
Model B: student model with last layer of the decoder branched to produce 2 outputs: a saliency map and a feature map for correspondance with teacher output.

Data

Dataset used is the SALICON dataset Link
Along with the original dataset the output from the teacher model is also generated for both trainining and validation images(stimuli) and stored.
As for augmentaion inputs were randomly flipped vertically and horizontally.

python ./saliency/main.py test -p ./data/salicon/stimuli/train
python ./saliency/main.py test -p ./data/salicon/stimuli/val

Training

The smoothing factor of 3 is used for the student model training. KL-Divergence is used as the loss fucnation and the evaluation metric to measure the performance of the model. Learning rate was set to 1e-3 and Adam optimizer was used. Both models were trained for 10 epochs.

Note: Number of epochs is not optimal, it was selected just to show the difference between the two types of training techniques.

Results

KL Divergence graph for Model A

KL Divergence graph for Model B

Conclusion

Although the comparison is not intensive and hyper-parameters can be tuned further but, the results show, that despite having same hyper parameters and almost similar architechure and number of parameters, the model fail to converge with simple training. With knowledge distilaltion a smaller model, in this case model with half the size as the original model, can be effectively trained to mimic the behaviour of the teacher and can give compareale performance even when the task is complex.

About

Exploring the concept of knowledge distillation on the task of Saliency detection.


Languages

Language:Jupyter Notebook 100.0%