deepchecks / deepchecks

Deepchecks: Tests for Continuous Validation of ML Models & Data. Deepchecks is a holistic open-source solution for all of your AI & ML validation needs, enabling to thoroughly test your data and models from research to production.

Home Page:https://docs.deepchecks.com/stable

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[BUG] MyModelWrapper is incorrectly interpreted as "Regressor" for classification metrics

chris-santiago opened this issue · comments

Describe the bug
I'm unable to use scorers that use predict_proba method in model evaluation suite. I receive this error:

[deepchecks][WARNING] - ROC AUC failed with error message - "MyModelWrapper should either be a classifier to be used with response_method=predict_proba or the response_method should be 'predict'. Got a regressor with response_method=predict_proba instead.". setting scores as None

A little digging shows that it originates here in the deepchecks.tabular.metric_utils.scorers module, where a user's model is wrapped prior to passing to scikit-learn scorer functions.

The error occurs when the sklearn scorer checks to see whether estimator is classifier/regressor and whether scorer is calling predict or predict_proba method. See this if-else block.

Because the MyModelWrapper class (used in DeepcheckScorer class) does not declare a _estimator_type attribute, sklearn treats the wrapper as a regressor and raises a ValueError:

else:  # estimator is a regressor
        if response_method != "predict":
            raise ValueError(
                f"{estimator.__class__.__name__} should either be a classifier to be "
                f"used with response_method={response_method} or the response_method "
                "should be 'predict'. Got a regressor with response_method="
                f"{response_method} instead."
            )

Suggested Fix

Adding self._estimator_type = "classifier" to the MyModelWrapper.init may resolve this issue.

Environment (please complete the following information):

  • OS: Mac
  • Python Version: 3.11
  • Deepchecks Version: 0.18
  • sklearn Version: 1.3.2

Additional context
Add any other context about the problem here.