Weizhi-Zhong / IP_LAP

CVPR2023 talking face implementation for Identity-Preserving Talking Face Generation With Landmark and Appearance Priors

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

load trained_checkpoint error

yaleimeng opened this issue · comments

训练完landmark 和render模型执行inference_single.py推理时,提示加载checkpoint错误,状态字典缺少一些keys。信息如下:

landmark_generator_model loaded from : checkpoints/landmark_generation/Pro_landmarkT5_d512_fe1024_lay4_head4/landmarkT5_d512_fe1024_lay4_head4_epoch_2020_checkpoint_step000012120.pth
renderer loaded from : checkpoints/renderer/Pro_renderer_T1_ref_N3/renderer_T1_ref_N3_epoch_7000_checkpoint_step000042000.pth
Load checkpoint from: checkpoints/landmark_generation/Pro_landmarkT5_d512_fe1024_lay4_head4/landmarkT5_d512_fe1024_lay4_head4_epoch_2020_checkpoint_step000012120.pth
--local/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
warnings.warn(
--local/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or None for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing weights=VGG19_Weights.IMAGENET1K_V1. You can also use weights=VGG19_Weights.DEFAULT to get the most up-to-date weights.
warnings.warn(msg)

Perceptual loss:
Mode: vgg19
Load checkpoint from: checkpoints/renderer/Pro_renderer_T1_ref_N3/renderer_T1_ref_N3_epoch_7000_checkpoint_step000042000.pth
Traceback (most recent call last):
File "IP_LAP/inference_single.py", line 194, in
renderer = load_model(model=Renderer(), path=renderer_checkpoint_path)
File "IP_LAP/inference_single.py", line 173, in load_model
model.load_state_dict(new_s)
File "local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Renderer:
Missing key(s) in state_dict: "flow_module.conv1.weight", "flow_module.conv1.bias", "flow_module.conv1_bn.weight", "flow_module.conv1_bn.bias", "flow_module.conv1_bn.running_mean", "flow_module.conv1_bn.running_var", "flow_module.conv2.weight", "flow_module.conv2.bias", "flow_module.conv2_bn.weight", "flow_module.conv2_bn.bias", "flow_module.conv2_bn.running_mean", "flow_module.conv2_bn.running_var", "flow_module.spade_layer_1.conv_1.weight", "flow_module.spade_layer_1.conv_1.bias", "flow_module.spade_layer_1.conv_2.weight", "flow_module.spade_layer_1.conv_2.bias", "flow_module.spade_layer_1.spade_layer_1.conv1.weight", "flow_module.spade_layer_1.spade_layer_1.conv1.bias", "flow_module.spade_layer_1.spade_layer_1.gamma.weight", "flow_module.spade_layer_1.spade_layer_1.gamma.bias", "flow_module.spade_layer_1.spade_layer_1.beta.weight", "flow_module.spade_layer_1.spade_layer_1.beta.bias", "flow_module.spade_layer_1.spade_layer_2.conv1.weight", "flow_module.spade_layer_1.spade_layer_2.conv1.bias", "flow_module.spade_layer_1.spade_layer_2.gamma.weight", "flow_module.spade_layer_1.spade_layer_2.gamma.bias", "flow_module.spade_layer_1.spade_layer_2.beta.weight", "flow_module.spade_layer_1.spade_layer_2.beta.bias", "flow_module.spade_layer_2.conv_1.weight", "flow_module.spade_layer_2.conv_1.bias", "flow_module.spade_layer_2.conv_2.weight", "flow_module.spade_layer_2.conv_2.bias", "flow_module.spade_layer_2.spade_layer_1.conv1.weight", "flow_module.spade_layer_2.spade_layer_1.conv1.bias", "flow_module.spade_layer_2.spade_layer_1.gamma.weight", "flow_module.spade_layer_2.spade_layer_1.gamma.bias", "flow_module.spade_layer_2.spade_layer_1.beta.weight", "flow_module.spade_layer_2.spade_layer_1.beta.bias", "flow_module.spade_layer_2.spade_layer_2.conv1.weight", "flow_module.spade_layer_2.spade_layer_2.conv1.bias", "flow_module.spade_layer_2.spade_layer_2.gamma.weight", "flow_module.spade_layer_2.spade_layer_2.gamma.bias", "flow_module.spade_layer_2.spade_layer_2.beta.weight", "flow_module.spade_layer_2.spade_layer_2.beta.bias", "flow_module.spade_layer_4.conv_1.weight", "flow_module.spade_layer_4.conv_1.bias", "flow_module.spade_layer_4.conv_2.weight", "flow_module.spade_layer_4.conv_2.bias", "flow_module.spade_layer_4.spade_layer_1.conv1.weight", "flow_module.spade_layer_4.spade_layer_1.conv1.bias", "flow_module.spade_layer_4.spade_layer_1.gamma.weight", "flow_module.spade_layer_4.spade_layer_1.gamma.bias", "flow_module.spade_layer_4.spade_layer_1.beta.weight", "flow_module.spade_layer_4.spade_layer_1.beta.bias", "flow_module.spade_layer_4.spade_layer_2.conv1.weight", "flow_module.spade_layer_4.spade_layer_2.conv1.bias", "flow_module.spade_layer_4.spade_layer_2.gamma.weight", "flow_module.spade_layer_4.spade_layer_2.gamma.bias", "flow_module.spade_layer_4.spade_layer_2.beta.weight", "flow_module.spade_layer_4.spade_layer_2.beta.bias", "flow_module.conv_4.weight", "flow_module.conv_4.bias", "flow_module.conv_5.0.weight", "flow_module.conv_5.0.bias", "flow_module.conv_5.2.weight", "flow_module.conv_5.2.bias".

@yaleimeng Hi~, thanks for your interest, and sorry for the bug.
The problem may be related to the line "new_s[k.replace('module.', '', 1)] = v" in inference_single.py and "self.flow_module = DenseFlowNetwork()" in video_renderer.py

I guess you train the render with a single one gpu.
Try to replace it with the following code in inference_single.py.
for k, v in s.items(): if k[:6]=='module': new_s[k.replace('module.', '', 1)] = v
And can you tell me whether it works after you try it? Thank you very much.

确实是使用单GPU训练的。
改了之后,render模型加载正常了,但是landmark加载又报错,状态字典很多key找不到。
Load checkpoint from: checkpoints/landmark_generation/Pro_landmarkT5_d512_fe1024_lay4_head4/landmarkT5_d512_fe1024_lay4_head4_epoch_2020_checkpoint_step000012120.pth
Traceback (most recent call last):
File "/IP_LAP/inference_single.py", line 193, in
landmark_generator_model = load_model(
File "/IP_LAP/inference_single.py", line 175, in load_model
model.load_state_dict(new_s)
File "/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Landmark_generator:
Missing key(s) in state_dict: "mel_encoder.0.conv_block.0.weight", "mel_encoder.0.conv_block.0.bias", "mel_encoder.0.conv_block.1.weight", "mel_encoder.0.conv_block.1.bias", "mel_encoder.0.conv_block.1.running_mean", "mel_encoder.0.conv_block.1.running_var", "mel_encoder.1.conv_block.0.weight", "mel_encoder.1.conv_block.0.bias", "mel_encoder.1.conv_block.1.weight", "mel_encoder.1.conv_block.1.bias", "mel_encoder.1.conv_block.1.running_mean", "mel_encoder.1.conv_block.1.running_var", "mel_encoder.2.conv_block.0.weight", "mel_encoder.2.conv_block.0.bias", "mel_encoder.2.conv_block.1.weight", "mel_encoder.2.conv_block.1.bias", "mel_encoder.2.conv_block.1.running_mean", "mel_encoder.2.conv_block.1.running_var", "mel_encoder.3.conv_block.0.weight", "mel_encoder.3.conv_block.0.bias", "mel_encoder.3.conv_block.1.weight", "mel_encoder.3.conv_block.1.bias", "mel_encoder.3.conv_block.1.running_mean", "mel_encoder.3.conv_block.1.running_var", "mel_encoder.4.conv_block.0.weight", "mel_encoder.4.conv_block.0.bias",
*********************还有很多。省略

@yaleimeng Sorry for my negligence.
Try to replace it with the following code in inference_single.py.

def load_model(model, path):
    print("Load checkpoint from: {}".format(path))
    checkpoint = _load(path)
    s = checkpoint["state_dict"]
    new_s = {}
    for k, v in s.items():
        if k[:6] == 'module':
            new_k=k.replace('module.', '', 1)
        else:
            new_k =k
        new_s[new_k] = v
    model.load_state_dict(new_s)
    model = model.to(device)
    return model.eval()

Also, can you tell me whether it works after you try it? Thank you very much.

Thanks, It works 。

Thank you very much~