PyTorch implementation of:
Few-shot Classification with LS-SVM Base Learner and Transductive Modules
Haoqing Wang, Zhi-hong Deng
Few-shot classification aims to recognize unseen classes with few labeled samples from each class. The base learners suitable for low-data scenarios and auxiliary information from the query samples are critical to the performance of the meta-learning models for few-shot classification. For this reason, we make the following improvements in this work: 1) we introduce multi-class least squares support vector machine as new base learner which has less computational overhead and better classification ability than existing ones; 2) further, in order to utilize the information from the query samples, we propose two transductive modules which aim to adjust the support set basing on the query samples and can be conveniently applied to various base learners to improve their few-shot classification accuracy. Combining the above improvements, we get the final model, FSLSTM (Few-Shot classification with LS-svm base learner and Transductive Modules). We conduct extensive experiments and ablation studies on mini-ImageNet and CIFAR-FS few-shot classification benchmarks. Experimental results show FSLSTM outperforms recent transductive meta-learning models, and ablation studies verify the effectiveness of each of our improvements.
If you use this code for your research, please cite our paper:
@article{wang2020few,
title={Few-shot Learning with LSSVM Base Learner and Transductive Modules},
author={Wang, Haoqing and Deng, Zhi-Hong},
journal={arXiv preprint arXiv:2009.05786},
year={2020}
}
- Python 3.5+
- PyTorch 1.4.0
- qpth 0.0.15
- tqdm
Download and decompress the mini-ImageNet and CIFAR-FS datasets.
-
It contains 100 classes with 600 images of size 84x84 in each class, which are built upon the ImageNet dataset. The 100 classes are divided into 64, 16, 20 for meta-training, meta-validation and meta-testing, respectively.
-
It contains 100 classes with 600 images of size 32x32 in each class, which are built upon the CIFAR-100 dataset. The 100 classes are divided into 64, 16, 20 for meta-training, meta-validation and meta-testing, respectively.
To explore the impact of using the network parameters pre-trained with large dataset as the initialization on the final classification results, we pre-train the backbone with Places365-standard dataset.
- Download and decompress the Place365-standard dataset.
cd ./pre_train wget http://data.csail.mit.edu/places/places365/places365standard_easyformat.tar tar -xf places365standard_easyformat.tar
- Pre-train the ResNet12 and Conv4 for mini-ImageNet (84x84) and CIFAR-FS (32x32).
cd ./pre_train python pretrain_Conv4.py --resize 36 --img_size 32 --filename Conv4_32 python pretrain_Conv4.py --resize 92 --img_size 84 --filename Conv4_84 python pretrain_ResNet12.py --resize 36 --img_size 32 --filename ResNet12_32 python pretrain_ResNet12.py --resize 92 --img_size 84 --filename ResNet12_84
- To train different base learners without the Inverse Attention Module:
Set
python train.py --gpu 0,1,2,3 --save_path 'Logs/CIFAR_FS_Conv4_LSSVM' --pretrain 'Logs/pretrain/Conv4_32_best.pth' --lr 5e-3 --train_shot 5 --val_shot 5 --head LSSVM --network Conv4 --dataset CIFAR_FS python train.py --gpu 0,1,2,3 --save_path 'Logs/CIFAR_FS_ResNet12_LSSVM' --pretrain 'Logs/pretrain/ResNet12_32_best.pth' --lr 5e-4 --train_shot 5 --val_shot 5 --head LSSVM --network ResNet12 --dataset CIFAR_FS python train.py --gpu 0,1,2,3 --save_path 'Logs/miniImageNet_Conv4_LSSVM_5s' --pretrain 'Logs/pretrain/Conv4_84_best.pth' --lr 5e-3 --train_shot 15 --val_shot 5 --head LSSVM --network Conv4 --dataset miniImageNet python train.py --gpu 0,1,2,3 --save_path 'Logs/miniImageNet_ResNet12_LSSVM_5s' --pretrain 'Logs/pretrain/ResNet12_84_best.pth' --lr 5e-4 --train_shot 15 --val_shot 5 --head LSSVM --network ResNet12 --dataset miniImageNet
--head
toNN
,RR
orSVM
for other base learners. - To train different base learners with our Inverse Attention Module:
python train_IAM.py --gpu 0,1,2,3 --save_path 'Logs/CIFAR_FS_Conv4_LSSVM_IAM' --pretrain 'Logs/pretrain/Conv4_32_best.pth' --dim 2048 --reduction 8 --IAM_lr 1e-2 --lr 5e-3 --train_shot 5 --val_shot 5 --head LSSVM --network Conv4 --dataset CIFAR_FS python train_IAM.py --gpu 0,1,2,3 --save_path 'Logs/CIFAR_FS_ResNet12_LSSVM_IAM' --pretrain 'Logs/pretrain/ResNet12_32_best.pth' --dim 2560 --reduction 8 --IAM_lr 5e-3 --lr 5e-4 --train_shot 5 --val_shot 5 --head LSSVM --network ResNet12 --dataset CIFAR_FS python train_IAM.py --gpu 0,1,2,3 --save_path 'Logs/miniImageNet_Conv4_LSSVM_IAM_5s' --pretrain 'Logs/pretrain/Conv4_84_best.pth' --dim 4608 --reduction 32 --IAM_lr 1e-2 --lr 5e-3 --train_shot 15 --val_shot 5 --head LSSVM --network Conv4 --dataset miniImageNet python train_IAM.py --gpu 0,1,2,3 --save_path 'Logs/miniImageNet_ResNet12_LSSVM_IAM_5s' --pretrain 'Logs/pretrain/ResNet12_84_best.pth' --dim 5760 --reduction 16 --IAM_lr 5e-3 --lr 5e-4 --train_shot 15 --val_shot 5 --head LSSVM --network ResNet12 --dataset miniImageNet
- To test the base learners without our transductive modules:
Set
python test.py --gpu 0,1,2,3 --load 'Logs/CIFAR_FS_ResNet12_LSSVM/best_model.pth' --episode 10000 --way 5 --shot 5 --query 15 --head LSSVM --network ResNet12 --dataset CIFAR_FS python test.py --gpu 0,1,2,3 --load 'Logs/miniImageNet_ResNet12_LSSVM_5s/best_model.pth' --episode 10000 --way 5 --shot 5 --query 15 --head LSSVM --network ResNet12 --dataset miniImageNet
--head
toNN
,RR
orSVM
for other base learners. - To test the base learners with our Inverse Attention Module:
python test_IAM.py --gpu 0,1,2,3 --load 'Logs/CIFAR_FS_ResNet12_LSSVM_IAM/best_model.pth' --dim 2560 --reduction 8 --episode 10000 --way 5 --shot 5 --query 15 --head LSSVM --network ResNet12 --dataset CIFAR_FS python test_IAM.py --gpu 0,1,2,3 --load 'Logs/miniImageNet_ResNet12_LSSVM_IAM_5s/best_model.pth' --dim 5760 --reduction 16 --episode 10000 --way 5 --shot 5 --query 15 --head LSSVM --network ResNet12 --dataset miniImageNet
- To test the base learners with our Inverse Attention Module and Pseudo Support Module:
python test_IAM.py --gpu 0,1,2,3 --load 'Logs/CIFAR_FS_ResNet12_LSSVM_IAM/best_model.pth' --psm_iters 10 --dim 2560 --reduction 8 --episode 10000 --way 5 --shot 5 --query 15 --head LSSVM --network ResNet12 --dataset CIFAR_FS python test_IAM.py --gpu 0,1,2,3 --load 'Logs/miniImageNet_ResNet12_LSSVM_IAM_5s/best_model.pth' --psm_iters 10 --dim 5760 --reduction 16 --episode 10000 --way 5 --shot 5 --query 15 --head LSSVM --network ResNet12 --dataset miniImageNet
This code is based on the implementations of MetaOptNet.