johannakarras / DreamPose

Official implementation of "DreamPose: Fashion Image-to-Video Synthesis via Stable Diffusion"

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Training Script Issue

jmahajan117 opened this issue · comments

Was anyone able to get the training script to work? I'm getting this error.

Traceback (most recent call last):
  File "/home/jaym2/DreamPose/train.py", line 481, in <module>
    main(args)
  File "/home/jaym2/DreamPose/train.py", line 341, in main
    latents = vae.encode(batch["frame_j"].to(dtype=weight_dtype)).latent_dist.sample()
  File "/home/jaym2/miniconda3/envs/DP/lib/python3.9/site-packages/diffusers/utils/accelerate_utils.py", line 46, in wrapper
    return method(self, *args, **kwargs)
  File "/home/jaym2/miniconda3/envs/DP/lib/python3.9/site-packages/diffusers/models/autoencoder_kl.py", line 164, in encode
    h = self.encoder(x)
  File "/home/jaym2/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/jaym2/miniconda3/envs/DP/lib/python3.9/site-packages/diffusers/models/vae.py", line 139, in forward
    sample = down_block(sample)
  File "/home/jaym2/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/jaym2/miniconda3/envs/DP/lib/python3.9/site-packages/diffusers/models/unet_2d_blocks.py", line 1081, in forward
    hidden_states = resnet(hidden_states, temb=None)
  File "/home/jaym2/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/jaym2/miniconda3/envs/DP/lib/python3.9/site-packages/diffusers/models/resnet.py", line 596, in forward
    hidden_states = self.norm1(hidden_states)
  File "/home/jaym2/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/jaym2/.local/lib/python3.9/site-packages/torch/nn/modules/normalization.py", line 272, in forward
    return F.group_norm(
  File "/home/jaym2/.local/lib/python3.9/site-packages/torch/nn/functional.py", line 2516, in group_norm
    return torch.group_norm(input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled)
RuntimeError: Expected weight to be a vector of size equal to the number of channels in input, but got weight of shape [128] and input of shape [128, 640, 512]

Figured it out... unsqueeze the batch