openvinotoolkit / training_extensions

Train, Evaluate, Optimize, Deploy Computer Vision Models via OpenVINO™

Home Page:https://openvinotoolkit.github.io/training_extensions/

Repository from Github https://github.comopenvinotoolkit/training_extensionsRepository from Github https://github.comopenvinotoolkit/training_extensions

FMeasureCallable bug inference with OV model

eugene123tw opened this issue · comments

Describe the bug

It seems like an error from FMeasureCallable. The default mAP from torchmetric works fine.

Steps to Reproduce

  1. otx train --config src/otx/recipe/instance_segmentation/maskrcnn_r50.yaml --data_root tests/assets/car_tree_bug --work_dir otx-workspace
  2. otx export --work_dir otx-workspace --checkpoint otx-workspace/.latest/train/best_checkpoint.ckpt
  3. otx test --config otx-workspace/.latest/train/configs.yaml --checkpoint otx-workspace/.latest/export/exported_model.xml --metric otx.core.metrics.fmeasure.FMeasureCallable

Environment:

╭────────────────────────────────────────────────────────────────────────────────────────────────────────────── Traceback (most recent call last) ───────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ /home/yuchunli/git/otx-2x/src/otx/cli/cli.py:522 in run                                                                                                                                                                                                        │
│                                                                                                                                                                                                                                                                │
│   519 │   │   │   fn_kwargs = self.prepare_subcommand_kwargs(self.subcommand)                                                                                                                                                                                  │
│   520 │   │   │   fn = getattr(self.engine, self.subcommand)                                                                                                                                                                                                   │
│   521 │   │   │   try:                                                                                                                                                                                                                                         │
│ ❱ 522 │   │   │   │   fn(**fn_kwargs)                                                                                                                                                                                                                          │
│   523 │   │   │   except Exception:                                                                                                                                                                                                                            │
│   524 │   │   │   │   self.console.print_exception(width=self.console.width)                                                                                                                                                                                   │
│   525 │   │   │   │   raise                                                                                                                                                                                                                                    │
│                                                                                                                                                                                                                                                                │
│ /home/yuchunli/git/otx-2x/src/otx/engine/engine.py:380 in test                                                                                                                                                                                                 │
│                                                                                                                                                                                                                                                                │
│   377 │   │   self._build_trainer(**kwargs)                                                                                                                                                                                                                    │
│   378 │   │                                                                                                                                                                                                                                                    │
│   379 │   │   with override_metric_callable(model=model, new_metric_callable=metric) as model:                                                                                                                                                                 │
│ ❱ 380 │   │   │   self.trainer.test(                                                                                                                                                                                                                           │
│   381 │   │   │   │   model=model,                                                                                                                                                                                                                             │
│   382 │   │   │   │   dataloaders=datamodule,                                                                                                                                                                                                                  │
│   383 │   │   │   )                                                                                                                                                                                                                                            │
│                                                                                                                                                                                                                                                                │
│ /home/yuchunli/git/otx-2x/venv/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:754 in test                                                                                                                                                   │
│                                                                                                                                                                                                                                                                │
│    751 │   │   self.state.fn = TrainerFn.TESTING                                                                                                                                                                                                               │
│    752 │   │   self.state.status = TrainerStatus.RUNNING                                                                                                                                                                                                       │
│    753 │   │   self.testing = True                                                                                                                                                                                                                             │
│ ❱  754 │   │   return call._call_and_handle_interrupt(                                                                                                                                                                                                         │
│    755 │   │   │   self, self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule                                                                                                                                                                   │
│    756 │   │   )                                                                                                                                                                                                                                               │
│    757                                                                                                                                                                                                                                                         │
│                                                                                                                                                                                                                                                                │
│ /home/yuchunli/git/otx-2x/venv/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:44 in _call_and_handle_interrupt                                                                                                                                 │
│                                                                                                                                                                                                                                                                │
│    41 │   try:                                                                                                                                                                                                                                                 │
│    42 │   │   if trainer.strategy.launcher is not None:                                                                                                                                                                                                        │
│    43 │   │   │   return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer,                                                                                                                                                                  │
│ ❱  44 │   │   return trainer_fn(*args, **kwargs)                                                                                                                                                                                                               │
│    45 │                                                                                                                                                                                                                                                        │
│    46 │   except _TunerExitException:                                                                                                                                                                                                                          │
│    47 │   │   _call_teardown_hook(trainer)                                                                                                                                                                                                                     │
│                                                                                                                                                                                                                                                                │
│ /home/yuchunli/git/otx-2x/venv/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:794 in _test_impl                                                                                                                                             │
│                                                                                                                                                                                                                                                                │
│    791 │   │   ckpt_path = self._checkpoint_connector._select_ckpt_path(                                                                                                                                                                                       │
│    792 │   │   │   self.state.fn, ckpt_path, model_provided=model_provided, model_connected=sel                                                                                                                                                                │
│    793 │   │   )                                                                                                                                                                                                                                               │
│ ❱  794 │   │   results = self._run(model, ckpt_path=ckpt_path)                                                                                                                                                                                                 │
│    795 │   │   # remove the tensors from the test results                                                                                                                                                                                                      │
│    796 │   │   results = convert_tensors_to_scalars(results)                                                                                                                                                                                                   │
│    797                                                                                                                                                                                                                                                         │
│                                                                                                                                                                                                                                                                │
│ /home/yuchunli/git/otx-2x/venv/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:989 in _run                                                                                                                                                   │
│                                                                                                                                                                                                                                                                │
│    986 │   │   # ----------------------------                                                                                                                                                                                                                  │
│    987 │   │   # RUN THE TRAINER                                                                                                                                                                                                                               │
│    988 │   │   # ----------------------------                                                                                                                                                                                                                  │
│ ❱  989 │   │   results = self._run_stage()                                                                                                                                                                                                                     │
│    990 │   │                                                                                                                                                                                                                                                   │
│    991 │   │   # ----------------------------                                                                                                                                                                                                                  │
│    992 │   │   # POST-Training CLEAN UP                                                                                                                                                                                                                        │
│                                                                                                                                                                                                                                                                │
│ /home/yuchunli/git/otx-2x/venv/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:1028 in _run_stage                                                                                                                                            │
│                                                                                                                                                                                                                                                                │
│   1025 │   │   self.lightning_module.zero_grad(**zero_grad_kwargs)                                                                                                                                                                                             │
│   1026 │   │                                                                                                                                                                                                                                                   │
│   1027 │   │   if self.evaluating:                                                                                                                                                                                                                             │
│ ❱ 1028 │   │   │   return self._evaluation_loop.run()                                                                                                                                                                                                          │
│   1029 │   │   if self.predicting:                                                                                                                                                                                                                             │
│   1030 │   │   │   return self.predict_loop.run()                                                                                                                                                                                                              │
│   1031 │   │   if self.training:                                                                                                                                                                                                                               │
│                                                                                                                                                                                                                                                                │
│ /home/yuchunli/git/otx-2x/venv/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py:182 in _decorator                                                                                                                                             │
│                                                                                                                                                                                                                                                                │
│   179 │   │   else:                                                                                                                                                                                                                                            │
│   180 │   │   │   context_manager = torch.no_grad                                                                                                                                                                                                              │
│   181 │   │   with context_manager():                                                                                                                                                                                                                          │
│ ❱ 182 │   │   │   return loop_run(self, *args, **kwargs)                                                                                                                                                                                                       │
│   183 │                                                                                                                                                                                                                                                        │
│   184 │   return _decorator                                                                                                                                                                                                                                    │
│   185                                                                                                                                                                                                                                                          │
│                                                                                                                                                                                                                                                                │
│ /home/yuchunli/git/otx-2x/venv/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py:141 in run                                                                                                                                              │
│                                                                                                                                                                                                                                                                │
│   138 │   │   │   finally:                                                                                                                                                                                                                                     │
│   139 │   │   │   │   self._restarting = False                                                                                                                                                                                                                 │
│   140 │   │   self._store_dataloader_outputs()                                                                                                                                                                                                                 │
│ ❱ 141 │   │   return self.on_run_end()                                                                                                                                                                                                                         │
│   142 │                                                                                                                                                                                                                                                        │
│   143 │   def setup_data(self) -> None:                                                                                                                                                                                                                        │
│   144 │   │   trainer = self.trainer                                                                                                                                                                                                                           │
│                                                                                                                                                                                                                                                                │
│ /home/yuchunli/git/otx-2x/venv/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py:253 in on_run_end                                                                                                                                       │
│                                                                                                                                                                                                                                                                │
│   250 │   │   self.trainer._logger_connector._evaluation_epoch_end()                                                                                                                                                                                           │
│   251 │   │                                                                                                                                                                                                                                                    │
│   252 │   │   # hook                                                                                                                                                                                                                                           │
│ ❱ 253 │   │   self._on_evaluation_epoch_end()                                                                                                                                                                                                                  │
│   254 │   │                                                                                                                                                                                                                                                    │
│   255 │   │   logged_outputs, self._logged_outputs = self._logged_outputs, []  # free memory                                                                                                                                                                   │
│   256 │   │   # include any logged outputs on epoch_end                                                                                                                                                                                                        │
│                                                                                                                                                                                                                                                                │
│ /home/yuchunli/git/otx-2x/venv/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py:329 in _on_evaluation_epoch_end                                                                                                                         │
│                                                                                                                                                                                                                                                                │
│   326 │   │                                                                                                                                                                                                                                                    │
│   327 │   │   hook_name = "on_test_epoch_end" if trainer.testing else "on_validation_epoch_end                                                                                                                                                                 │
│   328 │   │   call._call_callback_hooks(trainer, hook_name)                                                                                                                                                                                                    │
│ ❱ 329 │   │   call._call_lightning_module_hook(trainer, hook_name)                                                                                                                                                                                             │
│   330 │   │                                                                                                                                                                                                                                                    │
│   331 │   │   trainer._logger_connector.on_epoch_end()                                                                                                                                                                                                         │
│   332                                                                                                                                                                                                                                                          │
│                                                                                                                                                                                                                                                                │
│ /home/yuchunli/git/otx-2x/venv/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:157 in _call_lightning_module_hook                                                                                                                               │
│                                                                                                                                                                                                                                                                │
│   154 │   pl_module._current_fx_name = hook_name                                                                                                                                                                                                               │
│   155 │                                                                                                                                                                                                                                                        │
│   156 │   with trainer.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{hoo                                                                                                                                                                 │
│ ❱ 157 │   │   output = fn(*args, **kwargs)                                                                                                                                                                                                                     │
│   158 │                                                                                                                                                                                                                                                        │
│   159 │   # restore current_fx when nested context                                                                                                                                                                                                             │
│   160 │   pl_module._current_fx_name = prev_fx_name                                                                                                                                                                                                            │
│                                                                                                                                                                                                                                                                │
│ /home/yuchunli/git/otx-2x/src/otx/core/model/base.py:248 in on_test_epoch_end                                                                                                                                                                                  │
│                                                                                                                                                                                                                                                                │
│   245 │                                                                                                                                                                                                                                                        │
│   246 │   def on_test_epoch_end(self) -> None:                                                                                                                                                                                                                 │
│   247 │   │   """Callback triggered when the test epoch ends."""                                                                                                                                                                                               │
│ ❱ 248 │   │   self._log_metrics(self.metric, "test")                                                                                                                                                                                                           │
│   249 │                                                                                                                                                                                                                                                        │
│   250 │   def setup(self, stage: str) -> None:                                                                                                                                                                                                                 │
│   251 │   │   """Lightning hook that is called at the beginning of fit (train + validate), val                                                                                                                                                                 │
│                                                                                                                                                                                                                                                                │
│ /home/yuchunli/git/otx-2x/src/otx/core/model/instance_segmentation.py:698 in _log_metrics                                                                                                                                                                      │
│                                                                                                                                                                                                                                                                │
│   695 │   def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwargs)                                                                                                                                                                 │
│   696 │   │   best_confidence_threshold = self.hparams.get("best_confidence_threshold", 0.0)                                                                                                                                                                   │
│   697 │   │   compute_kwargs = {"best_confidence_threshold": best_confidence_threshold}                                                                                                                                                                        │
│ ❱ 698 │   │   return super()._log_metrics(meter, key, **compute_kwargs)                                                                                                                                                                                        │
│   699                                                                                                                                                                                                                                                          │
│                                                                                                                                                                                                                                                                │
│ /home/yuchunli/git/otx-2x/src/otx/core/model/base.py:324 in _log_metrics                                                                                                                                                                                       │
│                                                                                                                                                                                                                                                                │
│   321 │   │   │   msg = f"These keyword arguments are removed since they are not in the functi                                                                                                                                                                 │
│   322 │   │   │   logger.debug(msg)                                                                                                                                                                                                                            │
│   323 │   │                                                                                                                                                                                                                                                    │
│ ❱ 324 │   │   results: dict[str, Tensor] = meter.compute(**filtered_kwargs)                                                                                                                                                                                    │
│   325 │   │                                                                                                                                                                                                                                                    │
│   326 │   │   if not isinstance(results, dict):                                                                                                                                                                                                                │
│   327 │   │   │   raise TypeError(results)                                                                                                                                                                                                                     │
│                                                                                                                                                                                                                                                                │
│ /home/yuchunli/git/otx-2x/venv/lib/python3.10/site-packages/torchmetrics/metric.py:616 in wrapped_func                                                                                                                                                         │
│                                                                                                                                                                                                                                                                │
│    613 │   │   │   │   should_sync=self._to_sync,                                                                                                                                                                                                              │
│    614 │   │   │   │   should_unsync=self._should_unsync,                                                                                                                                                                                                      │
│    615 │   │   │   ):                                                                                                                                                                                                                                          │
│ ❱  616 │   │   │   │   value = _squeeze_if_scalar(compute(*args, **kwargs))                                                                                                                                                                                    │
│    617 │   │   │                                                                                                                                                                                                                                               │
│    618 │   │   │   if self.compute_with_cache:                                                                                                                                                                                                                 │
│    619 │   │   │   │   self._computed = value                                                                                                                                                                                                                  │
│                                                                                                                                                                                                                                                                │
│ /home/yuchunli/git/otx-2x/src/otx/core/metrics/fmeasure.py:711 in compute                                                                                                                                                                                      │
│                                                                                                                                                                                                                                                                │
│   708 │   │                                                                                                                                                                                                                                                    │
│   709 │   │   if best_confidence_threshold is not None:                                                                                                                                                                                                        │
│   710 │   │   │   (index,) = np.where(                                                                                                                                                                                                                         │
│ ❱ 711 │   │   │   │   np.isclose(list(np.arange(*boxes_pair.confidence_range)), best_confidenc                                                                                                                                                                 │
│   712 │   │   │   )                                                                                                                                                                                                                                            │
│   713 │   │   │   computed_f_measure = result.per_confidence.all_classes_f_measure_curve[int(i                                                                                                                                                                 │
│   714 │   │   else:                                                                                                                                                                                                                                            │
│                                                                                                                                                                                                                                                                │
│ /home/yuchunli/git/otx-2x/venv/lib/python3.10/site-packages/numpy/core/numeric.py:2345 in isclose                                                                                                                                                              │
│                                                                                                                                                                                                                                                                │
│   2342 │   #       timedelta works if `atol` is an integer or also a timedelta.                                                                                                                                                                                │
│   2343 │   #       Although, the default tolerances are unlikely to be useful                                                                                                                                                                                  │
│   2344 │   if y.dtype.kind != "m":                                                                                                                                                                                                                             │
│ ❱ 2345 │   │   dt = multiarray.result_type(y, 1.)                                                                                                                                                                                                              │
│   2346 │   │   y = asanyarray(y, dtype=dt)                                                                                                                                                                                                                     │
│   2347 │                                                                                                                                                                                                                                                       │
│   2348 │   xfin = isfinite(x)                                                                                                                                                                                                                                  │
╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
DTypePromotionError: The DType <class 'numpy._FloatAbstractDType'> could not be promoted by <class 'numpy.dtypes.StrDType'>. This means that no common DType exists for the given inputs. For example they cannot be stored in a single array unless the dtype is 
`object`. The full list of DTypes is: (<class 'numpy.dtypes.StrDType'>, <class 'numpy._FloatAbstractDType'>)
Backend TkAgg is interactive backend. Turning interactive mode on.

#3441 - releated PR is merged