abhayraw1 / transfer_learning

Transfer Learning

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Transfer learning

Transfer learing is technique used to leverage the knowledge learned by a model in some other task. It is in a way similar to how humans learn as well. For any task we do not start from scratch and use preexisting knowledge to help learn newer task with ease.

As deep learning models learn different feature representations at different layers, we can use a pre-trained model and use its feature extracting capabilities to our advantage. We can achieve this in two ways. By replacing the last classification layer of the model with a newly initialized one and then training the model by either training all the parameters or only the last layer.

To run

First download the ant and bees datasets from here. Or the cats and dogs dataset from here. Make sure the data directory structure is as follows:

data/
  |--- dataset1/
  |       |---train/
  |       |     |---class1/
  |       |     |---class2/
  |       |---val
  |             |---class1/
  |             |---class2/
  |
  |--- dataset2/
  |       |---train/
  |       |     |---class1/
  |       |     |---class2/
  |       |---val/
  |             |---class1/
  |             |---class2/

Then run:

$ python main.py <path-to-dataset-dir> --freeze --num-epochs 10 --exp-name <experiment-name>

Remove the freeze argument to train all the parameters.

Results

A pretrained resnet-18 model was used in this experiment (link to pretrained model). The ants and bees dataset contains 244 train images and 153 images for validation. The cats and dogs dataset contains 500 images for both training and validation. A linear learning rate scheduler was also used to decay the learning rate every few epochs in these experiments.

Ants vs Bees dataset

Ants vs Bees dataset

Cats vs Dogs dataset

Ants vs Bees dataset

References

  1. ResNet Model
  2. Pytorch Tutorial Transfer Learning
  3. A Comprehensive Hands-on Guide to Transfer Learning with Real-World Applications in Deep Learning by Dipanjan (DJ) Sarkar

About

Transfer Learning


Languages

Language:Python 100.0%