deepchecks / deepchecks

Deepchecks: Tests for Continuous Validation of ML Models & Data. Deepchecks is a holistic open-source solution for all of your AI & ML validation needs, enabling to thoroughly test your data and models from research to production.

Home Page:https://docs.deepchecks.com/stable

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Error adding custom scorers to SimpleModelComparison check

noamzbr opened this issue · comments

Discussed in #2588

Originally posted by cam2rogers June 7, 2023
Hi there!

I have been having trouble adding a list of custom scorers to the simple model comparison check. It works just fine with the default configuration. The issue is attached below: the display outputs correctly, but an error occurs with the result of the check. I believe the issue stems from this CheckResult command in the run_logic() function.

image

Is there an easy fix to this issue? Code is attached below for reproducibility.

from deepchecks.tabular.checks import SimpleModelComparison
from deepchecks.tabular import Dataset
from xgboost import XGBClassifier
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_breast_cancer

# Sample data
bc = load_breast_cancer(as_frame=True)
X = bc['data']
y = bc['target']

X = X[['worst radius', 'worst texture', 'worst perimeter', 'worst area', 'worst smoothness', 'worst compactness', 'worst concavity', 'worst concave points', 'worst symmetry', 'worst fractal dimension']]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=2000, stratify=y)

# Fit model
model = XGBClassifier()
model.fit(X_train, y_train)

# Custom scorers
alt_scorers = ['accuracy', 'roc_auc', 'f1']

# Create train and test DC Dataset objects
train_ds = Dataset(X_train, label=y_train, cat_features=[], label_type='binary')
test_ds = Dataset(X_test, label=y_test, cat_features=[], label_type='binary')

# Simple model check
sm_check = SimpleModelComparison(strategy='most_frequent', scorers=alt_scorers)
sm_check.add_condition_gain_greater_than(0.25)
sm_check.run(train_ds, test_ds, model)