naszilla / tabzilla

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

TabNet error - dimensions

duncanmcelfresh opened this issue · comments

This occurs with alg TabNet on dataset openml__sulfur__360966, this occurs on all hparam samples tested:

Traceback (most recent call last):
  File "/home/shared/tabzilla/TabSurvey/tabzilla_experiment.py", line 136, in __call__
    result = cross_validation(model, self.dataset, self.time_limit)
  File "/home/shared/tabzilla/TabSurvey/tabzilla_utils.py", line 237, in cross_validation
    loss_history, val_loss_history = curr_model.fit(
  File "/home/shared/tabzilla/TabSurvey/models/tabnet.py", line 45, in fit
    self.model.fit(
  File "/opt/conda/envs/torch/lib/python3.10/site-packages/pytorch_tabnet/abstract_model.py", line 223, in fit
    self._train_epoch(train_dataloader)
  File "/opt/conda/envs/torch/lib/python3.10/site-packages/pytorch_tabnet/abstract_model.py", line 434, in _train_epoch
    batch_logs = self._train_batch(X, y)
  File "/opt/conda/envs/torch/lib/python3.10/site-packages/pytorch_tabnet/abstract_model.py", line 469, in _train_batch
    output, M_loss = self.network(X)
  File "/opt/conda/envs/torch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/envs/torch/lib/python3.10/site-packages/pytorch_tabnet/tab_network.py", line 583, in forward
    return self.tabnet(x)
  File "/opt/conda/envs/torch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/envs/torch/lib/python3.10/site-packages/pytorch_tabnet/tab_network.py", line 468, in forward
    steps_output, M_loss = self.encoder(x)
  File "/opt/conda/envs/torch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/envs/torch/lib/python3.10/site-packages/pytorch_tabnet/tab_network.py", line 150, in forward
    x = self.initial_bn(x)
  File "/opt/conda/envs/torch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/envs/torch/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py", line 168, in forward
    return F.batch_norm(
  File "/opt/conda/envs/torch/lib/python3.10/site-packages/torch/nn/functional.py", line 2419, in batch_norm
    _verify_batch_size(input.size())
  File "/opt/conda/envs/torch/lib/python3.10/site-packages/torch/nn/functional.py", line 2387, in _verify_batch_size
    raise ValueError("Expected more than 1 value per channel when training, got input size {}".format(size))
ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 6])

This happens when the last training batch has size 1. I have pushed a fix that makes TabNet drop the last batch from the dataloader when this happens.