lucidrains / DALLE-pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

RuntimeError: Error(s) in loading state_dict for DALLE

065294847 opened this issue · comments

Hi all, I'm training model on a new dataset. Would like to test the model, but when I run generate.py, I get the following error:

Traceback (most recent call last):
File "/content/DALLE-pytorch/generate.py", line 95, in
dalle.load_state_dict(weights, strict=False)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1483, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for DALLE:
size mismatch for image_emb.weight: copying a param with shape torch.Size([1024, 512]) from checkpoint, the shape in current model is torch.Size([8192, 512]).
size mismatch for transformer.pos_emb: copying a param with shape torch.Size([1, 513, 60]) from checkpoint, the shape in current model is torch.Size([1, 1281, 60]).
size mismatch for to_logits.1.weight: copying a param with shape torch.Size([50688, 512]) from checkpoint, the shape in current model is torch.Size([57856, 512]).
size mismatch for to_logits.1.bias: copying a param with shape torch.Size([50688]) from...

etc.

Can anybody help me out?

Thanks,
Daniel

I have exact problem do not know how to fix..

i think the VAE you used to train isn't the one you are giving it during generation - i can probably throw a more informative error

i think the VAE you used to train isn't the one you are giving it during generation - i can probably throw a more informative error

I downloaded several from the link https://github.com/robvanvolt/DALLE-models. None of them work. Would like to share one which fits the generate.py now? thanks.

Ah yes, I forgot to use the --taming flag for generation, thanks.

The model parameters changed in recent commits, these old checkpoints are not compatible with recent code. You will have to retrain

yea, there are two issues here - i should save the type of VAE used and just do a proper error message

should also save the version number of dalle-pytorch in each checkpoint and do an assert for equality

i'll fix both of these today

ok, this newest commit should resolve both issues 1c25f54