s-chh / PatchRot

Official Pytorch Implementation of PatchRot: A Self-Supervised Technique for Training Vision Transformers

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

PatchRot

This is the official PyTorch Implementation of our upcoming BMVC 2024 PatchRot paper "PatchRot: Self-Supervised Training of Vision Transformers by Rotation Prediction".

Introduction

PatchRot rotates images and image patches and trains the network to predict the rotation angles. The network learns to extract global image and patch-level features through this process. PatchRot pretraining extracts superior features and provides improved performance.

Run commands (also available in run_cifar10.sh):

Run main_pretrain.py to pre-train the network with PatchRot, followed by main_finetune.py --init patchrot to finetune the network.
main_finetune.py --init none can be used to train the network without any pretraining (training from random initialization).
Below is an example on CIFAR10:

Method Run Command
PatchRot pretraining python main_pretrain.py --dataset cifar10
Finetuning pretrained model python main_finetune.py --dataset cifar10 --init patchrot
Training from random init python main_finetune.py --dataset cifar10 --init none

Replace cifar10 with the appropriate dataset.
Supported datasets: CIFAR10, CIFAR100, FashionMNIST, SVHN, TinyImageNet, Animals10N, and ImageNet100.

CIFAR10, CIFAR100, FashionMNIST, and SVHN datasets will be downloaded to the path specified in the "data_path" argument (default: "./data").
TinyImageNet, Animals10N, and ImageNet100 need to be downloaded, and the path needs to be provided using the "data_path" argument.

Results

Dataset Without PatchRot Pretraining With PatchRot Pretraining
CIFAR10 84.4 91.3
CIFAR100 56.5 66.7
FashionMNIST 93.4 94.6
SVHN 92.9 96.4
Animals10N 69.6 79.5
TinyImageNet 38.4 48.8
ImageNet100 64.6 75.4

About

Official Pytorch Implementation of PatchRot: A Self-Supervised Technique for Training Vision Transformers

License:MIT License


Languages

Language:Python 99.4%Language:Shell 0.6%