We release the code and trained models of our paper Gate-Shift Networks for Video Action Recognition. If you find our work useful for your research, please cite
@InProceedings{gsm,
author = {Sudhakaran, Swathikiran and Escalera, Sergio and Lanz, Oswald},
title = {{Gate-Shift Networks for Video Action Recognition}},
booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year = {2020}
}
- Python 3.5
- PyTorch 1.2
- TensorboardX
-
Something Something-v1: Download the frames from the official website. Copy the directory containing frames and the train-val files to
dataset-->something-v1
. Runpython data_scripts/process_dataset_something.py
to create the train/val list files. -
Diving48: Download the videos and the annotations from the official website. Copy the directory containing videos and the annotations to the directory
dataset-->Diving48
. Runpython data_scripts/extract_frames_diving48.py
for extracting the frames from the videos. Runpython data_scripts/process_dataset_diving.py
for creating the train/test list files.
python main.py something-v1 RGB --arch BNInception \
--num_segments 8 --consensus_type avg \
--batch-size 16 --iter_size 2 --dropout 0.5 \
--lr 0.01 --warmup 10 --epochs 60 --eval-freq 5 \
--gd 20 --run_iter 1 -j 16 --npb --gsm
python test_models.py something-v1 RGB models/something-v1_RGB_InceptionV3_avg_segment16_checkpoint.pth.tar \
--arch InceptionV3 --crop_fusion_type avg \
--test_segments 16 --test_crops 1 --num_clips 1 --gsm
To evaluate using 2 clips sampled from each model, change --num_clips 1
to --num_clips 2
.
For prediction using ensemble of models, perform evaluation with the option --save_scores
to save the prediction scores and run python average_scores.py
.
The models can be downloaded by running python download_models.py
or from google drive.
The table shows the results reported in the paper. To reproduce the results, run the script obtained when clicked on the accuracy scores.
No. of frames | Top-1 Accuracy (%) | Something Something-v1 Visualization |
---|---|---|
8 | 49.01 | |
12 | 51.58 | |
16 | 50.63 | |
24 | 49.63 | |
8x2 | 50.43 | |
12x2 | 51.98 | |
8x2 + 12x2 + 16 + 24 | 55.16 |
To reproduce the results on Diving48 dataset, click on 39.03% (16 frames) and 40.27% (16x2 frames).
This implementation is built upon the TRN-pytorch codebase which is based on TSN-pytorch. We thank Yuanjun Xiong and Bolei Zhou for releasing TSN-pytorch and TRN-pytorch repos.