kongbia / pytorch-classification

Config File is All You Need: An Image Classification Codebase Written in PyTorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Config File is All You Need: An Image Classification Codebase Written in PyTorch

This project aims at providing the necessary building blocks for easily creating an image classification model using PyTorch.

Note: I finished this project in my spare time within a week. So there is still a lot of work to be done.

Highlights

  • Convenient: You can use a config file to create an image classification model and train on your own datasets without writing any code.
  • Extensible: You can write your own modules (Dataset, Transform, Network, Loss and so on) and register them to the default config easily.
  • Parameter-is-Module: You can create a module by the parameter which is consisted of a module name and an argument list.
  • Multi-GPU training and inference: You can train your model on one GPU or use multi-GPU to train the model in parallel.

Accuracy

The top-1 accuracy (%) of different models in CIFAR-10 are shown as below. NETWORK_STRIDE is set to (2,2,2,2,2) and (1,1,2,2,2) respectively. Refer to DETAILS.md for more details about the parameter NETWORK_STRIDE.

Model (2,2,2,2,2) (1,1,2,2,2)
ResNet-18 86.10 92.64
ResNet-34 86.14 92.73
ResNet-50 86.65 92.20
ResNet-101 87.41 93.27
ResNet-152 87.01 93.25
ResNeXt-50, 32x4d 87.56 93.65
ResNeXt-101, 32x8d 88.24 93.75

Installation

pip3 install -r requirements.txt

Inference in a few lines

We provide a helper class to simplify writing inference pipelines using pre-trained models. Here is how we would do it. Run following code from the demo folder. (The pre-trained model and sample images can be downloaded from here [Baidu(PWD: f25u)][OneDrive]. You can choose any config in ./configs)

import sys

sys.path.append("../")

import cv2

from predictor import ClsDemo
from pytorch_classification.config import cfg


config = "/path_to_config"
img_path = "/path_to_image"
checkpoint_path = "/path_to_pre-trained_model"

cfg.merge_from_file(config)
cfg.merge_from_list(["CHECKPOINT", checkpoint_path])

cls_demo = ClsDemo(cfg)

image = cv2.imread(img_path)
pred = cls_demo.run_on_openv_image(image)

Perform training on CIFAR-10 dataset

You need to download the CIFAR-10 dataset and convert it the required GeneralDataset format in this codebase. (You can also download from here [Baidu(PWD: f25u)][OneDrive], which has been reformatted.) We recommend to symlink the path to the cifar-10 dataset to ./datasets as follows

# symlink the cifar-10 dataset
cd pytorch-classification
mkdir -p datasets/cifar-10
ln -s /path_to_cifar-10_dataset datasets/cifar-10

You can also configure your own paths to the datasets. For that, all you need to do is to modify ./pytorch_classification/config/data_catalog.py to point to the location where your dataset is stored.(See DETAILS.md for more details.)

Single GPU training

You can run the following without modifications to train your model on a single GPU.

python3 tools/train.py --config-file "configs/config_cifar10_R50_1gpu.yaml"

Multi-GPU training

We use internally torch.distributed.launch in order to launch multi-GPU training. This utility function from PyTorch spawns as many Python processes as the number of GPUs we want to use, and each Python process will only use a single GPU.

export NGPUS=8
python3 -m torch.distributed.launch --nproc_per_node=$NGPUS tools/train.py --config-file "configs/config_cifar10_R50_8gpu.yaml"

If you want to train your model on more GPUs, you should change the batch size SOLVER.BATCH_SIZE and learning rate SOLVER.BASE_LR adaptively.

Evaluation

You can test your model directly on single or multiple GPUs. Here is an example for multi-GPU testing:

export NGPUS=8
python3 -m torch.distributed.launch --nproc_per_node=$NGPUS tools/test.py --config-file "configs/config_cifar10_R50_8gpu.yaml"

Details

You can refer to DETAILS.md for more details.

License

This project is released under the MIT license. See LICENSE for additional details.

Acknowledgement

This codebase is heavily influenced by the project maskrcnn-benchmark.

About

Config File is All You Need: An Image Classification Codebase Written in PyTorch

License:MIT License


Languages

Language:Python 100.0%