jcjohnson / pytorch-vgg

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

vgg model checkpoint needs a change of classifier weight names

soumith opened this issue · comments

See https://discuss.pytorch.org/t/upgrading-torchvision-module-makes-old-model-useless/1719

The reason this changed was because of pytorch/vision#107 where Sam realized that he put dropout in the wrong location.

So the state_dict needs the names changed appropriately.

Just ran into this myself. The change is pretty simple, this should do it:

import torch
from torch.utils.model_zoo import load_url
from torchvision import models

sd = load_url("https://s3-us-west-2.amazonaws.com/jcjohns-models/vgg19-d01eb7cb.pth")
sd['classifier.0.weight'] = sd['classifier.1.weight']
sd['classifier.0.bias'] = sd['classifier.1.bias']
del sd['classifier.1.weight']
del sd['classifier.1.bias']

sd['classifier.3.weight'] = sd['classifier.4.weight']
sd['classifier.3.bias'] = sd['classifier.4.bias']
del sd['classifier.4.weight']
del sd['classifier.4.bias']

torch.save(sd, "vgg19-d01eb7cb.pth")

Would be great if you could upload the newer versions to s3

keep param order as model is an OrderedDict:

from collections import OrderedDict
from torch.utils.model_zoo import load_url
import torch

sd = load_url("https://s3-us-west-2.amazonaws.com/jcjohns-models/vgg19-d01eb7cb.pth")
map = {'classifier.1.weight':u'classifier.0.weight', 'classifier.1.bias':u'classifier.0.bias', 'classifier.4.weight':u'classifier.3.weight', 'classifier.4.bias':u'classifier.3.bias'}
sd = OrderedDict([(map[k] if k in map else k,v) for k,v in sd.iteritems()])
torch.save(sd, "vgg19-d01eb7cb.pth")