LTH14 / rcg

PyTorch implementation of RCG https://arxiv.org/abs/2312.03701

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Question about training RDM with main_rdm.py

wyyy04 opened this issue · comments

Thanks for your excellent work! I am facing difficulties while training RDM, and I hope to receive your assistance.

When training RDM, in ddpm.py, at line 564 in the get_input function, the input x (32, 256, 256, 3) after feature extraction has a shape of (32, 197, 768), where 32 is the batch size. However, an error occurs at line 578 ”rep = self.pretrained_encoder.head(rep)“ with the following traceback:

File "/home/user/anaconda3/envs/ldm/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/user/Diffusion_FSAR/rcg-main/rdm/models/diffusion/ddpm.py", line 578, in get_input
rep = self.pretrained_encoder.head(rep)
File "/home/user/anaconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/user/anaconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/container.py", line 217, in forward
input = module(input)
File "/home/user/anaconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/user/anaconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py", line 171, in forward
return F.batch_norm(
File "/home/user/anaconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/functional.py", line 2450, in batch_norm
return torch.batch_norm(
RuntimeError: running_mean should contain 197 elements not 4096

This appears to be a mismatch between the dimensions of the input x and the model "self.pretrained_encoder.head".
I am uncertain about the cause, and I am hopeful to receive your clarification and support. Thank you!

Thanks for your interest! Please make sure your timm version is 0.3.2, as later versions use a different forward_features implementation. #9 Please check this issue for a similar problem and solution.

Thank you very much for your guidance. I have resolved the current issue by updating the “timm” to 0.3.2. Best wishes!