Jing--Li / HyMOS

Hyperspherical classification for Multi-source Open-Set domain adaptation

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

HyMOS (Hyperspherical classification for Multi-source Open-Set domain adaptation)

PyTorch official implementation of "Distance-based Hyperspherical Classification for Multi-source Open-Set Domain Adaptation". Video presentation available here. Test Image 1

Vision systems trained in close-world scenarios will inevitably fail when presented with new environmental conditions, new data distributions and novel classes at deployment time. How to move towards open-world learning is along standing research question, but the existing solutions mainly focus on specific aspects of the problem (single domain open-set, multi-domain closed-set), or propose complex strategies which combine multiple losses and manually tuned hyperparameters. In this work we tackle multi-source open-set domain adaptation by introducing HyMOS: a straightforward supervised model that exploits the power of contrastive learning and the properties of its hyperspherical feature space to correctly predict known labels on the target, while rejecting samples belonging to any unknown class. HyMOS includes a tailored data balancing to enforce cross-source alignment and introduces style transfer among the instance transformations of contrastive learning for source-target adaptation, avoiding the risk of negative transfer. Finally a self-training strategy refines the model without the need for handcrafted thresholds. We validate our method over three challenging datasets and provide an extensive quantitative and qualitative experimental analysis. The obtained results show that HyMOS outperforms several open-set and universal domain adaptation approaches, defining the new state-of-the-art.

Office-31 (HOS)

D,A -> W W,A -> D W,D -> A Avg.
HyMOS 90.2 89.9 60.8 80.3

Office-Home (HOS)

Ar,Pr,Cl→Rw Ar,Pr,Rw→Cl Cl,Pr,Rw→Ar Cl,Ar,Rw→Pr Avg.
HyMOS 71.0 64.6 62.2 71.1 67.2

DomainNet (HOS)

I,P -> S I,P -> C Avg.
HyMOS 57.5 61.0 59.3

Code

1. Requirements

Packages

The code requires these packages:

  • python 3.6+
  • torch 1.6+
  • torchvision 0.7+
  • CUDA 10.1+
  • scikit-learn 0.22+
  • tensorboardX
  • tqdm
  • torchlars == 0.1.2
  • apex == 0.1
  • diffdist == 0.1

Datasets

Datasets OfficeHome, Office31 and DomainNet should be placed in ~/data. They can be downloaded from official sites:

Make sure that ~/data/OfficeHome/<domain> points to the correct domain directory for all the domains. It may be possible that you need to rename Real World folder to remove the space. Similarly you should check ~/data/DomainNet/<domain> and ~/data/Office31/<domain>.

Pretrained model

We use ResNet50 pretrained via SupCLR, taken from official github repository. We converted the checkpoint to pytorch format using this guide: here. The converted model is available here.

AdaIN model

We use a freely available PyTorch based AdaIN implementation that can be found here. Follow the instructions to train a model. Put source data in train_content_dir and target data in train_style_dir. We also included a model trained by us for the Office31 Dslr,Webcam -> Amazon shift together with this code. The file is named Amazon_adain.pth.

2. Training

In the examples below the training is performed on multiple GPUs. It is possible to use more or less by changing the value in --nproc_per_node=2 and setting CUDA_VISIBLE_DEVICES appropriately. In order to obtain domain and class-balance in each training mini batch the number of known classes of the datasets has to be divisible by the number of GPUs used. For example in the case of OfficeHome we use 3 GPUs because there are 45 known classes.

We use 'test_domain' to refer to target domain. Train output, with saved models and log files, is stored in logs/ directory.

Office31

CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 --master_port=10001 train.py --dataset Office31 \
    --test_domain <test_domain> --pretrain <path_to_resnet50_supclr_pretrained.pth> --adain_ckpt <path_to_adain_checkpoint.pth>

Test domain should be one of "Amazon", "Webcam", "Dslr".

For example to train for the experiment having Amazon as target using the provided AdaIN model and a SupCLR pretrained ResNet50 model we use:

CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 --master_port=10001 train.py --dataset Office31 \
    --test_domain Amazon --pretrain pretrained/resnet50_SupCLR.pth --adain_ckpt Amazon_adain.pth

Use a different port if 10001 is already taken.

OfficeHome

CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3 --master_port=10001 train.py --dataset OfficeHome \
    --test_domain <test_domain> --pretrain <path_to_resnet50_supclr_pretrained.pth> --adain_ckpt <path_to_adain_checkpoint.pth>

Test domain should be one of "Art", "Clipart", "Product", "RealWorld".

DomainNet

CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 --master_port=10001 train.py --dataset DomainNet \
    --test_domain <test_domain> --pretrain <path_to_resnet50_supclr_pretrained.pth> --adain_ckpt <path_to_adain_checkpoint.pth>

Test domain should be one of "ipc", "ips".

3. Evaluation

Periodic evaluation is performed during training. The final model can be tested using:

CUDA_VISIBLE_DEVICES=0 python eval.py --dataset <dataset> --test_domain <target_domain> --load_path <path_to_last.model>

For example to test the model trained for Office31 shift having Amazon as target we use:

CUDA_VISIBLE_DEVICES=0 python eval.py --dataset Office31 --test_domain Amazon --load_path logs/Dataset-Office31_Target-Amazon_Mode-HyMOS_st_batchK-20_batchP-2_iterative_ProbST-0.5/last.model

Citation

To cite, please use the following reference:

@inproceedings{bucci2022distance,
  title={Distance-based Hyperspherical Classification for Multi-source Open-Set Domain Adaptation},
  author={Silvia Bucci, Francesco Cappio Borlino, Barbara Caputo, Tatiana Tommasi},
  booktitle={Winter Conference on Applications of Computer Vision},
  year={2022}
} 

About

Hyperspherical classification for Multi-source Open-Set domain adaptation


Languages

Language:Python 100.0%