xiaolai-sqlai / mobilenetv3

mobilenetv3 with pytorch,provide pre-train model

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

the model download can not decompression, can re upload again? thanks a lot , or send to my email ? 328511963@qq.com thanks a lot

cuicuizhang1989 opened this issue · comments

直接加载模型就行,多GPU训练的,可能需要对模型进行修改后加载

@excllent123 你好,需要怎么修改模型才能加载?我直接加载模型也会出错

commented

def load_checkpoint(model, checkpoint_PATH):
if checkpoint_PATH != None:
model_CKPT = torch.load(checkpoint_PATH)
model.load_state_dict({k.replace('module.', ''): v for k, v in model_CKPT['state_dict'].items()})
print('loading checkpoint!')
return model

def mobilenet_large_v3(pretrained=False,**kwargs):
if pretrained:
model = MobileNetV3_Large(**kwargs)
return load_checkpoint(model,'mbv3_large.pth.tar')

return MobileNetV3_Large(**kwargs)

class Finetune_MobileNetV3_Large(nn.Module):
def init(self,class_nums):
super(Finetune_MobileNetV3_Large, self).init()
self.class_nums = class_nums
self.base = mobilenet_large_v3(pretrained=True)
self.base.linear4 = nn.Linear(1280, self.class_nums)

def forward(self, x):
    x = self.base(x)
    return x

感谢