Official PyTorch implementation of Semi-Supervised 3D Tooth Segmentation Using nn-UNet with Axial Attention and Positional Correction (MICCAI 2023 STS).
Install nnU-Net as below. You should meet the requirements of nnUNet, our method does not need any additional requirements. For more details, please refer to https://github.com/MIC-DKFZ/nnUNet/tree/v1.7.1
git clone https://github.com/MIC-DKFZ/nnUNet.git
cd nnUNet
pip install -e .
Following nnUNet, give a TaskID (e.g. Task001) to the labeled data and organize them folowing the requirement of nnUNet.
nnUNet_raw_data_base/nnUNet_raw_data/Task01_Tooth/
├── dataset.json
├── imagesTr
├── imagesTs
└── labelsTr
Here we do use the default setting.
nnUNet_plan_and_preprocess -t 1 --verify_dataset_integrity
for FOLD in 0 1 2 3 4
do
CUDA_VISIBLE_DEVICES=0,1 nnUNet_train_DP 3d_fullres nnUNetTrainerV2_DP 1 $FOLD -gpus 2 -c --npz
done
nnUNet_predict -i $INPUTS_FOLDER -o $OUTPUTS_FOLDER -t 2 -m 3d_fullres --save_npz
- Give a new TaskID (e.g. Task002) and organize the Labeled Data and Pseudo Labeled Data as above.
- Conduct automatic preprocessing using nnUNet as above.
nnUNet_plan_and_preprocess -t 2 --verify_dataset_integrity
- Training new nnUNet by all training data
for FOLD in 0 1 2 3 4 do CUDA_VISIBLE_DEVICES=0,1 nnUNet_train_DP 3d_fullres nnUNetTrainerV2_DP 2 $FOLD -gpus 2 -c --npz done
- Generate new pseudo labels for unlabeled data.
We compare Pseudo Labels in different rounds and filter out the labels with high variants.
select_pseudo_label.ipynb
./nnunet/network_architecture/generic_UNet.py
for FOLD in 0 1 2 3 4
do
CUDA_VISIBLE_DEVICES=0,1 nnUNet_train_DP 3d_fullres nnUNetTrainerV2_DP 3 $FOLD -gpus 2 -c --npz
done
pyhton add_suffix.py
We modify the generic_UNet.py
of nnunet source code for efficiency. Please make sure the code backup is done and then copy the whole repo to your nnunet environment.
nnUNet_predict -i $INPUTS_FOLDER -o $OUTPUTS_FOLDER -t 3 -m 3d_fullres --save_npz
python position_correction.py