model TabNet + dataset openml__arrhythmia__5 issue
duncanmcelfresh opened this issue · comments
duncanmcelfresh commented
full traceback:
Traceback (most recent call last):
File "/home/shared/tabzilla/TabSurvey/tabzilla_experiment.py", line 134, in __call__
result = cross_validation(model, self.dataset, self.time_limit)
File "/home/shared/tabzilla/TabSurvey/tabzilla_utils.py", line 237, in cross_validation
loss_history, val_loss_history = curr_model.fit(
File "/home/shared/tabzilla/TabSurvey/models/tabnet.py", line 39, in fit
self.model.fit(
File "/opt/conda/envs/torch/lib/python3.10/site-packages/pytorch_tabnet/abstract_model.py", line 185, in fit
self.update_fit_params(
File "/opt/conda/envs/torch/lib/python3.10/site-packages/pytorch_tabnet/tab_model.py", line 54, in update_fit_params
check_output_dim(train_labels, y)
File "/opt/conda/envs/torch/lib/python3.10/site-packages/pytorch_tabnet/multiclass_utils.py", line 384, in check_output_dim
raise ValueError(
ValueError: Valid set -- {0, 1, 3, 4, 5, 6, 7, 8, 9, 11, 12} --
contains unkown targets from training --
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12}
looks like this is related to #27 ?