dnth / yolov5-deepsparse-blogpost

By the end of this post, you will learn how to: Train a SOTA YOLOv5 model on your own data. Sparsify the model using SparseML quantization aware training, sparse transfer learning, and one-shot quantization. Export the sparsified model and run it using the DeepSparse engine at insane speeds. P/S: The end result - YOLOv5 on CPU at 180+ FPS using on

Home Page:https://dicksonneoh.com/portfolio/supercharging_yolov5_180_fps_cpu/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

export.py error

lucheng07082221 opened this issue · comments

thank you for your great work!

I have trained pruned-quantized model by this command:

python3 train.py --cfg ./models_v5.0/yolov5s.yaml --recipe ../recipes/yolov5.transfer_learn_pruned_quantized.md --data pistols.yaml --hyp data/hyps/hyp.scratch.yaml --weights yolov5s.pt --img 416 --batch-size 8 --optimizer SGD --project yolov5-deepsparse --name yolov5s-sgd-pruned-quantized

but when I trans this pt to onnx error, my command is:

python3 export.py --weights /home/lc/work/det/yolov5-deepsparse-blogpost/yolov5-train/yolov5-deepsparse/yolov5s-sgd-pruned-quantized3/weights/last.pt --include onnx --imgsz 416 --dynamic --simplify
export: data=data/coco128.yaml, weights=['/home/lc/work/det/yolov5-deepsparse-blogpost/yolov5-train/yolov5-deepsparse/yolov5s-sgd-pruned-quantized3/weights/last.pt'], imgsz=[416], batch_size=1, device=cpu, half=False, inplace=False, train=False, optimize=False, int8=False, dynamic=True, simplify=True, opset=12, verbose=False, workspace=4, nms=False, agnostic_nms=False, topk_per_class=100, topk_all=100, iou_thres=0.45, conf_thres=0.25, remove_grid=False, include=['onnx']
YOLOv5 🚀 12612f2 torch 1.9.0+cu102 CPU

Fusing layers...
YOLOv5s summary: 224 layers, 7053910 parameters, 0 gradients, 16.3 GFLOPs
2022-06-26 17:04:59 sparseml.optim.manager INFO Created recipe manager with metadata: {
"metadata": null
}
Created recipe manager with metadata: {
"metadata": null
}
Traceback (most recent call last):
File "export.py", line 715, in
main(opt)
File "export.py", line 704, in main
run(**vars(opt))
File "/usr/local/lib/python3.6/dist-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
return func(*args, **kwargs)
File "export.py", line 593, in run
model, extras = load_checkpoint(type_='ensemble', weights=weights, device=device) # load FP32 model
File "export.py", line 529, in load_checkpoint
state_dict = load_state_dict(model, state_dict, run_mode=not ensemble_type, exclude_anchors=exclude_anchors)
File "export.py", line 553, in load_state_dict
model.load_state_dict(state_dict, strict=not run_mode) # load
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 1407, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for Model:
Missing key(s) in state_dict: "model.0.conv.conv.quant.activation_post_process.scale", "model.0.conv.conv.quant.activation_post_process.zero_point", "model.0.conv.conv.quant.activation_post_process.fake_quant_enabled", "model.0.conv.conv.quant.activation_post_process.observer_enabled", "model.0.conv.conv.quant.activation_post_process.scale", "model.0.conv.conv.quant.activation_post_process.zero_point", "model.0.conv.conv.quant.activation_post_process.activation_post_process.min_val", "model.0.conv.conv.quant.activation_post_process.activation_post_process.max_val", "model.0.conv.conv.module.weight", "model.0.conv.conv.module.bias", "model.0.conv.conv.module.weight_fake_quant.scale", "model.0.conv.conv.module.weight_fake_quant.zero_point", "model.0.conv.conv.module.weight_fake_quant.fake_quant_enabled", "model.0.conv.conv.module.weight_fake_quant.observer_enabled", "model.0.conv.conv.module.weight_fake_quant.scale", "model.0.conv.conv.module.weight_fake_quant.zero_point", "model.0.conv.conv.module.weight_fake_quant.activation_post_process.min_val", "model.0.conv.conv.module.weight_fake_quant.activation_post_process.max_val", "model.0.conv.conv.module.activation_post_process.scale", "model.0.conv.conv.module.activation_post_process.zero_point", "model.0.conv.conv.module.activation_post_process.fake_quant_enabled",

thank you for your reply!