sua-choi / CMS

[CVPR'24] Official PyTorch implementation of Contrastive Mean-Shift Learning for Generalized Category Discovery

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Contrastive Mean-Shift Learning for Generalized Category Discovery

Pohang University of Science and Technology (POSTECH)



result

Environmnet installation

This project is built upon the following environment:

The package requirements can be installed via requirements.txt,

pip install -r requirements.txt

Datasets

We use fine-grained benchmarks in this paper, including:

We also use generic object recognition datasets, including:

Please follow this repo to set up the data.

Download the datasets, ssb splits, and pretrained backbone by following the file structure below and set DATASET_ROOT={YOUR DIRECTORY} in config.py.

    DATASET_ROOT/
    ├── cifar100/
    │   ├── cifar-100-python\
    │   │   ├── meta/
    │       ├── ...
    ├── CUB_200_2011/
    │   ├── attributes/
    │   ├── ...
    ├── ...
    CMS/
    ├── data/
    │   ├── ssb_splits/
    ├── models/
    │   ├── dino_vitbase16_pretrain.pth
    ├── ...

Training

bash bash_scripts/contrastive_meanshift_training.sh

Example bash commands for training are as follows:

# GCD
python -m methods.contrastive_meanshift_training  \
            --dataset_name 'cub' \
            --lr 0.05 \
            --temperature 0.25 \
            --wandb 

# Inductive GCD
python -m methods.contrastive_meanshift_training  \
            --dataset_name 'cub' \
            --lr 0.05 \
            --temperature 0.25 \
            --inductive \
            --wandb 

Evaluation

bash bash_scripts/meanshift_clustering.sh

Example bash command for evaluation is as follows. It will require changing model_name.

python -m methods.meanshift_clustering \
        --dataset_name 'cub' \
        --model_name 'cub_best' \

Results and checkpoints

Experimental results on GCD task.

All Old Novel Checkpoints
CIFAR100 82.3 85.7 75.5 link
ImageNet100 84.7 95.6 79.2 link
CUB 68.2 76.5 64.0 link
Stanford Cars 56.9 76.1 47.6 link
FGVC-Aircraft 56.0 63.4 52.3 link
Herbarium19 36.4 54.9 26.4 link

Experimental results on inductive GCD task.

All Old Novel Checkpoints
CIFAR100 80.7 84.4 65.9 link
ImageNet100 85.7 95.7 75.8 link
CUB 69.7 76.5 63.0 link
Stanford Cars 57.8 75.2 41.0 link
FGVC-Aircraft 53.3 62.7 43.8 link
Herbarium19 46.2 53.0 38.9 link

Citation

If you find our code or paper useful, please consider citing our paper:

  @inproceedings{choi2024contrastive,
    title={Contrastive Mean-Shift Learning for Generalized Category Discovery},
    author={Choi, Sua and Kang, Dahyun and Cho, Minsu},
    booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
    year={2024}
  }

Related Repos

The codebase is largely built on Generalized Category Discovery and PromptCAL.

Acknowledgements

This work was supported by the NRF grant (NRF-2021R1A2C3012728 (50%)) and the IITP grants (2022-0-00113: Developing a Sustainable Collaborative Multi-modal Lifelong Learning Framework (45%), 2019-0-01906: AI Graduate School Program at POSTECH (5%)) funded by Ministry of Science and ICT, Korea.

About

[CVPR'24] Official PyTorch implementation of Contrastive Mean-Shift Learning for Generalized Category Discovery


Languages

Language:Python 96.9%Language:Shell 3.1%