FirasGit / medicaldiffusion

Medical Diffusion: This repository contains the code to our paper Medical Diffusion: Denoising Diffusion Probabilistic Models for 3D Medical Image Synthesis

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Training DDPM on LIDC: Out-of-memory problem and the date preprocessing code

Ryann-Ran opened this issue · comments

Hi there,
Thank you for your valuable contribution to the project.

I've been reproducing results on the LIDC dataset, and I have a couple of questions regarding the experimental setup.

I trained the DDPM model on the LIDC dataset with a batch_size of 50, using an NVIDIA RTX 3090 with 24GB GPU RAM, with the command line:
python train_ddpm.py model=ddpm dataset=lidc model.vqgan_ckpt='/home/ps/data/wangyuran/code/medicaldiffusion/checkpoints/vq_gan/LIDC/low_compression/lightning_logs/version_1/checkpoints/epoch98-step100000-10000-train/recon_loss0.09.ckpt' model.diffusion_img_size=32 model.diffusion_depth_size=32 model.diffusion_num_channels=8 model.dim_mults=[1,2,4,8] model.batch_size=50 model.gpus=0
I encountered a CUDA out-of-memory issue, which was like:
RuntimeError: CUDA out of memory. Tried to allocate 1.56 GiB (GPU 0; 23.69 GiB total capacity; 20.77 GiB already allocated; 856.56 MiB free; 20.79 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Then, I experimented with reduced batch sizes(40,30,20), but the problem persisted until I set the batch_size to 10. However, despite this adjustment, the training result doesn't look good.
I would like to know whether there might be issues with my operations on LIDC data preprocessing, potentially leading to the out-of-memory situation. Could you either release the code for LIDC data preprocessing or offer guidance on addressing the out-of-memory problem?

I would greatly appreciate your assistance in clarifying these questions. Thank you in advance for your help!

commented

Hello, I'm trying to use train_vqgan to compress images into the latent space. After training for 50000 steps, the loss becomes NaN. How can I solve this issue? In the first 50000 steps, there is a lot of white noise around the reconstructed image, but the reconstruction loss is already 9%. Is this reasonable? I would greatly appreciate it if you could answer when you have time.

@xin-wo-1 I hope these two methods can help you solve the NaN problems: 1. set model.precision=32; 2. check the way of data preprocessing.
And I'm also troubled by the white noise issue. My reconstruction loss finally reaches 0.3 and there is still a lot of white noise. How did you manage to reduce the loss to 0.09?

commented

@Ryann-Ran I trained the model on a brain MRI dataset with limited diversity, so I reduced the number of feature channels in the intermediate layers of the network. This adjustment allowed it to run within the constraints of a 35GB GPU memory. I have also modified the configuration file by setting model.precision=32. As a result, the reconstruction loss has decreased to 5%. However, there is still some white noise in the reconstructed output. I hope this information can help you.

@xin-wo-1 Thanks for your answer. Since the reconstruction image isn't shown in the paper, I continued to train the DDPM model using my custom dataset due to time constraints. The generated results seem to be okay.

commented

Hello, I'm currently facing an issue while training DDPM. My latent space dimensions are 256, 32, 32, 32. When I set model_channel in UNet to 32, the loss consistently stays around 70%. Then, when I increase model_channel to 128, the loss gradually decreases to 45%. However, as I continue training, the loss returns to around 70%. When attempting to further increase the initial channels in UNet, I encounter out-of-memory issues. Could you please let me know what the initial value of model_channel is during DDPM training? Your prompt response would be greatly appreciated.

@xin-wo-1 Hi there, the size of my input image is 256×256×32(h,w,d) and the latent vector encoded by the VQGAN I trained is of size 8×64×64×8(c,h,w,d). So I set model.diffusion_img_size=64 model.diffusion_depth_size=8 model.diffusion_num_channels=8. The loss is around 0.05 at step 150k.

commented

@Ryann-Ran Thanks for your reply, I found that when batch_size is set to 4, its loss oscillation is between 0.06 and 0.30, and after increasing the batch_size, it can alleviate this problem, so I would like to ask how much is better for batch_size, and how long did it take you to train this relatively stable generative model

@xin-wo-1 I set model.batch_size=10 using two NVIDIA RTX 3090 with 24GB GPU RAM, resulting in a total batch size of 20. The loss gets stable at step 150k just as the paper says, and this process takes about 12 days.
However, compared to the batch size mentioned in the paper, I still have concerns that my batch size might not be large enough, even though the generated volume appears satisfactory to the naked eye. The memory requirements are pretty high.

commented

Thank you for your response.I want to ask if the 'timesteps: 300' in this configuration file needs to be changed to 1000. There is no 'ddim' sampler in the code, and 'ddpm' seems like it can't handle such a small time step. When the timestep is 300, the loss is around 8% after training. There are no MRI radiological features. I would greatly appreciate a prompt reply.

@xin-wo-1 I didn't change the timesteps in my experiments. I learned little from MRI. Could you give some examples of the missing MRI radiological features?

commented

The results I obtained only reveal brain regions, where the tissue grayscale does not match typical MRI features. Furthermore, upon searching, I found that the timestep for DDPM cannot be excessively small such as 300, as it leads to blurry generated results, which aligns with my experimental findings.So I wanted to ask about the status of your experiment.

Could you share me some codes for LIDC data preprocessing?

@kanydao Here's the official code:

"""Adapted from https://github.com/peterhan91/cycleGAN/blob/db8f1d958c0879c29cf3932cae74a166317be812/prepro.py#L39"""

import os
import numpy as np
from glob import glob
import pydicom
import scipy.ndimage
from pathlib import Path
import argparse
from multiprocessing import Pool, cpu_count
from tqdm import tqdm



class CTExtractor:
    def __init__(self, input_path, out_path):
        super(CTExtractor, self).__init__()

        self.MIN_BOUND = -1000.0
        self.MAX_BOUND = 400.0
        self.PIXEL_MEAN = 0.25
        self.roi = 320
        self.size = 128

        self.path = input_path
        self.outpath = out_path
        self.slices = []
        self.fname = ''

    # Load the scans in given folder path
    def load_scan(self):
        slices_ = [pydicom.read_file(s) for s in glob(
            os.path.join(self.path, self.fname, '*/*/*.dcm'))]

        # Problem when CXR is available. This fixes it.
        num_subfolders = len(os.listdir(os.path.join(self.path, self.fname)))
        if num_subfolders > 1:
            print(f"Filename: {self.fname}, No. Subfolders: {num_subfolders}")
            slices = []
            for s in slices_:
                if s.Modality == 'CT':
                    slices.append(s)
                else:
                    print(s.Modality)
        else:
            slices = slices_

        slices.sort(key=lambda x: float(x.ImagePositionPatient[2]))
        try:
            slice_thickness = np.abs(
                slices[0].ImagePositionPatient[2] - slices[1].ImagePositionPatient[2])
        except:
            slice_thickness = np.abs(
                slices[0].SliceLocation - slices[1].SliceLocation)

        for s in slices:
            s.SliceThickness = slice_thickness
            if s.Modality != 'CT':
                print(f"NOT A CT. This is a {s.Modality}")

        return slices

    def get_pixels_hu(self, slices):
        image = np.stack([s.pixel_array for s in slices])
        # Convert to int16 (from sometimes int16),
        # should be possible as values should always be low enough (<32k)
        image = image.astype(np.int16)

        # Set outside-of-scan pixels to 0
        # The intercept is usually -1024, so air is approximately 0
        image[image == -2000] = 0

        # Convert to Hounsfield units (HU)
        for slice_number in range(len(slices)):

            intercept = slices[slice_number].RescaleIntercept
            slope = slices[slice_number].RescaleSlope

            if slope != 1:
                image[slice_number] = slope * \
                    image[slice_number].astype(np.float64)
                image[slice_number] = image[slice_number].astype(np.int16)

            image[slice_number] += np.int16(intercept)

        return np.array(image, dtype=np.int16)

    def resample(self, image, scan, new_spacing=[1.0, 1.0, 1.0]):
        # Determine current pixel spacing
        # print(scan[0].SliceThickness)
        # print(scan[0].PixelSpacing)
        spacing = np.array([scan[0].SliceThickness] +
                           list(scan[0].PixelSpacing), dtype=np.float32)

        resize_factor = spacing / new_spacing
        new_real_shape = image.shape * resize_factor
        new_shape = np.round(new_real_shape)
        real_resize_factor = new_shape / image.shape
        new_spacing = spacing / real_resize_factor

        image = scipy.ndimage.interpolation.zoom(
            image, real_resize_factor, mode='nearest')

        return image, new_spacing

    def normalize(self, image):
        image = (image - self.MIN_BOUND) / (self.MAX_BOUND - self.MIN_BOUND)
        image[image > 1] = 1.
        image[image < 0] = 0.
        return image*2-1.

    def zero_center(self, image):
        image = image - self.PIXEL_MEAN
        return image

    def pad_center(self, pix_resampled):
        pad_z = max(self.roi - pix_resampled.shape[0], 0)
        pad_x = max(self.roi - pix_resampled.shape[1], 0)
        pad_y = max(self.roi - pix_resampled.shape[2], 0)
        try:
            pad = np.pad(pix_resampled,
                         [(pad_z//2, pad_z-pad_z//2), (pad_x//2,
                                                       pad_x-pad_x//2), (pad_y//2, pad_y-pad_y//2)],
                         mode='constant',
                         constant_values=pix_resampled[0][10][10])
        except ValueError:
            print(pix_resampled.shape)
        except IndexError:
            print(pix_resampled.shape)
            pass
        return pad

    def crop_center(self, vol, cropz, cropy, cropx):
        z, y, x = vol.shape
        startx = x//2-(cropx//2)
        starty = y//2-(cropy//2)
        startz = z//2-(cropz//2)
        return vol[startz:startz+cropz, starty:starty+cropy, startx:startx+cropx]

    def save(self):
        path = os.path.join(self.outpath, self.fname, '128.npy')
        Path(os.path.join(self.outpath, self.fname)).mkdir(
            parents=True, exist_ok=True)
        np.save(path, self.vol)

    def run(self, fname):
        self.fname = fname
        self.patient = self.load_scan()
        self.vol = self.get_pixels_hu(self.patient)
        self.vol, _ = self.resample(self.vol, self.patient)
        if self.vol.shape[0] >= self.roi and self.vol.shape[1] >= self.roi and self.vol.shape[2] >= self.roi:
            self.vol = self.crop_center(self.vol, self.roi, self.roi, self.roi)
        else:
            self.vol = self.pad_center(self.vol)
            self.vol = self.crop_center(self.vol, self.roi, self.roi, self.roi)
        assert self.vol.shape == (self.roi, self.roi, self.roi)
        self.vol = scipy.ndimage.zoom(self.vol,
                                      [self.size/self.roi, self.size /
                                          self.roi, self.size/self.roi],
                                      mode='nearest')
        assert self.vol.shape == (self.size, self.size, self.size)
        self.vol = self.normalize(self.vol)
        self.save()


def worker(fname, extractor):
    try:
        extractor.run(fname)
    except:
        print('Error extracting the lung CT')
        print(fname)


if __name__ == "__main__":
    # Argument Parsing
    parser = argparse.ArgumentParser(description='CTExtractor for processing CT scans.')
    parser.add_argument('--input_path', type=str, required=True, help='Path to the input CT scans directory')
    parser.add_argument('--path_output', type=str, required=True, help='Path to the directory to save processed CT scans')
    args = parser.parse_args()

    input_path = args.input_path  
    path_output = args.path_output  


    extractor = CTExtractor(input_path, path_output)

    def worker_partial(fname):
        return worker(fname, extractor)

    fnames = os.listdir(input_path)
    print('total # of scans', len(fnames))

    with Pool(processes=4) as pool:
        res = list(tqdm(pool.imap(
            worker_partial, iter(fnames)), total=len(fnames)))

@Ryann-Ran Thanks for sharing the code!