Tramac / Fast-SCNN-pytorch

A PyTorch Implementation of Fast-SCNN: Fast Semantic Segmentation Network

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Issues with and / Missing key(s) in state_dict

kame-hameha opened this issue · comments

  1. I ran into an issue with the and scripts that is missing keys in the state_dict:
    Traceback (most recent call last): File "", line 55, in demo() File "", line 43, in demo model = get_fast_scnn(args.dataset, pretrained=True, root=args.weights_folder, map_cpu=args.cpu).to(device) File "/mnt/git/Fast-SCNN-pytorch/models/", line 251, in get_fast_scnn model.load_state_dict(torch.load(os.path.join(root, 'fast_scnn_%s.pth' % acronyms[dataset]))) File "/home/user/miniconda/envs/py36/lib/python3.6/site-packages/torch/nn/modules/", line 769, in load_state_dict, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for FastSCNN: Missing key(s) in state_dict: "learning_to_downsample.conv.conv.0.weight",

  2. I could do the training via:
    python --model fast_scnn --dataset citys --batch-size 60 --epochs 80

  3. I am using the follwing nvidia-docker with
    docker pull anibali/pytorch:cuda-10.0
    --> CUDA10.0
    --> PyTorch 1.0.0
    --> Python 3.6.5 :: Anaconda, Inc.
    --> Torchvision 0.2.1
    --> Using 2 Titan RTX GPUs
    --> nvidia-smi: NVIDIA-SMI 410.104 Driver Version: 410.104 CUDA Version: 10.0

Thank you in advance!

Hi @kame-hameha ,

I faced the same issue and solved the problem based on the following pytorch discussion post:

In short, a model trained with nn.DataParallel have state names with the prefix module., whereas loading this model w/o nn.DataParallel expects state names w/o module. prefix.
As a result, missing keys and unexpected keys error occurs.

The following code snippet is what I used to fix this issue.
Hope it helps you :-)

# models/
def get_fast_scnn(model_file, dataset='citys', pretrained=False, root='./weights', map_cpu=False, **kwargs):
    from data_loader import datasets
    from collections import OrderedDict

    def convert_state_dict(data):
        """Remove 'module.' prefix.
        new_state_dict = OrderedDict()
        for k, v in data.items():
            name = k[7:]
            new_state_dict[name] = v
        return new_state_dict

    acronyms = {
        'pascal_voc': 'voc',
        'pascal_aug': 'voc',
        'ade20k': 'ade',
        'coco': 'coco',
        'citys': 'citys',

    model = FastSCNN(datasets[dataset].NUM_CLASS, **kwargs)
    if pretrained:
            state_dict = torch.load(os.path.join(root, model_file), map_location='cpu')
            state_dict = torch.load(os.path.join(root, model_file))

    return model

@priancho @kame-hameha There has an easy way to solve it, when you train the model ,please modify the way to save model, in :, save_path) into, save_path)