tiepvupsu / tabml

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Auto infer trainer from model_wrapper

tiepvupsu opened this issue · comments

In many cases, models are associated with a certain trainer, e.g. fit method in scikit-learn. Tabml should:

  1. Automatically infer the associated trainer if possible
  2. Allow users to specify trainers

For 1., model_wrappers could have an extra method called get_associated_trainer(*) to return the trainer. This method should return None by default.

For 2., users could specify trainers as in the current flow, i.e. in pipeline_config. If users specify it in pipeline_config, this trainer will override the default trainer. If users do not, get_associated_trainer must not return None.

(*) We need to deal with circular import here when model_wrappers.py and trainers.py import each other.