andreas128 / RePaint

Official PyTorch Code and Models of "RePaint: Inpainting using Denoising Diffusion Probabilistic Models", CVPR 2022

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Model size other than 256?

ExponentialML opened this issue · comments

Hello, thanks for the release. I've trained a diffusion model using the guided-diffusion repository, but I can't seem to load the model. It's trained on a resolution of 64x64. This is a trained diffusion model using the improved-diffusion repository, not a classifier.

When setting the parameters and doing inference, I'm getting an error. I've tested both trained classifiers and diffusion models on guided-diffusion, as well as improved-diffusion, and both have worked without issue.

With RePaint, the error is as follows:

Traceback (most recent call last):
  File "test.py", line 180, in <module>
    main(conf_arg)
  File "test.py", line 69, in main
    model.load_state_dict(
  File "/conda/env/repaint/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1223, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for UNetModel:
	Missing key(s) in state_dict: "label_emb.weight", "input_blocks.4.0.in_layers.0.weight", "input_blocks.4.0.in_layers.0.bias", "input_blocks.4.0.in_layers.2.weight", "input_blocks.4.0.in_layers.2.bias", "input_blocks.4.0.emb_layers.1.weight", "input_blocks.4.0.emb_layers.1.bias", "input_blocks.4.0.out_layers.0.weight", "input_blocks.4.0.out_layers.0.bias", "input_blocks.4.0.out_layers.3.weight", "input_blocks.4.0.out_layers.3.bias", "input_blocks.8.0.in_layers.0.weight", "input_blocks.8.0.in_layers.0.bias", "input_blocks.8.0.in_layers.2.weight", "input_blocks.8.0.in_layers.2.bias", "input_blocks.8.0.emb_layers.1.weight", "input_blocks.8.0.emb_layers.1.bias", "input_blocks.8.0.out_layers.0.weight", "input_blocks.8.0.out_layers.0.bias", "input_blocks.8.0.out_layers.3.weight", "input_blocks.8.0.out_layers.3.bias", "input_blocks.12.0.in_layers.0.weight", "input_blocks.12.0.in_layers.0.bias", "input_blocks.12.0.in_layers.2.weight", "input_blocks.12.0.in_layers.2.bias", "input_blocks.12.0.emb_layers.1.weight", "input_blocks.12.0.emb_layers.1.bias", "input_blocks.12.0.out_layers.0.weight", "input_blocks.12.0.out_layers.0.bias", "input_blocks.12.0.out_layers.3.weight", "input_blocks.12.0.out_layers.3.bias", "output_blocks.3.2.in_layers.0.weight", "output_blocks.3.2.in_layers.0.bias", "output_blocks.3.2.in_layers.2.weight", "output_blocks.3.2.in_layers.2.bias", "output_blocks.3.2.emb_layers.1.weight", "output_blocks.3.2.emb_layers.1.bias", "output_blocks.3.2.out_layers.0.weight", "output_blocks.3.2.out_layers.0.bias", "output_blocks.3.2.out_layers.3.weight", "output_blocks.3.2.out_layers.3.bias", "output_blocks.7.2.in_layers.0.weight", "output_blocks.7.2.in_layers.0.bias", "output_blocks.7.2.in_layers.2.weight", "output_blocks.7.2.in_layers.2.bias", "output_blocks.7.2.emb_layers.1.weight", "output_blocks.7.2.emb_layers.1.bias", "output_blocks.7.2.out_layers.0.weight", "output_blocks.7.2.out_layers.0.bias", "output_blocks.7.2.out_layers.3.weight", "output_blocks.7.2.out_layers.3.bias", "output_blocks.11.1.in_layers.0.weight", "output_blocks.11.1.in_layers.0.bias", "output_blocks.11.1.in_layers.2.weight", "output_blocks.11.1.in_layers.2.bias", "output_blocks.11.1.emb_layers.1.weight", "output_blocks.11.1.emb_layers.1.bias", "output_blocks.11.1.out_layers.0.weight", "output_blocks.11.1.out_layers.0.bias", "output_blocks.11.1.out_layers.3.weight", "output_blocks.11.1.out_layers.3.bias". 
	Unexpected key(s) in state_dict: "input_blocks.4.0.op.weight", "input_blocks.4.0.op.bias", "input_blocks.8.0.op.weight", "input_blocks.8.0.op.bias", "input_blocks.12.0.op.weight", "input_blocks.12.0.op.bias", "output_blocks.3.2.conv.weight", "output_blocks.3.2.conv.bias", "output_blocks.7.2.conv.weight", "output_blocks.7.2.conv.bias", "output_blocks.11.1.conv.weight", "output_blocks.11.1.conv.bias". 
	size mismatch for out.2.weight: copying a param with shape torch.Size([3, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([6, 128, 3, 3]).
	size mismatch for out.2.bias: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([6]).

My model was trained using these parameters:

MODEL_FLAGS="--image_size 64 --num_channels 128 --num_res_blocks 3"
DIFFUSION_FLAGS="--diffusion_steps 4000 --noise_schedule linear"

With this custom yaml configuration (I'm not using a classifier here as I can't get the diffusion model to load):

attention_resolutions: 16,8
class_cond: false
diffusion_steps: 4000
learn_sigma: true
noise_schedule: linear
num_channels: 128
num_head_channels: 128
num_heads: 4
num_res_blocks: 3
resblock_updown: true
use_fp16: false
use_scale_shift_norm: true
classifier_scale: 4.0
lr_kernel_n_std: 2
num_samples: 100
show_progress: true
timestep_respacing: '250'
use_kl: false
predict_xstart: false
rescale_timesteps: false
rescale_learned_sigmas: false
classifier_use_fp16: false
classifier_width: 128
classifier_depth: 2
classifier_attention_resolutions: 16,8
classifier_use_scale_shift_norm: true
classifier_resblock_updown: true
classifier_pool: attention
num_heads_upsample: -1
channel_mult: ''
dropout: 0.0
use_checkpoint: false
use_new_attention_order: false
clip_denoised: true
use_ddim: false
latex_name: RePaint
method_name: Repaint
image_size: 64
model_path: ./trees/ema_0.9999_050000.pt
name: trees
inpa_inj_sched_prev: true
n_jobs: 1
print_estimated_vars: true
inpa_inj_sched_prev_cumnoise: false

Any help would be appreciated!

Can you print the parameters of the UNet from training in this line?

And then compare it with the parameters in this line?

Then try to adapt the config file. If it works, can you please add you config file to this repo?

Does that solve it?

Thanks, you led me in the right direction. I ended up retraining the model because it seems like I needed to use --learn_sigma True when training. I also retrained on a cosine schedule, and had to add in the cosine parts from the guided-diffusion repository since they were removed from this one.

After doing that and changing the configuration, everything worked. I'll post the configuration when I'm able to, then I'll close the issue.

Great, thanks!