This repo showcases a various of tradiontial computer vision methods
, unsupervised learning methods
and deep learning models
for coral skeletal image segmentation.
This is an extended work and Pytorch implementation of Coral Density Estimation Project.
The dataset is downloaded from the original repo by Ainsley Rutterford.
Traditional CV/histogram_based.py: Multi-level histogram thresholding based segmentation
Traditional CV/equalised_otsu.py: Binary Otsu's method
Traditional CV/canny.py: Canny's edge detection
Unsupervised Learning/random_walker.py: Random Walker Segmentaion
Unsupervised Learning/K-means-clustering.py: K-means clustering
Unsupervised Learning/Gaussian_Mixture_Model.py: Gaussian Mixture Model
This implementation is mostly based on my Fetal_Segmentation repo.
note: This UNet implementation is rather a vanilla model, there is no BatchNorm, DropOut utilised. If one follow the original paper strictly, there will be a conflict betweent input and output sizes(572 to 388). To avoid label and prediction mismatch in this implementatino, a resize function has been applied after every up-convolution in expansive path and at final output layer.
data/train set: Stores all the unsplited training data
data/test_set: Stores all data for prediction use
[model]: Stores the best weight model generated after training
[runs]: Stores the log file for tensorboard
Image Augmentation.py: A script generates more images using various transformations
main.py: The main script imports dataset, trainer, loss functions to run the model
dataset.py: Customise a dataset to process the trainig images
model.py: Construct the SegNet and UNet model
train.py: The trainer to run epochs
loss_functions.py: Define the dice loss + BCElogits loss function
predict.py: Script to predict unlabeld images
torch == 1.8.0
torchvision
torchsummary
numpy
scipy
skimage
matplotlib
PIL
Image Augmentation is recommonded to be performed in local directory for best performance. Use Image Augmentation.py to generate images and corresponding lables.
main.py workflow
The main.py file is the only file needs to be run and other utils will be import to here
Set Parameters
Set the paramers for path to train/image
, train/label
, test/image
, test/label
and save_model
; also change the h,w
for input image, which model
to use, numbers of epochs
to run, batch_size
, learning rate
and learning rate scheduler dropping rate
for the optimizer
Data Augmentation
Defined a series of data transformation can be called upon dataloading
Dataset Loader
Call the customised dataset.py
to meet pytorch DataLoader
standard
Load Model
Load the pre-built models as a choice from parameters
Load Loss Function
Import the pre-defined loss_function
, the loss function is the sum of BCELoss and Dice loss, the metrics is the Dice Coefficient.
Define Optimizer and Scheduler
An Adam optimiser is used with a learning rate scheduler when the loss plateaus
Load Trainer
load the pre-defined trainer function with parameteres set previously
Plots
Plot the graph for Loss vs Epochs
and Accuracy vs Epochs
Prediction demo 1:
prediction demo 2:
Accuracy plot:
Loss plot: