yandex-research / rtdl

Research on Tabular Deep Learning: Papers & Packages

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Is it possible to provide a scikit-learn interface?

hengzhe-zhang opened this issue · comments

This project is interesting and I want to use it as the baseline algorithm for my paper. However, it seems that I need to take several steps in order to make a prediction. Consequently, is it possible to provide a scikit-learn interface for making a convenient comparison between different algorithms?

UPD

There are several approaches to training and prediction:

  • implement things manually (as it is done in the official example)
  • use a general purpose framework (for example, Lightning).
  • (use with caution) use a high-level library (for example, skorch provides a Scikit-Learn interface for PyTorch, which looks like what you are looking for). WARNING: high-level libraries usually provide default training parameters that can be suboptimal for the actual task at hand. You should not rely on the default parameters. Instead, you should tune them or take inspiration from pipelines for tasks that are similar to your task. In all cases, you should explicitly pass all the training parameters (such as optimizer, batch size, learning rate, weight decay, early stopping settings, epochs, etc.) to the corresponding functions.

Do you mean that I only need to wrap the FTTransformer using skorch?

NOTE: I have updated the previous answer, please, read it first.

Do you mean that I only need to wrap the FTTransformer using skorch?

In theory, yes. Note that rtdl.FTTransformer expects two arguments (numerical and categorical features), so you will need to read this section.

Feel free to reopen the issue if you have more questions on the topic.

@Yura52 I implement a scikit-learn compatible interface for algorithms in this library and already open sourced it on GitHub. (https://github.com/zhenlingcn/scikit-rtdl)