This repository aims to implement a mushroom type classifier using PyTorch, utilizing various models to enhance performance. Additionally, the project includes an analysis of the model's performance using Gradient-Class Activation Map (Grad-CAM) visualization.
The dataset was obtained from Kaggle, specifically from the "LOVE OF A LIFETIME" collection. It consists of nine classes of mushrooms, which were downloaded from Kaggle and then split into train (65%), validation (20%), and test (15%) sets. The split was done equally among the classes.
Download and split.ipynb
- Resize to 299
- Normalize channals as pytorch recommendation
- mean = [0.485, 0.456, 0.406]
- std = [0.229, 0.224, 0.225]
- RandomHorizontalFlip
- RandomVerticalFlip
- RandomRotation with maximum degrees 15
The data was loaded and augmented Dataset_Generator.py
- ResNet50
- Convnext
- Linear (out_features, out_features //2)
- BatchNorm(out_features // 2)
- Relu()
- Dropout(0.3)
- Linear (out_features // 2, out_features //4)
- BatchNorm(out_features // 4)
- Relu()
- Dropout(0.2)
- Linear (out_features // 4, number of classes)
Grad-CAM is a visualization technique that allows us to understand what the network focuses on when making decisions based on an image. It combines the concepts of a saliency map and a class activation map. Grad-CAM works by computing the gradients of the output of the network to determine which parts of the image contribute the most to the network assigning the highest probability to a specific class.
By utilizing Grad-CAM, we can generate informative heatmaps that highlight the regions in the input image that are most influential in the network's decision-making process. These heatmaps help us interpret and analyze the model's behavior by visualizing the areas that the network pays the most attention to when classifying mushroom types.
- https://arxiv.org/pdf/1610.02391.pdf
- https://medium.com/@ninads79shukla/gradcam-73a752d368be
- https://towardsdatascience.com/understand-your-algorithm-with-grad-cam-d3b62fce353
The Grad-CAM was implemented in Grad_cam.py
Grad_cam_utils.py
This file contains functions to generate heatmaps using Grad-CAM and plot them.
- Accuracy
- Recall
- Precision
- F1-score
The Metrics was implemented in Metrics.py
- show_batch : show random images from each class
- show_aug_batch : show 9 random images after transformations
- plot_results : plot the results of the same metric for both the training and validation datasets
- Learning rate : 1e-3 with cosine annealing scheduler
- Optimizer : Adam
- Epochs : 100
- Loss : Cross entropy
- Freeze the weights of the backbone
After examining the ResNet notebook, it appears that the model is unable to effectively handle this particular dataset.
Data | Loss | Accuracy |
---|---|---|
Train | 0.207 | 93.3% |
Val | 0.724 | 77.5% |
Test | 0.656 | 77.8% |
Upon analyzing the dataset, it becomes apparent that mushrooms of the same class exhibit diverse shapes and colors. This variation poses a challenging task for humans to accurately classify the different types of mushrooms.
To overcome this complexity, a larger and more powerful model will be utilized. By employing a larger model, we aim to capture a broader range of features and patterns present in the mushroom images. This increased capacity will enhance the model's ability to differentiate between various types of mushrooms
According to the results, the ConvNetX model exhibits higher performance compared to the ResNet model, despite a lower number of training samples and some confusion within the dataset. The ConvNetX model demonstrates good performance even in challenging conditions
Data | Loss | Accuracy |
---|---|---|
Train | 0.16 | 99.8% |
Val | 0.419 | 91.4% |
Test | 0.403 | 99.2% |
The model was trained for 80 epochs using normal cross-entropy loss. Following that, an additional 20 epochs were trained using class-weighted cross-entropy loss and label smoothing with a factor of 0.1.
Loss plot in the first 80 epochs
Loss plot in the last 20 epochs
Accuracy plot
Report
confusion matrix
For results of train and val set go to model_analysis.ipynb
In the heatmaps generated by the network, we can observe instances where the network correctly classifies the mushroom type, and the heatmap aligns with the expected regions of importance. However, there are also cases where the network predicts the correct class, but the heatmap may appear misleading or less aligned with the expected regions.
This discrepancy between the heatmap and the expected regions can occur due to various factors. One possible reason is that the network might be relying on features or patterns that are not visually apparent or easily interpretable to humans. Neural networks are capable of learning complex representations and can identify distinguishing characteristics that may not be immediately obvious to us.
The above image presents an example of a confusion heatmap. In this case, the heatmap indicates that the highest importance region for classification is the fallen tree leaves, which are separate from the mushrooms. It is possible that a majority of images in this particular class contain these leaves, leading the model to associate this class with the presence of leaves.
The following are examples of logical heatmaps that align with the expected regions of importance for the corresponding mushroom