WenmuZhou / PytorchOCR

基于Pytorch的OCR工具库,支持常用的文字检测和识别算法

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

DB模型model.half()输出都成0了,这是为什么?

June-Li opened this issue · comments

您好,请教一下,我用您的DB模型训练好的模型half()的时候,输出都成0了,但是我自己训练的模型可以half(),这是为什么呢?

问题原因:
百度训练的DB模型整体参数偏大,bn层尤其明显,达到了1e8量级,所以fp32转fp16的时候会超出fp16范围;

解决办法:
把bn层的均值E[x]缩小1e2倍,Var[x]缩小1e4倍,并且把bn上一层卷积的weights同样缩小1e2,这样可以保证精度范围和值的范围不超fp16范围,并且bn层后的输出可以与原模型保持一致;
*注:选1e2是因为1e2正好可以将bn中最大的数缩小到fp16右端点范围内,而且缩小倍数太大会导致精度超出fp16左端点范围;

代码:
baidu_ckpt = torch.load('/volume/weights/Detector_text_model.pt', map_location='cpu')
cfg = baidu_ckpt['cfg']

baidu_state_dict = {}
for k, v in baidu_ckpt['state_dict'].items(): # k == 'head.binarize.conv1.weight' or
if k == 'head.binarize.conv1.weight':
v = torch.div(v, 1e2)
elif k == 'head.binarize.conv_bn1.running_mean':
v = torch.div(v, 1e2)
elif k == 'head.binarize.conv_bn1.running_var':
v = torch.div(v, 1e4)
baidu_state_dict[k.replace('module.', '')] = v
torch.save({'state_dict': baidu_state_dict, 'cfg': cfg}, 'output/DBNet/checkpoint/baidu.pth')