Problem using pre-trained weights
wreed opened this issue · comments
When I run
python test_models.py jester RGBFlow pretrained_models/MFF_jester_RGBFlow_BNInception_segment4_3f1c_best.pth.tar --arch BNInception --consensus_type MLP --test_crops 1 --num_motion 3 --test_segments 4
I get this error:
Initializing TSN with base model: BNInception.
TSN Configurations:
input_modality: RGBFlow
num_segments: 4
new_length: 3
consensus_module: MLP
dropout_ratio: 0.8
img_feature_dim: 256
/home/wilreed/dev/MFF-pytorch/.env/lib/python3.5/site-packages/torch/nn/modules/module.py:514: UserWarning: src is not broadcastable to dst, but they have the same number of elements. Falling back to deprecated pointwise behavior.
own_state[name].copy_(param)
Converting the ImageNet model to RGB+Flow init model
Done. RGBFlow model ready.
model epoch 38 best prec@1: 92.17555961317373
Traceback (most recent call last):
File "/home/wilreed/dev/MFF-pytorch/.env/lib/python3.5/site-packages/torch/nn/modules/module.py", line 514, in load_state_dict
own_state[name].copy_(param)
RuntimeError: invalid argument 2: sizes do not match at /pytorch/torch/lib/THC/generic/THCTensorCopy.c:101
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "test_models.py", line 85, in <module>
net.load_state_dict(base_dict, strict=False)
File "/home/wilreed/dev/MFF-pytorch/.env/lib/python3.5/site-packages/torch/nn/modules/module.py", line 519, in load_state_dict
.format(name, own_state[name].size(), param.size()))
RuntimeError: While copying the parameter named consensus.classifier.3.weight, whose dimensions in the model are torch.Size([174, 512]) and whose dimensions in the checkpoint are torch.Size([27, 512]).
Note that I encountered a different error when using pytorch 0.4. I was able to get this far by using 0.3.1.
Hello, did you properly prepared the jester dataset? Most probably you have the wrong "category.txt" file, because your created model has 174 class at the very end, and it should be 27 for jester dataset. Most probably, you work also on something-something dataset, as it has 174 classes. :)
You can continue to work with torch 0.3.1. Please let me know if you encounter any more exceptions.
Thanks. That makes sense. The Jester dataset link in the README points to the something-something dataset which threw me off.