akamaster / pytorch_resnet_cifar10

Proper implementation of ResNet-s for CIFAR10/100 in pytorch that matches description of the original paper.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Issues on loading pre-trained model

DeeperCS opened this issue · comments

commented

Thank you for providing those very useful pre-trained models. However, I got some troubles when loading them. What I did are listed as follows,

res20 = resnet20()
weights = torch.load('pytorch_resnet_cifar10/pretrained_models/resnet20.th')
res20.load_state_dict(weights)

It fails because the keys are not matching, e.g., "conv1.weight" in the constructed model while "module.conv1.weight" in the pre-trained weights.

So I'm wondering is it possible to provide an example code for loading the pre-trained model? Or how can I solve this problem? Thanks.

commented

Just got it solved, the following code works for me.

model = torch.nn.DataParallel(resnet20())
model.cuda()

checkpoint = torch.load('pytorch_resnet_cifar10/pretrained_models/resnet20.th')
model.load_state_dict(checkpoint['state_dict'])