virajprabhu / CLUE

PyTorch code for Active Domain Adaptation via Clustering Uncertainty-weighted Embeddings (ICCV 2021)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Viraj Prabhu, Arjun Chandrasekaran, Kate Saenko, Judy Hoffman

Generalizing deep neural networks to new target domains is critical to their real-world utility. In practice, it may be feasible to get some target data labeled, but to be cost-effective it is desirable to select a maximally-informative subset via active learning (AL). We study the problem of AL under a domain shift, called Active Domain Adaptation (Active DA). We demonstrate how existing AL approaches based solely on model uncertainty or diversity sampling are less effective for Active DA. We propose Clustering Uncertainty-weighted Embeddings (CLUE), a novel label acquisition strategy for Active DA that performs uncertainty-weighted clustering to identify target instances for labeling that are both uncertain under the model and diverse in feature space. CLUE consistently outperforms competing label acquisition strategies for Active DA and AL across learning settings on 6 diverse domain shifts for image classification.

method

Table of Contents

Setup and Dependencies

  1. Create an anaconda environment with Python 3.6 and activate:
conda create -n CLUE python=3.6.8
conda activate CLUE
  1. Navigate into the code directory: cd CLUE/
  2. Install dependencies: (Takes ~2-3 minutes)
pip install -r requirements.txt
  1. [If running the demo] Install nb_conda:
conda install -c anaconda-nb-extensions nb_conda

And you're all set up!

Usage

Train Active Domain Adaptation model

Run python train.py to train an active adaptation model from scratch, by passing it appropriate arguments.

We include hyperparameter configurations to reproduce paper numbers on DIGITS and DomainNet as configurations inside the config folder. For instance, to reproduce DIGITS (SVHN->MNIST) results with CLUE+MME, run:

python train.py --load_from_cfg True \ 
                --cfg_file config/digits/clue_mme.yml \
                --use_cuda False

To run a custom train job, you can create a custom config file and pass it to the train script. Pass --use_cuda False if you'd like to train on CPU instead.

Download data

Data for SVHN->MNIST is downloaded automatically via PyTorch. For DomainNet, follow the following steps:

  1. Download the original dataset for the domains of interest from this link – eg. Clipart and Sketch.
  2. Run:
python preprocess_domainnet.py --input_dir <input_directory> \
                               --domains 'clipart,sketch' \
                               --output_dir 'data/'

Pretrained checkpoints

At round 0, active adaptation begins from a model trained on the source domain, or from a model first trained on source and then adapted to the target via unsupervised domain adaptation. Checkpoints for reproducing DIGITS experiments have been included in the checkpoints/ directory, and those for reproducing DomainNet results on Clipart->Sketch can be downloaded at this link. Note that checkpoints for models after active adaptation are not included.

Evaluation and plotting Results

Run python evaluate.py by passing it appropriate arguments (see file for instructions). It will pretty-print raw results as well as save them as a figure in the plots/ directory. By default, it will generate a figure comparing CLUE + MME against a subset of representative Active DA and AL baselines and save it to the plots/ directory.

Demo

  1. Start a jupyter notebook with ```jupyter notebook''', and set the conda environment to adaclue
  2. Run the Jupyter notebook demo.ipynb, which will walk you through:
    • Loading SVHN, MNIST datasets and pretrained checkpoints
    • Label acquisition with baseline strategies and CLUE+MME
    • Training (on CPU) with acquired labels
    • Plotting performance after one round of Active DA on SVHN->MNIST

Reference

If you found this code useful, please consider citing:

@inproceedings{prabhu2021active,
  title={Active domain adaptation via clustering uncertainty-weighted embeddings},
  author={Prabhu, Viraj and Chandrasekaran, Arjun and Saenko, Kate and Hoffman, Judy},
  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
  pages={8505--8514},
  year={2021}
}

Acknowledgements

We would like to thank the developers of PyTorch for building an excellent framework, the Deep Active Learning repository for implementations of some of our baselines, and the numerous contributors to all the open-source packages we use.

License

MIT

About

PyTorch code for Active Domain Adaptation via Clustering Uncertainty-weighted Embeddings (ICCV 2021)

License:MIT License


Languages

Language:Jupyter Notebook 85.7%Language:Python 14.3%