krishnabits001 / pseudo_label_contrastive_training

Code for Pseudo label based contrastive learning joint training approach

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

pseudo_label_contrastive_training

Code for Pseudo label based contrastive learning joint training approach

Local contrastive loss with pseudo-label based self-training for semi-supervised medical image segmentation

The code is for the article "Local contrastive loss with pseudo-label based self-training for semi-supervised medical image segmentation" under review. With the proposed joint-training method using Contrastive loss, we get competitive segmentation performance with just 2 labeled training volumes compared to upperbound and compared methods.
https://arxiv.org/abs/2112.09645

Authors:
Krishna Chaitanya (email),
Ertunc Erdil,
Neerav Karani,
Ender Konukoglu.

Requirements:
Python 3.6.1,
Tensorflow 1.12.0,
rest of the requirements are mentioned in the "requirements.txt" file.

I) To clone the git repository.
git clone https://github.com/krishnabits001/pseudo_label_contrastive_training.git

II) Install python, required packages and tensorflow.
Then, install python packages required using below command or the packages mentioned in the file.
pip install -r requirements.txt

To install tensorflow
pip install tensorflow-gpu=1.12.0

III) Dataset download.
To download the ACDC Cardiac dataset, check the website :
https://www.creatis.insa-lyon.fr/Challenge/acdc.

To download the Medical Decathlon Prostate dataset, check the website :
http://medicaldecathlon.com/

To download the MMWHS Cardiac dataset, check the website :
http://www.sdspeople.fudan.edu.cn/zhuangxiahai/0/mmwhs/

All the images were bias corrected using N4 algorithm with a threshold value of 0.001. For more details, refer to the "N4_bias_correction.py" file in scripts.
Image and label pairs are re-sampled (to chosen target resolution) and cropped/zero-padded to a fixed size using "create_cropped_imgs.py" file.

IV) Train the model.
To do joint training run the script "pseudo_lbl_rand_init.sh" in train_model directory.
For instance, if we want to train for ACDC dataset with 2 training volumes and configuration c1 use below step.
bash pseudo_lbl_rand_init.sh tr2 c1 acdc


Above command, executes the below 2 steps of training:
Steps :
1) In Step 1: Train a baseline network model to infer the initial pseudo-labels for unlabeled data. This training is only done once at the start.
cd train_model/
python tr_baseline.py --no_of_tr_imgs=tr2 --comb_tr_imgs=c1 --dataset=acdc
  1. In Step 2: Post Step 1, we infer pseudo-labels of unlabeled data and perform the joint training based on contrastive loss and segmentation loss. This training is done iteratively, where the pseudo-labels are refined periodicallt.
    python prop_method_joint_tr_rand_init.py --no_of_tr_imgs=tr2 --comb_tr_imgs=c1 --dataset=acdc

V) Config files contents.
One can modify the contents of the below 2 config files to run the required experiments.
experiment_init directory contains 2 files.
Example for ACDC dataset:

  1. init_acdc.py
    --> contains the config details like target resolution, image dimensions, data path where the dataset is stored and path to save the trained models.
  2. data_cfg_acdc.py
    --> contains an example of data config details where one can set the patient ids which they want to use as train, validation and test images.

Bibtex citation:

@article{chaitanya2021local, title={Local contrastive loss with pseudo-label based self-training for semi-supervised medical image segmentation}, author={Chaitanya, Krishna and Erdil, Ertunc and Karani, Neerav and Konukoglu, Ender}, journal={arXiv preprint arXiv:2112.09645}, year={2021} }

About

Code for Pseudo label based contrastive learning joint training approach

License:GNU General Public License v3.0


Languages

Language:Python 99.8%Language:Shell 0.2%