inference time not as fast as expected
fabricecarles opened this issue · comments
Hi, thanks for open-sourcing the code and model weights
As i said in a previous post I would like to use EfficientLoFTR to do a comparative benchmark in our study.
I found strange results in my benchmark since in size 1x1x256x256 Efficient LoFTR inference time is close to 26 ms
This is better than LoFTR witch run at ~40 ms for this resolution but very close to topicFMfast when measured on my GeForce RTX 2070 Mobile GPU
Since topicFMfast is not in your benchmark I would like to know if I do a mistake when using your code.
here is my inference code :
import time
import cv2
import numpy as np
import pytorch_lightning as pl
import argparse
import pprint
import torch
import kornia as K
import kornia.feature as KF
import matplotlib.pyplot as plt
from kornia_moons.viz import draw_LAF_matches
from loguru import logger as loguru_logger
from src.config.default import get_cfg_defaults
from src.utils.profiler import build_profiler
from src.lightning.data import MultiSceneDataModule
from src.lightning.lightning_loftr import PL_LoFTR
def parse_args():
# init a costum parser which will be added into pl.Trainer parser
# check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
'--data_cfg_path', type=str, default="configs/data/megadepth_test_1500.py", help='data config path')
parser.add_argument(
'--main_cfg_path', type=str, default="configs/loftr/eloftr_optimized.py", help='main config path')
parser.add_argument(
'--ckpt_path', type=str, default="weights/eloftr_outdoor.ckpt", help='path to the checkpoint')
parser.add_argument(
'--dump_dir', type=str, default=None, help="if set, the matching results will be dump to dump_dir")
parser.add_argument(
'--profiler_name', type=str, default=None, help='options: [inference, pytorch], or leave it unset')
parser.add_argument(
'--batch_size', type=int, default=1, help='batch_size per gpu')
parser.add_argument(
'--num_workers', type=int, default=2)
parser.add_argument(
'--thr', type=float, default=None, help='modify the coarse-level matching threshold.')
parser.add_argument(
'--pixel_thr', type=float, default=None, help='modify the RANSAC threshold.')
parser.add_argument(
'--ransac', type=str, default=None, help='modify the RANSAC method')
parser.add_argument(
'--scannetX', type=int, default=832, help='ScanNet resize X')
parser.add_argument(
'--scannetY', type=int, default=832, help='ScanNet resize Y')
parser.add_argument(
'--megasize', type=int, default=1152, help='MegaDepth resize')
parser.add_argument(
'--npe', action='store_true', default=False, help='')
parser.add_argument(
'--fp32', action='store_true', default=False, help='')
parser.add_argument(
'--ransac_times', type=int, default=None, help='repeat ransac multiple times for more robust evaluation')
parser.add_argument(
'--rmbd', type=int, default=None, help='remove border matches')
parser.add_argument(
'--deter', action='store_true', default=False, help='use deterministic mode for testing')
parser = pl.Trainer.add_argparse_args(parser)
return parser.parse_args()
def inplace_relu(m):
classname = m.__class__.__name__
if classname.find('ReLU') != -1:
m.inplace=True
if __name__ == '__main__':
# parse arguments
args = parse_args()
pprint.pprint(vars(args))
# init default-cfg and merge it with the main- and data-cfg
config = get_cfg_defaults()
config.merge_from_file(args.main_cfg_path)
config.merge_from_file(args.data_cfg_path)
if args.deter:
torch.backends.cudnn.deterministic = True
pl.seed_everything(config.TRAINER.SEED) # reproducibility
# tune when testing
if args.thr is not None:
config.LOFTR.MATCH_COARSE.THR = args.thr
if args.scannetX is not None and args.scannetY is not None:
config.DATASET.SCAN_IMG_RESIZEX = args.scannetX
config.DATASET.SCAN_IMG_RESIZEY = args.scannetY
if args.megasize is not None:
config.DATASET.MGDPT_IMG_RESIZE = args.megasize
if args.npe:
if config.LOFTR.COARSE.ROPE:
assert config.DATASET.NPE_NAME is not None
if config.DATASET.NPE_NAME is not None:
if config.DATASET.NPE_NAME == 'megadepth':
config.LOFTR.COARSE.NPE = [832, 832, config.DATASET.MGDPT_IMG_RESIZE, config.DATASET.MGDPT_IMG_RESIZE] # [832, 832, 1152, 1152]
elif config.DATASET.NPE_NAME == 'scannet':
config.LOFTR.COARSE.NPE = [832, 832, config.DATASET.SCAN_IMG_RESIZEX, config.DATASET.SCAN_IMG_RESIZEX] # [832, 832, 640, 640]
else:
config.LOFTR.COARSE.NPE = [832, 832, 832, 832]
if args.ransac_times is not None:
config.LOFTR.EVAL_TIMES = args.ransac_times
if args.rmbd is not None:
config.LOFTR.MATCH_COARSE.BORDER_RM = args.rmbd
if args.pixel_thr is not None:
config.TRAINER.RANSAC_PIXEL_THR = args.pixel_thr
if args.ransac is not None:
config.TRAINER.POSE_ESTIMATION_METHOD = args.ransac
if args.ransac == 'LO-RANSAC' and config.TRAINER.RANSAC_PIXEL_THR == 0.5:
config.TRAINER.RANSAC_PIXEL_THR = 2.0
if args.fp32:
config.LOFTR.FP16 = False
loguru_logger.info(f"Args and config initialized!")
# lightning module
profiler = build_profiler(args.profiler_name)
model = PL_LoFTR(config, pretrained_ckpt=args.ckpt_path, profiler=profiler, dump_dir=args.dump_dir)
loguru_logger.info(f"LoFTR-lightning initialized!")
model.matcher = model.matcher.eval().cuda()
# model.matcher = torch.compile(model.matcher)
print('start inference')
# Load example images
img0_pth = "assets/01.BMP"
img1_pth = "assets/02.BMP"
img0_raw = cv2.imread(img0_pth, cv2.IMREAD_GRAYSCALE)
img1_raw = cv2.imread(img1_pth, cv2.IMREAD_GRAYSCALE)
size = 256
img0_raw = cv2.resize(img0_raw, (size, size)) # input size shuold be divisible by 8
img1_raw = cv2.resize(img1_raw, (size, size))
img0 = torch.from_numpy(img0_raw)[None][None].cuda() / 255.
img1 = torch.from_numpy(img1_raw)[None][None].cuda() / 255.
data_dict = {'image0': img0, 'image1': img1, 'pair_names': ('01', '02'), 'dataset_name' : 'scan4all'}
print('image 0 size', img0.shape)
print('image 1 size', img1.shape)
# inference (with warmup)
num_inferences = 105
times = np.zeros(num_inferences)
with torch.no_grad():
with torch.autocast(enabled=config.LOFTR.FP16, device_type='cuda', dtype=torch.float16):
for i in range(num_inferences):
torch.cuda.current_stream().synchronize()
t0 = time.time()
model.matcher(data_dict)
torch.cuda.current_stream().synchronize()
t1 = time.time()
current_time = (t1 - t0) *1000
print(f"inference pytorch {current_time :.1f} [ms]")
times[i] = current_time
print('times ', times)
print(f"average inference time = {times[5:].mean() :.1f} [ms] std {times[5:].std() :.1f} for {num_inferences - 5} samples")
print('data_dict.keys()', data_dict.keys())
print('mconf', data_dict['mconf'].shape)
print('data_dict', data_dict['mkpts0_f'].shape)
print('data_dict', data_dict['mkpts1_f'].shape)
# print('mconf', data_dict['mconf'])
mkpts0 = data_dict['mkpts0_f']
mkpts1 = data_dict['mkpts1_f']
mconf = data_dict['mconf']
mkpts0 = mkpts0.cpu().numpy()
mkpts1 = mkpts1.cpu().numpy()
# inliers filtering
mconf = mconf.unsqueeze(1)
mconf = mconf.cpu().numpy()
mconf = mconf > 0.2
print("mconf", mconf.shape)
# plot matchs
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
draw_LAF_matches(
KF.laf_from_center_scale_ori(
torch.from_numpy(mkpts0).view(1, -1, 2),
torch.ones(mkpts0.shape[0]).view(1, -1, 1, 1),
torch.ones(mkpts0.shape[0]).view(1, -1, 1),
),
KF.laf_from_center_scale_ori(
torch.from_numpy(mkpts1).view(1, -1, 2),
torch.ones(mkpts1.shape[0]).view(1, -1, 1, 1),
torch.ones(mkpts1.shape[0]).view(1, -1, 1),
),
torch.arange(mkpts0.shape[0]).view(-1, 1).repeat(1, 2),
K.tensor_to_image(img0),
K.tensor_to_image(img1),
mconf,
draw_dict={"inlier_color": (0.2, 1, 0.2), "tentative_color": None, "feature_color": (0.2, 0.5, 1), "vertical": False},
ax=ax
)
plt.savefig(f"assets/output_filtered_by_confidence_size{size}_num-match{len(mconf)}_{(t1 - t0) *1000 :0.1f}_ms.png")
here is my environment setup :
conda env create -f environment.yaml
conda activate eloftr
pip install torch==2.0.0+cu118 --index-url https://download.pytorch.org/whl/cu118
pip install -r requirements.txt
pip install kornia_moons
python inference.py
- average inference time = 26.0 [ms] std 2.5 for 100 samples at 1x1x256x256
- average inference time for topicFMfast is close to 30 ms on same image pair same PC and same GPU
Did I miss something to make your code more efficient ?
Bests
Thank you for sharing your results. Here are some suggestions at first glance.
-
Please use reparameterization before inference like this line:
which is significant to inference speed.
-
We only compare with TopicFM, as TopicFM+ employs a significantly higher number of OpenCV RANSAC iterations (10k vs. the standard 1k in other baselines) in their code, which greatly improves AUC but also substantially slows down RANSAC. Evaluating inference speed without considering accuracy isn't meaningful.
Megadepth AUC@(5,10,20) LoFTR 52.8 / 69.2 / 81.2 TopicFM 54.1 / 70.1 / 81.6 TopicFM+ 52.2 / 68.8 / 81.1 Ours 56.4 / 72.2 / 83.5 Ours (Opt.) 55.4 / 71.4 / 82.9 TopicFM+(10k) 58.2 / 72.8 / 83.2 Ours(10k) 59.3 / 74.1 / 84.6 -
If you run our timing scripts on an RTX3090, you will achieve the exact timings of 34ms (Full) and 27ms (Opt.) as reported in our paper.
We will provide a jupyter notebook demo to show how to use our model later, please stay tuned!
thanks for your advice with self.matcher = reparameter(self.matcher) inference time is improved a little bit
in your readme you plan to Add options of flash-attention and torch.compiler for better performance
is there other performance improvements expected ?
Sorry for the late reply. Yes, there's also FP16 inference. We have already modified some of the code and added a Jupyter notebook to demonstrate how to use FP16 inference (on modern GPUs) to accelerate our model. This will provide even faster speeds than mixed precision with almost no loss in accuracy.