pfriedri / wdm-3d

PyTorch implementation for "WDM: 3D Wavelet Diffusion Models for High-Resolution Medical Image Synthesis" (2024)

Home Page:https://pfriedri.github.io/wdm-3d-io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

WDM: 3D Wavelet Diffusion Models for High-Resolution Medical Image Synthesis

License: MIT Static Badge arXiv

This is the official PyTorch implementation of the paper WDM: 3D Wavelet Diffusion Models for High-Resolution Medical Image Synthesis by Paul Friedrich, Julia Wolleb, Florentin Bieder, Alicia Durrer and Philippe C. Cattin.

If you find our work useful, please consider to ⭐ star this repository and 📝 cite our paper:

@article{friedrich2024wdm,
         title={WDM: 3D Wavelet Diffusion Models for High-Resolution Medical Image Synthesis},
         author={Paul Friedrich and Julia Wolleb and Florentin Bieder and Alicia Durrer and Philippe C. Cattin},
         year={2024},
         journal={arXiv preprint arXiv:2402.19043}}

Paper Abstract

Due to the three-dimensional nature of CT- or MR-scans, generative modeling of medical images is a particularly challenging task. Existing approaches mostly apply patch-wise, slice-wise, or cascaded generation techniques to fit the high-dimensional data into the limited GPU memory. However, these approaches may introduce artifacts and potentially restrict the model's applicability for certain downstream tasks. This work presents WDM, a wavelet-based medical image synthesis framework that applies a diffusion model on wavelet decomposed images. The presented approach is a simple yet effective way of scaling diffusion models to high resolutions and can be trained on a single 40 GB GPU. Experimental results on BraTS and LIDC-IDRI unconditional image generation at a resolution of 128 x 128 x 128 show state-of-the-art image fidelity (FID) and sample diversity (MS-SSIM) scores compared to GANs, Diffusion Models, and Latent Diffusion Models. Our proposed method is the only one capable of generating high-quality images at a resolution of 256 x 256 x 256.

Dependencies

We recommend using a conda environment to install the required dependencies. You can create and activate such an environment called wdm by running the following commands:

mamba env create -f environment.yml
mamba activate wdm

Training & Sampling

For training a new model or sampling from an already trained one, you can simply adapt and use the script run.sh. All relevant hyperparameters for reproducing our results are automatically set when using the correct MODEL in the general settings. For executing the script, simply use the following command:

bash run.sh

Supported settings (set in run.sh file):

MODE: 'training', 'sampling'

MODEL: 'ours_unet_128', 'ours_unet_256', 'ours_wnet_128', 'ours_wnet_256'

DATASET: 'brats', 'lidc-idri'

Pretrained Models

We released pretrained models on HuggingFace.

Currently available models:

  • BraTS 128: BraTS, 128 x 128 x 128, U-Net backbone, 1.2M Iterations
  • LIDC-IDRI 128: LIDC-IDRI, 128 x 128 x 128, U-Net backbone, 1.2M Iterations

Data

To ensure good reproducibility, we trained and evaluated our network on two publicly available datasets:

  • BRATS 2023: Adult Glioma, a dataset containing routine clinically-acquired, multi-site multiparametric magnetic resonance imaging (MRI) scans of brain tumor patients. We just used the T1-weighted images for training. The data is available here.

  • LIDC-IDRI, a dataset containing multi-site, thoracic computed tomography (CT) scans of lung cancer patients. The data is available here.

The provided code works for the following data structure (you might need to adapt the DATA_DIR variable in run.sh):

data
└───BRATS
    └───BraTS-GLI-00000-000
        └───BraTS-GLI-00000-000-seg.nii.gz
        └───BraTS-GLI-00000-000-t1c.nii.gz
        └───BraTS-GLI-00000-000-t1n.nii.gz
        └───BraTS-GLI-00000-000-t2f.nii.gz
        └───BraTS-GLI-00000-000-t2w.nii.gz  
    └───BraTS-GLI-00001-000
    └───BraTS-GLI-00002-000
    ...

└───LIDC-IDRI
    └───LIDC-IDRI-0001
      └───preprocessed.nii.gz
    └───LIDC-IDRI-0002
    └───LIDC-IDRI-0003
    ...

We provide a script for preprocessing LIDC-IDRI. Simply run the following command with the correct path to the downloaded DICOM files DICOM_PATH and the directory you want to store the processed nifti files NIFTI_PATH:

python utils/preproc_lidc-idri.py --dicom_dir DICOM_PATH --nifti_dir NIFTI_PATH

Implementation Details for Comparing Methods

All experiments were performed on a system with an AMD Epyc 7742 CPU and a NVIDIA A100 (40GB) GPU.

TODOs

We plan to add further functionality to our framework:

  • Add compatibility for more datasets like MRNet, ADNI, or fastMRI
  • Release pre-trained models
  • Extend the framework for 3D image inpainting
  • Extend the framework for 3D image-to-image translation

Acknowledgements

Our code is based on / inspired by the following repositories:

For computing FID scores we use a pretrained model (resnet_50_23dataset.pth) from:

Thanks for making these projects open-source.

About

PyTorch implementation for "WDM: 3D Wavelet Diffusion Models for High-Resolution Medical Image Synthesis" (2024)

https://pfriedri.github.io/wdm-3d-io

License:MIT License


Languages

Language:Python 98.6%Language:Shell 1.4%