Newbeeer / TRM

Learning Representations that Support Robust Transfer of Predictors

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Transfer Risk Minimization (TRM)

Code for Learning Representations that Support Robust Transfer of Predictors

Yilun Xu, Tommi Jaakkola

TL,DR: We introduce a simple robust estimation criterion -- transfer risk -- that is specifically geared towards optimizing transfer to new environments. Effectively, the criterion amounts to finding a representation that minimizes the risk of applying any optimal predictor trained on one environment to another. The transfer risk essentially decomposes into two terms, a direct transfer term and a weighted gradient-matching term arising from the optimality of per-environment predictors.

Prepare the Datasets

Download PACS/Office-Home/MNIST dataset:

python scripts/download.py --data_dir {data_dir}

Places dataset can be downloaded at:

http://data.csail.mit.edu/places/places365/train_256_places365standard.tar ;

COCO dataset can be downloaded at:

http://images.cocodataset.org/annotations/annotations_trainval2017.zip

Preprocess the SceneCOCO dataset :

# preprocess COCO
python coco.py
# preprocess Places
python places.py

# generate SceceCOCO dataset
python cocoplaces.py

Running the Experiments

  • Datasets:

    • Synthetic datasets for controlled experiments: ColorMNIST / SceneCOCO
    • Real-world datasets: PACS / Office-Home
python -m domainbed.scripts.train  --data_dir {root} --algorithm {alg} \
	--dataset {dataset} --trial_seed {t_seed} --epochs {epochs}  (--shift {shift}) (--resnet50) (--test_eval)

root: root directory for the data
alg: ERM, VREx, IRM, GroupDRO, Fish, MLDG, TRM
t_seed: seed for data splitting
dataset: PACS or OfficeHome or ColoredMNIST or SceneCOCO
epochs: training epochs
resnet50: set ResNet50 as the backbone (default: ResNet18)
shift: for ColoredMNIST and SceneCOCO only, 0:label-correlated; 1: label-uncorrelated; 2: combine shift.
test_eval: test-domain validation set (default: train-domain validation set)

This implementation is based on / inspired by:

About

Learning Representations that Support Robust Transfer of Predictors


Languages

Language:Python 99.4%Language:Shell 0.6%