qiqihaer / RandLA-Net-pytorch

RandLA-Net's implementation with Pytorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Question about how to get the prediction of all point in each frame?

LeopoldACC opened this issue · comments

Hi,@qiqihaer ,Thanks for your reproducing the work in PyTorch.
Because I want to get the prediction of all point in each frame,but I found that Randla-Net and KPConv both do not support that form of input.The 2 works both use sampler to generate subsampling data in cyclic way,if I change the code as below that use whole of point,the evaluate result is too different as the public result and your test result.

eval mean acc: 0.182902
eval mean loss: 53.351186
mean IoU:4.3
IoU: 0.34  0.13  0.00  0.50  0.13  0.00  0.00  0.00  6.07  2.26  6.12  0.31 26.94  0.13 23.01  2.05 10.04  2.29  0.99
    def __getitem__(self, item):

        selected_pc, selected_labels, selected_idx, cloud_ind = self.spatially_regular_gen(item)
        return selected_pc, selected_labels, selected_idx, cloud_ind



    def spatially_regular_gen(self, item):
        # Generator loop

        if self.mode != 'test':
            cloud_ind = item
            pc_path = self.data_list[cloud_ind]
            pc, tree, labels = self.get_data(pc_path)
            # crop a small point cloud
            pick_idx = np.random.choice(len(pc), 1)[0]
            selected_pc, selected_labels, selected_idx = self.crop_pc(pc, labels, tree, pick_idx)
        else:
            cloud_ind = int(np.argmin(self.min_possibility))
            pick_idx = np.argmin(self.possibility[cloud_ind])
            pc_path = path_list[cloud_ind]
            pc, tree, labels = self.get_data(pc_path)
            selected_pc, selected_labels, selected_idx = self.crop_pc(pc, labels, tree, pick_idx)

            # update the possibility of the selected pc
            dists = np.sum(np.square((selected_pc - pc[pick_idx]).astype(np.float32)), axis=1)
            delta = np.square(1 - dists / np.max(dists))
            self.possibility[cloud_ind][selected_idx] += delta
            self.min_possibility[cloud_ind] = np.min(self.possibility[cloud_ind])

        return selected_pc.astype(np.float32), selected_labels.astype(np.int32), selected_idx.astype(np.int32), np.array([cloud_ind], dtype=np.int32)

    def get_data(self, file_path):
        seq_id = file_path.split('/')[-3]
        frame_id = file_path.split('/')[-1][:-4]
        path_list = file_path.split('/')
        point_file_path = '/'.join(path_list)
        path_list[-2] = "labels"
        path_list[-1] = frame_id + '.label'
        label_file_path = '/'.join(path_list)
        kd_tree_path = join(self.dataset_path, seq_id, 'KDTree', frame_id + '.pkl')
        if os.path.exists(kd_tree_path):
            # Read pkl with search tree
            with open(kd_tree_path, 'rb') as f:
                search_tree = pickle.load(f)
            points = np.array(search_tree.data, copy=False)
        else:
            points_origin = DP.load_pc_kitti(point_file_path)
            labels = DP.load_label_kitti(join(label_file_path), remap_lut)
            sub_points, sub_labels = DP.grid_sub_sampling(points_origin, labels=labels, grid_size=grid_size)
            search_tree = KDTree(sub_points)
            # KDTree_save = join(KDTree_path_out, str(scan_id[:-4]) + '.pkl')
            # np.save(join(pc_path_out, scan_id)[:-4], sub_points)
            # np.save(join(label_path_out, scan_id)[:-4], sub_labels)
            # with open(KDTree_save, 'wb') as f:
            #     pickle.dump(search_tree, f)
            # if seq_id == '08':
            #     proj_path = join(seq_path_out, 'proj')
            #     os.makedirs(proj_path) if not exists(proj_path) else None
            #     proj_inds = np.squeeze(search_tree.query(points_origin, return_distance=False))
            #     proj_inds = proj_inds.astype(np.int32)
            #     proj_save = join(proj_path, str(scan_id[:-4]) + '_proj.pkl')
            #     with open(proj_save, 'wb') as f:
            #         pickle.dump([proj_inds], f)
            points = np.array(search_tree.data, copy=False)
            labels = sub_labels
        # Load labels
        # if int(seq_id) >= 11:
        #     labels = np.zeros(np.shape(points)[0], dtype=np.uint8)
        # else:
        #     label_path = join(self.dataset_path, seq_id, 'labels', frame_id + '.npy')
        #     labels = np.squeeze(np.load(label_path))
        return points, search_tree, labels

If you would like, see this repo and this pr. tsunghan-wu/RandLA-Net-pytorch#8 (comment)