This repository contains a PyTorch implementation of the paper Residual Attention Network for Image Classification. The code is based on ResidualAttentionNetwork-pytorch.
The original repository contains the following implementations:
- ResidualAttentionModel92U for training on the CIFAR10 dataset
- ResidualAttentionModel92 for training on the ImageNet dataset
- ResidualAttentionModel448 for images with larger resolution
Example usage for training on CIFAR10:
$ python main.py --name ResNet-92-32U --tensorboard
or
$ python main_mixup.py --name ResNet-92-32U --tensorboard
Example usage for testing your trained model (be sure to use the same network model):
$ python main.py --test ResNet-92-32U
Note: To switch the model you're training on, be sure to replace the imported model:
from residual_attention_network import ResidualAttentionModel92U as ResidualAttentionModel
either on main.py or main_mixup.py, depending on if you use mixup or not.
To track training progress, this implementation uses TensorBoard which offers great ways to track and compare multiple experiments. To track PyTorch experiments in TensorBoard we use tensorboard_logger which can be installed with
pip install tensorboard_logger
optional:
Model | Dataset | Top-1 error |
---|---|---|
RAN92U | CIFAR10 | 4.6 |
RAN92U (with mixup) | CIFAR10 | 3.35 |
RAN92U (with mixup and simpler attention module) | CIFAR10 | 3.16 |
If you use Residual Attention Networks in your work, please cite the original paper as:
@misc{1704.06904,
Author = {Fei Wang and Mengqing Jiang and Chen Qian and Shuo Yang and Cheng Li and Honggang Zhang and Xiaogang Wang and Xiaoou Tang},
Title = {Residual Attention Network for Image Classification},
Year = {2017},
Eprint = {arXiv:1704.06904},
}
If this implementation is useful to you and your project, also consider citing or acknowledging this code repository.