mashaan14 / ADDA-toy

A pytorch implementation of ADDA paper published in CVPR2017 with 2D toy example.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

ADDA PyTorch implementation with a toy example

Python 3.10+ PyTorch torchvision

Adversarial Discriminative Domain Adaptation (ADDA) is one of the well-known benchmarks for domain adaptation tasks. ADDA was introduced in this paper:

@InProceedings{Tzeng_2017_CVPR,
  author =    {Tzeng, Eric and Hoffman, Judy and Saenko, Kate and Darrell, Trevor},
  title =     {Adversarial Discriminative Domain Adaptation},
  booktitle = {Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
  year =      {2017},
  month =     {July}
}

This ADDA implementation uses a 2D toy dataset with built-in plots that help to visualize how the ADDA algorithm adapts the target features.

Two dimensional dataset

The code starts by retrieving source dataset from data folder. Then it performs a rotation (domain shift) on a copy of the dataset. The rotated dataset is the target dataset. Here is a visualization of source and target datasets:

Source domain classifier

The encoder and classifier networks are trained to separate source class 0 and source class 1. Most of this logic happens in core.train_src function. Then, the learned model is tested on the test data:

Adversarial adaptation

The adversarial adaptation takes place in core.train_tgt function. The goal is to confuse the discriminator so it cannot tell if the sample is drawn from source or target domain. Once we reach this level of learning, we use this learned features to train the target encoder. For comparison, these features are passed through source classifier and target classifier:

>>> Testing target data using source encoder <<<
Avg Loss = 0.41065776348114014, Avg Accuracy = 89.000000%, ARI = 0.60646

>>> Testing target data using target encoder <<<
Avg Loss = 0.3132730381829398, Avg Accuracy = 100.000000%, ARI = 1.00000

About

A pytorch implementation of ADDA paper published in CVPR2017 with 2D toy example.


Languages

Language:Python 100.0%