naszilla / tabzilla

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

More Hopular Issues

duncanmcelfresh opened this issue · comments

encountered with dataset openml__collins__3567:

Traceback (most recent call last):
  File "/home/shared/tabzilla/TabSurvey/tabzilla_experiment.py", line 137, in __call__
    result = cross_validation(model, self.dataset, self.time_limit)
  File "/home/shared/tabzilla/TabSurvey/tabzilla_utils.py", line 247, in cross_validation
    train_predictions, train_probs = curr_model.predict_wrapper(X_train)
  File "/home/shared/tabzilla/TabSurvey/models/basemodel.py", line 108, in predict_wrapper
    self.predictions, self.prediction_probabilities = self.predict(X)
  File "/home/shared/tabzilla/TabSurvey/models/basemodel_torch.py", line 163, in predict
    self.predict_proba(X)
  File "/home/shared/tabzilla/TabSurvey/models/basemodel_torch.py", line 170, in predict_proba
    probas = self.predict_helper(X)
  File "/home/shared/tabzilla/TabSurvey/models/hopular_model.py", line 141, in predict_helper
    preds = self.model(batch_X[0])
  File "/opt/conda/envs/torch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/shared/tabzilla/TabSurvey/models/hopular/blocks.py", line 791, in forward
    embeddings = self.embeddings(input).unsqueeze(dim=0)
  File "/opt/conda/envs/torch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/shared/tabzilla/TabSurvey/models/hopular/blocks.py", line 119, in forward
    input_embedded = torch.cat(tuple(feature_embedding(
  File "/home/shared/tabzilla/TabSurvey/models/hopular/blocks.py", line 119, in <genexpr>
    input_embedded = torch.cat(tuple(feature_embedding(
  File "/opt/conda/envs/torch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/envs/torch/lib/python3.10/site-packages/torch/nn/modules/container.py", line 141, in forward
    input = module(input)
  File "/opt/conda/envs/torch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/envs/torch/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 103, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (256x1 and 2x32)