Error when loading weights for large model
mrlzla opened this issue · comments
Code:
net_large.load_state_dict(torch.load('pretrained/mobilenetv3-large-657e7b3d.pth'))
Error:
RuntimeError: Error(s) in loading state_dict for MobileNetV3:
Missing key(s) in state_dict: "classifier.0.weight", "classifier.0.bias", "classifier.3.weight", "classifier.3.bias".
Unexpected key(s) in state_dict: "classifier.1.weight", "classifier.1.bias", "classifier.5.weight", "classifier.5.bias".
just change keys in the checkpoint to expected
Code:
net_large.load_state_dict(torch.load('pretrained/mobilenetv3-large-657e7b3d.pth'))
Error:
RuntimeError: Error(s) in loading state_dict for MobileNetV3: Missing key(s) in state_dict: "classifier.0.weight", "classifier.0.bias", "classifier.3.weight", "classifier.3.bias". Unexpected key(s) in state_dict: "classifier.1.weight", "classifier.1.bias", "classifier.5.weight", "classifier.5.bias".
state_dict = torch.load("mobilenetv3-large-657e7b3d.pth")
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k=="classifier.1.weight":
k = "classifier.0.weight"
if k=="classifier.1.bias":
k = "classifier.0.bias"
if k=="classifier.5.weight":
k = "classifier.3.weight"
if k=="classifier.5.bias":
k = "classifier.3.bias"
new_state_dict[k] = v
model.load_state_dict(new_state_dict)
If you need only conv part, you can use
net_large.load_state_dict(torch.load('pretrained/mobilenetv3-large-657e7b3d.pth'), strict=False)