xiangyu8 / FSL_spatial_frequency

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Few-Shot Learning by Integrating Spatial and Frequency Representation

A few-shot classification algorithm: [Few-Shot Learning by Integrating Spatial and Frequency Representation]

Framework

Our code is built upon the code base of

A Closer Look at Few-shot Classification (code)

Charting the Right Manifold: Manifold Mixup for Few-shot Learning (code)

Learning in the Frequency Domain (code)

Leveraging the Feature Distribution in Transfer-based Few-Shot Learning (code)

Running the code

Dataset: mini-ImageNet, CIFAR-FS, CUB

Prerequisites:

  • python 3.6.9

  • libjpeg-turbo 2.0.3

============================================================================

Install:

To run S2M2_R and other algorithms install with: pip install -r requirement_dct.txt

To run PT+MAP install with: pip install -r requirement_map.txt (titanxp)

(I haven't tried to combine both environment setup.)

============================================================================

Prepare the dataset:

CUB

  • Change directory to filelists/CUB/
  • run 'source ./download_CUB.sh'
  • or download the dataset: CUB (have splition inside)
  • Then, run python make_json.py to get train.json, val.json and novel.json

CIFAR-FS

  • Change directory to filelists/cifar/
  • run 'source ./download_cifar.sh'
  • or download the dataset: CIFAR-FS (have splition inside)
  • Then, run python make_json.py to get train.json, val.json and novel.json

miniImagenet

  • Change directory to filelists/miniImagenet/
  • run 'source ./download_miniImagenet.sh'
  • or download the dataset mannually: mini-Imagenet (need to download splition using download_miniImagenet.sh or from this repo, under filelists/miniImagenet/, train.csv, test.csv and test.csv)
  • Then, run python make_json.py to get train.json, val.json and novel.json

============================================================================

Training

DATASETNAME: miniImagenet/cifar/CUB

METHODNAME: S2M2_R/rotation

(To run S2M2_R, run rotation first since S2M2_R is based on rotation according to S2M2_R algorithm.)

  1. train frequency version: (8x8 DCT filter with top left 24 channels selected)
python train_dct.py --dataset [DATASETNAME] --method [METHODNAME] --model WideResNet28_10 --train_aug --dct_status
  1. train spatial version:
python train_dct.py --dataset [DATASETNAME] --method [METHODNAME] --model WideResNet28_10 --train_aug

All results will be saved in the folder "checkpoints", the best model is "best.tar".

============================================================================

Save features

S2M2_R algorithm

  1. save frequency version:
python save_features.py --dataset [DATASETNAME] --method S2M2_R --model WideResNet28_10 --train_aug --dct_status
  1. save spatial version:
python save_features.py --dataset [DATASETNAME] --method S2M2_R --model WideResNet28_10 --train_aug
  1. save spatial and frequency version:
python save_features_both.py --dataset [DATASETNAME] --method S2M2_R --model WideResNet28_10 --train_aug

All features will be saved in the folder "./features".

PT-MAP algorithm

  1. save frequency version:
python save_plk.py --dataset [DATASETNAME] --method S2M2_R --model WideResNet28_10 --train_aug --dct_status
  1. save spatial version:
python save_plk.py --dataset [DATASETNAME] --method S2M2_R --model WideResNet28_10 --train_aug
  1. save spatial and frequency version:
python save_plk_both.py --dataset [DATASETNAME] --method S2M2_R --model WideResNet28_10 --train_aug

All features will be saved in the folder "./checkpoints/[DATASETNAME]/WideResNet28_10_S2M2_R_5way_1shot_aug/last".

=================================================================================

Testing

You could download checkpoints here.

S2M2_R algorithm

  1. test frequency version:
python test_dct.py --dataset [DATASETNAME] --method S2M2_R --model WideResNet28_10 --n_shot [1/5] --train_aug --dct_status	
  1. test spatial version:
python test_dct.py --dataset [DATASETNAME] --method S2M2_R --model WideResNet28_10 --n_shot [1/5] --train_aug	
  1. test spatial and frequency version:
python test_dct_both.py --dataset [DATASETNAME] --method S2M2_R --model WideResNet28_10 --n_shot [1/5] --train_aug	

PT-MAP algorithm

python test_standard.py	

Revise the .plk file folder in FSLtask.py for frequency(out.plk), spatial(out.plk), frequency+spatial versions (out_both.plk);

Revise the n_shot in test_standard.py to get result of 5-shot or 1-shot.

Comparison with the state-of-the-art on mini-ImageNet, CUB and CIFAR-FS dataset:

Method mini-ImageNet CUB CIFAR-FS
1-shot 5-shot 1-shot 5-shot 1-shot 5-shot
S2M2_R[1] 64.93 +- 0.18 83.18 +- 0.11 80.68 +- 0.81 90.85 +- 0.44 74.81 +- 0.19 87.47 +- 0.13
S2M2_R (s) 63.51 +- 0.18 81.54 +- 0.12 80.55 +- 0.78 91.52 +- 0.39 73.54 +- 0.20 86.90 +- 0.13
S2M2_R (f) 63.03 +- 0.18 80.80 +- 0.11 81.00 +- 0.76 91.08 +- 0.40 72.21 +- 0.20 85.72 +- 0.13
S2M2_R (s+f) 66.96 +- 0.18 84.31 +- 0.10 84.87 +- 0.72 93.52 +- 0.35 76.60 +- 0.19 88.55 +- 0.13
PT-MAP[2] 82.92 +- 0.26 88.82 +- 0.13 91.55 +- 0.19 93.99 +- 0.10 87.69 +- 0.23 90.68 +- 0.15
PT-MAP(s) 81.01 +- 0.25 88.07 +- 0.13 92.25 +- 0.18 94.62 +- 0.09 87.82 +- 0.22 91.00 +- 0.16
PT-MAP(f) 82.04 +- 0.23 88.68 +- 0.12 93.18 +- 0.16 95.02 +- 0.08 86.57 +- 0.23 90.28 +- 0.15
PT-MAP(s+f) 85.01 +- 0.22 90.72 +- 0.11 95.45 +- 0.13 96.70 +- 0.07 89.39 +- 0.21 92.08 +- 0.15

=================================================================================

Training frequency version with different filter sizes:

DATASETNAME: miniImagenet/cifar/CUB

MODEL: ResNet10dct/ResNet34dct/

METHOD: baseline++

Before training determine:

  • number of channels selection: ./data/datamgr.py (channels = 24 by default)

  • channels of channels selection: main/__init__.py (revise the channels according to Figure 3 in paper.)

python train_dct.py --dataset [DATASETNAME] --method [METHOD] --model [MODEL] --train_aug --dct_status --filter_size [2/4/6/8/10...]

References

A Closer Look at Few-shot Classification

Meta-Learning with Latent Embedding Optimization

Meta Learning with Differentiable Convex Optimization

Manifold Mixup: Better Representations by Interpolating Hidden States

Leveraging the Feature Distribution in Transfer-based Few-Shot Learning (code)

jpeg2dct

About


Languages

Language:Python 99.7%Language:Shell 0.3%