load mvit pretrained model,some error occurred
aiot-tech opened this issue · comments
Running the code below:
from pytorchvideo.models.hub import mvit_base_16x4
model = mvit_base_16x4(pretrained=True)
the output is:
RuntimeError: Error(s) in loading state_dict for MultiscaleVisionTransformers:
Missing key(s) in state_dict: "blocks.0.attn._attention_pool_k.pool.weight", "blocks.0.attn._attention_pool_k.norm.weight", "blocks.0.attn._attention_pool_k.norm.bias", "blocks.0.attn._attention_pool_v.pool.weight", "blocks.0.attn._attention_pool_v.norm.weight", "blocks.0.attn._attention_pool_v.norm.bias", "blocks.1.attn._attention_pool_q.pool.weight", "blocks.1.attn._attention_pool_q.norm.weight", "blocks.1.attn._attention_pool_q.norm.bias", "blocks.1.attn._attention_pool_k.pool.weight", "blocks.1.attn._attention_pool_k.norm.weight", "blocks.1.attn._attention_pool_k.norm.bias", "blocks.1.attn._attention_pool_v.pool.weight", "blocks.1.attn._attention_pool_v.norm.weight", "blocks.1.attn._attention_pool_v.norm.bias", "blocks.2.attn._attention_pool_k.pool.weight", "blocks.2.attn._attention_pool_k.norm.weight", "blocks.2.attn._attention_pool_k.norm.bias", "blocks.2.attn._attention_pool_v.pool.weight", "blocks.2.attn._attention_pool_v.norm.weight", "blocks.2.attn._attention_pool_v.norm.bias", "blocks.3.attn._attention_pool_q.pool.weight", "blocks.3.attn._attention_pool_q.norm.weight", "blocks.3.attn._attention_pool_q.norm.bias", "blocks.3.attn._attention_pool_k.pool.weight", "blocks.3.attn._attention_pool_k.norm.weight", "blocks.3.attn._attention_pool_k.norm.bias", "blocks.3.attn._attention_pool_v.pool.weight", "blocks.3.attn._attention_pool_v.norm.weight", "blocks.3.attn._attention_pool_v.norm.bias", "blocks.4.attn._attention_pool_k.pool.weight", "blocks.4.attn._attention_pool_k.norm.weight", "blocks.4.attn._attention_pool_k.norm.bias", "blocks.4.attn._attention_pool_v.pool.weight", "blocks.4.attn._attention_pool_v.norm.weight", "blocks.4.attn._attention_pool_v.norm.bias", "blocks.5.attn._attention_pool_k.pool.weight", "blocks.5.attn._attention_pool_k.norm.weight", "blocks.5.attn._attention_pool_k.norm.bias", "blocks.5.attn._attention_pool_v.pool.weight", "blocks.5.attn._attention_pool_v.norm.weight", "blocks.5.attn._attention_pool_v.norm.bias", "blocks.6.attn._attention_pool_k.pool.weight", "blocks.6.attn._attention_pool_k.norm.weight", "blocks.6.attn._attention_pool_k.norm.bias", "blocks.6.attn._attention_pool_v.pool.weight", "blocks.6.attn._attention_pool_v.norm.weight", "blocks.6.attn._attention_pool_v.norm.bias", "blocks.7.attn._attention_pool_k.pool.weight", "blocks.7.attn._attention_pool_k.norm.weight", "blocks.7.attn._attention_pool_k.norm.bias", "blocks.7.attn._attention_pool_v.pool.weight", "blocks.7.attn._attention_pool_v.norm.weight", "blocks.7.attn._attention_pool_v.norm.bias", "blocks.8.attn._attention_pool_k.pool.weight", "blocks.8.attn._attention_pool_k.norm.weight", "blocks.8.attn._attention_pool_k.norm.bias", "blocks.8.attn._attention_pool_v.pool.weight", "blocks.8.attn._attention_pool_v.norm.weight", "blocks.8.attn._attention_pool_v.norm.bias", "blocks.9.attn._attention_pool_k.pool.weight", "blocks.9.attn._attention_pool_k.norm.weight", "blocks.9.attn._attention_pool_k.norm.bias", "blocks.9.attn._attention_pool_v.pool.weight", "blocks.9.attn._attention_pool_v.norm.weight", "blocks.9.attn._attention_pool_v.norm.bias", "blocks.10.attn._attention_pool_k.pool.weight", "blocks.10.attn._attention_pool_k.norm.weight", "blocks.10.attn._attention_pool_k.norm.bias", "blocks.10.attn._attention_pool_v.pool.weight", "blocks.10.attn._attention_pool_v.norm.weight", "blocks.10.attn._attention_pool_v.norm.bias", "blocks.11.attn._attention_pool_k.pool.weight", "blocks.11.attn._attention_pool_k.norm.weight", "blocks.11.attn._attention_pool_k.norm.bias", "blocks.11.attn._attention_pool_v.pool.weight", "blocks.11.attn._attention_pool_v.norm.weight", "blocks.11.attn._attention_pool_v.norm.bias", "blocks.12.attn._attention_pool_k.pool.weight", "blocks.12.attn._attention_pool_k.norm.weight", "blocks.12.attn._attention_pool_k.norm.bias", "blocks.12.attn._attention_pool_v.pool.weight", "blocks.12.attn._attention_pool_v.norm.weight", "blocks.12.attn._attention_pool_v.norm.bias", "blocks.13.attn._attention_pool_k.pool.weight", "blocks.13.attn._attention_pool_k.norm.weight", "blocks.13.attn._attention_pool_k.norm.bias", "blocks.13.attn._attention_pool_v.pool.weight", "blocks.13.attn._attention_pool_v.norm.weight", "blocks.13.attn._attention_pool_v.norm.bias", "blocks.14.attn._attention_pool_q.pool.weight", "blocks.14.attn._attention_pool_q.norm.weight", "blocks.14.attn._attention_pool_q.norm.bias", "blocks.14.attn._attention_pool_k.pool.weight", "blocks.14.attn._attention_pool_k.norm.weight", "blocks.14.attn._attention_pool_k.norm.bias", "blocks.14.attn._attention_pool_v.pool.weight", "blocks.14.attn._attention_pool_v.norm.weight", "blocks.14.attn._attention_pool_v.norm.bias", "blocks.15.attn._attention_pool_k.pool.weight", "blocks.15.attn._attention_pool_k.norm.weight", "blocks.15.attn._attention_pool_k.norm.bias", "blocks.15.attn._attention_pool_v.pool.weight", "blocks.15.attn._attention_pool_v.norm.weight", "blocks.15.attn._attention_pool_v.norm.bias".
You can manually download the weights and load them with strict=False.
I believe the problem is that the optional _attention_pool_*
poolers which are used only under specific configurations remain uninitialised when loading the weights and hence you get the issue.
When running the code below:
from pytorchvideo.models.hub import mvit_base_16x4
path = "/root/.cache/torch/hub/checkpoints/MVIT_B_16x4.pyth"
model = mvit_base_16x4(pretrained=False)
model.load_state_dict(torch.load(path), strict=False)
I found that in addition to _attention_pool_*
, the missing keys include the following keys
['patch_embed.patch_model.weight',
'patch_embed.patch_model.bias',
'cls_positional_encoding.cls_token',
'cls_positional_encoding.pos_embed_spatial',
'cls_positional_encoding.pos_embed_temporal',
'cls_positional_encoding.pos_embed_class',
'blocks.0.norm1.weight',
'blocks.0.norm1.bias',
'blocks.0.attn.q.weight',
'blocks.0.attn.q.bias',
'blocks.0.attn.k.weight',
'blocks.0.attn.k.bias',
'blocks.0.attn.v.weight',
'blocks.0.attn.v.bias',
'blocks.0.attn.proj.weight',
'blocks.0.attn.proj.bias',
'blocks.0.attn.pool_k.weight',
'blocks.0.attn.norm_k.weight',
'blocks.0.attn.norm_k.bias',
'blocks.0.attn.pool_v.weight',
'blocks.0.attn.norm_v.weight',
'blocks.0.attn.norm_v.bias',
'blocks.0.norm2.weight',
'blocks.0.norm2.bias',
'blocks.0.mlp.fc1.weight',
'blocks.0.mlp.fc1.bias',
'blocks.0.mlp.fc2.weight',
'blocks.0.mlp.fc2.bias',
'blocks.0.proj.weight',
'blocks.0.proj.bias',
'blocks.1.norm1.weight',
'blocks.1.norm1.bias',
'blocks.1.attn.q.weight',
'blocks.1.attn.q.bias',
'blocks.1.attn.k.weight',
'blocks.1.attn.k.bias',
'blocks.1.attn.v.weight',
'blocks.1.attn.v.bias',
'blocks.1.attn.proj.weight',
'blocks.1.attn.proj.bias',
'blocks.1.attn.pool_q.weight',
'blocks.1.attn.norm_q.weight',
'blocks.1.attn.norm_q.bias',
'blocks.1.attn.pool_k.weight',
'blocks.1.attn.norm_k.weight',
'blocks.1.attn.norm_k.bias',
'blocks.1.attn.pool_v.weight',
'blocks.1.attn.norm_v.weight',
'blocks.1.attn.norm_v.bias',
'blocks.1.norm2.weight',
'blocks.1.norm2.bias',
'blocks.1.mlp.fc1.weight',
'blocks.1.mlp.fc1.bias',
'blocks.1.mlp.fc2.weight',
'blocks.1.mlp.fc2.bias',
'blocks.2.norm1.weight',
'blocks.2.norm1.bias',
'blocks.2.attn.q.weight',
'blocks.2.attn.q.bias',
'blocks.2.attn.k.weight',
'blocks.2.attn.k.bias',
'blocks.2.attn.v.weight',
'blocks.2.attn.v.bias',
'blocks.2.attn.proj.weight',
'blocks.2.attn.proj.bias',
'blocks.2.attn.pool_k.weight',
'blocks.2.attn.norm_k.weight',
'blocks.2.attn.norm_k.bias',
'blocks.2.attn.pool_v.weight',
'blocks.2.attn.norm_v.weight',
'blocks.2.attn.norm_v.bias',
'blocks.2.norm2.weight',
'blocks.2.norm2.bias',
'blocks.2.mlp.fc1.weight',
'blocks.2.mlp.fc1.bias',
'blocks.2.mlp.fc2.weight',
'blocks.2.mlp.fc2.bias',
'blocks.2.proj.weight',
'blocks.2.proj.bias',
'blocks.3.norm1.weight',
'blocks.3.norm1.bias',
'blocks.3.attn.q.weight',
'blocks.3.attn.q.bias',
'blocks.3.attn.k.weight',
'blocks.3.attn.k.bias',
'blocks.3.attn.v.weight',
'blocks.3.attn.v.bias',
'blocks.3.attn.proj.weight',
'blocks.3.attn.proj.bias',
'blocks.3.attn.pool_q.weight',
'blocks.3.attn.norm_q.weight',
'blocks.3.attn.norm_q.bias',
'blocks.3.attn.pool_k.weight',
'blocks.3.attn.norm_k.weight',
'blocks.3.attn.norm_k.bias',
'blocks.3.attn.pool_v.weight',
'blocks.3.attn.norm_v.weight',
'blocks.3.attn.norm_v.bias',
'blocks.3.norm2.weight',
'blocks.3.norm2.bias',
'blocks.3.mlp.fc1.weight',
'blocks.3.mlp.fc1.bias',
'blocks.3.mlp.fc2.weight',
'blocks.3.mlp.fc2.bias',
'blocks.4.norm1.weight',
'blocks.4.norm1.bias',
'blocks.4.attn.q.weight',
'blocks.4.attn.q.bias',
'blocks.4.attn.k.weight',
'blocks.4.attn.k.bias',
'blocks.4.attn.v.weight',
'blocks.4.attn.v.bias',
'blocks.4.attn.proj.weight',
'blocks.4.attn.proj.bias',
'blocks.4.attn.pool_k.weight',
'blocks.4.attn.norm_k.weight',
'blocks.4.attn.norm_k.bias',
'blocks.4.attn.pool_v.weight',
'blocks.4.attn.norm_v.weight',
'blocks.4.attn.norm_v.bias',
'blocks.4.norm2.weight',
'blocks.4.norm2.bias',
'blocks.4.mlp.fc1.weight',
'blocks.4.mlp.fc1.bias',
'blocks.4.mlp.fc2.weight',
'blocks.4.mlp.fc2.bias',
'blocks.5.norm1.weight',
'blocks.5.norm1.bias',
'blocks.5.attn.q.weight',
'blocks.5.attn.q.bias',
'blocks.5.attn.k.weight',
'blocks.5.attn.k.bias',
'blocks.5.attn.v.weight',
'blocks.5.attn.v.bias',
'blocks.5.attn.proj.weight',
'blocks.5.attn.proj.bias',
'blocks.5.attn.pool_k.weight',
'blocks.5.attn.norm_k.weight',
'blocks.5.attn.norm_k.bias',
'blocks.5.attn.pool_v.weight',
'blocks.5.attn.norm_v.weight',
'blocks.5.attn.norm_v.bias',
'blocks.5.norm2.weight',
'blocks.5.norm2.bias',
'blocks.5.mlp.fc1.weight',
'blocks.5.mlp.fc1.bias',
'blocks.5.mlp.fc2.weight',
'blocks.5.mlp.fc2.bias',
'blocks.6.norm1.weight',
'blocks.6.norm1.bias',
'blocks.6.attn.q.weight',
'blocks.6.attn.q.bias',
'blocks.6.attn.k.weight',
'blocks.6.attn.k.bias',
'blocks.6.attn.v.weight',
'blocks.6.attn.v.bias',
'blocks.6.attn.proj.weight',
'blocks.6.attn.proj.bias',
'blocks.6.attn.pool_k.weight',
'blocks.6.attn.norm_k.weight',
'blocks.6.attn.norm_k.bias',
'blocks.6.attn.pool_v.weight',
'blocks.6.attn.norm_v.weight',
'blocks.6.attn.norm_v.bias',
'blocks.6.norm2.weight',
'blocks.6.norm2.bias',
'blocks.6.mlp.fc1.weight',
'blocks.6.mlp.fc1.bias',
'blocks.6.mlp.fc2.weight',
'blocks.6.mlp.fc2.bias',
'blocks.7.norm1.weight',
'blocks.7.norm1.bias',
'blocks.7.attn.q.weight',
'blocks.7.attn.q.bias',
'blocks.7.attn.k.weight',
'blocks.7.attn.k.bias',
'blocks.7.attn.v.weight',
'blocks.7.attn.v.bias',
'blocks.7.attn.proj.weight',
'blocks.7.attn.proj.bias',
'blocks.7.attn.pool_k.weight',
'blocks.7.attn.norm_k.weight',
'blocks.7.attn.norm_k.bias',
'blocks.7.attn.pool_v.weight',
'blocks.7.attn.norm_v.weight',
'blocks.7.attn.norm_v.bias',
'blocks.7.norm2.weight',
'blocks.7.norm2.bias',
'blocks.7.mlp.fc1.weight',
'blocks.7.mlp.fc1.bias',
'blocks.7.mlp.fc2.weight',
'blocks.7.mlp.fc2.bias',
'blocks.8.norm1.weight',
'blocks.8.norm1.bias',
'blocks.8.attn.q.weight',
'blocks.8.attn.q.bias',
'blocks.8.attn.k.weight',
'blocks.8.attn.k.bias',
'blocks.8.attn.v.weight',
'blocks.8.attn.v.bias',
'blocks.8.attn.proj.weight',
'blocks.8.attn.proj.bias',
'blocks.8.attn.pool_k.weight',
'blocks.8.attn.norm_k.weight',
'blocks.8.attn.norm_k.bias',
'blocks.8.attn.pool_v.weight',
'blocks.8.attn.norm_v.weight',
'blocks.8.attn.norm_v.bias',
'blocks.8.norm2.weight',
'blocks.8.norm2.bias',
'blocks.8.mlp.fc1.weight',
'blocks.8.mlp.fc1.bias',
'blocks.8.mlp.fc2.weight',
'blocks.8.mlp.fc2.bias',
'blocks.9.norm1.weight',
'blocks.9.norm1.bias',
'blocks.9.attn.q.weight',
'blocks.9.attn.q.bias',
'blocks.9.attn.k.weight',
'blocks.9.attn.k.bias',
'blocks.9.attn.v.weight',
'blocks.9.attn.v.bias',
'blocks.9.attn.proj.weight',
'blocks.9.attn.proj.bias',
'blocks.9.attn.pool_k.weight',
'blocks.9.attn.norm_k.weight',
'blocks.9.attn.norm_k.bias',
'blocks.9.attn.pool_v.weight',
'blocks.9.attn.norm_v.weight',
'blocks.9.attn.norm_v.bias',
'blocks.9.norm2.weight',
'blocks.9.norm2.bias',
'blocks.9.mlp.fc1.weight',
'blocks.9.mlp.fc1.bias',
'blocks.9.mlp.fc2.weight',
'blocks.9.mlp.fc2.bias',
'blocks.10.norm1.weight',
'blocks.10.norm1.bias',
'blocks.10.attn.q.weight',
'blocks.10.attn.q.bias',
'blocks.10.attn.k.weight',
'blocks.10.attn.k.bias',
'blocks.10.attn.v.weight',
'blocks.10.attn.v.bias',
'blocks.10.attn.proj.weight',
'blocks.10.attn.proj.bias',
'blocks.10.attn.pool_k.weight',
'blocks.10.attn.norm_k.weight',
'blocks.10.attn.norm_k.bias',
'blocks.10.attn.pool_v.weight',
'blocks.10.attn.norm_v.weight',
'blocks.10.attn.norm_v.bias',
'blocks.10.norm2.weight',
'blocks.10.norm2.bias',
'blocks.10.mlp.fc1.weight',
'blocks.10.mlp.fc1.bias',
'blocks.10.mlp.fc2.weight',
'blocks.10.mlp.fc2.bias',
'blocks.11.norm1.weight',
'blocks.11.norm1.bias',
'blocks.11.attn.q.weight',
'blocks.11.attn.q.bias',
'blocks.11.attn.k.weight',
'blocks.11.attn.k.bias',
'blocks.11.attn.v.weight',
'blocks.11.attn.v.bias',
'blocks.11.attn.proj.weight',
'blocks.11.attn.proj.bias',
'blocks.11.attn.pool_k.weight',
'blocks.11.attn.norm_k.weight',
'blocks.11.attn.norm_k.bias',
'blocks.11.attn.pool_v.weight',
'blocks.11.attn.norm_v.weight',
'blocks.11.attn.norm_v.bias',
'blocks.11.norm2.weight',
'blocks.11.norm2.bias',
'blocks.11.mlp.fc1.weight',
'blocks.11.mlp.fc1.bias',
'blocks.11.mlp.fc2.weight',
'blocks.11.mlp.fc2.bias',
'blocks.12.norm1.weight',
'blocks.12.norm1.bias',
'blocks.12.attn.q.weight',
'blocks.12.attn.q.bias',
'blocks.12.attn.k.weight',
'blocks.12.attn.k.bias',
'blocks.12.attn.v.weight',
'blocks.12.attn.v.bias',
'blocks.12.attn.proj.weight',
'blocks.12.attn.proj.bias',
'blocks.12.attn.pool_k.weight',
'blocks.12.attn.norm_k.weight',
'blocks.12.attn.norm_k.bias',
'blocks.12.attn.pool_v.weight',
'blocks.12.attn.norm_v.weight',
'blocks.12.attn.norm_v.bias',
'blocks.12.norm2.weight',
'blocks.12.norm2.bias',
'blocks.12.mlp.fc1.weight',
'blocks.12.mlp.fc1.bias',
'blocks.12.mlp.fc2.weight',
'blocks.12.mlp.fc2.bias',
'blocks.13.norm1.weight',
'blocks.13.norm1.bias',
'blocks.13.attn.q.weight',
'blocks.13.attn.q.bias',
'blocks.13.attn.k.weight',
'blocks.13.attn.k.bias',
'blocks.13.attn.v.weight',
'blocks.13.attn.v.bias',
'blocks.13.attn.proj.weight',
'blocks.13.attn.proj.bias',
'blocks.13.attn.pool_k.weight',
'blocks.13.attn.norm_k.weight',
'blocks.13.attn.norm_k.bias',
'blocks.13.attn.pool_v.weight',
'blocks.13.attn.norm_v.weight',
'blocks.13.attn.norm_v.bias',
'blocks.13.norm2.weight',
'blocks.13.norm2.bias',
'blocks.13.mlp.fc1.weight',
'blocks.13.mlp.fc1.bias',
'blocks.13.mlp.fc2.weight',
'blocks.13.mlp.fc2.bias',
'blocks.13.proj.weight',
'blocks.13.proj.bias',
'blocks.14.norm1.weight',
'blocks.14.norm1.bias',
'blocks.14.attn.q.weight',
'blocks.14.attn.q.bias',
'blocks.14.attn.k.weight',
'blocks.14.attn.k.bias',
'blocks.14.attn.v.weight',
'blocks.14.attn.v.bias',
'blocks.14.attn.proj.weight',
'blocks.14.attn.proj.bias',
'blocks.14.attn.pool_q.weight',
'blocks.14.attn.norm_q.weight',
'blocks.14.attn.norm_q.bias',
'blocks.14.attn.pool_k.weight',
'blocks.14.attn.norm_k.weight',
'blocks.14.attn.norm_k.bias',
'blocks.14.attn.pool_v.weight',
'blocks.14.attn.norm_v.weight',
'blocks.14.attn.norm_v.bias',
'blocks.14.norm2.weight',
'blocks.14.norm2.bias',
'blocks.14.mlp.fc1.weight',
'blocks.14.mlp.fc1.bias',
'blocks.14.mlp.fc2.weight',
'blocks.14.mlp.fc2.bias',
'blocks.15.norm1.weight',
'blocks.15.norm1.bias',
'blocks.15.attn.q.weight',
'blocks.15.attn.q.bias',
'blocks.15.attn.k.weight',
'blocks.15.attn.k.bias',
'blocks.15.attn.v.weight',
'blocks.15.attn.v.bias',
'blocks.15.attn.proj.weight',
'blocks.15.attn.proj.bias',
'blocks.15.attn.pool_k.weight',
'blocks.15.attn.norm_k.weight',
'blocks.15.attn.norm_k.bias',
'blocks.15.attn.pool_v.weight',
'blocks.15.attn.norm_v.weight',
'blocks.15.attn.norm_v.bias',
'blocks.15.norm2.weight',
'blocks.15.norm2.bias',
'blocks.15.mlp.fc1.weight',
'blocks.15.mlp.fc1.bias',
'blocks.15.mlp.fc2.weight',
'blocks.15.mlp.fc2.bias',
'norm_embed.weight',
'norm_embed.bias',
'head.proj.weight',
'head.proj.bias']
@aiot-tech I'm not the maintainer of PyTorch Video, but I already spoke with @lyttonhao about the issue and it's on his radar to fix. It just happened that at the same time, I'm working on adding the specific model on TorchVision and the above workaround works for me. I managed to "translate" the weights in question at pytorch/vision#6179 without a problem. This might indicate you are using a different version of PyTorch Video (I'm currently using the main branch).
Thanks for reply.The version of pytorchvideo I use now is 0.1.5.
@aiot-tech I just noticed that you are not loading the weights properly. The issue is the given checkpoint contains not only the model weights but also other training specific info. Try doing:
model.load_state_dict(torch.load(path)["model_state"], strict=False)
Brilliant!