HKUDS / FlashST

[ICML'2024] "FlashST: A Simple and Universal Prompt-Tuning Framework for Traffic Prediction"

Home Page:https://arxiv.org/abs/2405.17898

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Reproduce results

moghadas76 opened this issue · comments

Hi,

python Run.py -dataset_test PEMS07M -mode eval -model MTGNN

produce:

============================scaler_mae_loss
Applying learning rate decay.
2024-08-10 16:46: Experiment log path in: /home/seyed/PycharmProjects/step/FlashST/model/../SAVE/eval/MTGNN
0%| | 0/20 [00:01<?, ?it/s]
Traceback (most recent call last):
File "/home/seyed/PycharmProjects/step/FlashST/model/Run.py", line 173, in
trainer.train_eval()
File "/home/seyed/PycharmProjects/step/FlashST/model/Trainer.py", line 128, in train_eval
train_epoch_loss, loss_pre = self.eval_trn_eps()
File "/home/seyed/PycharmProjects/step/FlashST/model/Trainer.py", line 180, in eval_trn_eps
out, q = self.model(data, data, self.args.dataset_test, self.batch_seen, nadj=nadj, lpls=lpls, useGNN=True, DSU=True)
File "/home/seyed/miniconda3/envs/FlashST/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/seyed/miniconda3/envs/FlashST/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 168, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File "/home/seyed/miniconda3/envs/FlashST/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 178, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/home/seyed/miniconda3/envs/FlashST/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
output.reraise()
File "/home/seyed/miniconda3/envs/FlashST/lib/python3.9/site-packages/torch/_utils.py", line 425, in reraise
raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/home/seyed/miniconda3/envs/FlashST/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
output = module(*input, **kwargs)
File "/home/seyed/miniconda3/envs/FlashST/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/seyed/PycharmProjects/step/FlashST/model/FlashST.py", line 152, in forward
return self.forward_pretrain(source, label, select_dataset, batch_seen, nadj, lpls, useGNN, DSU)
File "/home/seyed/PycharmProjects/step/FlashST/model/FlashST.py", line 155, in forward_pretrain
x_prompt_return = self.pretrain_model(source[..., :self.input_base_dim], source, None, nadj, lpls, useGNN)
File "/home/seyed/miniconda3/envs/FlashST/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/seyed/PycharmProjects/step/FlashST/model/PromptNet.py", line 118, in forward
hidden = torch.cat([time_series_emb] + node_emb + tem_emb, dim=-1).transpose(1, 3)
RuntimeError: Sizes of tensors must match except in dimension 2. Got 228 and 114 (The offending index is 1)

What should I do?

python Run.py -dataset_test PEMS07M -mode eval -model ori

0%| | 0/100 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/home/seyed/PycharmProjects/step/FlashST/model/Run.py", line 175, in
trainer.train_eval()
File "/home/seyed/PycharmProjects/step/FlashST/model/Trainer.py", line 128, in train_eval
train_epoch_loss, loss_pre = self.eval_trn_eps()
File "/home/seyed/PycharmProjects/step/FlashST/model/Trainer.py", line 180, in eval_trn_eps
out, q = self.model(data, data, self.args.dataset_test, self.batch_seen, nadj=nadj, lpls=lpls, useGNN=True, DSU=True)
File "/home/seyed/miniconda3/envs/FlashST/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/seyed/miniconda3/envs/FlashST/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 154, in forward
raise RuntimeError("module must have its parameters and buffers "
RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found one of them on device: cuda:1

You can solve this problem by disabling multi-GPU parallelism. We will update it accordingly. Thanks.