A pytorch
rewrite of the Fruit Image Classification project.
Much of the original dataset was lost due to hardware failure and has made it difficult to reproduce the results in our paper on the project. The models trained here are trained with recovered data and results may vary. So, in that sense, this repo is not an actual reproduction of the original work.
Model source : Torchvision
inception-v3
resnet101
All the layers of the models were frozen and the results from the final layer were passed through a linear layer and tuned.
- 2233 images
- 8 classes / labels
fresh_apple
,rotten_apple
,fresh_orange
,rotten_orange
,fresh_banana
,rotten_banana
,fresh_mango
,rotten_mango
- Link to dataset: Dataset
Pretrained model name | Epochs | Batch Size | Split Ratio (Train:Valid) | Learning Rate | Optimizer |
---|---|---|---|---|---|
Inception-V3 | 5 | 16 | 0.8 | 0.001 | Adam |
Resnet101 | 5 | 16 | 0.8 | 0.001 | Adam |
Pretrained model name | Accuracy | F1(macro) | F1(weighted) | Precision(macro) | Precision(weighted) | Recall(macro) | Recall(weighted) |
---|---|---|---|---|---|---|---|
Inception-V3 | 0.83 | 0.83 | 0.83 | 0.83 | 0.84 | 0.83 | 0.83 |
Resnet101 | 0.93 | 0.93 | 0.93 | 0.93 | 0.93 | 0.93 | 0.93 |
The interactive version of this graph can be found here.
*************************
inception-v3
precision recall f1-score support
fresh_apple 0.80 0.86 0.83 43
rotten_apple 0.72 0.68 0.70 19
fresh_orange 0.76 0.81 0.79 27
rotten_orange 0.82 0.77 0.79 30
fresh_banana 0.94 0.91 0.92 33
rotten_banana 0.85 1.00 0.92 17
fresh_mango 0.91 0.88 0.90 34
rotten_mango 0.83 0.71 0.77 21
accuracy 0.83 224
macro avg 0.83 0.83 0.83 224
weighted avg 0.84 0.83 0.83 224
*************************
*************************
resnet101
precision recall f1-score support
fresh_apple 0.95 0.98 0.97 43
rotten_apple 0.89 0.84 0.86 19
fresh_orange 0.87 0.96 0.91 27
rotten_orange 0.93 0.83 0.88 30
fresh_banana 1.00 1.00 1.00 33
rotten_banana 1.00 1.00 1.00 17
fresh_mango 0.94 0.91 0.93 34
rotten_mango 0.86 0.90 0.88 21
accuracy 0.93 224
macro avg 0.93 0.93 0.93 224
weighted avg 0.93 0.93 0.93 224
*************************
# Using Anaconda / Miniconda
conda env create -f fruit.yml
conda activate frutify
# Clone the repository, afterwards
cd frutify-torch
# create a directory for saving models if you're training
mkdir saved_models/
Already trained models can be found here which you can use to run inference :
Model | Link |
---|---|
InceptionV3 | Link |
Resnet101 | Link |
Download and store the saved model in saved_models
directory. Or you can save elsewhere and pass the path to the script.
python test.py --inception_path --resnet_path --split 0.8 --batch_size --saved_path
# example
python test.py --inception_path "./saved_models/inception-v3_5_16_0.001_1626740279.94928.ckpt" --resnet_path "./saved_models/resnet101_5_16_0.001_1626743414.734695.ckpt" --batch_size 16 --split 0.8
# note : if you have a powerful multicore CPU, you may want to use the --num_workers option to speed up
# data loading, pass the number of cores you want to use.
python test.py --inception_path "./saved_models/inception-v3_5_16_0.001_1626740279.94928.ckpt" --resnet_path "./saved_models/resnet101_5_16_0.001_1626743414.734695.ckpt" --num_workers 2 --batch_size 16 --split 0.8
For more command line options check test.py
Download the dataset from the provided link and copy the dataset
directory from the zip archive to the project directory.
Note: Use a GPU(>= 8GB VRAM), unless you want to wait for 20 minutes++ for each epoch to finish.
python trainer.py --model --device --split --batch_size --epochs --lr
# example
python trainer.py --model resnet101 --device gpu --split 0.8 --batch_size 16 --epochs 5 --lr 1e-3
# note : if you have a powerful multicore ( > 4) CPU, you may want to use the --num_workers option to speed up
# data loading, pass the number of cores you want to use.
python trainer.py --model resnet101 --device gpu --num_workers 12 --split 0.8 --batch_size 16 --epochs 5 --lr 1e-3
For more command line options check trainer.py
comet.ml logger doesn't work with multiple workers, which is a known issue. So if you want to use comet.ml for model training
visualization, don't use the num_workers
option. (It'll be slower but this is the only way, sadly!).
# example for using comet ml logger
python trainer.py --model inception-v3 --device gpu --split 0.8 --batch_size 16 --epochs 5 --lr 1e-3 --comet True
Check their website on how to get an API key and get started with pytorch and pytorch-lightning.
One of the co-authors, Md Abdul Ahad Chowdhury implemented a C# based rewrite of the project, which you can find here.
@article{ashraf2019fruit,
title={Fruit Image Classification Using Convolutional Neural Networks},
author={Ashraf, Shawon and Kadery, Ivan and Chowdhury, Md Abdul Ahad and Mahbub, Tahsin Zahin and Rahman, Rashedur M},
journal={International Journal of Software Innovation (IJSI)},
volume={7},
number={4},
pages={51--70},
year={2019},
publisher={IGI Global}
}