A PyTorch implementation of EGFF based on NPL 2022 paper Energy-Guided Feature Fusion for Zero-Shot Sketch-Based Image Retrieval.
conda install pytorch=1.10.1 torchvision cudatoolkit -c pytorch
pip install pytorch-metric-learning
pip install timm
pip install opencv-python
Sketchy Extended and TU-Berlin Extended datasets are used in this repo, you could download these datasets from official websites, or download them from Google Drive. The data directory structure is shown as follows:
├──sketchy
├── train
├── sketch
├── airplane
├── n02691156_58-1.jpg
└── ...
...
├── photo
same structure as sketch
├── val
same structure as train
...
├──tuberlin
same structure as sketchy
...
python train.py --data_name tuberlin
optional arguments:
--data_root Datasets root path [default value is '/home/data']
--data_name Dataset name [default value is 'sketchy'](choices=['sketchy', 'tuberlin'])
--backbone_type Backbone type [default value is 'resnet50'](choices=['resnet50', 'vgg16'])
--proj_dim Projected embedding dim [default value is 512]
--batch_size Number of images in each mini-batch [default value is 64]
--epochs Number of epochs over the model to train [default value is 10]
--warmup Number of warmups over the model to train [default value is 1]
--save_root Result saved root path [default value is 'result']
python test.py --num 4
optional arguments:
--data_root Datasets root path [default value is '/home/data']
--query_name Query image name [default value is '/home/data/sketchy/val/sketch/cow/n01887787_591-14.jpg']
--data_base Queried database [default value is 'result/sketchy_resnet50_512_vectors.pth']
--num Retrieval number [default value is 8]
--save_root Result saved root path [default value is 'result']
python vis.py --model_name result/sketchy_resnet50_2048_model.pth
optional arguments:
--vis_name Visual image name [default value is '/home/data/sketchy/val/photo/helicopter/ext_5.jpg']
--model_name Model name [default value is 'result/sketchy_resnet50_512_model.pth']
--save_root Result saved root path [default value is 'result']
The models are trained on one NVIDIA GeForce RTX 3090 (24G) GPU. AdamW
is used to optimize the model, lr
is 1e-5
and weight decay
is 5e-4
. all the hyper-parameters are the default values.
Backbone | Dim | Sketchy Extended | TU-Berlin Extended | Download | ||||||
---|---|---|---|---|---|---|---|---|---|---|
mAP@200 | mAP@all | P@100 | P@200 | mAP@200 | mAP@all | P@100 | P@200 | |||
VGG16 | 64 | 36.1 | 39.8 | 52.8 | 48.1 | 44.2 | 39.3 | 57.1 | 53.9 | u7qg |
VGG16 | 512 | 42.7 | 45.1 | 58.9 | 53.6 | 48.6 | 42.8 | 60.7 | 57.2 | 6up4 |
VGG16 | 4096 | 44.6 | 47.3 | 60.1 | 55.2 | 50.0 | 44.1 | 61.8 | 58.5 | hznm |
ResNet50 | 64 | 43.3 | 46.6 | 58.6 | 54.3 | 50.7 | 47.7 | 61.1 | 58.5 | uhkp |
ResNet50 | 512 | 52.6 | 55.4 | 66.0 | 61.7 | 58.0 | 53.5 | 67.5 | 65.0 | u8ct |
ResNet50 | 2048 | 53.7 | 56.8 | 66.4 | 62.5 | 60.4 | 56.1 | 69.4 | 67.1 | ipr3 |