How to load a pretrained model?
kamadforge opened this issue · comments
It seems that although there is a flag for pre-trained model in the trainer.py, but it is not used to load the model and the training proceeds from scratch.
Note: I ended up loading it using the load checkpoint function.
True, this needs to be corrected. Are you willing to make a contribution?
Hi, I solved this problem. Hope this helps you.
import torch
from resnet import *
import dill # in order to save Lambda Layer
# your devices
device_ids = [0, 1]
# the network architecture coresponding to the checkpoint
model = resnet20()
# remember to set map_location
check_point = torch.load('resnet20-12fca82f.th', map_location='cuda:%d' % device_ids[0])
# cause the model are saved from Parallel, we need to wrap it
model = torch.nn.DataParallel(model, device_ids=device_ids)
model.load_state_dict(check_point['state_dict'])
# pay attention to .module! without this, if you load the model, it will be attached with [Parallel.module]
# that will lead to some trouble!
torch.save(model.module, 'resnet20_check_point.pth', pickle_module=dill)
# load the converted pretrained model
net = torch.load('resnet20_check_point.pth', map_location='cuda:%d' % device_ids[0])
x = torch.rand(size=(1, 3, 32, 32)).cuda(device_ids[0])
out = net(x)
print(out)
It seems that although there is a flag for pre-trained model in the trainer.py, but it is not used to load the model and the training proceeds from scratch. Note: I ended up loading it using the load checkpoint function.
A year passed...Hope you have solved this problem. If haven't, have a try on my solution. Good luck!
Hi, I solved this problem. Hope this helps you.
import torch from resnet import * import dill # in order to save Lambda Layer # your devices device_ids = [0, 1] # the network architecture coresponding to the checkpoint model = resnet20() # remember to set map_location check_point = torch.load('resnet20-12fca82f.th', map_location='cuda:%d' % device_ids[0]) # cause the model are saved from Parallel, we need to wrap it model = torch.nn.DataParallel(model, device_ids=device_ids) model.load_state_dict(check_point['state_dict']) # pay attention to .module! without this, if you load the model, it will be attached with [Parallel.module] # that will lead to some trouble! torch.save(model.module, 'resnet20_check_point.pth', pickle_module=dill) # load the converted pretrained model net = torch.load('resnet20_check_point.pth', map_location='cuda:%d' % device_ids[0]) x = torch.rand(size=(1, 3, 32, 32)).cuda(device_ids[0]) out = net(x) print(out)
model = model.cuda()
param = torch.load("pretrained_models/resnet56-4bfd9763.th")
model.load_state_dict(param['state_dict'])
Thanks, this also works.