Update:
- The manuscript has been accepted in CVPR 2024.
- Core code has been updated
This is the official implementation based on pytorch of the paper VRP-SAM: SAM with Visual Reference Prompt
Authors: Yanpeng Sun, Jiahui Chen, Shan Zhang, Xinyu Zhang, Qiang Chen, Gang Zhang, Errui Ding, Jingdong Wang, Zechao Li
- Python 3.10
- PyTorch 1.12
- cuda 11.6
Conda environment settings:
conda create -n vrpsam python=3.10
conda activate vrpsam
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.6 -c pytorch -c conda-forge
Segment-Anything-Model setting:
cd ./segment-anything
pip install -v -e .
cd ..
Download following datasets:
Download PASCAL VOC2012 devkit (train/val data):
wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tarDownload PASCAL VOC2012 SDS extended mask annotations from our [Google Drive].
Download COCO2014 train/val images and annotations:
wget http://images.cocodataset.org/zips/train2014.zip wget http://images.cocodataset.org/zips/val2014.zip wget http://images.cocodataset.org/annotations/annotations_trainval2014.zipDownload COCO2014 train/val annotations from our Google Drive: [train2014.zip], [val2014.zip]. (and locate both train2014/ and val2014/ under annotations/ directory).
Create a directory '../dataset' for the above few-shot segmentation datasets and appropriately place each dataset to have following directory structure:
../ # parent directory
├── ./ # current (project) directory
│ ├── common/ # (dir.) helper functions
│ ├── data/ # (dir.) dataloaders and splits for each FSSS dataset
│ ├── model/ # (dir.) implementation of VRP-SAM
│ ├── segment-anything/ # code for SAM
│ ├── README.md # intstruction for reproduction
│ ├── train.py # code for training HSNet
│ └── SAM2Pred.py # code for prediction module
│
└── Datasets_HSN/
├── VOC2012/ # PASCAL VOC2012 devkit
│ ├── Annotations/
│ ├── ImageSets/
│ ├── ...
│ └── SegmentationClassAug/
└── COCO2014/
├── annotations/
│ ├── train2014/ # (dir.) training masks (from Google Drive)
│ ├── val2014/ # (dir.) validation masks (from Google Drive)
│ └── ..some json files..
├── train2014/
└── val2014/
We provide a example training script "train.sh". Detailed training argumnets are as follows:
python3 -m torch.distributed.launch --nproc_per_node=$GPUs$ train.py \ --datapath $PATH_TO_YOUR_DATA$ \ --logpath $PATH_TO_YOUR_LOG$ \ --benchmark {coco, pascal} \ --backbone {vgg16, resnet50, resnet101} \ --fold {0, 1, 2, 3} \ --condition {point, scribble, box, mask} \ --num_queirs 50 \ --epochs 50 \ --lr 1e-4 \ --bsz 2
If you use this code for your research, please consider citing:
@inproceedings{sun2024vrp,
title={VRP-SAM: SAM with Visual Reference Prompt},
author={Sun, Yanpeng and Chen, Jiahui and Zhang, Shan and Zhang, Xinyu and Chen, Qiang and Zhang, Gang and Ding, Errui and Wang, Jingdong and Li, Zechao},
booktitle={Conference on Computer Vision and Pattern Recognition 2024},
year={2024}
}