FirasGit / medicaldiffusion

Medical Diffusion: This repository contains the code to our paper Medical Diffusion: Denoising Diffusion Probabilistic Models for 3D Medical Image Synthesis

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Nan issue

CaiwenXu opened this issue · comments

Hi, many thanks for your excellent work! I have a problem when training the VQ GAN, the loss will suddenly become nan, and do you know why this happens? I used the LIDC dataset.

I'm currently having the same problem I used the exact same configs provided here and still no luck, very unstable training.
The Model does also suffer from mode collapse after the Discriminator starts training.

I believe this problem may stem from the accumulate_grad_batches parameter. I trained a run for more than 50000 steps successfully, but trying to replicate training with accumulate_grad_batches > 1 runs into the nan problem. @CWX-student can you confirm this or do you have any other info on your end?

Update: Using setting the precision parameter in the config to at least 32 seems to alleviate this problem. https://discuss.pytorch.org/t/distributed-training-gives-nan-loss-but-single-gpu-training-is-fine/63664/6