LTH14 / rcg

PyTorch implementation of RCG https://arxiv.org/abs/2312.03701

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

mismatched size from pretrained vqgan

creatorcao opened this issue · comments

commented

Hi! Thank you for your great work! I try to custom train vqgan and to load the checkpoint to mage pixel generator, but I received this error. Do you know why? I trained the vqgan with one gpu and didn't change the config file.

Traceback (most recent call last):
  File "main_mage.py", line 297, in <module>
    main(args)
  File "main_mage.py", line 197, in main
    model = models_mage.__dict__[args.model](mask_ratio_mu=args.mask_ratio_mu, mask_ratio_std=args.mask_ratio_std,
  File "./rcg/pixel_generator/mage/models_mage.py", line 594, in mage_vit_base_patch16
    model = MaskedGenerativeEncoderViT(
  File "./rcg/pixel_generator/mage/models_mage.py", line 299, in __init__
    self.vqgan = VQModel(ddconfig=vqgan_config.params.ddconfig,
  File "./rcg/pixel_generator/mage/taming/models/vqgan.py", line 28, in __init__
    self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
  File "./rcg/pixel_generator/mage/taming/models/vqgan.py", line 50, in init_from_ckpt
    self.load_state_dict(sd, strict=False)
  File ".local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for VQModel:
	size mismatch for encoder.down.2.block.0.nin_shortcut.weight: copying a param with shape torch.Size([256, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 256, 1, 1]).
	size mismatch for encoder.down.4.block.0.nin_shortcut.weight: copying a param with shape torch.Size([512, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([512, 512, 1, 1]).
	size mismatch for encoder.conv_out.weight: copying a param with shape torch.Size([256, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 512, 1, 1]).
	size mismatch for decoder.up.1.block.0.nin_shortcut.weight: copying a param with shape torch.Size([128, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 128, 1, 1]).
	size mismatch for decoder.up.3.block.0.nin_shortcut.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 256, 1, 1]).

MAGE uses a slightly different VQGAN network architecture than the original VQGAN. You could consider using the original network arch https://github.com/CompVis/taming-transformers/blob/master/taming/models/vqgan.py

commented

Thank you for your quick reply!
I used the original VQGAN and it worked! But it also needs to change the loss lossconfig: target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator. The loss log starts from 7 on a toy dataset. Do you know if this affects training the MAGE? Can you share your MAGE training log?

[12:34:50.471893] Epoch: [0]  [ 0/68]  eta: 0:43:28  lr: 0.000000  loss: 7.0747 (7.0747)  time: 38.3553  data: 2.9274  max mem: 6478
[12:35:10.637951] Epoch: [0]  [20/68]  eta: 0:02:13  lr: 0.000001  loss: 6.7850 (6.7941)  time: 1.0063  data: 0.0093  max mem: 7586
[12:35:28.299013] Epoch: [0]  [40/68]  eta: 0:00:51  lr: 0.000002  loss: 5.6179 (6.2519)  time: 0.8827  data: 0.0030  max mem: 7586
[12:35:47.316502] Epoch: [0]  [60/68]  eta: 0:00:12  lr: 0.000003  loss: 4.9155 (5.8232)  time: 0.9507  data: 0.0027  max mem: 7586
[12:35:54.612478] Epoch: [0]  [67/68]  eta: 0:00:01  lr: 0.000004  loss: 4.7324 (5.7004)  time: 0.9852  data: 0.0033  max mem: 7586

It won't affect the MAGE training. The MAGE training loss is not related to this VQGAN lossconfig. This VQGAN lossconfig is used to specify the VQGAN training loss, which is used only in the VQGAN training.

MAGE's training loss will be around 5.7 on ImageNet. However, depending on the dataset, the training loss can vary a lot -- some datasets are easier while others are harder. Your training loss looks reasonable. I typically look at the generation performance to see whether my training works or not instead of the training loss.

commented

Great. Thanks a lot. 👍

commented

您好,我在evaluate MAGE的时候碰到了如下的错误。我使用了taming repo的vqgan.py去tokenize自己的数据,添加了那个lossconfig在VQModel,但是这个error说load checkpoint后mismatched size。您可以解答一下吗?是因为
前面有人说的单卡训练VQGAN后load state_dict出现的错误(因为保存的权重中没有module),还是因为MAGE的VQGAN与original vqgan.py的结构不同出现的问题呢?

Traceback (most recent call last): File "/gpfs/space/home/etais/hpc_ping/rcg/main_mage.py", line 298, in <module> main(args) File "/gpfs/space/home/etais/hpc_ping/rcg/main_mage.py", line 198, in main model = models_mage.__dict__[args.model](mask_ratio_mu=args.mask_ratio_mu, mask_ratio_std=args.mask_ratio_std, File "/gpfs/space/home/etais/hpc_ping/rcg/pixel_generator/mage/models_mage.py", line 595, in mage_vit_base_patch16 model = MaskedGenerativeEncoderViT( File "/gpfs/space/home/etais/hpc_ping/rcg/pixel_generator/mage/models_mage.py", line 299, in __init__ self.vqgan = VQModel(ddconfig=vqgan_config.params.ddconfig, File "/gpfs/space/home/etais/hpc_ping/rcg/pixel_generator/mage/taming/models/vqgan.py", line 50, in __init__ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) File "/gpfs/space/home/etais/hpc_ping/rcg/pixel_generator/mage/taming/models/vqgan.py", line 66, in init_from_ckpt self.load_state_dict(sd, strict=False) File "/gpfs/space/home/etais/hpc_ping/.conda/envs/mages/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1497, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for VQModel: size mismatch for encoder.down.2.block.0.nin_shortcut.weight: copying a param with shape torch.Size([256, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 256, 1, 1]). size mismatch for encoder.down.4.block.0.nin_shortcut.weight: copying a param with shape torch.Size([512, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([512, 512, 1, 1]). size mismatch for encoder.conv_out.weight: copying a param with shape torch.Size([256, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 512, 1, 1]). size mismatch for decoder.up.1.block.0.nin_shortcut.weight: copying a param with shape torch.Size([128, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 128, 1, 1]). size mismatch for decoder.up.3.block.0.nin_shortcut.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 256, 1, 1]).

这是由于MAGE的vqgan与original vqgan.py的结构不同

commented

谢谢解答!
orginal vqgan.py里面有self.loss = instantiate_from_config(lossconfig),于是我在pixel_generator/mage/models_mage.py 那里load pretrained VQGAN添加了lossconfig=vqgan_config.params.lossconfig。似乎MAGE vqgan和 original vqgan.py只有这个loss有变化,但是这样就得到了上面的错误。您可以教我怎么更改吗?

我记得MAGE vqgan的变化主要是encoder decoder的网络结构有一些变化(比如没有attention)。由于VQGAN loss在MAGE training里不需要,我建议你可以把两边的都去掉(在训练完VQGAN后从checkpoint里扔掉)。

commented

我把两边的VQGAN(pretrained VQGAN checkpoint和MAGE的VQGAN)打印出来后,去掉pretrained VQGAN checkpoint多的结构,比如loss和不同的encoder, decoder, 但是仍然得到上面一样的错误。两边的config也是一样的。还可能是什么原因呢?这样去掉权重的一些结构会影响生成结果吗?

这个报错是因为同名的层在MAGE的VQGAN和原始VQGAN的结构不一样,。既然你有自己训练的VQGAN checkpoint,我建议你把MAGE里的VQGAN文件直接替换成原始VQGAN。用MAGE的VQGAN文件load原始的VQGAN是不行的。