facebookresearch / pytorchvideo

A deep learning library for video understanding research.

Home Page:https://pytorchvideo.org/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

load mvit pretrained model,some error occurred

aiot-tech opened this issue · comments

commented

Running the code below:

from pytorchvideo.models.hub import mvit_base_16x4
model = mvit_base_16x4(pretrained=True)

the output is:

RuntimeError: Error(s) in loading state_dict for MultiscaleVisionTransformers:
	Missing key(s) in state_dict: "blocks.0.attn._attention_pool_k.pool.weight", "blocks.0.attn._attention_pool_k.norm.weight", "blocks.0.attn._attention_pool_k.norm.bias", "blocks.0.attn._attention_pool_v.pool.weight", "blocks.0.attn._attention_pool_v.norm.weight", "blocks.0.attn._attention_pool_v.norm.bias", "blocks.1.attn._attention_pool_q.pool.weight", "blocks.1.attn._attention_pool_q.norm.weight", "blocks.1.attn._attention_pool_q.norm.bias", "blocks.1.attn._attention_pool_k.pool.weight", "blocks.1.attn._attention_pool_k.norm.weight", "blocks.1.attn._attention_pool_k.norm.bias", "blocks.1.attn._attention_pool_v.pool.weight", "blocks.1.attn._attention_pool_v.norm.weight", "blocks.1.attn._attention_pool_v.norm.bias", "blocks.2.attn._attention_pool_k.pool.weight", "blocks.2.attn._attention_pool_k.norm.weight", "blocks.2.attn._attention_pool_k.norm.bias", "blocks.2.attn._attention_pool_v.pool.weight", "blocks.2.attn._attention_pool_v.norm.weight", "blocks.2.attn._attention_pool_v.norm.bias", "blocks.3.attn._attention_pool_q.pool.weight", "blocks.3.attn._attention_pool_q.norm.weight", "blocks.3.attn._attention_pool_q.norm.bias", "blocks.3.attn._attention_pool_k.pool.weight", "blocks.3.attn._attention_pool_k.norm.weight", "blocks.3.attn._attention_pool_k.norm.bias", "blocks.3.attn._attention_pool_v.pool.weight", "blocks.3.attn._attention_pool_v.norm.weight", "blocks.3.attn._attention_pool_v.norm.bias", "blocks.4.attn._attention_pool_k.pool.weight", "blocks.4.attn._attention_pool_k.norm.weight", "blocks.4.attn._attention_pool_k.norm.bias", "blocks.4.attn._attention_pool_v.pool.weight", "blocks.4.attn._attention_pool_v.norm.weight", "blocks.4.attn._attention_pool_v.norm.bias", "blocks.5.attn._attention_pool_k.pool.weight", "blocks.5.attn._attention_pool_k.norm.weight", "blocks.5.attn._attention_pool_k.norm.bias", "blocks.5.attn._attention_pool_v.pool.weight", "blocks.5.attn._attention_pool_v.norm.weight", "blocks.5.attn._attention_pool_v.norm.bias", "blocks.6.attn._attention_pool_k.pool.weight", "blocks.6.attn._attention_pool_k.norm.weight", "blocks.6.attn._attention_pool_k.norm.bias", "blocks.6.attn._attention_pool_v.pool.weight", "blocks.6.attn._attention_pool_v.norm.weight", "blocks.6.attn._attention_pool_v.norm.bias", "blocks.7.attn._attention_pool_k.pool.weight", "blocks.7.attn._attention_pool_k.norm.weight", "blocks.7.attn._attention_pool_k.norm.bias", "blocks.7.attn._attention_pool_v.pool.weight", "blocks.7.attn._attention_pool_v.norm.weight", "blocks.7.attn._attention_pool_v.norm.bias", "blocks.8.attn._attention_pool_k.pool.weight", "blocks.8.attn._attention_pool_k.norm.weight", "blocks.8.attn._attention_pool_k.norm.bias", "blocks.8.attn._attention_pool_v.pool.weight", "blocks.8.attn._attention_pool_v.norm.weight", "blocks.8.attn._attention_pool_v.norm.bias", "blocks.9.attn._attention_pool_k.pool.weight", "blocks.9.attn._attention_pool_k.norm.weight", "blocks.9.attn._attention_pool_k.norm.bias", "blocks.9.attn._attention_pool_v.pool.weight", "blocks.9.attn._attention_pool_v.norm.weight", "blocks.9.attn._attention_pool_v.norm.bias", "blocks.10.attn._attention_pool_k.pool.weight", "blocks.10.attn._attention_pool_k.norm.weight", "blocks.10.attn._attention_pool_k.norm.bias", "blocks.10.attn._attention_pool_v.pool.weight", "blocks.10.attn._attention_pool_v.norm.weight", "blocks.10.attn._attention_pool_v.norm.bias", "blocks.11.attn._attention_pool_k.pool.weight", "blocks.11.attn._attention_pool_k.norm.weight", "blocks.11.attn._attention_pool_k.norm.bias", "blocks.11.attn._attention_pool_v.pool.weight", "blocks.11.attn._attention_pool_v.norm.weight", "blocks.11.attn._attention_pool_v.norm.bias", "blocks.12.attn._attention_pool_k.pool.weight", "blocks.12.attn._attention_pool_k.norm.weight", "blocks.12.attn._attention_pool_k.norm.bias", "blocks.12.attn._attention_pool_v.pool.weight", "blocks.12.attn._attention_pool_v.norm.weight", "blocks.12.attn._attention_pool_v.norm.bias", "blocks.13.attn._attention_pool_k.pool.weight", "blocks.13.attn._attention_pool_k.norm.weight", "blocks.13.attn._attention_pool_k.norm.bias", "blocks.13.attn._attention_pool_v.pool.weight", "blocks.13.attn._attention_pool_v.norm.weight", "blocks.13.attn._attention_pool_v.norm.bias", "blocks.14.attn._attention_pool_q.pool.weight", "blocks.14.attn._attention_pool_q.norm.weight", "blocks.14.attn._attention_pool_q.norm.bias", "blocks.14.attn._attention_pool_k.pool.weight", "blocks.14.attn._attention_pool_k.norm.weight", "blocks.14.attn._attention_pool_k.norm.bias", "blocks.14.attn._attention_pool_v.pool.weight", "blocks.14.attn._attention_pool_v.norm.weight", "blocks.14.attn._attention_pool_v.norm.bias", "blocks.15.attn._attention_pool_k.pool.weight", "blocks.15.attn._attention_pool_k.norm.weight", "blocks.15.attn._attention_pool_k.norm.bias", "blocks.15.attn._attention_pool_v.pool.weight", "blocks.15.attn._attention_pool_v.norm.weight", "blocks.15.attn._attention_pool_v.norm.bias".

You can manually download the weights and load them with strict=False.

I believe the problem is that the optional _attention_pool_* poolers which are used only under specific configurations remain uninitialised when loading the weights and hence you get the issue.

commented

When running the code below:

from pytorchvideo.models.hub import mvit_base_16x4

path = "/root/.cache/torch/hub/checkpoints/MVIT_B_16x4.pyth"
model = mvit_base_16x4(pretrained=False)
model.load_state_dict(torch.load(path), strict=False)

I found that in addition to _attention_pool_*, the missing keys include the following keys

['patch_embed.patch_model.weight',
 'patch_embed.patch_model.bias',
 'cls_positional_encoding.cls_token',
 'cls_positional_encoding.pos_embed_spatial',
 'cls_positional_encoding.pos_embed_temporal',
 'cls_positional_encoding.pos_embed_class',
 'blocks.0.norm1.weight',
 'blocks.0.norm1.bias',
 'blocks.0.attn.q.weight',
 'blocks.0.attn.q.bias',
 'blocks.0.attn.k.weight',
 'blocks.0.attn.k.bias',
 'blocks.0.attn.v.weight',
 'blocks.0.attn.v.bias',
 'blocks.0.attn.proj.weight',
 'blocks.0.attn.proj.bias',
 'blocks.0.attn.pool_k.weight',
 'blocks.0.attn.norm_k.weight',
 'blocks.0.attn.norm_k.bias',
 'blocks.0.attn.pool_v.weight',
 'blocks.0.attn.norm_v.weight',
 'blocks.0.attn.norm_v.bias',
 'blocks.0.norm2.weight',
 'blocks.0.norm2.bias',
 'blocks.0.mlp.fc1.weight',
 'blocks.0.mlp.fc1.bias',
 'blocks.0.mlp.fc2.weight',
 'blocks.0.mlp.fc2.bias',
 'blocks.0.proj.weight',
 'blocks.0.proj.bias',
 'blocks.1.norm1.weight',
 'blocks.1.norm1.bias',
 'blocks.1.attn.q.weight',
 'blocks.1.attn.q.bias',
 'blocks.1.attn.k.weight',
 'blocks.1.attn.k.bias',
 'blocks.1.attn.v.weight',
 'blocks.1.attn.v.bias',
 'blocks.1.attn.proj.weight',
 'blocks.1.attn.proj.bias',
 'blocks.1.attn.pool_q.weight',
 'blocks.1.attn.norm_q.weight',
 'blocks.1.attn.norm_q.bias',
 'blocks.1.attn.pool_k.weight',
 'blocks.1.attn.norm_k.weight',
 'blocks.1.attn.norm_k.bias',
 'blocks.1.attn.pool_v.weight',
 'blocks.1.attn.norm_v.weight',
 'blocks.1.attn.norm_v.bias',
 'blocks.1.norm2.weight',
 'blocks.1.norm2.bias',
 'blocks.1.mlp.fc1.weight',
 'blocks.1.mlp.fc1.bias',
 'blocks.1.mlp.fc2.weight',
 'blocks.1.mlp.fc2.bias',
 'blocks.2.norm1.weight',
 'blocks.2.norm1.bias',
 'blocks.2.attn.q.weight',
 'blocks.2.attn.q.bias',
 'blocks.2.attn.k.weight',
 'blocks.2.attn.k.bias',
 'blocks.2.attn.v.weight',
 'blocks.2.attn.v.bias',
 'blocks.2.attn.proj.weight',
 'blocks.2.attn.proj.bias',
 'blocks.2.attn.pool_k.weight',
 'blocks.2.attn.norm_k.weight',
 'blocks.2.attn.norm_k.bias',
 'blocks.2.attn.pool_v.weight',
 'blocks.2.attn.norm_v.weight',
 'blocks.2.attn.norm_v.bias',
 'blocks.2.norm2.weight',
 'blocks.2.norm2.bias',
 'blocks.2.mlp.fc1.weight',
 'blocks.2.mlp.fc1.bias',
 'blocks.2.mlp.fc2.weight',
 'blocks.2.mlp.fc2.bias',
 'blocks.2.proj.weight',
 'blocks.2.proj.bias',
 'blocks.3.norm1.weight',
 'blocks.3.norm1.bias',
 'blocks.3.attn.q.weight',
 'blocks.3.attn.q.bias',
 'blocks.3.attn.k.weight',
 'blocks.3.attn.k.bias',
 'blocks.3.attn.v.weight',
 'blocks.3.attn.v.bias',
 'blocks.3.attn.proj.weight',
 'blocks.3.attn.proj.bias',
 'blocks.3.attn.pool_q.weight',
 'blocks.3.attn.norm_q.weight',
 'blocks.3.attn.norm_q.bias',
 'blocks.3.attn.pool_k.weight',
 'blocks.3.attn.norm_k.weight',
 'blocks.3.attn.norm_k.bias',
 'blocks.3.attn.pool_v.weight',
 'blocks.3.attn.norm_v.weight',
 'blocks.3.attn.norm_v.bias',
 'blocks.3.norm2.weight',
 'blocks.3.norm2.bias',
 'blocks.3.mlp.fc1.weight',
 'blocks.3.mlp.fc1.bias',
 'blocks.3.mlp.fc2.weight',
 'blocks.3.mlp.fc2.bias',
 'blocks.4.norm1.weight',
 'blocks.4.norm1.bias',
 'blocks.4.attn.q.weight',
 'blocks.4.attn.q.bias',
 'blocks.4.attn.k.weight',
 'blocks.4.attn.k.bias',
 'blocks.4.attn.v.weight',
 'blocks.4.attn.v.bias',
 'blocks.4.attn.proj.weight',
 'blocks.4.attn.proj.bias',
 'blocks.4.attn.pool_k.weight',
 'blocks.4.attn.norm_k.weight',
 'blocks.4.attn.norm_k.bias',
 'blocks.4.attn.pool_v.weight',
 'blocks.4.attn.norm_v.weight',
 'blocks.4.attn.norm_v.bias',
 'blocks.4.norm2.weight',
 'blocks.4.norm2.bias',
 'blocks.4.mlp.fc1.weight',
 'blocks.4.mlp.fc1.bias',
 'blocks.4.mlp.fc2.weight',
 'blocks.4.mlp.fc2.bias',
 'blocks.5.norm1.weight',
 'blocks.5.norm1.bias',
 'blocks.5.attn.q.weight',
 'blocks.5.attn.q.bias',
 'blocks.5.attn.k.weight',
 'blocks.5.attn.k.bias',
 'blocks.5.attn.v.weight',
 'blocks.5.attn.v.bias',
 'blocks.5.attn.proj.weight',
 'blocks.5.attn.proj.bias',
 'blocks.5.attn.pool_k.weight',
 'blocks.5.attn.norm_k.weight',
 'blocks.5.attn.norm_k.bias',
 'blocks.5.attn.pool_v.weight',
 'blocks.5.attn.norm_v.weight',
 'blocks.5.attn.norm_v.bias',
 'blocks.5.norm2.weight',
 'blocks.5.norm2.bias',
 'blocks.5.mlp.fc1.weight',
 'blocks.5.mlp.fc1.bias',
 'blocks.5.mlp.fc2.weight',
 'blocks.5.mlp.fc2.bias',
 'blocks.6.norm1.weight',
 'blocks.6.norm1.bias',
 'blocks.6.attn.q.weight',
 'blocks.6.attn.q.bias',
 'blocks.6.attn.k.weight',
 'blocks.6.attn.k.bias',
 'blocks.6.attn.v.weight',
 'blocks.6.attn.v.bias',
 'blocks.6.attn.proj.weight',
 'blocks.6.attn.proj.bias',
 'blocks.6.attn.pool_k.weight',
 'blocks.6.attn.norm_k.weight',
 'blocks.6.attn.norm_k.bias',
 'blocks.6.attn.pool_v.weight',
 'blocks.6.attn.norm_v.weight',
 'blocks.6.attn.norm_v.bias',
 'blocks.6.norm2.weight',
 'blocks.6.norm2.bias',
 'blocks.6.mlp.fc1.weight',
 'blocks.6.mlp.fc1.bias',
 'blocks.6.mlp.fc2.weight',
 'blocks.6.mlp.fc2.bias',
 'blocks.7.norm1.weight',
 'blocks.7.norm1.bias',
 'blocks.7.attn.q.weight',
 'blocks.7.attn.q.bias',
 'blocks.7.attn.k.weight',
 'blocks.7.attn.k.bias',
 'blocks.7.attn.v.weight',
 'blocks.7.attn.v.bias',
 'blocks.7.attn.proj.weight',
 'blocks.7.attn.proj.bias',
 'blocks.7.attn.pool_k.weight',
 'blocks.7.attn.norm_k.weight',
 'blocks.7.attn.norm_k.bias',
 'blocks.7.attn.pool_v.weight',
 'blocks.7.attn.norm_v.weight',
 'blocks.7.attn.norm_v.bias',
 'blocks.7.norm2.weight',
 'blocks.7.norm2.bias',
 'blocks.7.mlp.fc1.weight',
 'blocks.7.mlp.fc1.bias',
 'blocks.7.mlp.fc2.weight',
 'blocks.7.mlp.fc2.bias',
 'blocks.8.norm1.weight',
 'blocks.8.norm1.bias',
 'blocks.8.attn.q.weight',
 'blocks.8.attn.q.bias',
 'blocks.8.attn.k.weight',
 'blocks.8.attn.k.bias',
 'blocks.8.attn.v.weight',
 'blocks.8.attn.v.bias',
 'blocks.8.attn.proj.weight',
 'blocks.8.attn.proj.bias',
 'blocks.8.attn.pool_k.weight',
 'blocks.8.attn.norm_k.weight',
 'blocks.8.attn.norm_k.bias',
 'blocks.8.attn.pool_v.weight',
 'blocks.8.attn.norm_v.weight',
 'blocks.8.attn.norm_v.bias',
 'blocks.8.norm2.weight',
 'blocks.8.norm2.bias',
 'blocks.8.mlp.fc1.weight',
 'blocks.8.mlp.fc1.bias',
 'blocks.8.mlp.fc2.weight',
 'blocks.8.mlp.fc2.bias',
 'blocks.9.norm1.weight',
 'blocks.9.norm1.bias',
 'blocks.9.attn.q.weight',
 'blocks.9.attn.q.bias',
 'blocks.9.attn.k.weight',
 'blocks.9.attn.k.bias',
 'blocks.9.attn.v.weight',
 'blocks.9.attn.v.bias',
 'blocks.9.attn.proj.weight',
 'blocks.9.attn.proj.bias',
 'blocks.9.attn.pool_k.weight',
 'blocks.9.attn.norm_k.weight',
 'blocks.9.attn.norm_k.bias',
 'blocks.9.attn.pool_v.weight',
 'blocks.9.attn.norm_v.weight',
 'blocks.9.attn.norm_v.bias',
 'blocks.9.norm2.weight',
 'blocks.9.norm2.bias',
 'blocks.9.mlp.fc1.weight',
 'blocks.9.mlp.fc1.bias',
 'blocks.9.mlp.fc2.weight',
 'blocks.9.mlp.fc2.bias',
 'blocks.10.norm1.weight',
 'blocks.10.norm1.bias',
 'blocks.10.attn.q.weight',
 'blocks.10.attn.q.bias',
 'blocks.10.attn.k.weight',
 'blocks.10.attn.k.bias',
 'blocks.10.attn.v.weight',
 'blocks.10.attn.v.bias',
 'blocks.10.attn.proj.weight',
 'blocks.10.attn.proj.bias',
 'blocks.10.attn.pool_k.weight',
 'blocks.10.attn.norm_k.weight',
 'blocks.10.attn.norm_k.bias',
 'blocks.10.attn.pool_v.weight',
 'blocks.10.attn.norm_v.weight',
 'blocks.10.attn.norm_v.bias',
 'blocks.10.norm2.weight',
 'blocks.10.norm2.bias',
 'blocks.10.mlp.fc1.weight',
 'blocks.10.mlp.fc1.bias',
 'blocks.10.mlp.fc2.weight',
 'blocks.10.mlp.fc2.bias',
 'blocks.11.norm1.weight',
 'blocks.11.norm1.bias',
 'blocks.11.attn.q.weight',
 'blocks.11.attn.q.bias',
 'blocks.11.attn.k.weight',
 'blocks.11.attn.k.bias',
 'blocks.11.attn.v.weight',
 'blocks.11.attn.v.bias',
 'blocks.11.attn.proj.weight',
 'blocks.11.attn.proj.bias',
 'blocks.11.attn.pool_k.weight',
 'blocks.11.attn.norm_k.weight',
 'blocks.11.attn.norm_k.bias',
 'blocks.11.attn.pool_v.weight',
 'blocks.11.attn.norm_v.weight',
 'blocks.11.attn.norm_v.bias',
 'blocks.11.norm2.weight',
 'blocks.11.norm2.bias',
 'blocks.11.mlp.fc1.weight',
 'blocks.11.mlp.fc1.bias',
 'blocks.11.mlp.fc2.weight',
 'blocks.11.mlp.fc2.bias',
 'blocks.12.norm1.weight',
 'blocks.12.norm1.bias',
 'blocks.12.attn.q.weight',
 'blocks.12.attn.q.bias',
 'blocks.12.attn.k.weight',
 'blocks.12.attn.k.bias',
 'blocks.12.attn.v.weight',
 'blocks.12.attn.v.bias',
 'blocks.12.attn.proj.weight',
 'blocks.12.attn.proj.bias',
 'blocks.12.attn.pool_k.weight',
 'blocks.12.attn.norm_k.weight',
 'blocks.12.attn.norm_k.bias',
 'blocks.12.attn.pool_v.weight',
 'blocks.12.attn.norm_v.weight',
 'blocks.12.attn.norm_v.bias',
 'blocks.12.norm2.weight',
 'blocks.12.norm2.bias',
 'blocks.12.mlp.fc1.weight',
 'blocks.12.mlp.fc1.bias',
 'blocks.12.mlp.fc2.weight',
 'blocks.12.mlp.fc2.bias',
 'blocks.13.norm1.weight',
 'blocks.13.norm1.bias',
 'blocks.13.attn.q.weight',
 'blocks.13.attn.q.bias',
 'blocks.13.attn.k.weight',
 'blocks.13.attn.k.bias',
 'blocks.13.attn.v.weight',
 'blocks.13.attn.v.bias',
 'blocks.13.attn.proj.weight',
 'blocks.13.attn.proj.bias',
 'blocks.13.attn.pool_k.weight',
 'blocks.13.attn.norm_k.weight',
 'blocks.13.attn.norm_k.bias',
 'blocks.13.attn.pool_v.weight',
 'blocks.13.attn.norm_v.weight',
 'blocks.13.attn.norm_v.bias',
 'blocks.13.norm2.weight',
 'blocks.13.norm2.bias',
 'blocks.13.mlp.fc1.weight',
 'blocks.13.mlp.fc1.bias',
 'blocks.13.mlp.fc2.weight',
 'blocks.13.mlp.fc2.bias',
 'blocks.13.proj.weight',
 'blocks.13.proj.bias',
 'blocks.14.norm1.weight',
 'blocks.14.norm1.bias',
 'blocks.14.attn.q.weight',
 'blocks.14.attn.q.bias',
 'blocks.14.attn.k.weight',
 'blocks.14.attn.k.bias',
 'blocks.14.attn.v.weight',
 'blocks.14.attn.v.bias',
 'blocks.14.attn.proj.weight',
 'blocks.14.attn.proj.bias',
 'blocks.14.attn.pool_q.weight',
 'blocks.14.attn.norm_q.weight',
 'blocks.14.attn.norm_q.bias',
 'blocks.14.attn.pool_k.weight',
 'blocks.14.attn.norm_k.weight',
 'blocks.14.attn.norm_k.bias',
 'blocks.14.attn.pool_v.weight',
 'blocks.14.attn.norm_v.weight',
 'blocks.14.attn.norm_v.bias',
 'blocks.14.norm2.weight',
 'blocks.14.norm2.bias',
 'blocks.14.mlp.fc1.weight',
 'blocks.14.mlp.fc1.bias',
 'blocks.14.mlp.fc2.weight',
 'blocks.14.mlp.fc2.bias',
 'blocks.15.norm1.weight',
 'blocks.15.norm1.bias',
 'blocks.15.attn.q.weight',
 'blocks.15.attn.q.bias',
 'blocks.15.attn.k.weight',
 'blocks.15.attn.k.bias',
 'blocks.15.attn.v.weight',
 'blocks.15.attn.v.bias',
 'blocks.15.attn.proj.weight',
 'blocks.15.attn.proj.bias',
 'blocks.15.attn.pool_k.weight',
 'blocks.15.attn.norm_k.weight',
 'blocks.15.attn.norm_k.bias',
 'blocks.15.attn.pool_v.weight',
 'blocks.15.attn.norm_v.weight',
 'blocks.15.attn.norm_v.bias',
 'blocks.15.norm2.weight',
 'blocks.15.norm2.bias',
 'blocks.15.mlp.fc1.weight',
 'blocks.15.mlp.fc1.bias',
 'blocks.15.mlp.fc2.weight',
 'blocks.15.mlp.fc2.bias',
 'norm_embed.weight',
 'norm_embed.bias',
 'head.proj.weight',
 'head.proj.bias']

@aiot-tech I'm not the maintainer of PyTorch Video, but I already spoke with @lyttonhao about the issue and it's on his radar to fix. It just happened that at the same time, I'm working on adding the specific model on TorchVision and the above workaround works for me. I managed to "translate" the weights in question at pytorch/vision#6179 without a problem. This might indicate you are using a different version of PyTorch Video (I'm currently using the main branch).

commented

Thanks for reply.The version of pytorchvideo I use now is 0.1.5.

@aiot-tech I just noticed that you are not loading the weights properly. The issue is the given checkpoint contains not only the model weights but also other training specific info. Try doing:

model.load_state_dict(torch.load(path)["model_state"], strict=False)
commented

Brilliant!