cszn / KAIR

Image Restoration Toolbox (PyTorch). Training and testing codes for DPIR, USRNet, DnCNN, FFDNet, SRMD, DPSR, BSRGAN, SwinIR

Home Page:https://cszn.github.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

CUDA Out of memory when training SwinIR with L2 Loss

IMasI2Cat opened this issue · comments

Hi, I'm trying to run a SwinIR training. I followed as a template the json in options/swinir/train_swinir_sr_realworld_x4_psnr.json for my own data (RGB) with just some slight modifications (i.e. dataset_type to "sr", LR, checkpoints, paths to my data and dataloader_num_workers to 1). I have been working with this configuration for several experiments with no issues. However, now I tried to run it with L2 loss ("G_lossfn_type": "l2" ) and I'm having a CUDA Out of memory error on a certain step of the first validation. I also tried "l2sum" and the same happens, as well as I tried to reduce batch size to 1. I'm copying the error trace. Any idea would help.

Traceback (most recent call last):                                                                                                                                                                                 
  File "/home/imas/code/Repos/KAIR/main_train_psnr.py", line 281, in <module>                                                                                                                             
    main()                                                                                                                                                                                                         
  File "/home/imas/code/Repos/KAIR/main_train_psnr.py", line 242, in main                                                                                                                                 
    model.test()                                                                                                                                                                                                   
  File "/home/imas/code/Repos/KAIR/models/model_plain.py", line 219, in test                                                                                                                              
    self.netG_forward()                                                                                                                                                                                            
  File "/home/imas/code/Repos/KAIR/models/model_plain.py", line 169, in netG_forward                                                                                                                      
    self.E = self.netG(self.L)                                                                                                                                                                                     
  File "/home/imas/miniconda3/envs/pytorch3d2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl                                                                             
    return forward_call(*input, **kwargs)                                                                                                                                                                          
  File "/home/imas/miniconda3/envs/pytorch3d2/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 169, in forward                                                                         
    return self.module(*inputs[0], **kwargs[0])                                                                                                                                                                    
  File "/home/imas/miniconda3/envs/pytorch3d2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl                                                                             
    return forward_call(*input, **kwargs)                                                                                                                                                                          
  File "/home/imas/code/Repos/KAIR/models/network_swinir.py", line 826, in forward                                                                                                                        
    x = self.conv_after_body(self.forward_features(x)) + x                                                                                                                                                         
  File "/home/imas/code/Repos/KAIR/models/network_swinir.py", line 798, in forward_features                                                                                                               
    x = layer(x, x_size)                                                                                                                                                                                           
  File "/home/imas/miniconda3/envs/pytorch3d2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl                                                                             
    return forward_call(*input, **kwargs)                                                                                                                                                                          
  File "/home/imas/code/Repos/KAIR/models/network_swinir.py", line 482, in forward                                                                                                                        
    return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x                                                                                                             
  File "/home/imas/miniconda3/envs/pytorch3d2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl                                                                             
    return forward_call(*input, **kwargs)                                                                                                                                                                          
  File "/home/imas/code/Repos/KAIR/models/network_swinir.py", line 402, in forward                                                                                                                        
    x = blk(x, x_size)                                                                                                                                                                                             
  File "/home/imas/miniconda3/envs/pytorch3d2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl                                                                             
    return forward_call(*input, **kwargs)                                                                                                                                                                          
  File "/home/imas/code/Repos/KAIR/models/network_swinir.py", line 262, in forward                                                                                                                        
    attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))                                                                                                                             
  File "/home/imas/miniconda3/envs/pytorch3d2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl                                                                             
    return forward_call(*input, **kwargs)                                                                                                                                                                          
  File "/home/imas/code/Repos/KAIR/models/network_swinir.py", line 130, in forward                                                                                                                        
    attn = attn + relative_position_bias.unsqueeze(0)                                                                                                                                                              
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.05 GiB (GPU 0; 23.68 GiB total capacity; 17.74 GiB already allocated; 796.50 MiB free; 21.27 GiB reserved in total by PyTorch) If reserved mem
ory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF                                                              
srun: error: i2c01: task 0: Exited with exit code 1