AI-Unipi / Image3DGenerator

A python class compatible with TensorFlow to perform data augmentation on 3D objects during CNN training.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

alt text

Table of Contents πŸ”Ž πŸ“–

Introduction πŸ“―

This repository contains a class to permorm data augmentation on 3D objects (e.g. 3D medical images). It is a 3D (..well 4D with the number of channels included πŸ˜…) version of the 2D "tf.keras.preprocessing.image.ImageDataGenerator". We also provide two examples as a simple guideline.

Data augmentation is a regularization technique that has been found extremely usefull when training CNNs. It is a techinique that prevents the model of seeing the original training and validation data during training, and instead applies some transofrmations on the original training data (or batches) and lets the model see those instead. Data augmentation is a mean to reduce overfitting and make a more robust model.

Team Members πŸ‘₯

Inspiration πŸ’‘

While doing my thesis this summer (Stefanos), I realized that the tensorflow resources on 3D image model training are limited (almost none pretrained models, limited regularization techniques etc). 3D data are sometimes hard to find, especially medical and they are often not many in number. Hence, we really believe that data augmentation can have a huge impact on the overfitting prevention.

This repository is a quarantine project (yes, we were bored 😌) created to help the few other crazies working on similar projects.

Usage πŸ“‹

The Image3DGenerator class despite its name is actually a "tf.keras.utils.Sequence" object, or in other words a base object for fitting to a sequence of data, such as a dataset. Sequence are a safer way to do multiprocessing as this structure guarantees that the network will only train once on each sample per epoch which is not the case with generators.

This class applies random transformations to the original training and validation data which change during each epoch.

The options we provide (yet) are the following:

  • Generation of batches without any transformation
  • Rotation: Randomly rotates the whole object to a range of angles drawn from a normal distribution with 0 mean and variance specified by the user.
  • Gaussian noise: Adds random noise to the 3D objects drawn from a normal distribution with 0 mean and variance specified by the user.
  • Normalization: Applies a min max scaler transofrmation to the objects which bounds the voxel values between 0 and 1.

Data/File Formats πŸ“

In order to use this class your data and folders should be structured as follows:

----data-folder/data.npy
--Image3DGenerator.py
--your_python_script

Notes:
⚑ The data folder should contain each 3D object seperately, each in a numpy array form (.npy)
⚑ Each 3D object should have the following dimension order: (object_length, object_height, object_width, number_of_channels (if grayscaled can be skipped)).
To use the class you need to do the following steps:

  • Create a dictionary containing the ID of the training (and validation examples if applicable).
  • Create a dictionary containing all the training (and validation) IDs along with their classes. The classes should be integers starting with 0.

After having all the prerequirements ready you simply type the following:

from Image3DGenerator import DataGenerator

params = { 
          'dim': your object's dimensions,
          'batch_size': opted batch size,
          'n_classes': number of your classes,
          'n_channels': 1 if grayscale, 3 if RGB,
          'rotation': True in case you want to apply random roation during training,
          'normalisation': True,
          'min_bound': in case normalisation is True, specify the minimum voxel value of your objects,
          'max_bound': in case normalisation is True, specify the maximum voxel value of your objects,
          'gaussian_noise': True,
          'noise_mean': 0,
          'noise_std': 0.01,
          'shuffle': True,
          'rotate_std':45,
          'path':'./data' #path of the folder containing the data,
          'display_ID':False}

# Generators
training_generator = DataGenerator(dictionary['train'], labels, **params)
validation_generator = DataGenerator(dictionary['validation'], labels, **params)

#After creating and compliling your tf model

model.fit(x = training_generator,
          epochs= no_epochs, 
          validation_data= validation_generator)

Examples πŸ‘€

Two examples with codes and outputs are available at the examples folder. Below you will find visual examples with the intention to help the user understand how the class treats the data during training.

Visual Examples

1) Grayscale CT scan
Transformations applied: Rotation, Noise, Normalisation

GIF

Original Transformed

3D

Original --> Transformed

2) RGB gif
Transformations applied: Rotation, Noise

Original Transformed

3) RGB gif
Transformations applied: Rotation

Original Transformed

4) RGB gif
Transformations applied: Noise

Original Transformed

Contribution 🀘

This is a quarantine project developed by two recent data science graduates so there is undeniably room for improvement. Pull requests are more than welcome. We would be glad to hear your feedback and have a chat. For major changes, please open an issue first to discuss what you would like to change.

License

MIT

About

A python class compatible with TensorFlow to perform data augmentation on 3D objects during CNN training.

License:MIT License


Languages

Language:Python 100.0%