naszilla / tabzilla

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

model TabNet + dataset openml__arrhythmia__5 issue

duncanmcelfresh opened this issue · comments

full traceback:

Traceback (most recent call last):
  File "/home/shared/tabzilla/TabSurvey/tabzilla_experiment.py", line 134, in __call__
    result = cross_validation(model, self.dataset, self.time_limit)
  File "/home/shared/tabzilla/TabSurvey/tabzilla_utils.py", line 237, in cross_validation
    loss_history, val_loss_history = curr_model.fit(
  File "/home/shared/tabzilla/TabSurvey/models/tabnet.py", line 39, in fit
    self.model.fit(
  File "/opt/conda/envs/torch/lib/python3.10/site-packages/pytorch_tabnet/abstract_model.py", line 185, in fit
    self.update_fit_params(
  File "/opt/conda/envs/torch/lib/python3.10/site-packages/pytorch_tabnet/tab_model.py", line 54, in update_fit_params
    check_output_dim(train_labels, y)
  File "/opt/conda/envs/torch/lib/python3.10/site-packages/pytorch_tabnet/multiclass_utils.py", line 384, in check_output_dim
    raise ValueError(
ValueError: Valid set -- {0, 1, 3, 4, 5, 6, 7, 8, 9, 11, 12} --
                             contains unkown targets from training --
                             {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12}

looks like this is related to #27 ?