Boolean data column lead to wrong predictions
pschleiter opened this issue · comments
- ebm2onnx version: 3.1.1
- Python version: 3.10.13
- Operating System: Windows
Description
I like to convert an ExplainableBoostingClassifier
in to an onnx model. The data is containing columns with the dtype bool
. Converting to onnx is working, but the resulting onnx model does not produce the same probabilities (not even close).
The workaround I found, was converting the columns of type bool
to int
. Then the probabilities where almost equal.
Although the workaround works fine for me, I would be thankful if this could be fixed.
What I Did
Here is a minimal example demonstrating the issue:
import pandas as pd
from interpret.glassbox import ExplainableBoostingClassifier
import ebm2onnx
import onnxruntime
# Small dataset
X = pd.DataFrame(
{
"feature1": [0,0,1,1] * 8,
"feature2": [0]*16 + [1]*16,
}
).astype("bool")
y = pd.Series( [0]*28 + [1]*4)
# Train model
model = ExplainableBoostingClassifier(
feature_types=["ordinal", "nominal"],
interactions=0
)
model.fit(X=X, y=y)
# Compute probability with onnx
proba = model.predict_proba(X)
# Compute probability with onnx
onnx_model = ebm2onnx.to_onnx(
model,
ebm2onnx.get_dtype_from_pandas(X),
predict_proba=True,
)
sess = onnxruntime.InferenceSession(onnx_model.SerializeToString())
onnx_proba = sess.run(
output_names=["probabilities"],
input_feed={
"feature1": X.loc[:, "feature1"].to_numpy(),
"feature2": X.loc[:, "feature2"].to_numpy(),
}
)[0]
from numpy.testing import assert_almost_equal
# Test
assert_almost_equal(
actual=onnx_proba,
desired=proba,
decimal=2,
)
#Example
#onnx_proba[:2]
# array([[0.8545578 , 0.14544217],
# [0.8545578 , 0.14544217]], dtype=float32)
# proba[:2]
# array([[0.90316828, 0.09683172],
# [0.90316828, 0.09683172]])
This is supposed to work. Thanks for the code sample, I will investigate it.
v3.1.2 fixes this issue but I reopen it because there is still one conversion failure.
If all values of the feature were True or all values were False, then the conversion fails.
@MainRo Thanks a lot for fixing it.