andreas128 / RePaint

Official PyTorch Code and Models of "RePaint: Inpainting using Denoising Diffusion Probabilistic Models", CVPR 2022

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Guided diffusion pretrain can't sample in the Repaint

xiongsua opened this issue · comments

I sample get mistake on the Repaint . When i train on the guided-diffusion and after use ema_pt to smaple. Because guided-diffusion code need to change? and what do i need change in the guided-diffusion code

Same problem. Do you have any solution.

It seems that the model architecture has changed.

Try to use the model code from your training by copying it here.

Does that work?

Hi, did you sovled the problem?

I sample get mistake on the Repaint . When i train on the guided-diffusion and after use ema_pt to smaple. Because guided-diffusion code need to change? and what do i need change in the guided-diffusion code

commented

@xiongsua Hi, did you sovled the problem?

I meet the same question, the values in the checkpoint do not match the network,
RuntimeError: Error(s) in loading state_dict for UNetModel:
Missing key(s) in state_dict: "input_blocks.4.0.in_layers.0.weight", "input_blocks.4.0.in_layers.0.bias", "input_blocks.4.0.in_layers.2.weight", "input_blocks.4.0.in_layers.2.bias", "input_blocks.4.0.emb_layers.1.weight", "i
nput_blocks.4.0.emb_layers.1.bias", "input_blocks.4.0.out_layers.0.weight", "input_blocks.4.0.out_layers.0.bias", "input_blocks.4.0.out_layers.3.weight", "input_blocks.4.0.out_layers.3.bias", "input_blocks.5.1.norm.weight", "input_b
locks.5.1.norm.bias", "input_blocks.5.1.qkv.weight", "input_blocks.5.1.qkv.bias", "input_blocks.5.1.proj_out.weight", "input_blocks.5.1.proj_out.bias", "input_blocks.6.1.norm.weight", "input_blocks.6.1.norm.bias", "input_blocks.6.1.
qkv.weight", "input_blocks.6.1.qkv.bias", "input_blocks.6.1.proj_out.weight", "input_blocks.6.1.proj_out.bias", "input_blocks.7.1.norm.weight", "input_blocks.7.1.norm.bias", "input_blocks.7.1.qkv.weight", "input_blocks.7.1.qkv.bias"
, "input_blocks.7.1.proj_out.weight", "input_blocks.7.1.proj_out.bias", "input_blocks.8.0.in_layers.0.weight", "input_blocks.8.0.in_layers.0.bias", "input_blocks.8.0.in_layers.2.weight", "input_blocks.8.0.in_layers.2.bias", "input_b
locks.8.0.emb_layers.1.weight", "input_blocks.8.0.emb_layers.1.bias", "input_blocks.8.0.out_layers.0.weight", "input_blocks.8.0.out_layers.0.bias", "input_blocks.8.0.out_layers.3.weight", "input_blocks.8.0.out_layers.3.bias", "input
blocks.12.0.in_layers.0.weight", "input_blocks.12.0.in_layers.0.bias", "input_blocks.12.0.in_layers.2.weight", "input_blocks.12.0.in_layers.2.bias", "input_blocks.12.0.emb_layers.1.weight", "input_blocks.12.0.emb_layers.1.bias", "i
nput_blocks.12.0.out_layers.0.weight", "input_blocks.12.0.out_layers.0.bias", "input_blocks.12.0.out_layers.3.weight", "input_blocks.12.0.out_layers.3.bias", "output_blocks.3.2.in_layers.0.weight", "output_blocks.3.2.in_layers.0.bia
s", "output_blocks.3.2.in_layers.2.weight", "output_blocks.3.2.in_layers.2.bias", "output_blocks.3.2.emb_layers.1.weight", "output_blocks.3.2.emb_layers.1.bias", "output_blocks.3.2.out_layers.0.weight", "output_blocks.3.2.out_layers
.0.bias", "output_blocks.3.2.out_layers.3.weight", "output_blocks.3.2.out_layers.3.bias", "output_blocks.7.2.in_layers.0.weight", "output_blocks.7.2.in_layers.0.bias", "output_blocks.7.2.in_layers.2.weight", "output_blocks.7.2.in_la
yers.2.bias", "output_blocks.7.2.emb_layers.1.weight", "output_blocks.7.2.emb_layers.1.bias", "output_blocks.7.2.out_layers.0.weight", "output_blocks.7.2.out_layers.0.bias", "output_blocks.7.2.out_layers.3.weight", "output_blocks.7.
2.out_layers.3.bias", "output_blocks.8.1.norm.weight", "output_blocks.8.1.norm.bias", "output_blocks.8.1.qkv.weight", "output_blocks.8.1.qkv.bias", "output_blocks.8.1.proj_out.weight", "output_blocks.8.1.proj_out.bias", "output_bloc
ks.9.1.norm.weight", "output_blocks.9.1.norm.bias", "output_blocks.9.1.qkv.weight", "output_blocks.9.1.qkv.bias", "output_blocks.9.1.proj_out.weight", "output_blocks.9.1.proj_out.bias", "output_blocks.10.1.norm.weight", "output_bloc
ks.10.1.norm.bias", "output_blocks.10.1.qkv.weight", "output_blocks.10.1.qkv.bias", "output_blocks.10.1.proj_out.weight", "output_blocks.10.1.proj_out.bias", "output_blocks.11.1.norm.weight", "output_blocks.11.1.norm.bias", "output

blocks.11.1.qkv.weight", "output_blocks.11.1.qkv.bias", "output_blocks.11.1.proj_out.weight", "output_blocks.11.1.proj_out.bias", "output_blocks.11.2.in_layers.0.weight", "output_blocks.11.2.in_layers.0.bias", "output_blocks.11.2.in
_layers.2.weight", "output_blocks.11.2.in_layers.2.bias", "output_blocks.11.2.emb_layers.1.weight", "output_blocks.11.2.emb_layers.1.bias", "output_blocks.11.2.out_layers.0.weight", "output_blocks.11.2.out_layers.0.bias", "output_bl
ocks.11.2.out_layers.3.weight", "output_blocks.11.2.out_layers.3.bias".
Unexpected key(s) in state_dict: "input_blocks.4.0.op.weight", "input_blocks.4.0.op.bias", "input_blocks.8.0.op.weight", "input_blocks.8.0.op.bias", "input_blocks.12.0.op.weight", "input_blocks.12.0.op.bias", "output_blocks.
3.2.conv.weight", "output_blocks.3.2.conv.bias", "output_blocks.7.2.conv.weight", "output_blocks.7.2.conv.bias", "output_blocks.11.1.conv.weight", "output_blocks.11.1.conv.bias".
size mismatch for out.2.weight: copying a param with shape torch.Size([3, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([6, 128, 3, 3]).
size mismatch for out.2.bias: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([6]).

I meet the same question, the values in the checkpoint do not match the network, RuntimeError: Error(s) in loading state_dict for UNetModel: Missing key(s) in state_dict: "input_blocks.4.0.in_layers.0.weight", "input_blocks.4.0.in_layers.0.bias", "input_blocks.4.0.in_layers.2.weight", "input_blocks.4.0.in_layers.2.bias", "input_blocks.4.0.emb_layers.1.weight", "i nput_blocks.4.0.emb_layers.1.bias", "input_blocks.4.0.out_layers.0.weight", "input_blocks.4.0.out_layers.0.bias", "input_blocks.4.0.out_layers.3.weight", "input_blocks.4.0.out_layers.3.bias", "input_blocks.5.1.norm.weight", "input_b locks.5.1.norm.bias", "input_blocks.5.1.qkv.weight", "input_blocks.5.1.qkv.bias", "input_blocks.5.1.proj_out.weight", "input_blocks.5.1.proj_out.bias", "input_blocks.6.1.norm.weight", "input_blocks.6.1.norm.bias", "input_blocks.6.1. qkv.weight", "input_blocks.6.1.qkv.bias", "input_blocks.6.1.proj_out.weight", "input_blocks.6.1.proj_out.bias", "input_blocks.7.1.norm.weight", "input_blocks.7.1.norm.bias", "input_blocks.7.1.qkv.weight", "input_blocks.7.1.qkv.bias" , "input_blocks.7.1.proj_out.weight", "input_blocks.7.1.proj_out.bias", "input_blocks.8.0.in_layers.0.weight", "input_blocks.8.0.in_layers.0.bias", "input_blocks.8.0.in_layers.2.weight", "input_blocks.8.0.in_layers.2.bias", "input_b locks.8.0.emb_layers.1.weight", "input_blocks.8.0.emb_layers.1.bias", "input_blocks.8.0.out_layers.0.weight", "input_blocks.8.0.out_layers.0.bias", "input_blocks.8.0.out_layers.3.weight", "input_blocks.8.0.out_layers.3.bias", "input blocks.12.0.in_layers.0.weight", "input_blocks.12.0.in_layers.0.bias", "input_blocks.12.0.in_layers.2.weight", "input_blocks.12.0.in_layers.2.bias", "input_blocks.12.0.emb_layers.1.weight", "input_blocks.12.0.emb_layers.1.bias", "i nput_blocks.12.0.out_layers.0.weight", "input_blocks.12.0.out_layers.0.bias", "input_blocks.12.0.out_layers.3.weight", "input_blocks.12.0.out_layers.3.bias", "output_blocks.3.2.in_layers.0.weight", "output_blocks.3.2.in_layers.0.bia s", "output_blocks.3.2.in_layers.2.weight", "output_blocks.3.2.in_layers.2.bias", "output_blocks.3.2.emb_layers.1.weight", "output_blocks.3.2.emb_layers.1.bias", "output_blocks.3.2.out_layers.0.weight", "output_blocks.3.2.out_layers .0.bias", "output_blocks.3.2.out_layers.3.weight", "output_blocks.3.2.out_layers.3.bias", "output_blocks.7.2.in_layers.0.weight", "output_blocks.7.2.in_layers.0.bias", "output_blocks.7.2.in_layers.2.weight", "output_blocks.7.2.in_la yers.2.bias", "output_blocks.7.2.emb_layers.1.weight", "output_blocks.7.2.emb_layers.1.bias", "output_blocks.7.2.out_layers.0.weight", "output_blocks.7.2.out_layers.0.bias", "output_blocks.7.2.out_layers.3.weight", "output_blocks.7. 2.out_layers.3.bias", "output_blocks.8.1.norm.weight", "output_blocks.8.1.norm.bias", "output_blocks.8.1.qkv.weight", "output_blocks.8.1.qkv.bias", "output_blocks.8.1.proj_out.weight", "output_blocks.8.1.proj_out.bias", "output_bloc ks.9.1.norm.weight", "output_blocks.9.1.norm.bias", "output_blocks.9.1.qkv.weight", "output_blocks.9.1.qkv.bias", "output_blocks.9.1.proj_out.weight", "output_blocks.9.1.proj_out.bias", "output_blocks.10.1.norm.weight", "output_bloc ks.10.1.norm.bias", "output_blocks.10.1.qkv.weight", "output_blocks.10.1.qkv.bias", "output_blocks.10.1.proj_out.weight", "output_blocks.10.1.proj_out.bias", "output_blocks.11.1.norm.weight", "output_blocks.11.1.norm.bias", "output blocks.11.1.qkv.weight", "output_blocks.11.1.qkv.bias", "output_blocks.11.1.proj_out.weight", "output_blocks.11.1.proj_out.bias", "output_blocks.11.2.in_layers.0.weight", "output_blocks.11.2.in_layers.0.bias", "output_blocks.11.2.in _layers.2.weight", "output_blocks.11.2.in_layers.2.bias", "output_blocks.11.2.emb_layers.1.weight", "output_blocks.11.2.emb_layers.1.bias", "output_blocks.11.2.out_layers.0.weight", "output_blocks.11.2.out_layers.0.bias", "output_bl ocks.11.2.out_layers.3.weight", "output_blocks.11.2.out_layers.3.bias". Unexpected key(s) in state_dict: "input_blocks.4.0.op.weight", "input_blocks.4.0.op.bias", "input_blocks.8.0.op.weight", "input_blocks.8.0.op.bias", "input_blocks.12.0.op.weight", "input_blocks.12.0.op.bias", "output_blocks. 3.2.conv.weight", "output_blocks.3.2.conv.bias", "output_blocks.7.2.conv.weight", "output_blocks.7.2.conv.bias", "output_blocks.11.1.conv.weight", "output_blocks.11.1.conv.bias". size mismatch for out.2.weight: copying a param with shape torch.Size([3, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([6, 128, 3, 3]). size mismatch for out.2.bias: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([6]).

@fangtun Hi, i have the same issue. Have you solved it?

Hi, I also have a same issue.

Traceback (most recent call last): File "test.py", line 180, in
main(conf_arg)
File "test.py", line 69, in main
model.load_state_dict(
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 2001, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for UNetModel:
Missing key(s) in state_dict: "input_blocks.3.0.in_layers.0.weight", "input_blocks.3.0.in_layers.0.bias", "input_blocks.3.0.in_layers.2.weight", "input_blocks.3.0.in_layers.2.bias", "input_blocks.3.0.emb_layers.1.weight", "input_blocks.3.0.emb_layers.1.bias", "input_blocks.3.0.out_layers.0.weight", "input_blocks.3.0.out_layers.0.bias", "input_blocks.3.0.out_layers.3.weight", "input_blocks.3.0.out_layers.3.bias", "input_blocks.6.0.in_layers.0.weight", "input_blocks.6.0.in_layers.0.bias", "input_blocks.6.0.in_layers.2.weight", "input_blocks.6.0.in_layers.2.bias", "input_blocks.6.0.emb_layers.1.weight", "input_blocks.6.0.emb_layers.1.bias", "input_blocks.6.0.out_layers.0.weight", "input_blocks.6.0.out_layers.0.bias", "input_blocks.6.0.out_layers.3.weight", "input_blocks.6.0.out_layers.3.bias", "input_blocks.9.0.in_layers.0.weight", "input_blocks.9.0.in_layers.0.bias", "input_blocks.9.0.in_layers.2.weight", "input_blocks.9.0.in_layers.2.bias", "input_blocks.9.0.emb_layers.1.weight", "input_blocks.9.0.emb_layers.1.bias", "input_blocks.9.0.out_layers.0.weight", "input_blocks.9.0.out_layers.0.bias", "input_blocks.9.0.out_layers.3.weight", "input_blocks.9.0.out_layers.3.bias", "input_blocks.12.0.in_layers.0.weight", "input_blocks.12.0.in_layers.0.bias", "input_blocks.12.0.in_layers.2.weight", "input_blocks.12.0.in_layers.2.bias", "input_blocks.12.0.emb_layers.1.weight", "input_blocks.12.0.emb_layers.1.bias", "input_blocks.12.0.out_layers.0.weight", "input_blocks.12.0.out_layers.0.bias", "input_blocks.12.0.out_layers.3.weight", "input_blocks.12.0.out_layers.3.bias", "input_blocks.15.0.in_layers.0.weight", "input_blocks.15.0.in_layers.0.bias", "input_blocks.15.0.in_layers.2.weight", "input_blocks.15.0.in_layers.2.bias", "input_blocks.15.0.emb_layers.1.weight", "input_blocks.15.0.emb_layers.1.bias", "input_blocks.15.0.out_layers.0.weight", "input_blocks.15.0.out_layers.0.bias", "input_blocks.15.0.out_layers.3.weight", "input_blocks.15.0.out_layers.3.bias", "output_blocks.2.2.in_layers.0.weight", "output_blocks.2.2.in_layers.0.bias", "output_blocks.2.2.in_layers.2.weight", "output_blocks.2.2.in_layers.2.bias", "output_blocks.2.2.emb_layers.1.weight", "output_blocks.2.2.emb_layers.1.bias", "output_blocks.2.2.out_layers.0.weight", "output_blocks.2.2.out_layers.0.bias", "output_blocks.2.2.out_layers.3.weight", "output_blocks.2.2.out_layers.3.bias", "output_blocks.5.2.in_layers.0.weight", "output_blocks.5.2.in_layers.0.bias", "output_blocks.5.2.in_layers.2.weight", "output_blocks.5.2.in_layers.2.bias", "output_blocks.5.2.emb_layers.1.weight", "output_blocks.5.2.emb_layers.1.bias", "output_blocks.5.2.out_layers.0.weight", "output_blocks.5.2.out_layers.0.bias", "output_blocks.5.2.out_layers.3.weight", "output_blocks.5.2.out_layers.3.bias", "output_blocks.8.2.in_layers.0.weight", "output_blocks.8.2.in_layers.0.bias", "output_blocks.8.2.in_layers.2.weight", "output_blocks.8.2.in_layers.2.bias", "output_blocks.8.2.emb_layers.1.weight", "output_blocks.8.2.emb_layers.1.bias", "output_blocks.8.2.out_layers.0.weight", "output_blocks.8.2.out_layers.0.bias", "output_blocks.8.2.out_layers.3.weight", "output_blocks.8.2.out_layers.3.bias", "output_blocks.11.1.in_layers.0.weight", "output_blocks.11.1.in_layers.0.bias", "output_blocks.11.1.in_layers.2.weight", "output_blocks.11.1.in_layers.2.bias", "output_blocks.11.1.emb_layers.1.weight", "output_blocks.11.1.emb_layers.1.bias", "output_blocks.11.1.out_layers.0.weight", "output_blocks.11.1.out_layers.0.bias", "output_blocks.11.1.out_layers.3.weight", "output_blocks.11.1.out_layers.3.bias", "output_blocks.14.1.in_layers.0.weight", "output_blocks.14.1.in_layers.0.bias", "output_blocks.14.1.in_layers.2.weight", "output_blocks.14.1.in_layers.2.bias", "output_blocks.14.1.emb_layers.1.weight", "output_blocks.14.1.emb_layers.1.bias", "output_blocks.14.1.out_layers.0.weight", "output_blocks.14.1.out_layers.0.bias", "output_blocks.14.1.out_layers.3.weight", "output_blocks.14.1.out_layers.3.bias".
Unexpected key(s) in state_dict: "input_blocks.3.0.op.weight", "input_blocks.3.0.op.bias", "input_blocks.6.0.op.weight", "input_blocks.6.0.op.bias", "input_blocks.9.0.op.weight", "input_blocks.9.0.op.bias", "input_blocks.12.0.op.weight", "input_blocks.12.0.op.bias", "input_blocks.15.0.op.weight", "input_blocks.15.0.op.bias", "output_blocks.2.2.conv.weight", "output_blocks.2.2.conv.bias", "output_blocks.5.2.conv.weight", "output_blocks.5.2.conv.bias", "output_blocks.8.2.conv.weight", "output_blocks.8.2.conv.bias", "output_blocks.11.1.conv.weight", "output_blocks.11.1.conv.bias", "output_blocks.14.1.conv.weight", "output_blocks.14.1.conv.bias".

@handsomecong001 Have you tried it?

It might not work perfect, based on the arguments. However, in my case updating UNetModel and adding SiLU in the import section worked.

`class UNetModel(nn.Module):
"""
The full UNet model with attention and timestep embedding.

:param in_channels: channels in the input Tensor.
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
    attention will take place. May be a set, list, or tuple.
    For example, if this contains 4, then at 4x downsampling, attention
    will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
    downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param num_classes: if specified (as an int), then this model will be
    class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
:param num_heads: the number of attention heads in each attention layer.
"""

def __init__(
    self,
    image_size,
    in_channels,
    model_channels,
    out_channels,
    num_res_blocks,
    attention_resolutions,
    dropout=0,
    channel_mult=(1, 2, 4, 8),
    conv_resample=True,
    dims=2,
    num_classes=None,
    use_checkpoint=False,
    use_fp16=False,
    num_heads=1,
    num_head_channels=-1,
    num_heads_upsample=-1,
    use_scale_shift_norm=False,
    resblock_updown=False,
    use_new_attention_order=False,
    conf=None
):
    super().__init__()

    if num_heads_upsample == -1:
        num_heads_upsample = num_heads

    self.image_size = image_size
    self.in_channels = in_channels
    self.model_channels = model_channels
    self.out_channels = out_channels
    self.num_res_blocks = num_res_blocks
    self.attention_resolutions = attention_resolutions
    self.dropout = dropout
    self.channel_mult = channel_mult
    self.conv_resample = conv_resample
    self.num_classes = num_classes
    self.use_checkpoint = use_checkpoint
    self.dtype = th.float16 if use_fp16 else th.float32
    self.num_heads = num_heads
    self.num_head_channels = num_head_channels
    self.num_heads_upsample = num_heads_upsample
    self.conf = conf

    time_embed_dim = model_channels * 4
    self.time_embed = nn.Sequential(
        linear(model_channels, time_embed_dim),
        SiLU(),
        linear(time_embed_dim, time_embed_dim),
    )

    if self.num_classes is not None:
        self.label_emb = nn.Embedding(num_classes, time_embed_dim)

    self.input_blocks = nn.ModuleList(
        [
            TimestepEmbedSequential(
                conv_nd(dims, in_channels, model_channels, 3, padding=1)
            )
        ]
    )
    input_block_chans = [model_channels]
    ch = model_channels
    ds = 1
    for level, mult in enumerate(channel_mult):
        for _ in range(num_res_blocks):
            layers = [
                ResBlock(
                    ch,
                    time_embed_dim,
                    dropout,
                    out_channels=mult * model_channels,
                    dims=dims,
                    use_checkpoint=use_checkpoint,
                    use_scale_shift_norm=use_scale_shift_norm,
                )
            ]
            ch = mult * model_channels
            if ds in attention_resolutions:
                layers.append(
                    AttentionBlock(
                        ch,
                        use_checkpoint=use_checkpoint,
                        num_heads=num_heads,
                        num_head_channels=num_head_channels,
                        use_new_attention_order=use_new_attention_order,
                    )
                )
            self.input_blocks.append(TimestepEmbedSequential(*layers))
            input_block_chans.append(ch)
        if level != len(channel_mult) - 1:
            self.input_blocks.append(
                TimestepEmbedSequential(Downsample(ch, conv_resample, dims=dims))
            )
            input_block_chans.append(ch)
            ds *= 2

    self.middle_block = TimestepEmbedSequential(
        ResBlock(
            ch,
            time_embed_dim,
            dropout,
            dims=dims,
            use_checkpoint=use_checkpoint,
            use_scale_shift_norm=use_scale_shift_norm,
        ),
        AttentionBlock(ch,
                       use_checkpoint=use_checkpoint,
                       num_heads=num_heads,
                       num_head_channels=num_head_channels,
                       use_new_attention_order=use_new_attention_order,
        ),
        ResBlock(
            ch,
            time_embed_dim,
            dropout,
            dims=dims,
            use_checkpoint=use_checkpoint,
            use_scale_shift_norm=use_scale_shift_norm,
        ),
    )

    self.output_blocks = nn.ModuleList([])
    for level, mult in list(enumerate(channel_mult))[::-1]:
        for i in range(num_res_blocks + 1):
            layers = [
                ResBlock(
                    ch + input_block_chans.pop(),
                    time_embed_dim,
                    dropout,
                    out_channels=model_channels * mult,
                    dims=dims,
                    use_checkpoint=use_checkpoint,
                    use_scale_shift_norm=use_scale_shift_norm,
                )
            ]
            ch = model_channels * mult
            if ds in attention_resolutions:
                layers.append(
                    AttentionBlock(
                        ch,
                        use_checkpoint=use_checkpoint,
                        num_heads=num_heads_upsample,
                        num_head_channels=num_head_channels,
                        use_new_attention_order=use_new_attention_order,
                    )
                )
            if level and i == num_res_blocks:
                layers.append(Upsample(ch, conv_resample, dims=dims))
                ds //= 2
            self.output_blocks.append(TimestepEmbedSequential(*layers))

    self.out = nn.Sequential(
        normalization(ch),
        SiLU(),
        zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
    )

def convert_to_fp16(self):
    """
    Convert the torso of the model to float16.
    """
    self.input_blocks.apply(convert_module_to_f16)
    self.middle_block.apply(convert_module_to_f16)
    self.output_blocks.apply(convert_module_to_f16)

def convert_to_fp32(self):
    """
    Convert the torso of the model to float32.
    """
    self.input_blocks.apply(convert_module_to_f32)
    self.middle_block.apply(convert_module_to_f32)
    self.output_blocks.apply(convert_module_to_f32)

@property
def inner_dtype(self):
    """
    Get the dtype used by the torso of the model.
    """
    return next(self.input_blocks.parameters()).dtype

def forward(self, x, timesteps, y=None, gt=None, **kwargs):
    """
    Apply the model to an input batch.

    :param x: an [N x C x ...] Tensor of inputs.
    :param timesteps: a 1-D batch of timesteps.
    :param y: an [N] Tensor of labels, if class-conditional.
    :return: an [N x C x ...] Tensor of outputs.
    """
    assert (y is not None) == (
        self.num_classes is not None
    ), "must specify y if and only if the model is class-conditional"

    hs = []
    emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))

    if self.num_classes is not None:
        assert y.shape == (x.shape[0],)
        emb = emb + self.label_emb(y)

    h = x.type(self.inner_dtype)
    for module in self.input_blocks:
        h = module(h, emb)
        hs.append(h)
    h = self.middle_block(h, emb)
    for module in self.output_blocks:
        cat_in = th.cat([h, hs.pop()], dim=1)
        h = module(cat_in, emb)
    h = h.type(x.dtype)
    return self.out(h)

def get_feature_vectors(self, x, timesteps, y=None):
    """
    Apply the model and return all of the intermediate tensors.

    :param x: an [N x C x ...] Tensor of inputs.
    :param timesteps: a 1-D batch of timesteps.
    :param y: an [N] Tensor of labels, if class-conditional.
    :return: a dict with the following keys:
             - 'down': a list of hidden state tensors from downsampling.
             - 'middle': the tensor of the output of the lowest-resolution
                         block in the model.
             - 'up': a list of hidden state tensors from upsampling.
    """
    hs = []
    emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
    if self.num_classes is not None:
        assert y.shape == (x.shape[0],)
        emb = emb + self.label_emb(y)
    result = dict(down=[], up=[])
    h = x.type(self.inner_dtype)
    for module in self.input_blocks:
        h = module(h, emb)
        hs.append(h)
        result["down"].append(h.type(x.dtype))
    h = self.middle_block(h, emb)
    result["middle"] = h.type(x.dtype)
    for module in self.output_blocks:
        cat_in = th.cat([h, hs.pop()], dim=1)
        h = module(cat_in, emb)
        result["up"].append(h.type(x.dtype))
    return result

class SuperResModel(UNetModel):
"""
A UNetModel that performs super-resolution.

Expects an extra kwarg `low_res` to condition on a low-resolution image.
"""

def __init__(self, image_size, in_channels, *args, **kwargs):
    super().__init__(image_size, in_channels * 2, *args, **kwargs)

def forward(self, x, timesteps, low_res=None, **kwargs):
    _, _, new_height, new_width = x.shape
    upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
    x = th.cat([x, upsampled], dim=1)
    return super().forward(x, timesteps, **kwargs)

def get_feature_vectors(self, x, timesteps, low_res=None, **kwargs):
    _, new_height, new_width, _ = x.shape
    upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
    x = th.cat([x, upsampled], dim=1)
    return super().get_feature_vectors(x, timesteps, **kwargs)

`