dreamquark-ai / tabnet

PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf

Home Page:https://dreamquark-ai.github.io/tabnet/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

TabNetClassifier explainability

ranellout opened this issue · comments

Hello
First of all I want to thank you for the awesome package, I've had some really nice results with the TabNetClassifier!

I wanted to ask - Is there a way to interpret the TabNetClassifier feature importance with SHAP values or some other nice visualization package?
I was able to visualize the feature importance and it's nice but it doesn't tell the direction of prediction contribution for each feature. The local explainability with the masks didn't worked for me at all..
Thanks!!

The code I've tried for SHAP visualization:

background_adult = shap.maskers.Independent(X_valid, max_samples=100)
explainer = shap.Explainer(clf.predict_proba, background_adult)
shap_values = explainer(X_valid[:100])
shap.plots.beeswarm(shap_values)

The above code raised the following error: "The passed model is not callable and cannot be analyzed directly with the given masker!"

other visualization for SHAP didn't work also.

I'm not looking for SHAP implementation specifically, rather than some sort of visualization for the feature importance.

Thank you!

what happens if you change the explainer to this : explainer = shap.Explainer(clf, background_adult)

unfortunately, it still raises the same error

Is the issue still persists?

As it's independent from the repo it will probably persist quite a while!