with hybrid model i get not correct dictionary .
magicse opened this issue · comments
with hybrid model i get not correct dictionary .
demo.py --model_type dpt_hybrid --attn_interval=2
Python38\lib\site-packages\torch\nn\modules\module.py", line 1604, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for DataParallel:
Missing key(s) in state_dict: "module.pretrained.model.blocks.2.attn.relative_position_bias_table", "module.pretrained.model.blocks.2.attn.relative_posi
tion_index", "module.pretrained.model.blocks.4.attn.relative_position_bias_table", "module.pretrained.model.blocks.4.attn.relative_position_index", "module.pret
rained.model.blocks.8.attn.relative_position_bias_table", "module.pretrained.model.blocks.8.attn.relative_position_index", "module.pretrained.model.blocks.10.at
tn.relative_position_bias_table", "module.pretrained.model.blocks.10.attn.relative_position_index".
Unexpected key(s) in state_dict: "module.pretrained.model.blocks.3.attn.relative_position_bias_table", "module.pretrained.model.blocks.3.attn.relative_p
osition_index", "module.pretrained.model.blocks.9.attn.relative_position_bias_table", "module.pretrained.model.blocks.9.attn.relative_position_index".
missing and unexpected keys related to the attn.relative_position_bias_table and attn.relative_position_index
checking dictionary of model and checkpoint
import torch
from dpt.models import DPTDepthModel
attn_interval = 2
# Instantiate the model
model = DPTDepthModel(
backbone="vitb_rn50_384",
#backbone="vitl16_384",
non_negative=True,
enable_attention_hooks=False,
attn_interval=attn_interval,
)
# Load the checkpoint file
checkpoint = torch.load('checkpoints/vita-hybrid.pth', map_location=torch.device('cpu'))
#checkpoint = torch.load('checkpoints/vita-large.pth', map_location=torch.device('cpu'))
# Retrieve model state dictionary keys
model_state_dict_keys = set(model.state_dict().keys())
# Extract keys from the loaded checkpoint without 'module' prefix
checkpoint_keys = set([key.replace("module.", "") for key in checkpoint.keys()])
# Find missing and unexpected keys
missing_keys = model_state_dict_keys - checkpoint_keys
unexpected_keys = checkpoint_keys - model_state_dict_keys
# Print the missing and unexpected keys
print(f"Missed keys in the model's state_dict from checkpoint:")
for key in missing_keys:
print(f"{key}")
print(f"\nUnexpected keys in checkpoint's state_dict for the model:")
for key in unexpected_keys:
print(f"{key}")
Output
Missed keys in the model's state_dict from checkpoint:
pretrained.model.blocks.10.attn.relative_position_index
pretrained.model.blocks.2.attn.relative_position_index
pretrained.model.blocks.8.attn.relative_position_index
pretrained.model.blocks.10.attn.relative_position_bias_table
pretrained.model.blocks.8.attn.relative_position_bias_table
pretrained.model.blocks.4.attn.relative_position_bias_table
pretrained.model.blocks.4.attn.relative_position_index
pretrained.model.blocks.2.attn.relative_position_bias_table
Unexpected keys in checkpoint's state_dict for the model:
pretrained.model.blocks.3.attn.relative_position_index
pretrained.model.blocks.9.attn.relative_position_index
pretrained.model.blocks.3.attn.relative_position_bias_table
pretrained.model.blocks.9.attn.relative_position_bias_table
We just released the 'attn_interval = 3' model for dpt-hybird, so please set attn_interval = 3 when you use the hybrid one.