lliuz / ARFlow

The official PyTorch implementation of the paper "Learning by Analogy: Reliable Supervision from Transformations for Unsupervised Optical Flow Estimation".

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

What the corresponding relation between config files and checkpoints?

Kewenjing1020 opened this issue · comments

Hi,
I noticed there're several config files and checkpoints for each dataset. Take KITTI as an example, what's the corresponding relation between config files kitti15_ft_ar.json kitti15_ft.json kitti_raw.json and checkpoints pwclite_ar_mv.tar pwclite_ar.tar pwclite_raw.tar? Which config should I use if I want to reproduce these three models?

Besides, I tried to evaluate the checkpoint pwclite_ar_mv.tar with all the three config files and always got the following error:

[INFO] => using pre-trained weights checkpoints/KITTI15/pwclite_ar_mv.tar.
Traceback (most recent call last):
File "train.py", line 50, in
basic_train.main(cfg, _log)
File "/proj/xcdhdstaff1/wenjingk/SLAM/ARFlow-master/basic_train.py", line 53, in main
train_loader, valid_loader, model, loss, _log, cfg.save_root, cfg.train)
File "/proj/xcdhdstaff1/wenjingk/SLAM/ARFlow-master/trainer/kitti_trainer.py", line 13, in init
train_loader, valid_loader, model, loss_func, _log, save_root, config)
File "/proj/xcdhdstaff1/wenjingk/SLAM/ARFlow-master/trainer/base_trainer.py", line 26, in init
self.model = self._init_model(model)
File "/proj/xcdhdstaff1/wenjingk/SLAM/ARFlow-master/trainer/base_trainer.py", line 75, in _init_model
model.load_state_dict(weights)
File "/scratch/workspace/wenjingk/anaconda-3.6/envs/python3.6/lib/python3.6/site-packages/torch/nn/modules/module.py", line 830, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for PWCLite:
size mismatch for flow_estimators.conv1.0.weight: copying a param with shape torch.Size([128, 198, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 115, 3, 3]).
size mismatch for context_networks.convs.0.0.weight: copying a param with shape torch.Size([128, 68, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 34, 3, 3]).

I guess you need to have 3 frames in the config for the multi-view checkpoint.