sail-sg / volo

VOLO: Vision Outlooker for Visual Recognition

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

There is a problem when loading the pretrained weights

JonnesLin opened this issue · comments

commented

A problem happen when I load the pretrained weight you provided.


UnpicklingError Traceback (most recent call last)
in
9 # as we interpolate the position embeding for different image size.
10 load_pretrained_weights(model, "/home/featurize/work/checkpoints/archive/data.pkl", use_ema=False,
---> 11 strict=False, num_classes=1000)

/cloud/volo/utils/utils.py in load_pretrained_weights(model, checkpoint_path, use_ema, strict, num_classes)
140 num_classes=1000):
141 '''load pretrained weight for VOLO models'''
--> 142 state_dict = load_state_dict(checkpoint_path, model, use_ema, num_classes)
143 model.load_state_dict(state_dict, strict=strict)
144

/cloud/volo/utils/utils.py in load_state_dict(checkpoint_path, model, use_ema, num_classes)
92 if checkpoint_path and os.path.isfile(checkpoint_path):
93 # checkpoint = torch.load(checkpoint_path, map_location='cpu')
---> 94 checkpoint = torch.load(checkpoint_path)
95 state_dict_key = 'state_dict'
96 if isinstance(checkpoint, dict):

/environment/python/versions/miniconda3-4.7.12/lib/python3.7/site-packages/torch/serialization.py in load(f, map_location, pickle_module, **pickle_load_args)
591 return torch.jit.load(opened_file)
592 return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
--> 593 return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
594
595

/environment/python/versions/miniconda3-4.7.12/lib/python3.7/site-packages/torch/serialization.py in _legacy_load(f, map_location, pickle_module, **pickle_load_args)
760 "functionality.")
761
--> 762 magic_number = pickle_module.load(f, **pickle_load_args)
763 if magic_number != MAGIC_NUMBER:
764 raise RuntimeError("Invalid magic number; corrupt file?")

UnpicklingError: A load persistent id instruction was encountered,
but no persistent_load function was specified.

Hi, when you load pretrain models, don't untar the download file as torch.load() can direclty load '.pth.tar'.