dreamquark-ai / tabnet

PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf

Home Page:https://dreamquark-ai.github.io/tabnet/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Interpreting Sparsity on Global Importance

SantaTitular opened this issue · comments

Hello again,

I'm using TabNetClassifier to study a binary classification problem. Although I have quite small balanced datasets (~1k samples ), I'm testing the accuracy with 1601 and 300 features and get pretty good results (and is not overfitting). After performing tunning of hyperparameters according to your suggestion in the paper, getting the best model, and computing the gloal feature imporance of the original training dataset (global_expl = loaded_clf._compute_feature_importances(X_original)#Global explainability) I often get a very sparsed vector for high aucs(90%+) and somewhat sparsed in less aucs(~75/80%). The sparsed vector is often a peak in a specific feature. Am I correct in interpreting that the sole feature can be used to perform the binary classification? How come that with less optimal results I get other somewhat equally important features?

Cheers,
Tomás

Hello @SantaTitular,

Just to clarify, this is not the official repository of the original TabNet paper, simply an unofficial implementation in pytorch.

Global feature importance gives you a global idea of which feature is mostly looked at by the attention layers. However, TabNet has a built in attention mechanism and can therefore select the features for each input point. That means that a feature which is rarely used will have a lower global importance, it does not mean that it is completely useless and never used by the model.

If you want to know which features are useless and can be dropped without any loss in prediction accuracy you'll have to perform a different feature importance analysis: something like permutation importance or simply removing weak columns in the training set and monitor the loss in prediction accuracy.

To sum up:

  • you can't say that the most important feature can reach the same level of accuracy alone than with other features
  • you could also see no drop in accuracy when removing one important feature : this is because information could be available on other columns. Imagine taking one important column and creating two columns for this one which would sum to the original one -> it's easier for the model to only look at the original column but if it is not available the model might use the two other to get the same information.

Hi @Optimox , Thanks for the clarification!

Regarding the attention, maybe I'm confusing the interpretability the global attention of the model with the feature selection process. Althoguh TabNet introduces sparsity in its feature selection process, global attention is an aggregation of the local masks and, thus, I was not expecting such sparse, Dirac like outputs in high accuracy cases. Since it is just a binary classification I though this meant that TabNet would not require such a high amount of features, hence we could discard very low importance ones. But I'm guessing that by just using a couple of features (the Dirac ones) and discarding the very low importance ones, we fundamentally alter the information that the TabNet model used to make the decisison? I guess thats you suggest that I probably need to do a different study to know which features are not required as using the permutation importance or removing variables.

With this, I do think that your suggestions are correct. I'm just trying to understand how can I better understand/interpret the results I'm getting.

Possibly out of the question scope but have you tried using sklearn's permutation importance function with TabNetClassifier(binary)? I tried the r = permutation_importance(net, X_test, y_test,scoring="balanced_accuracy",n_repeats=5,random_state=0) following the traning but I just get a vector of zeros.

Mean aggregation is not perfect, so without knowing the distribution of attention per example you can't know if a feature is not used 99% but very useful 1% of the time. Also, except if you the attention score is exactly 0 for a feature, it can still have an impact on the final score with 1% attention.

About the sklearn perumtation importance function, I've never tried it with TabNet, it's possible that it's not compatible as the repo is not 100% scikit compatible. But this is a fairly simple algorithm that you can code yourself.

Hi again,

I checked out the attention issue and I do have 0's on several features assuming I use all of them. However, it was more elucidating to divide into subsets and evaluate the accuracy (or auc). Ultimately I guess it depends on the strategy to study the data itself!

PS - Is there an easier way to get the best valid accuracy/auc after calling the clf.fit() (since it adjusts to the best epoch) other than just predicting again the scores?

Edit: Does the library change seed? I tried configuring an identical seed so that the weights are always initialized the same but after calling a configure_seed() function before any splitting and initialization I dont get reproducible aucs

Yes there is a seed parameter in the fit method. This makes sure that results are reproducible. This is tested in CI so it works for sure.