weechien / fruits-360

A fruits and vegetables image classifier ๐Ÿ‘

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Classifying fruits and vegetables via transfer learning

alt text


Today, we will explore on an existing image classifier - the Resnet18 model, and try to apply transfer learning to classify 131 types of fruits and vegetables.



Letโ€™s cover this article briefly:

  1. We will be using the PyTorch framework to train our model.
  2. We will also be using Kaggle to write our code and train the model using a GPU.
    You will require an account for Kaggle.
  3. The dataset will be from the Fruit-Images-Dataset.
    • Total number of images is 90,483.
    • Total number of training images is 67,692 (one fruit or vegetable per image).
    • Total number of testing images is 22,688 (one fruit or vegetable per image).
  4. You can get all the code and training steps from Jovian.
  5. I will only use function calls here so as to avoid cluttering the page.
    You may refer to the link above for the function definitions.

We will be exploring the dataset to have a better understanding of it.

In the notebook, let's first clone the dataset from GitHub.

!git clone https://github.com/Horea94/Fruit-Images-Dataset

alt text

Import all the required libraries.

import os
import torch
import matplotlib.pyplot as plt

from torch import nn
from torch import optim
from os.path import join
from torchvision import models
from tqdm.notebook import tqdm
from torchvision import transforms
from torch.nn import functional as F
from torchvision.datasets import ImageFolder
from torch.utils.data import Subset, DataLoader, random_split

Let's also check the type of device we are running.
Make sure to enable GPU on Kaggle.

device = get_default_device()
print(device)

alt text

Next, we will import and load the datasets into pytorch.
We will be splitting the existing test dataset into 50% validation and 50% test.

# Dataset folders and loaders

DATA_DIR = join(os.getcwd(), 'Fruit-Images-Dataset')
TRAIN_DIR = join(DATA_DIR, 'Training')
TEST_DIR = join(DATA_DIR, 'Test')

data_transformer = transforms.Compose([transforms.ToTensor()])

test_dataset_split = ImageFolder(TEST_DIR, transform=data_transformer)
n = len(test_dataset_split)
n_validation = int(validation_pct_of_test * n)

train_dataset = ImageFolder(TRAIN_DIR, transform=data_transformer)
validation_dataset = Subset(test_dataset_split, range(n_validation))
test_dataset = Subset(test_dataset_split, range(n_validation, n))

train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True, num_workers=2, pin_memory=True)
validation_dataloader = DataLoader(validation_dataset, batch_size, shuffle=True, num_workers=2, pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size, shuffle=True, num_workers=2, pin_memory=True)

train_dataloader = DeviceDataLoader(train_dataloader, device)
validation_dataloader = DeviceDataLoader(validation_dataloader, device)
test_dataloader = DeviceDataLoader(test_dataloader, device)

The training dataset overview is shown below.

# Information on the dataset

print('Number of training dataset:', len(train_dataset))
print('Number of validation dataset:', len(validation_dataset))
print('Number of testing dataset:', len(test_dataset))
print('Number of classes:',len(train_dataset.classes))
[print(f'{idx}๏ผš {cls}') for idx, cls in enumerate(train_dataset.classes)]

alt text

alt text

Let's take a peek at some of the training images.

# Information on a single image

images, labels = next(iter(train_dataloader))
plot_img(images, labels)
print('Image shape:', images[0].shape)

alt text

Setup the hyperparameters, initialize the model, then start training

# Set hyperparams, initialize model, then start training

num_epochs = 25
batch_size = 32
learning_rate = 1e-4
momentum = .9
opt_func = torch.optim.SGD

model = to_device(FruitModel(), device)

history = fit(num_epochs, learning_rate, model, train_dataloader, validation_dataloader, opt_func)

We are training our resnet-18 model for 25 epochs with a batch size of 32, and feeding it with the fruits and vegetables images. As the epochs progresses, the loss gradually goes down and accuracy goes up. Validation accuracy remains stable, which indicates that the model has yet to overfit over the 25 epochs.

Epoch [0], train_loss: 4.6832, val_loss: 4.5394, train_acc: 25.00%, val_acc: 45.75%
Epoch [1], train_loss: 4.4252, val_loss: 4.3475, train_acc: 83.33%, val_acc: 80.10%
Epoch [2], train_loss: 4.2717, val_loss: 4.2281, train_acc: 83.33%, val_acc: 88.16%
Epoch [3], train_loss: 4.1752, val_loss: 4.1462, train_acc: 100.00%, val_acc: 91.95%
Epoch [4], train_loss: 4.1144, val_loss: 4.0996, train_acc: 100.00%, val_acc: 92.58%
Epoch [5], train_loss: 4.0740, val_loss: 4.0661, train_acc: 83.33%, val_acc: 94.09%
Epoch [6], train_loss: 4.0457, val_loss: 4.0409, train_acc: 100.00%, val_acc: 94.41%
Epoch [7], train_loss: 4.0248, val_loss: 4.0232, train_acc: 100.00%, val_acc: 95.17%
Epoch [8], train_loss: 4.0088, val_loss: 4.0109, train_acc: 91.67%, val_acc: 95.69%
Epoch [9], train_loss: 3.9958, val_loss: 4.0013, train_acc: 91.67%, val_acc: 95.76%
Epoch [10], train_loss: 3.9855, val_loss: 3.9911, train_acc: 100.00%, val_acc: 95.98%
Epoch [11], train_loss: 3.9771, val_loss: 3.9843, train_acc: 91.67%, val_acc: 95.81%
Epoch [12], train_loss: 3.9700, val_loss: 3.9790, train_acc: 91.67%, val_acc: 96.63%
Epoch [13], train_loss: 3.9641, val_loss: 3.9713, train_acc: 100.00%, val_acc: 96.59%
Epoch [14], train_loss: 3.9586, val_loss: 3.9679, train_acc: 100.00%, val_acc: 96.94%
Epoch [15], train_loss: 3.9540, val_loss: 3.9634, train_acc: 100.00%, val_acc: 96.99%
Epoch [16], train_loss: 3.9500, val_loss: 3.9594, train_acc: 100.00%, val_acc: 96.88%
Epoch [17], train_loss: 3.9466, val_loss: 3.9566, train_acc: 100.00%, val_acc: 97.05%
Epoch [18], train_loss: 3.9433, val_loss: 3.9536, train_acc: 100.00%, val_acc: 97.40%
Epoch [19], train_loss: 3.9404, val_loss: 3.9503, train_acc: 100.00%, val_acc: 97.56%
Epoch [20], train_loss: 3.9378, val_loss: 3.9476, train_acc: 91.67%, val_acc: 97.62%
Epoch [21], train_loss: 3.9354, val_loss: 3.9464, train_acc: 100.00%, val_acc: 97.86%
Epoch [22], train_loss: 3.9334, val_loss: 3.9441, train_acc: 100.00%, val_acc: 97.67%
Epoch [23], train_loss: 3.9314, val_loss: 3.9431, train_acc: 100.00%, val_acc: 97.77%
Epoch [24], train_loss: 3.9297, val_loss: 3.9404, train_acc: 91.67%, val_acc: 97.93%

Let's visualize the data above in a graph.
The training and validation accuracies are not far off from each other, with a validation accuracy of around 97%.

# Plot train-validation accuracy

train_acc = [i['train_correct'] / i['train_total'] for i in history]
val_acc = [i['val_correct'] / i['val_total'] for i in history]

plot_chart('Train-Validation Accuracy', ['train', 'validation'], [train_acc, val_acc], 'number of epochs', 'accuracy')

alt text

The training and validation losses are also stable.

# Plot training-validation loss

train_loss = [i['train_loss'] for i in history]
val_loss = [i['val_loss'] for i in history]

plot_chart('Train-Validation Accuracy', ['train', 'validation'], [train_loss, val_loss], 'number of epochs', 'loss')

alt text

We will sample some of the testing data on the model.

# Prediction on sample testing data

images, labels = next(iter(test_dataloader))

predict_and_plot(images, labels)

alt text

Finally, let's run the entire testing dataset on the model.

# Prediction on testing data

test_preds = predict_dl(test_dataloader, model)
print(f'Accuracy on test data: {test_preds:.2%}')

Accuracy on test data: 98.46%

Conclusion and closing thoughts

The model accuracy on the test data was surprisingly better at 98% than the training and validation data. This suggests that the model is not overfitting to its training data.

It's possible to further improve the model's accuracy by feeding it with more images via data augmentation. There is also room for improvement by tweaking the hyperparamters as it has not been tested exhaustively.

That concludes our short exploration on retraining an existing image classifier to predict fruit images. I hope you have learned something from this article, thanks for reading!

About

A fruits and vegetables image classifier ๐Ÿ‘