akamaster / pytorch_resnet_cifar10

Proper implementation of ResNet-s for CIFAR10/100 in pytorch that matches description of the original paper.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How to load a pretrained model?

kamadforge opened this issue · comments

It seems that although there is a flag for pre-trained model in the trainer.py, but it is not used to load the model and the training proceeds from scratch.
Note: I ended up loading it using the load checkpoint function.

True, this needs to be corrected. Are you willing to make a contribution?

Hi, I solved this problem. Hope this helps you.

import torch
from resnet import *
import dill  # in order to save Lambda Layer

# your devices
device_ids = [0, 1]

# the network architecture coresponding to the checkpoint
model = resnet20()

# remember to set map_location
check_point = torch.load('resnet20-12fca82f.th', map_location='cuda:%d' % device_ids[0])

# cause the model are saved from Parallel, we need to wrap it
model = torch.nn.DataParallel(model, device_ids=device_ids)
model.load_state_dict(check_point['state_dict'])

# pay attention to .module! without this, if you load the model, it will be attached with [Parallel.module]
# that will lead to some trouble!
torch.save(model.module, 'resnet20_check_point.pth', pickle_module=dill)

# load the converted pretrained model
net = torch.load('resnet20_check_point.pth', map_location='cuda:%d' % device_ids[0])
x = torch.rand(size=(1, 3, 32, 32)).cuda(device_ids[0])
out = net(x)
print(out)

It seems that although there is a flag for pre-trained model in the trainer.py, but it is not used to load the model and the training proceeds from scratch. Note: I ended up loading it using the load checkpoint function.

A year passed...Hope you have solved this problem. If haven't, have a try on my solution. Good luck!

commented

Hi, I solved this problem. Hope this helps you.

import torch
from resnet import *
import dill  # in order to save Lambda Layer

# your devices
device_ids = [0, 1]

# the network architecture coresponding to the checkpoint
model = resnet20()

# remember to set map_location
check_point = torch.load('resnet20-12fca82f.th', map_location='cuda:%d' % device_ids[0])

# cause the model are saved from Parallel, we need to wrap it
model = torch.nn.DataParallel(model, device_ids=device_ids)
model.load_state_dict(check_point['state_dict'])

# pay attention to .module! without this, if you load the model, it will be attached with [Parallel.module]
# that will lead to some trouble!
torch.save(model.module, 'resnet20_check_point.pth', pickle_module=dill)

# load the converted pretrained model
net = torch.load('resnet20_check_point.pth', map_location='cuda:%d' % device_ids[0])
x = torch.rand(size=(1, 3, 32, 32)).cuda(device_ids[0])
out = net(x)
print(out)
model = model.cuda()
param = torch.load("pretrained_models/resnet56-4bfd9763.th")
model.load_state_dict(param['state_dict'])

Thanks, this also works.