ZDstandup / self-supervised-3d-tasks

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

3D Self-Supervised Methods for Medical Imaging

Keras implementation of multiple self-supervised methods for 3D and 2D applications. This repository implements all the methods in this paper: 3D Self-Supervised Methods for Medical Imaging

If you find this repository useful, please consider citing our paper in your work:

@misc{taleb20203d,
    title={3D Self-Supervised Methods for Medical Imaging},
    author={Aiham Taleb and Winfried Loetzsch and Noel Danz and Julius Severin and Thomas Gaertner and Benjamin Bergner and Christoph Lippert},
    year={2020},
    eprint={2006.03829},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}

Overview

This codebase contains a implementation of five self-supervised representation learning techniques, utility code for running training and evaluation loops.

Usage instructions

In this codebase we provide configurations for training/evaluation of our models.

For debugging or running small experiments we support training and evaluation using a single GPU device.

Preparing data

Our implementations of the algorithms require the data to be squared for 2D or cubic for 3D images.

Clone the repository and install dependencies

Make sure you have anaconda installed.

Then perform the following commands, while you are in your desired workspace directory:

git clone https://gitlab.com/statistical-genomics/self-supervised-3d-tasks.git
cd self-supervised-3d-tasks
conda env create -f env_all_platforms.yml
conda activate conda-env
pip install -e .

Running the experiments

To train any of the self-supervised tasks with a specific algorithm, run python train.py configs/train/{algorithm}_{dimension}.json To run the downstream task and initialize the weights from a pretrained checkpoint, run python finetune.py configs/finetune/{algorithm}_{dimension}.json

Setting the configs

In the two example configs below, the respective parameters for training and testing configs are explained.

Training:

{
  "algorithm": "String. ('cpc'|'rotation'|'rpl'|'jigsaw'|'exemplar')",
  "batch_size": "Integer. Batch size.",
  "lr": "Float. Learning rate.",
  "epochs": "Integer. Amount of epochs as integer.",

  "encoder_architecture": "String. Name of the encoder architecture. ('DenseNet121'|'InceptionV3'|'ResNet50'|'ResNet50V2'|'ResNet101'|'ResNet101V2'|'ResNet152'|'InceptionResNetV2')",
  "top_architecture": "String. Name of the top level architecture. ('big_fully'|'simple_multiclass'|'unet_3d_upconv'|'unet_3d_upconv_patches') ",
    
  "dataset_name": "String. Name of the dataset, only used for labeling the log data.",
  "data_is_3D": "Boolean. Is the dataset 3D?.",
  "data_dir": "String. Path to of the data directory.",
  "data_dim": "Integer. Dimension of image.",
  "number_channels": "Integer. The number of channels of the image.",

  "patch_jitter": "Integer. CPC, RPL, Jigsaw specific. Amount of pixels the jitter every patch should have.",
  "patches_per_side": "Integer. CPC, RPL specific. Amount of patches per dimension. 2 patches per side result in 8 patches for a 2D and 16 patches for a 3D image.",
  "crop_size": "Integer. CPC specific. For CPC the whole image can be randomly cropped to a smaller size to make the self-supervised task harder",
  "code_size": "Integer. CPC, Exemplar specific. Specify the dimension of the latent space",
  
  "train_data_generator_args": {
    "suffix":  "String. ('.png'|'.jpeg')",
    "multilabel": "Boolean. Shall data be transformed to multilabel representation. (0 => [0, 0], 1 => [1, 0], 2 => [1, 1]",
    "augment": "Boolean. Include additional augmentations during loading the data. 2D augmentations: zooming, rotating. 3D augmentations: flipping, color distortion, rotation",
    "augment_zoom_only": "Boolean. 2D specific augmentations without rotating the image.",
    "shuffle": "Boolean. Shuffle the data after each epoch."
  },
  "val_data_generator_args": {
    "suffix":  "String. ('.png'|'.jpeg')",
    "multilabel": "Boolean. Shall data be transformed to multilabel representation. (0 => [0, 0], 1 => [1, 0], 2 => [1, 1]",
    "augment": "Boolean. Include additional augmentations during loading the data. 2D augmentations: zooming, rotating. 3D augmentations: flipping, color distortion, rotation",
    "augment_zoom_only": "Boolean. 2D specific augmentations without rotating the image.",
    "shuffle": "Boolean. Shuffle the data after each epoch."
  },

  "save_checkpoint_every_n_epochs": "Integer. Backup epoch even without improvements every n epochs.",
  "val_split": "Float between 0 and 1. Percentage of images used for test, None for no validation set.",
  "pooling": "String. (None|'avg'|'max')",
  "enc_filters": "Integer. Amount of filters used for the encoder model"
}

Finetuning:

{
  "algorithm": "String. ('cpc'|'rotation'|'rpl'|'jigsaw'|'exemplar')",
  "lr": "Float. Learning rate.",
  "batch_size": "Integer. Batch size.",
  "val_split": "Float between 0 and 1. Percentage of images used for test. None for no validation set.",
  "epochs_warmup": "Integer. Amount of epochs used for warmup with frozen weights. ",
  "epochs": "Integer. Amount of epochs.",
  "repetitions": "Integer. Repetitions of the experiment.",
  "exp_splits": "Array<Integer>. Percentages of training data that should be used for the experiments. ([100,10,1,50,25])",


  "encoder_architecture": "String. Name of the encoder architecture. ('DenseNet121'|'InceptionV3'|'ResNet50'|'ResNet50V2'|'ResNet101'|'ResNet101V2'|'ResNet152'|'InceptionResNetV2')",
  "top_architecture": "String. Name of the top level architecture. ('big_fully'|'simple_multiclass'|'unet_3d_upconv'|'unet_3d_upconv_patches')",
  "prediction_architecture": "String. ('big_fully'|'simple_multiclass'|'unet_3d_upconv')",
  "pooling": "String. (None|'avg'|'max')",


  "dataset_name": "String. Name of the dataset, only used for labeling the log data.",
  "data_is_3D": "Boolean. Is the dataset 3D?.",
  "data_dim": "Integer. Dimension of image.",
  "number_channels": "Integer. The number of channels of the image.",
  "data_dir": "String. Path to the data directory the model was trained on.",
  "data_dir_train": "String. Path to the data directory containing the finetuning train data.",
  "data_dir_test": "String. Path to the data directory containing the finetuning test data.",
  "csv_file_train": "String. Path to the csv file containing the finetuning train data.",
  "csv_file_test": "String. Path to the csv file containing the finetuning test data.",
  "train_data_generator_args": {
    "suffix":  "String. ('.png'|'.jpeg')",
    "multilabel": "Boolean. Shall data be transformed to multilabel representation. (0 => [0, 0], 1 => [1, 0], 2 => [1, 1]",
    "augment": "Boolean. nclude additional augmentations during loading the data. 2D augmentations: zooming, rotating. 3D augmentations: flipping, color distortion, rotation.",
    "augment_zoom_only": "Boolean. 2D specific augmentations without rotating the image.",
    "shuffle": "Boolean. Shuffle the data after each epoch."
  },
  "val_data_generator_args": {
    "suffix":  "String. ('.png'|'.jpeg')",
    "multilabel": "Boolean. Shall data be transformed to multilabel representation. (0 => [0, 0], 1 => [1, 0], 2 => [1, 1]",
    "augment": "Boolean. Include additional augmentations during loading the data. 2D augmentations: zooming, rotating. 3D augmentations: flipping, color distortion, rotation",
    "augment_zoom_only": "Boolean. 2D specific augmentations without rotating the image.",
    "shuffle": "Boolean. Shuffle the data after each epoch."
  },
  "test_data_generator_args": {
    "suffix":  "String. ('.png'|'.jpeg')",
    "multilabel": "Boolean. Shall data be transformed to multilabel representation. (0 => [0, 0], 1 => [1, 0], 2 => [1, 1]",
    "augment": "Boolean. Include additional augmentations during loading the data. 2D augmentations: zooming, rotating. 3D augmentations: flipping, color distortion, rotation",
    "augment_zoom_only": "Boolean. 2D specific augmentations without rotating the image.",
    "shuffle": "Boolean. Shuffle the data after each epoch."
  },

  "metrics": "Array<String>. Metrics to be used. ('accuracy'|'mse')",
  "loss": "String. Loss to be used. ('binary_crossentropy'|'weighted_dice_loss'|'weighted_sum_loss'|'weighted_categorical_crossentropy'|'jaccard_distance')",
  "scores": "Array<String>. Scores to be used. ('qw_kappa'|'qw_kappa_kaggle'|'cat_accuracy'|'cat_acc_kaggle'|'dice'|'jaccard')",
  "clipnorm": "Float. Gradients will be clipped when their L2 norm exceeds this value.",
  "clipvalue": "Float. Gradients will be clipped when their absolute value exceeds this value.",

  "embed_dim": "Integer. Size of the embedding vector of the model.",

  "load_weights": "Boolean. Shall weights be loaded from model checkpoint.",
  "model_checkpoint":"String. Path to model checkpoint.",

  "patches_per_side": "Integer. CPC, RPL specific. Amount of patches per dimension. 2 patches per side result in 8 patches for a 2D and 16 patches for a 3D image.",
  "enc_filters": "Integer. Amount of filters used for the encoder model"
}

About

License:Apache License 2.0


Languages

Language:Python 100.0%