Issues on loading pre-trained model
DeeperCS opened this issue · comments
Thank you for providing those very useful pre-trained models. However, I got some troubles when loading them. What I did are listed as follows,
res20 = resnet20()
weights = torch.load('pytorch_resnet_cifar10/pretrained_models/resnet20.th')
res20.load_state_dict(weights)
It fails because the keys are not matching, e.g., "conv1.weight" in the constructed model while "module.conv1.weight" in the pre-trained weights.
So I'm wondering is it possible to provide an example code for loading the pre-trained model? Or how can I solve this problem? Thanks.
Just got it solved, the following code works for me.
model = torch.nn.DataParallel(resnet20())
model.cuda()
checkpoint = torch.load('pytorch_resnet_cifar10/pretrained_models/resnet20.th')
model.load_state_dict(checkpoint['state_dict'])