SimonZeng7108 / Coral_Analysis

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Coral Analysis via various methods

This repo showcases a various of tradiontial computer vision methods, unsupervised learning methods and deep learning models for coral skeletal image segmentation.
This is an extended work and Pytorch implementation of Coral Density Estimation Project.
The dataset is downloaded from the original repo by Ainsley Rutterford.

Traditional Computer Vision methods:

Traditional CV/histogram_based.py: Multi-level histogram thresholding based segmentation

Traditional CV/equalised_otsu.py: Binary Otsu's method

Traditional CV/canny.py: Canny's edge detection

Unsupervised learning methods:

Unsupervised Learning/random_walker.py: Random Walker Segmentaion

Unsupervised Learning/K-means-clustering.py: K-means clustering

Unsupervised Learning/Gaussian_Mixture_Model.py: Gaussian Mixture Model

Deep Learning Models:

This implementation is mostly based on my Fetal_Segmentation repo.

note: This UNet implementation is rather a vanilla model, there is no BatchNorm, DropOut utilised. If one follow the original paper strictly, there will be a conflict betweent input and output sizes(572 to 388). To avoid label and prediction mismatch in this implementatino, a resize function has been applied after every up-convolution in expansive path and at final output layer.

Repository overview

data/train set: Stores all the unsplited training data
data/test_set: Stores all data for prediction use
[model]: Stores the best weight model generated after training
[runs]: Stores the log file for tensorboard
Image Augmentation.py: A script generates more images using various transformations main.py: The main script imports dataset, trainer, loss functions to run the model
dataset.py: Customise a dataset to process the trainig images
model.py: Construct the SegNet and UNet model
train.py: The trainer to run epochs
loss_functions.py: Define the dice loss + BCElogits loss function
predict.py: Script to predict unlabeld images

Requirements

  • torch == 1.8.0
  • torchvision
  • torchsummary
  • numpy
  • scipy
  • skimage
  • matplotlib
  • PIL

Image Augmentation

Image Augmentation is recommonded to be performed in local directory for best performance. Use Image Augmentation.py to generate images and corresponding lables.

main.py workflow

The main.py file is the only file needs to be run and other utils will be import to here

Set Parameters

Set the paramers for path to train/image, train/label, test/image, test/label and save_model; also change the h,w for input image, which model to use, numbers of epochs to run, batch_size, learning rate and learning rate scheduler dropping rate for the optimizer

Data Augmentation

Defined a series of data transformation can be called upon dataloading

Dataset Loader

Call the customised dataset.py to meet pytorch DataLoader standard

Load Model

Load the pre-built models as a choice from parameters

Load Loss Function

Import the pre-defined loss_function, the loss function is the sum of BCELoss and Dice loss, the metrics is the Dice Coefficient.

Define Optimizer and Scheduler

An Adam optimiser is used with a learning rate scheduler when the loss plateaus

Load Trainer

load the pre-defined trainer function with parameteres set previously

Plots

Plot the graph for Loss vs Epochs and Accuracy vs Epochs

Results

Prediction demo 1:

prediction demo 2:

Accuracy plot:

Loss plot:

About


Languages

Language:Python 100.0%