park-sj / pytorch-3dunet

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

DOI Build Status

pytorch-3dunet

PyTorch implementation 3D U-Net and its variants:

The code allows for training the U-Net for both: semantic segmentation (binary and multi-class) and regression problems (e.g. de-noising, learning deconvolutions).

2D U-Net

Training the standard 2D U-Net is also possible, see train_config_2d for example configuration. Just make sure to keep the singleton z-dimension in your H5 dataset (i.e. (1, Y, X) instead of (Y, X)) , cause data loading / data augmentation requires tensors of rank 3 always.

Prerequisites

  • Linux
  • NVIDIA GPU
  • CUDA CuDNN

Running on Windows

The package has not been tested on Windows, however some reported using it on Windows. One thing to keep in mind: when training with CrossEntropyLoss: the label type in the config file should be change from long to int64, otherwise there will be an error: RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #2 'target'.

Supported Loss Functions

Semantic Segmentation

  • BCEWithLogitsLoss (binary cross-entropy)
  • DiceLoss (standard DiceLoss defined as 1 - DiceCoefficient used for binary semantic segmentation; when more than 2 classes are present in the ground truth, it computes the DiceLoss per channel and averages the values).
  • BCEDiceLoss (Linear combination of BCE and Dice losses, i.e. alpha * BCE + beta * Dice, alpha, beta can be specified in the loss section of the config)
  • CrossEntropyLoss (one can specify class weights via weight: [w_1, ..., w_k] in the loss section of the config)
  • PixelWiseCrossEntropyLoss (one can specify not only class weights but also per pixel weights in order to give more gradient to important (or under-represented) regions in the ground truth)
  • WeightedCrossEntropyLoss (see 'Weighted cross-entropy (WCE)' in the below paper for a detailed explanation; one can specify class weights via weight: [w_1, ..., w_k] in the loss section of the config)
  • GeneralizedDiceLoss (see 'Generalized Dice Loss (GDL)' in the below paper for a detailed explanation; one can specify class weights via weight: [w_1, ..., w_k] in the loss section of the config). Note: use this loss function only if the labels in the training dataset are very imbalanced e.g. one class having at least 3 orders of magnitude more voxels than the others. Otherwise use standard DiceLoss.

For a detailed explanation of some of the supported loss functions see: Generalised Dice overlap as a deep learning loss function for highly unbalanced segmentations Carole H. Sudre, Wenqi Li, Tom Vercauteren, Sebastien Ourselin, M. Jorge Cardoso

Regression

  • MSELoss
  • L1Loss
  • SmoothL1Loss
  • WeightedSmoothL1Loss - extension of the SmoothL1Loss which allows to weight the voxel values above (below) a given threshold differently

Supported Evaluation Metrics

Semantic Segmentation

  • MeanIoU - Mean intersection over union
  • DiceCoefficient - Dice Coefficient (computes per channel Dice Coefficient and returns the average) If a 3D U-Net was trained to predict cell boundaries, one can use the following semantic instance segmentation metrics (the metrics below are computed by running connected components on thresholded boundary map and comparing the resulted instances to the ground truth instance segmentation):
  • BoundaryAveragePrecision - Average Precision applied to the boundary probability maps: thresholds the boundary maps given by the network, runs connected components to get the segmentation and computes AP between the resulting segmentation and the ground truth
  • AdaptedRandError - Adapted Rand Error (see http://brainiac2.mit.edu/SNEMI3D/evaluation for a detailed explanation)

If not specified MeanIoU will be used by default.

Regression

  • PSNR - peak signal to noise ratio

Getting Started

Installation

  • The easiest way to install pytorch-3dunet package is via conda:
conda create -n 3dunet -c conda-forge -c awolny python=3.7 pytorch-3dunet
conda activate 3dunet

After installation the following commands are accessible within the conda environment: train3dunet for training the network and predict3dunet for prediction (see below).

  • One can also install directly from source:
python setup.py install

Installation tips

Make sure that the installed pytorch is compatible with your CUDA version, otherwise the training/prediction will fail to run on GPU. You can re-install pytorch compatible with your CUDA in the 3dunet env by:

conda install -c pytorch torchvision cudatoolkit=<YOU_CUDA_VERSION> pytorch

Train

Given that pytorch-3dunet package was installed via conda as described above, one can train the network by simply invoking:

train3dunet --config <CONFIG>

where CONFIG is the path to a YAML configuration file, which specifies all aspects of the training procedure.

See e.g. train_config_ce.yaml which describes how to train a standard 3D U-Net on a randomly generated 3D volume and random segmentation mask (random_label3D.h5) with cross-entropy loss (just a demo).

In order to train on your own data just provide the paths to your HDF5 training and validation datasets in the train_config_ce.yaml. The HDF5 files should contain the raw/label data sets in the following axis order: DHW (in case of 3D) CDHW (in case of 4D).

One can monitor the training progress with Tensorboard tensorboard --logdir <checkpoint_dir>/logs/ (you need tensorflow installed in your conda env), where checkpoint_dir is the path to the checkpoint directory specified in the config.

To try out training on randomly generated data right away, just checkout the repository and run:

cd pytorch3dunet
train3dunet --config ../resources/train_config_ce.yaml # train with CrossEntropyLoss (segmentation)
#train3dunet --config ../resources/train_config_dice.yaml # train with DiceLoss (segmentation)
#train3dunet --config ../resources/train_config_regression.yaml # train with SmoothL1Loss (regression)

To try out a boundary prediction task given a sample 3D confocal volume of plant cells (cell membrane marker), run:

cd pytorch3dunet
train3dunet --config ../resources/train_boundary.yaml

Training tips

When training with binary-based losses, i.e.: BCEWithLogitsLoss, DiceLoss, BCEDiceLoss, GeneralizedDiceLoss:

  1. the label data has to be 4D (one target binary mask per channel). If you have a 3D binary data (foreground/background), you can just change ToTensor transform for the label to contain expand_dims: true, see e.g. train_config_dice.yaml.
  2. final_sigmoid=True has to be present in the model section of the config, since every output channel gives the probability of the foreground. When training with cross entropy based losses (WeightedCrossEntropyLoss, CrossEntropyLoss, PixelWiseCrossEntropyLoss) set final_sigmoid=False so that Softmax normalization is applied to the output.

Prediction

Given that pytorch-3dunet package was installed via conda as described above, one can run the prediction via:

predict3dunet --config <CONFIG>

To run the prediction on randomly generated 3D volume (just for demonstration purposes) from random_label3D.h5 and a network trained with cross-entropy loss:

cd pytorch3dunet
predict3dunet --config ../resources/test_config_ce.yaml

or if trained with DiceLoss:

cd pytorch3dunet
predict3dunet --config ../resources/test_config_dice.yaml

Predicted volume will be saved to resources/random_label3D_probabilities.h5.

In order to predict on your own data, just provide the path to your model as well as paths to HDF5 test files (seetest_config_ce.yaml).

Prediction tips

In order to avoid checkerboard artifacts in the output prediction masks the patch predictions are averaged, so make sure that patch/stride params lead to overlapping blocks, e.g. patch: [64 128 128] stride: [32 96 96] will give you a 'halo' of 32 voxels in each direction.

Data Parallelism

By default, if multiple GPUs are available training/prediction will be run on all the GPUs using DataParallel. If training/prediction on all available GPUs is not desirable, restrict the number of GPUs using CUDA_VISIBLE_DEVICES, e.g.

CUDA_VISIBLE_DEVICES=0,1 train3dunet --config <CONFIG>

or

CUDA_VISIBLE_DEVICES=0,1 predict3dunet --config <CONFIG>

Sample configuration files

Semantic segmentation

Regression

2D (semantic segmentation)

Contribute

If you want to contribute back, please make a pull request.

Cite

If you use this code for your research, please cite as:

@article {Wolny2020.01.17.910562,
	author = {Wolny, Adrian and Cerrone, Lorenzo and Vijayan, Athul and Tofanelli, Rachele and Barro,
              Amaya Vilches and Louveaux, Marion and Wenzl, Christian and Steigleder, Susanne and Pape, 
              Constantin and Bailoni, Alberto and Duran-Nebreda, Salva and Bassel, George and Lohmann,
              Jan U. and Hamprecht, Fred A. and Schneitz, Kay and Maizel, Alexis and Kreshuk, Anna},
	title = {Accurate And Versatile 3D Segmentation Of Plant Tissues At Cellular Resolution},
	elocation-id = {2020.01.17.910562},
	year = {2020},
	doi = {10.1101/2020.01.17.910562},
	publisher = {Cold Spring Harbor Laboratory},
	URL = {https://www.biorxiv.org/content/early/2020/01/18/2020.01.17.910562}, 
	eprint = {https://www.biorxiv.org/content/early/2020/01/18/2020.01.17.910562.full.pdf},
	journal = {bioRxiv}
}

추가 설명 (손준원 작성)

model이나 predictor, dataset등을 새로 만들어 추가할 때, get 함수의 탐색 범위에 새로 추가한 python 파일을 명시해주면 기존에 있는 것들과 같은 방식으로 사용할 수 있다. 예를 들어, model.py의 get_model 함수의 modules을 참고하라.

def get_model(config):
    def _model_class(class_name):
        modules = ['pytorch3dunet.unet3d.model', 'pytorch3dunet.unet3d.rev_model', 'pytorch3dunet.unet3d.unetr_model', 'pytorch3dunet.unet3d.revunetr_model']
        for module in modules:
            m = importlib.import_module(module)
            clazz = getattr(m, class_name, None)
            if clazz:
                return clazz
        raise RuntimeError(f'Unsupported model class: {class_name}')

    assert 'model' in config, 'Could not find model configuration'
    model_config = config['model']
    model_class = _model_class(model_config['name'])
    return model_class(**model_config)

wolny가 작성한 original code와 달리 수정한 코드에서는 patch 단위가 아니라 전체 이미지를 통째로 forward 하는 것을 default로 한다. 이를 위해서는, DicomDataset, StandardPredictor를 사용하면 된다. 만약에 patch 단위로 쪼개 forward 하고 싶다면 훈련 과정에서는 dicom2npz.py를 이용해 미리 훈련데이터를 가공해서 NpzDataset을 이용해 학습하고, prediction 단계에서는 PatchwisePredictor과 PatchwiseDcmDataset을 사용하면 된다. 뼈, 비강은 현재 일부만 crop해서 forward 하도록 되어있는데 훈련 시에는 NpzDataset을 이용하고, prediction에는 ABDataset과 ABPredictor를 이용하면 된다. ABPredictor는 실행하면 x, y, z 범위를 입력받고 그에 맞는 부분만 잘라서 forward한다.

NpzDatasets은 각 파일이 dictionary로 'ct' key에 ct 영상이, 'mask'에 mask array가 저장되어 있어야 한다.

About

License:MIT License