Tramac / Fast-SCNN-pytorch

A PyTorch Implementation of Fast-SCNN: Fast Semantic Segmentation Network

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Issues with eval.py and demo.py / Missing key(s) in state_dict

kame-hameha opened this issue · comments

commented
  1. I ran into an issue with the eval.py and demo.py scripts that is missing keys in the state_dict:
    ##########
    Traceback (most recent call last): File "demo.py", line 55, in demo() File "demo.py", line 43, in demo model = get_fast_scnn(args.dataset, pretrained=True, root=args.weights_folder, map_cpu=args.cpu).to(device) File "/mnt/git/Fast-SCNN-pytorch/models/fast_scnn.py", line 251, in get_fast_scnn model.load_state_dict(torch.load(os.path.join(root, 'fast_scnn_%s.pth' % acronyms[dataset]))) File "/home/user/miniconda/envs/py36/lib/python3.6/site-packages/torch/nn/modules/module.py", line 769, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for FastSCNN: Missing key(s) in state_dict: "learning_to_downsample.conv.conv.0.weight",
    ...
    "module.classifier.conv.1.bias".
    ##########

  2. I could do the training via:
    python train.py --model fast_scnn --dataset citys --batch-size 60 --epochs 80

  3. I am using the follwing nvidia-docker with
    docker pull anibali/pytorch:cuda-10.0
    --> CUDA10.0
    --> PyTorch 1.0.0
    --> Python 3.6.5 :: Anaconda, Inc.
    --> Torchvision 0.2.1
    --> Using 2 Titan RTX GPUs
    --> nvidia-smi: NVIDIA-SMI 410.104 Driver Version: 410.104 CUDA Version: 10.0

Thank you in advance!

Hi @kame-hameha ,

I faced the same issue and solved the problem based on the following pytorch discussion post:

In short, a model trained with nn.DataParallel have state names with the prefix module., whereas loading this model w/o nn.DataParallel expects state names w/o module. prefix.
As a result, missing keys and unexpected keys error occurs.

The following code snippet is what I used to fix this issue.
Hope it helps you :-)

#
# models/fast_scnn.py:get_fast_scnn()
#
def get_fast_scnn(model_file, dataset='citys', pretrained=False, root='./weights', map_cpu=False, **kwargs):
    from data_loader import datasets
    from collections import OrderedDict

    def convert_state_dict(data):
        """Remove 'module.' prefix.
        """
        new_state_dict = OrderedDict()
        for k, v in data.items():
            name = k[7:]
            new_state_dict[name] = v
        return new_state_dict

    acronyms = {
        'pascal_voc': 'voc',
        'pascal_aug': 'voc',
        'ade20k': 'ade',
        'coco': 'coco',
        'citys': 'citys',
    }

    model = FastSCNN(datasets[dataset].NUM_CLASS, **kwargs)
    if pretrained:
        if(map_cpu):
            state_dict = torch.load(os.path.join(root, model_file), map_location='cpu')
        else:
            state_dict = torch.load(os.path.join(root, model_file))
        model.load_state_dict(convert_state_dict(state_dict))

    return model
commented

@priancho @kame-hameha There has an easy way to solve it, when you train the model ,please modify the way to save model, in train.py :
torch.save(model.state_dict(), save_path) into torch.save(model.module.state_dict(), save_path)