KexianHust / ViTA

ViTA: Video Transformer Adaptor for Robust Video Depth Estimation

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

with hybrid model i get not correct dictionary .

magicse opened this issue · comments

with hybrid model i get not correct dictionary .
demo.py --model_type dpt_hybrid --attn_interval=2

Python38\lib\site-packages\torch\nn\modules\module.py", line 1604, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for DataParallel:
Missing key(s) in state_dict: "module.pretrained.model.blocks.2.attn.relative_position_bias_table", "module.pretrained.model.blocks.2.attn.relative_posi
tion_index", "module.pretrained.model.blocks.4.attn.relative_position_bias_table", "module.pretrained.model.blocks.4.attn.relative_position_index", "module.pret
rained.model.blocks.8.attn.relative_position_bias_table", "module.pretrained.model.blocks.8.attn.relative_position_index", "module.pretrained.model.blocks.10.at
tn.relative_position_bias_table", "module.pretrained.model.blocks.10.attn.relative_position_index".
Unexpected key(s) in state_dict: "module.pretrained.model.blocks.3.attn.relative_position_bias_table", "module.pretrained.model.blocks.3.attn.relative_p
osition_index", "module.pretrained.model.blocks.9.attn.relative_position_bias_table", "module.pretrained.model.blocks.9.attn.relative_position_index".

missing and unexpected keys related to the attn.relative_position_bias_table and attn.relative_position_index

checking dictionary of model and checkpoint

import torch
from dpt.models import DPTDepthModel

attn_interval = 2

# Instantiate the model
model = DPTDepthModel(
    backbone="vitb_rn50_384",
    #backbone="vitl16_384",
    non_negative=True,
    enable_attention_hooks=False,
    attn_interval=attn_interval,
)

# Load the checkpoint file
checkpoint = torch.load('checkpoints/vita-hybrid.pth', map_location=torch.device('cpu'))
#checkpoint = torch.load('checkpoints/vita-large.pth', map_location=torch.device('cpu'))

# Retrieve model state dictionary keys
model_state_dict_keys = set(model.state_dict().keys())

# Extract keys from the loaded checkpoint without 'module' prefix
checkpoint_keys = set([key.replace("module.", "") for key in checkpoint.keys()])

# Find missing and unexpected keys
missing_keys = model_state_dict_keys - checkpoint_keys
unexpected_keys = checkpoint_keys - model_state_dict_keys

# Print the missing and unexpected keys
print(f"Missed keys in the model's state_dict from checkpoint:")
for key in missing_keys:
    print(f"{key}")
print(f"\nUnexpected keys in checkpoint's state_dict for the model:")
for key in unexpected_keys:
    print(f"{key}")

Output

Missed keys in the model's state_dict from checkpoint:
pretrained.model.blocks.10.attn.relative_position_index
pretrained.model.blocks.2.attn.relative_position_index
pretrained.model.blocks.8.attn.relative_position_index
pretrained.model.blocks.10.attn.relative_position_bias_table
pretrained.model.blocks.8.attn.relative_position_bias_table
pretrained.model.blocks.4.attn.relative_position_bias_table
pretrained.model.blocks.4.attn.relative_position_index
pretrained.model.blocks.2.attn.relative_position_bias_table

Unexpected keys in checkpoint's state_dict for the model:
pretrained.model.blocks.3.attn.relative_position_index
pretrained.model.blocks.9.attn.relative_position_index
pretrained.model.blocks.3.attn.relative_position_bias_table
pretrained.model.blocks.9.attn.relative_position_bias_table

We just released the 'attn_interval = 3' model for dpt-hybird, so please set attn_interval = 3 when you use the hybrid one.