Necas209 / ResidualAttentionNetwork-PyTorch

A PyTorch implementation for Residual Attention Networks

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

A PyTorch implementation for Residual Attention Networks

This repository contains a PyTorch implementation of the paper Residual Attention Network for Image Classification. The code is based on ResidualAttentionNetwork-pytorch.

The original repository contains the following implementations:

  • ResidualAttentionModel92U for training on the CIFAR10 dataset
  • ResidualAttentionModel92 for training on the ImageNet dataset
  • ResidualAttentionModel448 for images with larger resolution

Train and test your model

Example usage for training on CIFAR10:

$ python main.py --name ResNet-92-32U --tensorboard

or

$ python main_mixup.py --name ResNet-92-32U --tensorboard

Example usage for testing your trained model (be sure to use the same network model):

$ python main.py --test ResNet-92-32U

Note: To switch the model you're training on, be sure to replace the imported model:

from residual_attention_network import ResidualAttentionModel92U as ResidualAttentionModel

either on main.py or main_mixup.py, depending on if you use mixup or not.

Tracking training progress with TensorBoard

To track training progress, this implementation uses TensorBoard which offers great ways to track and compare multiple experiments. To track PyTorch experiments in TensorBoard we use tensorboard_logger which can be installed with

pip install tensorboard_logger

Dependencies

optional:

Results (from source repository)

Model Dataset Top-1 error
RAN92U CIFAR10 4.6
RAN92U (with mixup) CIFAR10 3.35
RAN92U (with mixup and simpler attention module) CIFAR10 3.16

Cite

If you use Residual Attention Networks in your work, please cite the original paper as:

@misc{1704.06904,
    Author = {Fei Wang and Mengqing Jiang and Chen Qian and Shuo Yang and Cheng Li and Honggang Zhang and Xiaogang Wang and Xiaoou Tang},
    Title = {Residual Attention Network for Image Classification},
    Year = {2017},
    Eprint = {arXiv:1704.06904},
}

If this implementation is useful to you and your project, also consider citing or acknowledging this code repository.

About

A PyTorch implementation for Residual Attention Networks


Languages

Language:Python 100.0%