Classifying fruits and vegetables via transfer learning
Today, we will explore on an existing image classifier - the Resnet18 model, and try to apply transfer learning to classify 131 types of fruits and vegetables.
Letโs cover this article briefly:
- We will be using the PyTorch framework to train our model.
- We will also be using Kaggle to write our code and train the model using a GPU.
You will require an account for Kaggle. - The dataset will be from the Fruit-Images-Dataset.
- Total number of images is 90,483.
- Total number of training images is 67,692 (one fruit or vegetable per image).
- Total number of testing images is 22,688 (one fruit or vegetable per image).
- You can get all the code and training steps from Jovian.
- I will only use function calls here so as to avoid cluttering the page.
You may refer to the link above for the function definitions.
We will be exploring the dataset to have a better understanding of it.
In the notebook, let's first clone the dataset from GitHub.
!git clone https://github.com/Horea94/Fruit-Images-Dataset
Import all the required libraries.
import os
import torch
import matplotlib.pyplot as plt
from torch import nn
from torch import optim
from os.path import join
from torchvision import models
from tqdm.notebook import tqdm
from torchvision import transforms
from torch.nn import functional as F
from torchvision.datasets import ImageFolder
from torch.utils.data import Subset, DataLoader, random_split
Let's also check the type of device we are running.
Make sure to enable GPU on Kaggle.
device = get_default_device()
print(device)
Next, we will import and load the datasets into pytorch.
We will be splitting the existing test dataset into 50% validation and 50% test.
# Dataset folders and loaders
DATA_DIR = join(os.getcwd(), 'Fruit-Images-Dataset')
TRAIN_DIR = join(DATA_DIR, 'Training')
TEST_DIR = join(DATA_DIR, 'Test')
data_transformer = transforms.Compose([transforms.ToTensor()])
test_dataset_split = ImageFolder(TEST_DIR, transform=data_transformer)
n = len(test_dataset_split)
n_validation = int(validation_pct_of_test * n)
train_dataset = ImageFolder(TRAIN_DIR, transform=data_transformer)
validation_dataset = Subset(test_dataset_split, range(n_validation))
test_dataset = Subset(test_dataset_split, range(n_validation, n))
train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True, num_workers=2, pin_memory=True)
validation_dataloader = DataLoader(validation_dataset, batch_size, shuffle=True, num_workers=2, pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size, shuffle=True, num_workers=2, pin_memory=True)
train_dataloader = DeviceDataLoader(train_dataloader, device)
validation_dataloader = DeviceDataLoader(validation_dataloader, device)
test_dataloader = DeviceDataLoader(test_dataloader, device)
The training dataset overview is shown below.
# Information on the dataset
print('Number of training dataset:', len(train_dataset))
print('Number of validation dataset:', len(validation_dataset))
print('Number of testing dataset:', len(test_dataset))
print('Number of classes:',len(train_dataset.classes))
[print(f'{idx}๏ผ {cls}') for idx, cls in enumerate(train_dataset.classes)]
Let's take a peek at some of the training images.
# Information on a single image
images, labels = next(iter(train_dataloader))
plot_img(images, labels)
print('Image shape:', images[0].shape)
Setup the hyperparameters, initialize the model, then start training
# Set hyperparams, initialize model, then start training
num_epochs = 25
batch_size = 32
learning_rate = 1e-4
momentum = .9
opt_func = torch.optim.SGD
model = to_device(FruitModel(), device)
history = fit(num_epochs, learning_rate, model, train_dataloader, validation_dataloader, opt_func)
We are training our resnet-18 model for 25 epochs with a batch size of 32, and feeding it with the fruits and vegetables images. As the epochs progresses, the loss gradually goes down and accuracy goes up. Validation accuracy remains stable, which indicates that the model has yet to overfit over the 25 epochs.
Epoch [0], train_loss: 4.6832, val_loss: 4.5394, train_acc: 25.00%, val_acc: 45.75%
Epoch [1], train_loss: 4.4252, val_loss: 4.3475, train_acc: 83.33%, val_acc: 80.10%
Epoch [2], train_loss: 4.2717, val_loss: 4.2281, train_acc: 83.33%, val_acc: 88.16%
Epoch [3], train_loss: 4.1752, val_loss: 4.1462, train_acc: 100.00%, val_acc: 91.95%
Epoch [4], train_loss: 4.1144, val_loss: 4.0996, train_acc: 100.00%, val_acc: 92.58%
Epoch [5], train_loss: 4.0740, val_loss: 4.0661, train_acc: 83.33%, val_acc: 94.09%
Epoch [6], train_loss: 4.0457, val_loss: 4.0409, train_acc: 100.00%, val_acc: 94.41%
Epoch [7], train_loss: 4.0248, val_loss: 4.0232, train_acc: 100.00%, val_acc: 95.17%
Epoch [8], train_loss: 4.0088, val_loss: 4.0109, train_acc: 91.67%, val_acc: 95.69%
Epoch [9], train_loss: 3.9958, val_loss: 4.0013, train_acc: 91.67%, val_acc: 95.76%
Epoch [10], train_loss: 3.9855, val_loss: 3.9911, train_acc: 100.00%, val_acc: 95.98%
Epoch [11], train_loss: 3.9771, val_loss: 3.9843, train_acc: 91.67%, val_acc: 95.81%
Epoch [12], train_loss: 3.9700, val_loss: 3.9790, train_acc: 91.67%, val_acc: 96.63%
Epoch [13], train_loss: 3.9641, val_loss: 3.9713, train_acc: 100.00%, val_acc: 96.59%
Epoch [14], train_loss: 3.9586, val_loss: 3.9679, train_acc: 100.00%, val_acc: 96.94%
Epoch [15], train_loss: 3.9540, val_loss: 3.9634, train_acc: 100.00%, val_acc: 96.99%
Epoch [16], train_loss: 3.9500, val_loss: 3.9594, train_acc: 100.00%, val_acc: 96.88%
Epoch [17], train_loss: 3.9466, val_loss: 3.9566, train_acc: 100.00%, val_acc: 97.05%
Epoch [18], train_loss: 3.9433, val_loss: 3.9536, train_acc: 100.00%, val_acc: 97.40%
Epoch [19], train_loss: 3.9404, val_loss: 3.9503, train_acc: 100.00%, val_acc: 97.56%
Epoch [20], train_loss: 3.9378, val_loss: 3.9476, train_acc: 91.67%, val_acc: 97.62%
Epoch [21], train_loss: 3.9354, val_loss: 3.9464, train_acc: 100.00%, val_acc: 97.86%
Epoch [22], train_loss: 3.9334, val_loss: 3.9441, train_acc: 100.00%, val_acc: 97.67%
Epoch [23], train_loss: 3.9314, val_loss: 3.9431, train_acc: 100.00%, val_acc: 97.77%
Epoch [24], train_loss: 3.9297, val_loss: 3.9404, train_acc: 91.67%, val_acc: 97.93%
Let's visualize the data above in a graph.
The training and validation accuracies are not far off from each other, with a validation accuracy of around 97%.
# Plot train-validation accuracy
train_acc = [i['train_correct'] / i['train_total'] for i in history]
val_acc = [i['val_correct'] / i['val_total'] for i in history]
plot_chart('Train-Validation Accuracy', ['train', 'validation'], [train_acc, val_acc], 'number of epochs', 'accuracy')
The training and validation losses are also stable.
# Plot training-validation loss
train_loss = [i['train_loss'] for i in history]
val_loss = [i['val_loss'] for i in history]
plot_chart('Train-Validation Accuracy', ['train', 'validation'], [train_loss, val_loss], 'number of epochs', 'loss')
We will sample some of the testing data on the model.
# Prediction on sample testing data
images, labels = next(iter(test_dataloader))
predict_and_plot(images, labels)
Finally, let's run the entire testing dataset on the model.
# Prediction on testing data
test_preds = predict_dl(test_dataloader, model)
print(f'Accuracy on test data: {test_preds:.2%}')
Accuracy on test data: 98.46%
Conclusion and closing thoughts
The model accuracy on the test data was surprisingly better at 98% than the training and validation data. This suggests that the model is not overfitting to its training data.
It's possible to further improve the model's accuracy by feeding it with more images via data augmentation. There is also room for improvement by tweaking the hyperparamters as it has not been tested exhaustively.
That concludes our short exploration on retraining an existing image classifier to predict fruit images. I hope you have learned something from this article, thanks for reading!