lmbxmu / HRank

Pytorch implementation of our paper accepted by CVPR 2020 (Oral) -- HRank: Filter Pruning using High-Rank Feature Map

Home Page:https://128.84.21.199/abs/2002.10179

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

HRank: Filter Pruning using High-Rank Feature Map (Link).

Pytorch implementation of HRank (CVPR 2020, Oral).

This repository was what we used during the preparation of CVPR 2020. You can enjoy it!

A better version of HRank (in both fine-tuning efficiency and accuracy performance) are released at HRankPlus. It is highly suggested!

Framework of HRank. In the left column, we first use images to run through the convolutional layers to get the feature maps. In the middle column, we then estimate the rank of each feature map, which is used as the criteria for pruning. The right column shows the pruning (the red filters), and fine-tuning where the green filters are updated and the blue filters are frozen.

Tips

Any problem, please contact the authors via emails: lmbxmu@stu.xmu.edu.cn or ethan.zhangyc@gmail.com. Do not post issues with github as much as possible, just in case that I could not receive the emails from github thus ignore the posted issues.

Citation

If you find HRank useful in your research, please consider citing:

@inproceedings{lin2020hrank,
  title={HRank: Filter Pruning using High-Rank Feature Map},
  author={Lin, Mingbao and Ji, Rongrong and Wang, Yan and Zhang, Yichen and Zhang, Baochang and Tian, Yonghong and Shao, Ling},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
  pages={1529--1538},
  year={2020}
}

Running Code

In this code, you can run our models on CIFAR-10 and ImageNet dataset. The code has been tested by Python 3.6, Pytorch 1.0 and CUDA 9.0 on Ubuntu 16.04.

Rank Generation

python rank_generation.py \
--resume [pre-trained model dir] \
--arch [model arch name] \
--limit [batch numbers] \
--gpu [gpu_id]

Model Training

For the ease of reproducibility. we provide some of the experimental results and the corresponding pruned rate of every layer as belows:

Attention! The actual pruning rates are much higher than these presented in the paper since we do not count the next-layer channel removal (For example, if 50 filters are removed in the first layer, then the corresponding 50 channels in the second-layer filters should be removed as well).

1. VGG-16
Params Flops Accuracy
2.64M(82.1%) 108.61M(65.3%) 92.34%
python main.py \
--job_dir ./result/vgg_16_bn/[folder name] \
--resume [pre-trained model dir] \
--arch vgg_16_bn \
--compress_rate [0.95]+[0.5]*6+[0.9]*4+[0.8]*2 \
--gpu [gpu_id]
2. ResNet56
Params Flops Accuracy
0.49M(42.4%) 62.72M(50.0%) 93.17%
python main.py \
--job_dir ./result/resnet_56/[folder name] \
--resume [pre-trained model dir] \
--arch resnet_56 \
--compress_rate [0.1]+[0.60]*35+[0.0]*2+[0.6]*6+[0.4]*3+[0.1]+[0.4]+[0.1]+[0.4]+[0.1]+[0.4]+[0.1]+[0.4] \
--gpu [gpu_id]
3. ResNet110

Note that, in the paper, we mistakenly regarded the FLOPs as 148.70M(41.2%). We apologize for it and We will update the arXiv version as soon as possible.

Params Flops Accuracy
1.04M(38.7%) 156.90M(37.9%) 94.23%
python main.py \
--job_dir ./result/resnet_110/[folder name] \
--resume [pre-trained model dir] \
--arch resnet_110 \
--compress_rate [0.1]+[0.40]*36+[0.40]*36+[0.4]*36 \
--gpu [gpu_id]
4. DenseNet40
Params Flops Accuracy
0.66M(36.5%) 167.41M(40.8%) 94.24%
python main.py \
--job_dir ./result/densenet_40/[folder name] \
--resume [pre-trained model dir] \
--arch densenet_40 \
--compress_rate [0.0]+[0.1]*6+[0.7]*6+[0.0]+[0.1]*6+[0.7]*6+[0.0]+[0.1]*6+[0.7]*5+[0.0] \
--gpu [gpu_id]
5. GoogLeNet
Params Flops Accuracy
1.86M(69.8%) 0.45B(70.4%) 94.07%
python main.py \
--job_dir ./result/googlenet/[folder name] \
--resume [pre-trained model dir] \
--arch googlenet \
--compress_rate [0.10]+[0.8]*5+[0.85]+[0.8]*3 \
--gpu [gpu_id]
6. ResNet50
Params Flops Acc Top1 Acc Top5
13.77M 1.55B 71.98% 91.01%
python main.py \
--dataset imagenet \
--data_dir [ImageNet dataset dir] \
--job_dir./result/resnet_50/[folder name] \
--resume [pre-trained model dir] \
--arch resnet_50 \
--compress_rate [0.2]+[0.8]*10+[0.8]*13+[0.55]*19+[0.45]*10 \
--gpu [gpu_id]

After training, checkpoints and loggers can be found in the job_dir. The pruned model will be named [arch]_cov[i] for stage i, and therefore the final pruned model is the one with largest i.

Get FLOPS & Params

python cal_flops_params.py \
--arch resnet_56_convwise \
--compress_rate [0.1]+[0.60]*35+[0.0]*2+[0.6]*6+[0.4]*3+[0.1]+[0.4]+[0.1]+[0.4]+[0.1]+[0.4]+[0.1]+[0.4]

Evaluate Final Performance

python evaluate.py \
--dataset [dataset name] \
--data_dir [dataset dir] \
--test_model_dir [job dir of test model] \
--arch [arch name] \
--gpu [gpu id]

Other optional arguments

optional arguments:
    --data_dir			dataset directory
    				default='./data'
    --dataset			dataset name
    				default: cifar10
    				Optional: cifar10', imagenet
    --lr			initial learning rate
    				default: 0.01
    --lr_decay_step		learning rate decay step
				default: 5,10
    --resume			load the model from the specified checkpoint
    --resume_mask		mask loading directory
    --gpu			Select gpu to use
    				default: 0
    --job_dir			The directory where the summaries will be stored.
    --epochs			The num of epochs to train.
    				default: 30
    --train_batch_size		Batch size for training.
    				default: 128
    --eval_batch_size		Batch size for validation. 
				default: 100
    --start_cov			The num of conv to start prune
    				default: 0
    --compress_rate 		compress rate of each conv
    --arch			The architecture to prune
    				default: vgg_16_bn
				Optional: resnet_50, vgg_16_bn, resnet_56, resnet_110, densenet_40, googlenet

Pre-trained Models

Additionally, we provide the pre-trained models used in our experiments.

CIFAR-10:

Vgg-16 | ResNet56 | ResNet110 | DenseNet-40 | GoogLeNet

ImageNet:

ResNet50

About

Pytorch implementation of our paper accepted by CVPR 2020 (Oral) -- HRank: Filter Pruning using High-Rank Feature Map

https://128.84.21.199/abs/2002.10179


Languages

Language:Python 100.0%