Issues with eval.py and demo.py / Missing key(s) in state_dict
kame-hameha opened this issue · comments
-
I ran into an issue with the eval.py and demo.py scripts that is missing keys in the state_dict:
##########
Traceback (most recent call last): File "demo.py", line 55, in demo() File "demo.py", line 43, in demo model = get_fast_scnn(args.dataset, pretrained=True, root=args.weights_folder, map_cpu=args.cpu).to(device) File "/mnt/git/Fast-SCNN-pytorch/models/fast_scnn.py", line 251, in get_fast_scnn model.load_state_dict(torch.load(os.path.join(root, 'fast_scnn_%s.pth' % acronyms[dataset]))) File "/home/user/miniconda/envs/py36/lib/python3.6/site-packages/torch/nn/modules/module.py", line 769, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for FastSCNN: Missing key(s) in state_dict: "learning_to_downsample.conv.conv.0.weight",
...
"module.classifier.conv.1.bias".
########## -
I could do the training via:
python train.py --model fast_scnn --dataset citys --batch-size 60 --epochs 80
-
I am using the follwing nvidia-docker with
docker pull anibali/pytorch:cuda-10.0
--> CUDA10.0
--> PyTorch 1.0.0
--> Python 3.6.5 :: Anaconda, Inc.
--> Torchvision 0.2.1
--> Using 2 Titan RTX GPUs
--> nvidia-smi: NVIDIA-SMI 410.104 Driver Version: 410.104 CUDA Version: 10.0
Thank you in advance!
Hi @kame-hameha ,
I faced the same issue and solved the problem based on the following pytorch discussion post:
In short, a model trained with nn.DataParallel
have state names with the prefix module.
, whereas loading this model w/o nn.DataParallel
expects state names w/o module.
prefix.
As a result, missing keys and unexpected keys error occurs.
The following code snippet is what I used to fix this issue.
Hope it helps you :-)
#
# models/fast_scnn.py:get_fast_scnn()
#
def get_fast_scnn(model_file, dataset='citys', pretrained=False, root='./weights', map_cpu=False, **kwargs):
from data_loader import datasets
from collections import OrderedDict
def convert_state_dict(data):
"""Remove 'module.' prefix.
"""
new_state_dict = OrderedDict()
for k, v in data.items():
name = k[7:]
new_state_dict[name] = v
return new_state_dict
acronyms = {
'pascal_voc': 'voc',
'pascal_aug': 'voc',
'ade20k': 'ade',
'coco': 'coco',
'citys': 'citys',
}
model = FastSCNN(datasets[dataset].NUM_CLASS, **kwargs)
if pretrained:
if(map_cpu):
state_dict = torch.load(os.path.join(root, model_file), map_location='cpu')
else:
state_dict = torch.load(os.path.join(root, model_file))
model.load_state_dict(convert_state_dict(state_dict))
return model
@priancho @kame-hameha There has an easy way to solve it, when you train the model ,please modify the way to save model, in train.py :
torch.save(model.state_dict(), save_path) into torch.save(model.module.state_dict(), save_path)