tjuskyzhang / Scaled-YOLOv4-TensorRT

Got 100fps on TX2. Got 500fps on GeForce GTX 1660 Ti. If the project is useful to you, please Star it.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

你好 运行yolov4-tiny-tensorrt这个例子时,运行gen_wts.py报错

HeuMindFusion opened this issue · comments

Traceback (most recent call last):
File "gen_wts.py", line 13, in
model.load_state_dict(torch.load(weights, map_location=device)['model'])
File "/home/sany/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1224, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for Darknet:
Missing key(s) in state_dict: "total_ops", "total_params", "module_list.total_ops", "module_list.total_params", "module_list.0.total_ops", "module_list.0.total_params", "module_list.1.total_ops", "module_list.1.total_params", "module_list.2.total_ops", "module_list.2.total_params", "module_list.3.total_ops", "module_list.3.total_params", "module_list.4.total_ops", "module_list.4.total_params", "module_list.5.total_ops", "module_list.5.total_params", "module_list.6.total_ops", "module_list.6.total_params", "module_list.7.total_ops", "module_list.7.total_params", "module_list.8.total_ops", "module_list.8.total_params", "module_list.10.total_ops", "module_list.10.total_params", "module_list.11.total_ops", "module_list.11.total_params", "module_list.12.total_ops", "module_list.12.total_params", "module_list.13.total_ops", "module_list.13.total_params", "module_list.14.total_ops", "module_list.14.total_params", "module_list.15.total_ops", "module_list.15.total_params", "module_list.16.total_ops", "module_list.16.total_params", "module_list.18.total_ops", "module_list.18.total_params", "module_list.19.total_ops", "module_list.19.total_params", "module_list.20.total_ops", "module_list.20.total_params", "module_list.21.total_ops", "module_list.21.total_params", "module_list.22.total_ops", "module_list.22.total_params", "module_list.23.total_ops", "module_list.23.total_params", "module_list.24.total_ops", "module_list.24.total_params", "module_list.26.total_ops", "module_list.26.total_params", "module_list.27.total_ops", "module_list.27.total_params", "module_list.28.total_ops", "module_list.28.total_params", "module_list.29.total_ops", "module_list.29.total_params", "module_list.30.total_ops", "module_list.30.total_params", "module_list.31.total_ops", "module_list.31.total_params", "module_list.32.total_ops", "module_list.32.total_params", "module_list.34.total_ops", "module_list.34.total_params", "module_list.35.total_ops", "module_list.35.total_params", "module_list.36.total_ops", "module_list.36.total_params", "module_list.37.total_ops", "module_list.37.total_params".

gen_wts.py
import struct
import sys
from models import *
from utils.utils import *
from utils.torch_utils import select_device
model = Darknet('cfg/yolov4-tiny.cfg', (416, 416))
weights = sys.argv[1]
device = select_device('cpu')

dev = '0'
print(model)
if weights.endswith('.pt'): # pytorch format
model.load_state_dict(torch.load(weights, map_location=device)['model'])
print("------------------------------")
else: # darknet format
load_darknet_weights(model, weights)

f = open('yolov4-tiny.wts', 'w')
f.write('{}\n'.format(len(model.state_dict().keys())))
for k, v in model.state_dict().items():
vr = v.reshape(-1).cpu().numpy()
f.write('{} {} '.format(k, len(vr)))
for vv in vr:
f.write(' ')
f.write(struct.pack('>f',float(vv)).hex())
f.write('\n')