ShawonAshraf / frutify-torch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

frutify-torch

DOI

Python PyTorch NumPy

A pytorch rewrite of the Fruit Image Classification project.

A little note

Much of the original dataset was lost due to hardware failure and has made it difficult to reproduce the results in our paper on the project. The models trained here are trained with recovered data and results may vary. So, in that sense, this repo is not an actual reproduction of the original work.

Pre trained models in use

Model source : Torchvision

  • inception-v3
  • resnet101

All the layers of the models were frozen and the results from the final layer were passed through a linear layer and tuned.

Dataset Information

  • 2233 images
  • 8 classes / labels
  • fresh_apple, rotten_apple, fresh_orange, rotten_orange, fresh_banana, rotten_banana, fresh_mango, rotten_mango
  • Link to dataset: Dataset

Training Setup

Pretrained model name Epochs Batch Size Split Ratio (Train:Valid) Learning Rate Optimizer
Inception-V3 5 16 0.8 0.001 Adam
Resnet101 5 16 0.8 0.001 Adam

Evaluation

Pretrained model name Accuracy F1(macro) F1(weighted) Precision(macro) Precision(weighted) Recall(macro) Recall(weighted)
Inception-V3 0.83 0.83 0.83 0.83 0.84 0.83 0.83
Resnet101 0.93 0.93 0.93 0.93 0.93 0.93 0.93

Detailed

Graphs (from comet.ml)

training

The interactive version of this graph can be found here.

Metrics

*************************
inception-v3


               precision    recall  f1-score   support

  fresh_apple       0.80      0.86      0.83        43
 rotten_apple       0.72      0.68      0.70        19
 fresh_orange       0.76      0.81      0.79        27
rotten_orange       0.82      0.77      0.79        30
 fresh_banana       0.94      0.91      0.92        33
rotten_banana       0.85      1.00      0.92        17
  fresh_mango       0.91      0.88      0.90        34
 rotten_mango       0.83      0.71      0.77        21

     accuracy                           0.83       224
    macro avg       0.83      0.83      0.83       224
 weighted avg       0.84      0.83      0.83       224

*************************

*************************
resnet101

               precision    recall  f1-score   support

  fresh_apple       0.95      0.98      0.97        43
 rotten_apple       0.89      0.84      0.86        19
 fresh_orange       0.87      0.96      0.91        27
rotten_orange       0.93      0.83      0.88        30
 fresh_banana       1.00      1.00      1.00        33
rotten_banana       1.00      1.00      1.00        17
  fresh_mango       0.94      0.91      0.93        34
 rotten_mango       0.86      0.90      0.88        21

     accuracy                           0.93       224
    macro avg       0.93      0.93      0.93       224
 weighted avg       0.93      0.93      0.93       224

*************************

ENV setup

# Using Anaconda / Miniconda
conda env create -f fruit.yml

conda activate frutify

# Clone the repository, afterwards
cd frutify-torch
# create a directory for saving models if you're training
mkdir saved_models/ 

Saved models

Already trained models can be found here which you can use to run inference :

Model Link
InceptionV3 Link
Resnet101 Link

Testing

Download and store the saved model in saved_models directory. Or you can save elsewhere and pass the path to the script.

python test.py --inception_path --resnet_path --split 0.8 --batch_size --saved_path
# example
python test.py --inception_path "./saved_models/inception-v3_5_16_0.001_1626740279.94928.ckpt" --resnet_path "./saved_models/resnet101_5_16_0.001_1626743414.734695.ckpt" --batch_size 16 --split 0.8  

# note : if you have a powerful multicore CPU, you may want to use the --num_workers option to speed up
# data loading, pass the number of cores you want to use.
python test.py --inception_path "./saved_models/inception-v3_5_16_0.001_1626740279.94928.ckpt" --resnet_path "./saved_models/resnet101_5_16_0.001_1626743414.734695.ckpt" --num_workers 2 --batch_size 16 --split 0.8  

For more command line options check test.py

Training

Download the dataset from the provided link and copy the dataset directory from the zip archive to the project directory. Note: Use a GPU(>= 8GB VRAM), unless you want to wait for 20 minutes++ for each epoch to finish.

python trainer.py --model  --device  --split  --batch_size  --epochs  --lr 
# example
python trainer.py --model resnet101 --device gpu --split 0.8 --batch_size 16 --epochs 5 --lr 1e-3

# note : if you have a powerful multicore ( > 4) CPU, you may want to use the --num_workers option to speed up
# data loading, pass the number of cores you want to use.
python trainer.py --model resnet101 --device gpu --num_workers 12 --split 0.8 --batch_size 16 --epochs 5 --lr 1e-3

For more command line options check trainer.py

Running with comet.ml logger

comet.ml logger doesn't work with multiple workers, which is a known issue. So if you want to use comet.ml for model training visualization, don't use the num_workers option. (It'll be slower but this is the only way, sadly!).

# example for using comet ml logger
python trainer.py --model inception-v3 --device gpu --split 0.8 --batch_size 16 --epochs 5 --lr 1e-3 --comet True

Check their website on how to get an API key and get started with pytorch and pytorch-lightning.

Other implementations

One of the co-authors, Md Abdul Ahad Chowdhury implemented a C# based rewrite of the project, which you can find here.

Citing the original paper

@article{ashraf2019fruit,
  title={Fruit Image Classification Using Convolutional Neural Networks},
  author={Ashraf, Shawon and Kadery, Ivan and Chowdhury, Md Abdul Ahad and Mahbub, Tahsin Zahin and Rahman, Rashedur M},
  journal={International Journal of Software Innovation (IJSI)},
  volume={7},
  number={4},
  pages={51--70},
  year={2019},
  publisher={IGI Global}
}