the model mismatch for mege 'model.load state_dict(checkpoint['model']),when I code with viz_rcg.ipynb

Yisher opened this issue · comments

Hello!thank you for your great work.
I trained rdm.pth from,and trained mage.pth from, when I want to visualize the genereation, I encount this problem:
RuntimeError Traceback (most recent call last)
Cell In[9], line 2
1 checkpoint = torch.load(os.path.join('output/checkpoint-last.pth'), map_location='cpu')
----> 2 model.load_state_dict(checkpoint['model'], strict=True)
3 model.cuda()
4 _ = model.eval()

RuntimeError: Error(s) in loading state_dict for MaskedGenerativeEncoderViT:
size mismatch for cls_token: copying a param with shape torch.Size([1, 1, 768]) from checkpoint, the shape in current model is torch.Size([1, 1, 1024]).
size mismatch for pos_embed: copying a param with shape torch.Size([1, 257, 768]) from checkpoint, the shape in current model is torch.Size([1, 257, 1024]).
size mismatch for mask_token: copying a param with shape torch.Size([1, 1, 768]) from checkpoint, the shape in current model is torch.Size([1, 1, 1024]).
size mismatch for decoder_pos_embed: copying a param with shape torch.Size([1, 257, 768]) from checkpoint, the shape in current model is torch.Size([1, 257, 1024]).
size mismatch for decoder_pos_embed_learned: copying a param with shape torch.Size([1, 257, 768]) from checkpoint, the shape in current model is torch.Size([1, 257, 1024]).
size mismatch for token_emb.word_embeddings.weight: copying a param with shape torch.Size([2025, 768]) from checkpoint, the shape in current model is torch.Size([2025, 1024]).
size mismatch for token_emb.position_embeddings.weight: copying a param with shape torch.Size([257, 768]) from checkpoint, the shape in current model is torch.Size([257, 1024]).
size mismatch for decoder_pred.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([768, 1024]).
size mismatch for mlm_layer.fc.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
size mismatch for mlm_layer.fc.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1024]).
size mismatch for mlm_layer.ln.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1024]).
size mismatch for mlm_layer.ln.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1024]).
I can't understand why ,when I come with this problem.
I used my own dataset for training and use no distributed training

and my training mode is base mode,not the large or huge one ,that's the detail:
Model = MaskedGenerativeEncoderViT(
(token_emb): BertEmbeddings(
(word_embeddings): Embedding(2025, 768)
(position_embeddings): Embedding(257, 768)
(LayerNorm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
(patch_embed): PatchEmbed(
(proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
(blocks): ModuleList(
(0-11): 12 x Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(attn_drop): Dropout(p=0.1, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.1, inplace=False)
(drop_path): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop): Dropout(p=0.1, inplace=False)
(norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(decoder_embed): Linear(in_features=768, out_features=768, bias=True)
(decoder_blocks): ModuleList(
(0-7): 8 x Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(attn_drop): Dropout(p=0.1, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.1, inplace=False)
(drop_path): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop): Dropout(p=0.1, inplace=False)
(decoder_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(decoder_pred): Linear(in_features=768, out_features=768, bias=True)
(mlm_layer): MlmLayer(
(fc): Linear(in_features=768, out_features=768, bias=True)
(gelu): GELU(approximate='none')
(ln): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(criterion): LabelSmoothingCrossEntropy()
(pretrained_encoder): VisionTransformerMoCo(
(patch_embed): PatchEmbed(
(proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
(pos_drop): Dropout(p=0.0, inplace=False)
(blocks): ModuleList(
(0-11): 12 x Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
(drop_path): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU(approximate='none')
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop): Dropout(p=0.0, inplace=False)
(norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(head): Sequential(
(0): Linear(in_features=768, out_features=4096, bias=False)
(1): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Linear(in_features=4096, out_features=4096, bias=False)
(4): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
(6): Linear(in_features=4096, out_features=256, bias=False)
(7): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
(vqgan): VQModel(
(encoder): Encoder(
(conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(down): ModuleList(
(0-1): 2 x Module(
(block): ModuleList(
(0-1): 2 x ResnetBlock(
(norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(downsample): Downsample()
(2): Module(
(block): ModuleList(
(0): ResnetBlock(
(norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(nin_shortcut): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): ResnetBlock(
(norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(downsample): Downsample()
(3): Module(
(block): ModuleList(
(0-1): 2 x ResnetBlock(
(norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(downsample): Downsample()
(4): Module(
(block): ModuleList(
(0): ResnetBlock(
(norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(nin_shortcut): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): ResnetBlock(
(norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(mid): Module(
(block_1): ResnetBlock(
(norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(block_2): ResnetBlock(
(norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(norm_out): GroupNorm(32, 512, eps=1e-06, affine=True)
(conv_out): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
(decoder): Decoder(
(conv_in): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(mid): Module(
(block_1): ResnetBlock(
(norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(block_2): ResnetBlock(
(norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(up): ModuleList(
(0): Module(
(block): ModuleList(
(0-1): 2 x ResnetBlock(
(norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): Module(
(block): ModuleList(
(0): ResnetBlock(
(norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
(conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(nin_shortcut): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): ResnetBlock(
(norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(upsample): Upsample(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Module(
(block): ModuleList(
(0-1): 2 x ResnetBlock(
(norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(upsample): Upsample(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): Module(
(block): ModuleList(
(0): ResnetBlock(
(norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
(conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(nin_shortcut): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): ResnetBlock(
(norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(upsample): Upsample(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): Module(
(block): ModuleList(
(0-1): 2 x ResnetBlock(
(norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(upsample): Upsample(
(conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm_out): GroupNorm(32, 128, eps=1e-06, affine=True)
(conv_out): Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(quantize): VectorQuantizer2(
(embedding): Embedding(1024, 256)

I changed the ipynb's code from "model = models_mage.mage_vit_large_patch16" to model = models_mage.mage_vit_base_patch16,because I trained model in base mode ,now the error change into a new one:

RuntimeError Traceback (most recent call last)
Cell In[13], line 2
1 checkpoint = torch.load(os.path.join('output/checkpoint-last.pth'), map_location='cpu')
----> 2 model.load_state_dict(checkpoint['model'], strict=True)
3 model.cuda()
4 _ = model.eval()
RuntimeError: Error(s) in loading state_dict for MaskedGenerativeEncoderViT:
Missing key(s) in state_dict: "latent_prior_proj.weight", "latent_prior_proj.bias".
It seemed that there are only two keys missing ,but I dont know why

Did you set use_rep when training MAGE? It seems there's no latent_prior_proj in your trained model, which should be initialized here

I'll try it later when my gpu available,thx for reply!
what's more ,because I train my model in windows system and just use only one gpu4070 ,I can't init with the torch.distributed.launch ,so I have always meet the problem of the function-code in,that is concat_all_gather(tensor):
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
tensors_gather = [torch.ones_like(tensor)
for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
output =, dim=0)
return output
It will tell me there are "RuntimeError: Default process group has not been initialized, please make sure to call init_process_"when the code running into "torch.distributed.get_world_size()" and " torch.distributed.all_gather".
so I change the code as:
tensors_gather = [torch.ones_like(tensor)
for _ in range(1)]
output =, dim=0)
return output
then the code runs,
do you think my change of the code is accepetable?I am afraid it will destroy the structure of layers.

No -- you should simply comment out the concat_all_gather line. Your modification will return an output full of 1, as you use tensors_gather = [torch.ones_like(tensor) for _ in range(1)].

ok, I think that will be fine,but how to comment it?because there are code in

gen_images_batch = misc.concat_all_gather(gen_images_batch)

when I comment it , the new error is

the new error is
File "D:\DeepLearning\rcg-main\", line 297, in <module> main(args) File "D:\DeepLearning\rcg-main\", line 270, in main gen_img(model, args, epoch, batch_size=16, log_writer=log_writer, cfg=0) File "D:\DeepLearning\rcg-main\", line 102, in gen_img gen_images_batch, _ = model(None, None, ^^^^^^^^^^^^^^^^^ File "C:\Users\Yisher\anaconda3\Lib\site-packages\torch\nn\modules\", line 1501, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "D:\DeepLearning\rcg-main\pixel_generator\mage\", line 455, in forward return self.gen_image(bsz, num_iter, choice_temperature, sampled_rep, rdm_steps, eta, cfg, class_label_gen) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "D:\DeepLearning\rcg-main\pixel_generator\mage\", line 533, in gen_image input_embeddings[:, 0] = self.latent_prior_proj(sampled_rep) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\Yisher\anaconda3\Lib\site-packages\torch\nn\modules\", line 1501, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\Yisher\anaconda3\Lib\site-packages\torch\nn\modules\", line 114, in forward return F.linear(input, self.weight, self.bias) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: mat1 and mat2 shapes cannot be multiplied (16x256 and 768x768),
(in that error case i've set use_pre)

You just comment it out and it should be fine. This error is caused by use_rep, not by commenting. You need to set --rep_dim=256. Please follow the provided command in Readme and its arguments

python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 --node_rank=0 \ \
--pretrained_enc_arch mocov3_vit_base \
--pretrained_enc_path pretrained_enc_ckpts/mocov3/vitb.pth.tar --rep_drop_prob 0.1 \
--use_rep --rep_dim 256 --pretrained_enc_withproj --pretrained_enc_proj_dim 256 \
--pretrained_rdm_cfg ${RDM_CFG_PATH} --pretrained_rdm_ckpt ${RDM_CKPT_PATH} \
--rdm_steps 250 --eta 1.0 --temp 6.0 --num_iter 20 --num_images 50000 --cfg 0.0 \
--batch_size 64 --input_size 256 \
--model mage_vit_base_patch16 \
--mask_ratio_min 0.5 --mask_ratio_max 1.0 --mask_ratio_mu 0.75 --mask_ratio_std 0.25 \
--epochs 200 \
--warmup_epochs 10 \
--blr 1.5e-4 --weight_decay 0.05 \
--output_dir ${OUTPUT_DIR} \
--data_path ${IMAGENET_DIR} \
--dist_url tcp://${MASTER_SERVER_ADDRESS}:2214

I feel sorry for repeatedly asking questions,when I follow the args ,it works,the code is running now. the --pretrained_enc_withproj one is also important. when I get the new output tomorrow ,i'll update whether the result looks good here.
thank you for your reply!

No worries -- please let me know if you encounter other problems.