Runtime error gradient computation when train.py
Salvatore-tech opened this issue · comments
Describe the bug
Good evening, i wanted to lauch train.py using a default config file for RAFT on a standard dataset (KITTI_2015).
I followed the instruction to install MMFlow from source successfully.
Reproduction
python tools/train.py configs/raft/raft_8x2_50k_kitti2015_and_Aug_288x960.py \
--load-from /home/s.starace/FlowNets/mmflow/checkpoints/raft/raft_8x2_100k_mixed_368x768.pth
-
Did you make any modifications on the code or config? Did you understand what you have modified?
I just changed the name of symlink that i created under /data (uppercase) -
What dataset did you use?
KITTI_2015
Environment
I launched the command on my PC and also on a little cluster and the output error is the same.
Error traceback
See log attached: slurm-53090.out.txt
Bug fix
Not sure about it, could either be a configuration issue in Encoder/Decoder or a regression.
I'll try the train.py using other models as well and update the report if i understand better the problem.
Training of different models than RAFT (PwcNet, LiteFlow, ec...) with the same script to load the dataset have NOT issue reported above
训练GMA的时候也会报同样错误。
Sorry @wz940216 but i did not get your answer, do you need additional detail about the issue?
@Salvatore-tech I'm not a developer, but I had the same problem as you when training GMA and RAFT while using mmflow.
@wz940216 that's interesting, I hope that the owners of this repository could give us a clue (I'd like to use RAFT in my use case because it should give better performance).
@MeowZheng @Zachary-66
I meet the same bug using PyTorch 1.12.1. Below is my log:
raft_kitti_bug.txt
However, when I use PyTorch 1.8.0, this bug no longer appears. Below is the normal log:
raft_kitti_normal.txt
I notice you are using PyTorch 1.12.1, which is the latest version and there might be some unexpected bugs. To save your time, I recommend that you follow the official command below to install PyTorch1.8.0 instead:
conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=10.2 -c pytorch
We will modify MMFlow as soon as possible to avoid similar bugs. Thanks a lot!
Thanks @Zachary-66 your fix did the job, i did not notice that operator otherwise i would have pulled the request.
I'm closing the issue