johannakarras / DreamPose

Official implementation of "DreamPose: Fashion Image-to-Video Synthesis via Stable Diffusion"

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

how should i kill this error?

zzk88862 opened this issue · comments

Hi, Fixed this error by reversing this commit.
So changing this line:
unet.conv_in.weight[:, 4:] = torch.zeros(unet.conv_in.weight[:, 3:].shape)
For what was previously there:
unet.conv_in.weight[:, 3:] = torch.zeros(unet.conv_in.weight[:, 3:].shape)

@LaiaTarres thanks for your method, i changed this line to unet.conv_in.weight[:, 3:] = torch.zeros(unet.conv_in.weight[:, 3:].shape),

but another error is happend,

qpkC42mv7H

how should i fix this error?

The workaround is in this other issue.

@LaiaTarres thanks for your method, i changed this line to unet.conv_in.weight[:, 3:] = torch.zeros(unet.conv_in.weight[:, 3:].shape),

but another error is happend,

qpkC42mv7H

how should i fix this error?

@zzk88862 I added it just after this line
https://github.com/johannakarras/DreamPose/blob/main/test.py#L54

I fixed this by modifying the line (this is a common issues so all you need to do is toc change the state dict names such that they match your expected state:

    for k, v in vae_state_dict.items():
        name1 = k.replace('module.', '')  #name = k[7:] if k[:7] == 'module' else k
        name2 = name1.replace('query', 'to_q')  #name = k[7:] if k[:7] == 'module' else k
        name3 = name2.replace('key', 'to_k')
        name4 = name3.replace('value', 'to_v')
        name = name4.replace('proj_attn', 'to_out.0')
        new_state_dict[name] = v
    pipe.vae.load_state_dict(new_state_dict)
    pipe.vae = pipe.vae.cuda()