yangjianxin1 / ClipCap-Chinese

基于ClipCap的看图说话Image Caption模型

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

加载权重,报错

sunyclj opened this issue · comments

sh scripts/predict_finerune_gpt2.sh
2024-02-27 15:20:22.435851: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable TF_ENABLE_ONEDNN_OPTS=0.
2024-02-27 15:20:22.478654: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-02-27 15:20:23.206885: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
2024-02-27 15:20:26.745 | INFO | models.model:init:75 - succeed to load pretrain gpt2 model
Traceback (most recent call last):
File "predict.py", line 186, in
main(args)
File "predict.py", line 129, in main
model.load_state_dict(torch.load(args.model_path, map_location=args.device))
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for ClipCaptionModel:
Unexpected key(s) in state_dict: "gpt2.transformer.h.0.attn.bias", "gpt2.transformer.h.0.attn.masked_bias", "gpt2.transformer.h.1.attn.bias", "gpt2.transformer.h.1.attn.masked_bias", "gpt2.transformer.h.2.attn.bias", "gpt2.transformer.h.2.attn.masked_bias", "gpt2.transformer.h.3.attn.bias", "gpt2.transformer.h.3.attn.masked_bias", "gpt2.transformer.h.4.attn.bias", "gpt2.transformer.h.4.attn.masked_bias", "gpt2.transformer.h.5.attn.bias", "gpt2.transformer.h.5.attn.masked_bias", "gpt2.transformer.h.6.attn.bias", "gpt2.transformer.h.6.attn.masked_bias", "gpt2.transformer.h.7.attn.bias", "gpt2.transformer.h.7.attn.masked_bias", "gpt2.transformer.h.8.attn.bias", "gpt2.transformer.h.8.attn.masked_bias", "gpt2.transformer.h.9.attn.bias", "gpt2.transformer.h.9.attn.masked_bias", "gpt2.transformer.h.10.attn.bias", "gpt2.transformer.h.10.attn.masked_bias", "gpt2.transformer.h.11.attn.bias", "gpt2.transformer.h.11.attn.masked_bias", "clip_project.bert.embeddings.position_ids".

@sunyclj 我是将predict.py脚本中的第129行
model.load_state_dict(torch.load(args.model_path, map_location=args.device))
改为
model.load_state_dict(torch.load(args.model_path, map_location=args.device), False)
就可以了,模型是在gpu上训练的,然后在cpu上进行加载会报这个错误