mit-han-lab / litepose

[CVPR'22] Lite Pose: Efficient Architecture Design for 2D Human Pose Estimation

Home Page:https://hanlab.mit.edu

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Jetson Nano inference 速度慢

happy-wook-kim opened this issue · comments

你好,最近我用过Litepose COCO还有见几个问题。所以有时间的话请你们帮我。
我看你们的LitePose-Auto-M结果, 就是144 Latency(ms)
然后我有LitePose-Auto-S,还有我得了300 Latency(ms)。 太慢了。你们知不知道为什么怎么慢。。。
我觉得如果我用'TVM变换',可能快一点。
请你们帮我怎么可以解决这个问题。

Hello, you can just speak English.

We have released the code for running our model on Jetson Nano with pre-built TVM binary in nano_demo. To convert the torch model to TVM binary, you may need to check the TVM Auto Scheduler Toturial.

thx a lot!

@lmxyy Thanks for the repo.
Have you first converted the LitePose model to onnx, and then to TVM by cross compiling it for a target tvm.target.Target("nvidia/jetson-nano") using auto scheduler ?

@lmxyy If possible, would you make the tvm auto schedule code available? I am not being able to cross compile it for jetson. Thank you.

@lmxyy Thanks for the repo.
Have you first converted the LitePose model to onnx, and then to TVM by cross compiling it for a target tvm.target.Target("nvidia/jetson-nano") using auto scheduler ?

Yes, that is how we did it. The target should be 'target = 'cuda -arch=sm_53 -max_num_threads=256 -max_threads_per_block=256' for cuda and the target_host is 'llvm -mtriple=aarch64-linux-gnu'.

An example script using TVM 0.8 autoscheduler is as follows:

import argparse
import os
from typing import Tuple

import numpy as np
import onnx
import torch
import tvm
from tvm import auto_scheduler, relay
from tvm.autotvm.measure.measure_methods import set_cuda_target_arch
from tvm.contrib.graph_executor import GraphModule

device_key = 'nano'
remote_host = '0.0.0.0'
remote_port = 9190


def run_tuning(args, tasks, task_weights):
    print("Begin tuning...")
    runner = auto_scheduler.RPCRunner(
        key=device_key,
        host=remote_host,
        port=remote_port,
        timeout=120,
        repeat=1,
        min_repeat_ms=200,
        enable_cpu_cache_flush=True,
        n_parallel=args.n_parallel
    )
    tuner = auto_scheduler.TaskScheduler(tasks, task_weights, load_log_file=args.resume_from)
    tune_option = auto_scheduler.TuningOptions(
        num_measure_trials=args.num_trials,
        runner=runner,
        measure_callbacks=[auto_scheduler.RecordToFile(args.log_file)],
    )
    tuner.tune(tune_option)


def auto_schedule(args, mod, params, target, target_host):
    tasks, task_weights = auto_scheduler.extract_tasks(mod['main'], params, target,
                                                       target_host=target_host,
                                                       opt_level=args.opt_level)
    for idx, task in enumerate(tasks):
        print("========== Task %d  (workload key: %s) ==========" % (idx, task.workload_key))
        print(task.compute_dag)
    run_tuning(args, tasks, task_weights)


def onnx2tvm(args, inputs, target, target_host):
    input_shapes = {}
    for index, torch_input in enumerate(inputs):
        name = 'input.%d' % (index + 1)
        input_shapes[name] = torch_input.shape
    onnx_model = onnx.load_model(args.onnx_path)
    relay_module, params = relay.frontend.from_onnx(onnx_model, shape=input_shapes)
    if not args.skip_tune:
        auto_schedule(args, relay_module, params, target, target_host)
    with auto_scheduler.ApplyHistoryBest(args.log_file):
        with tvm.transform.PassContext(opt_level=args.opt_level, config={"relay.backend.use_auto_scheduler": True}):
            lib = relay.build(relay_module, target, params=params, target_host=target_host)
    return lib


def get_executor(args, inputs, target, target_host):
    if not os.path.exists(args.engine_path) or args.force_rebuild:
        assert args.onnx_path is not None
        lib = onnx2tvm(args, inputs, target, target_host)
        os.makedirs(os.path.dirname(args.engine_path), exist_ok=True)
        lib.export_library(args.engine_path)
    remote = auto_scheduler.utils.request_remote(device_key, remote_host, remote_port, timeout=10000)
    remote.upload(args.engine_path)
    rlib = remote.load_module(os.path.basename(args.engine_path))

    device = remote.cuda() if args.device == 'gpu' else remote.cpu()
    gmod = GraphModule(rlib['default'](device))

    def executor(inputs: Tuple[tvm.nd.NDArray]):
        for index, value in enumerate(inputs):
            gmod.set_input(index, value)
        gmod.run()
        return tuple(gmod.get_output(index) for index in range(2))

    return rlib, executor, gmod, device


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--onnx_path', type=str, default=None)
    parser.add_argument('--force_rebuild', action='store_true')
    parser.add_argument('--engine_path', type=str, required=True)
    parser.add_argument('--warmup_times', type=int, default=10)
    parser.add_argument('--test_times', type=int, default=20)
    parser.add_argument('--resolution', type=int, default=512)
    parser.add_argument('--device', type=str, default='cpu', choices=['cpu', 'gpu'])
    parser.add_argument('--opt_level', type=int, default=4)
    parser.add_argument('--log_file', type=str, default='./log.json')
    parser.add_argument('--num_trials', type=int, default=20000)
    parser.add_argument('--resume_from', type=str, default=None)
    parser.add_argument('--skip_tune', action='store_true')
    parser.add_argument('--n_parallel', type=int, default=1)
    args = parser.parse_args()

    inputs = (torch.ones(1, 3, args.resolution, args.resolution),)
    if args.device == 'cpu':
        target = 'llvm -mtriple=aarch64-linux-gnu'
    else:
        set_cuda_target_arch('sm_53')
        target = 'cuda -arch=sm_53 -max_num_threads=256 -max_threads_per_block=256'
    target_host = 'llvm -mtriple=aarch64-linux-gnu'
    lib, executor, gmod, device = get_executor(args, inputs, target, target_host)
    x = tvm.nd.array(np.ones([1, 3, args.resolution, args.resolution], dtype=np.float32), device=device)
    inputs = (x,)
    outputs = executor(inputs)
    for i, output in enumerate(outputs):
        print('Output Absolute Sum%d: %.3f' % (i, float(abs(output.asnumpy()).sum())))
    gmod.set_input(0, x)
    ftimer = gmod.module.time_evaluator('run', device, repeat=3, min_repeat_ms=500)
    prof_res = np.array(ftimer().results) * 1e3  # convert to millisecond
    print('Mean inference time (std dev): %.2f ms (%.2f ms)' % (np.mean(prof_res), np.std(prof_res)))

You need to refer to TVM Auto Scheduler Tutorial to learn about how to tune the model using the RPC communication and cross compilation given an ONNX model.

@lmxyy Thanks a lot for the help. Really appreciate

@lmxyy

Is this restrcition ( '-max_num_threads=256 -max_threads_per_block=256') important for the performance after auto scheduler?