NVIDIA-AI-IOT / CUDA-PointPillars

A project demonstrating how to use CUDA-PointPillars to deal with cloud points data from lidar.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Training and evaluation works well when I change MAX_POINTS_PER_VOXEL from 32 to 100, but after converting model to ONNX, TFRT inference results in wrong predictions

Allamrahul opened this issue · comments

Dataset: I am using a custom dataset with npy files and annotations. I followed all steps required for custom dataset preparation and I am able to get great results with pytorch with 90% map on my eval set.

With MAX_POINTS_PER_VOXEL at the default value of 32, I am getting good results during EVAL, and more or less the same predictions during TFRT inference as well.

After this, I increased my MAX_POINTS_PER_VOXEL to 100 for better performance. I see better results during evaluation phase. However, when I convert the model to onnx and perform TFRT inference, I am seeing wrong predictions.

Export script evolution:
In regard to the export process, exporter.py and simplifier_onnx.py are being used in the script. However, both scripts are hardcoded for 3 classes for kitti dataset. I have just one class to detect. Hence, I referred to the following commit to make the onnx export work: https://github.com/NVIDIA-AI-IOT/CUDA-PointPillars/pull/77/commits.
After this , I was able to export but I faced the following issue after this: #82. I resolved this by tinkering with the export script, as mentioned on the following comment: #77 (comment).

Post this, I have also changed the hard coded MAX_VOXELS from 10000 to instead accept from the config file (40000).

I believe there are still bugs in the export script lurking in the shadows. PLEASE LOOK INTO THIS

I am pasting my export script for reference:

exporter.py file

# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import glob
import onnx
import torch
import argparse
import numpy as np

from pathlib import Path
from onnxsim import simplify
from pcdet.utils import common_utils
from pcdet.models import build_network
from pcdet.datasets import DatasetTemplate
from pcdet.config import cfg, cfg_from_yaml_file

from exporter_paramters import export_paramters as export_paramters
from simplifier_onnx import simplify_preprocess, simplify_postprocess

class DemoDataset(DatasetTemplate):
    def __init__(self, dataset_cfg, class_names, training=True, root_path=None, logger=None, ext='.bin'):
        """
        Args:
            root_path:
            dataset_cfg:
            class_names:
            training:
            logger:
        """
        super().__init__(
            dataset_cfg=dataset_cfg, class_names=class_names, training=training, root_path=root_path, logger=logger
        )
        self.root_path = root_path
        self.ext = ext
        data_file_list = glob.glob(str(root_path / f'*{self.ext}')) if self.root_path.is_dir() else [self.root_path]

        data_file_list.sort()
        self.sample_file_list = data_file_list

    def __len__(self):
        return len(self.sample_file_list)

    def __getitem__(self, index):
        if self.ext == '.bin':
            points = np.fromfile(self.sample_file_list[index], dtype=np.float32).reshape(-1, 4)
        elif self.ext == '.npy':
            points = np.load(self.sample_file_list[index])
        else:
            raise NotImplementedError

        input_dict = {
            'points': points,
            'frame_id': index,
        }

        data_dict = self.prepare_data(data_dict=input_dict)
        return data_dict

def parse_config():
    parser = argparse.ArgumentParser(description='arg parser')
    parser.add_argument('--cfg_file', type=str, default='cfgs/kitti_models/pointpillar.yaml',
                        help='specify the config for demo')
    parser.add_argument('--data_path', type=str, default='demo_data',
                        help='specify the point cloud data file or directory')
    parser.add_argument('--ckpt', type=str, default=None, help='specify the pretrained model')
    parser.add_argument('--ext', type=str, default='.bin', help='specify the extension of your point cloud data file')

    args = parser.parse_args()

    cfg_from_yaml_file(args.cfg_file, cfg)

    return args, cfg

def main():
    args, cfg = parse_config()
    export_paramters(cfg)
    logger = common_utils.create_logger()
    logger.info('------ Convert OpenPCDet model for TensorRT ------')
    demo_dataset = DemoDataset(
        dataset_cfg=cfg.DATA_CONFIG, class_names=cfg.CLASS_NAMES, training=False,
        root_path=Path(args.data_path), ext=args.ext, logger=logger
    )

    model = build_network(model_cfg=cfg.MODEL, num_class=len(cfg.CLASS_NAMES), dataset=demo_dataset)
    model.load_params_from_file(filename=args.ckpt, logger=logger, to_cpu=True)
    model.cuda()
    model.eval()
    np.set_printoptions(threshold=np.inf)
    with torch.no_grad():

        # MAX_VOXELS = 10000
        NUMBER_OF_CLASSES = len(cfg.CLASS_NAMES)
        MAX_POINTS_PER_VOXEL = None

        DATA_PROCESSOR = cfg.DATA_CONFIG.DATA_PROCESSOR
        POINT_CLOUD_RANGE = cfg.DATA_CONFIG.POINT_CLOUD_RANGE

        for i in DATA_PROCESSOR:
            if i['NAME'] == "transform_points_to_voxels":
                MAX_POINTS_PER_VOXEL = i['MAX_POINTS_PER_VOXEL']
                VOXEL_SIZES = i['VOXEL_SIZE']
                MAX_VOXELS = i['MAX_NUMBER_OF_VOXELS']['test']
                break

        print("ra35 DEBUG MAX_POINTS_PER_VOXEL, VOXEL_SIZES, MAX_VOXELS ", MAX_POINTS_PER_VOXEL, VOXEL_SIZES, MAX_VOXELS)

        if MAX_POINTS_PER_VOXEL == None:
            logger.info('Could Not Parse Config... Exiting')
            import sys
            sys.exit()

        VOXEL_SIZE_X = abs(POINT_CLOUD_RANGE[0] - POINT_CLOUD_RANGE[3]) / VOXEL_SIZES[0]
        VOXEL_SIZE_Y = abs(POINT_CLOUD_RANGE[1] - POINT_CLOUD_RANGE[4]) / VOXEL_SIZES[1]

        FEATURE_SIZE_X = VOXEL_SIZE_X / 2  # Is this number of bins?
        FEATURE_SIZE_Y = VOXEL_SIZE_Y / 2

        print("ra35 DEBUG FEATURE_SIZE_X FEATURE_SIZE_Y ", FEATURE_SIZE_X, FEATURE_SIZE_Y)

        dummy_voxels = torch.zeros(
          (MAX_VOXELS, MAX_POINTS_PER_VOXEL, 4),
          dtype=torch.float32,
          device='cuda:0')

        dummy_voxel_idxs = torch.zeros(
          (MAX_VOXELS, 4),
          dtype=torch.int32,
          device='cuda:0')

        dummy_voxel_num = torch.zeros(
          (1),
          dtype=torch.int32,
          device='cuda:0')
        print("ra35 DEBUG MAX_VOXELS  MAX_POINTS_PER_VOXEL", MAX_VOXELS, MAX_POINTS_PER_VOXEL)
        dummy_input = dict()
        dummy_input['voxels'] = dummy_voxels
        dummy_input['voxel_num_points'] = dummy_voxel_num
        dummy_input['voxel_coords'] = dummy_voxel_idxs
        dummy_input['batch_size'] = torch.tensor(1)

        torch.onnx.export(model,       # model being run
          dummy_input,               # model input (or a tuple for multiple inputs)
          "./pointpillar_raw.onnx",  # where to save the model (can be a file or file-like object)
          export_params=True,        # store the trained parameter weights inside the model file
          opset_version=11,          # the ONNX version to export the model to
          do_constant_folding=True,  # whether to execute constant folding for optimization
          keep_initializers_as_inputs=True,
          input_names = ['voxels', 'voxel_num', 'voxel_idxs'],   # the model's input names
          output_names = ['cls_preds', 'box_preds', 'dir_cls_preds'], # the model's output names
          )

        onnx_raw = onnx.load("./pointpillar_raw.onnx")  # load onnx model
        onnx_trim_post = simplify_postprocess(onnx_raw, FEATURE_SIZE_X, FEATURE_SIZE_Y, NUMBER_OF_CLASSES)

        onnx_simp, check = simplify(onnx_trim_post)
        assert check, "Simplified ONNX model could not be validated"

        onnx_final = simplify_preprocess(onnx_simp, VOXEL_SIZE_Y, VOXEL_SIZE_X, MAX_POINTS_PER_VOXEL)
        onnx.save(onnx_final, "pointpillar.onnx")
        print('finished exporting onnx')

    logger.info('[PASS] ONNX EXPORTED.')


if __name__ == '__main__':
    main()

simplifier_onnx.py

# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import onnx
import numpy as np
import onnx_graphsurgeon as gs

@gs.Graph.register()
def replace_with_clip(self, inputs, outputs,  voxel_array):
    for inp in inputs:
        inp.outputs.clear()

    for out in outputs:
        out.inputs.clear()

    op_attrs = dict()
    op_attrs["dense_shape"] =  voxel_array

    return self.layer(name="PPScatter_0", op="PPScatterPlugin", inputs=inputs, outputs=outputs, attrs=op_attrs)


def loop_node(graph, current_node, loop_time=0):
    for i in range(loop_time):
        next_node = [node for node in graph.nodes if len(node.inputs) != 0 and len(current_node.outputs) != 0 and node.inputs[0] == current_node.outputs[0]][0]
        current_node = next_node
    return next_node


def simplify_postprocess(onnx_model, FEATURE_SIZE_X, FEATURE_SIZE_Y, NUMBER_OF_CLASSES):
    print("Use onnx_graphsurgeon to adjust postprocessing part in the onnx...")
    graph = gs.import_onnx(onnx_model)

    cls_preds = gs.Variable(name="cls_preds", dtype=np.float32, shape=(1, int(FEATURE_SIZE_Y), int(FEATURE_SIZE_X), 2 * NUMBER_OF_CLASSES * NUMBER_OF_CLASSES))
    box_preds = gs.Variable(name="box_preds", dtype=np.float32, shape=(1, int(FEATURE_SIZE_Y), int(FEATURE_SIZE_X), 14 * NUMBER_OF_CLASSES))
    dir_cls_preds = gs.Variable(name="dir_cls_preds", dtype=np.float32, shape=(1, int(FEATURE_SIZE_Y), int(FEATURE_SIZE_X), 4 * NUMBER_OF_CLASSES))

    tmap = graph.tensors()
    new_inputs = [tmap["voxels"], tmap["voxel_idxs"], tmap["voxel_num"]]
    new_outputs = [cls_preds, box_preds, dir_cls_preds]

    for inp in graph.inputs:
      if inp not in new_inputs:
        inp.outputs.clear()

    for out in graph.outputs:
      out.inputs.clear()

    first_ConvTranspose_node = [node for node in graph.nodes if node.op == "ConvTranspose"][0]
    concat_node = loop_node(graph, first_ConvTranspose_node, 3)
    assert concat_node.op == "Concat"

    first_node_after_concat = [node for node in graph.nodes if len(node.inputs) != 0 and len(concat_node.outputs) != 0 and node.inputs[0] == concat_node.outputs[0]]

    for i in range(3):
        transpose_node = loop_node(graph, first_node_after_concat[i], 1)
        assert transpose_node.op == "Transpose"
        transpose_node.outputs = [new_outputs[i]]

    graph.inputs = new_inputs
    graph.outputs = new_outputs
    graph.cleanup().toposort()

    return gs.export_onnx(graph)


def simplify_preprocess(onnx_model, VOXEL_SIZE_Y, VOXEL_SIZE_X, MAX_POINTS_PER_VOXEL):
    print("Use onnx_graphsurgeon to modify onnx...")
    graph = gs.import_onnx(onnx_model)

    tmap = graph.tensors()
    MAX_VOXELS = tmap["voxels"].shape[0]
    print("ra35 DEBUG VOXEL_SIZE_Y, VOXEL_SIZE_X ", VOXEL_SIZE_Y, VOXEL_SIZE_X)

    VOXEL_ARRAY = np.array([int(VOXEL_SIZE_Y), int(VOXEL_SIZE_X)])

    # voxels: [V, P, C']
    # V is the maximum number of voxels per frame
    # P is the maximum number of points per voxel
    # C' is the number of channels(features) per point in voxels.
    input_new = gs.Variable(name="voxels", dtype=np.float32, shape=(MAX_VOXELS, MAX_POINTS_PER_VOXEL, 10))

    # voxel_idxs: [V, 4]
    # V is the maximum number of voxels per frame
    # 4 is just the length of indexs encoded as (frame_id, z, y, x).
    X = gs.Variable(name="voxel_idxs", dtype=np.int32, shape=(MAX_VOXELS, 4))

    # voxel_num: [1]
    # Gives valid voxels number for each frame
    Y = gs.Variable(name="voxel_num", dtype=np.int32, shape=(1,))

    first_node_after_pillarscatter = [node for node in graph.nodes if node.op == "Conv"][0]

    first_node_pillarvfe = [node for node in graph.nodes if node.op == "MatMul"][0]

    next_node = current_node = first_node_pillarvfe
    for i in range(6):
        next_node = [node for node in graph.nodes if node.inputs[0] == current_node.outputs[0]][0]
        if i == 5:              # ReduceMax
            current_node.attrs['keepdims'] = [0]
            break
        current_node = next_node

    last_node_pillarvfe = current_node

    #merge some layers into one layer between inputs and outputs as below
    graph.inputs.append(Y)
    inputs = [last_node_pillarvfe.outputs[0], X, Y]
    outputs = [first_node_after_pillarscatter.inputs[0]]
    graph.replace_with_clip(inputs, outputs,  VOXEL_ARRAY)

    # Remove the now-dangling subgraph.
    graph.cleanup().toposort()

    #just keep some layers between inputs and outputs as below
    graph.inputs = [first_node_pillarvfe.inputs[0] , X, Y]
    graph.outputs = [tmap["cls_preds"], tmap["box_preds"], tmap["dir_cls_preds"]]

    graph.cleanup()

    #Rename the first tensor for the first layer
    graph.inputs = [input_new, X, Y]
    first_add = [node for node in graph.nodes if node.op == "MatMul"][0]
    first_add.inputs[0] = input_new

    graph.cleanup().toposort()

    return gs.export_onnx(graph)


if __name__ == '__main__':
    mode_file = "pointpillar-native-sim.onnx"
    simplify_preprocess(onnx.load(mode_file))

Hello, I am facing the same problem.

Since my point clouds are very dense, I increased my MAX_POINTS_PER_VOXEL to 200, which leads to much better results with PyTorch. However, after onnx-conversion followed by TFRT inference, all predicted bounding boxes are wrong.