HazyResearch / correct-n-contrast

Official code repository for Correct-N-Contrast

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Correct-N-Contrast: a Contrastive Approach for Improving Robustness to Spurious Correlations

Correct-N-Contrast

This repository contains code for Correct-N-Contrast (CNC), presented at ICML 2022.

Requirements

To install requirements, we recommend setting up a virtual environment with conda:

conda env create -f environment.yml  
conda activate cnc

We also include a requirements.txt file for installing dependencies with pip install -r requirements.txt.

List of (installable) dependencies:

  • python 3.7.9
  • matplotlib 3.3.2
  • numpy 1.19.2
  • pandas 1.1.3
  • pillow 8.0.1
  • pytorch=1.7.0
  • scikit-learn 0.23.2
  • scipy 1.5.2
  • transformers 4.4.2
  • torchvision 0.8.1
  • tqdm 4.54.0
  • umap-learn 0.4.6

Datasets and code

Colored MNIST: Running the training command below should automatically download and setup the Colored MNIST dataset.

Waterbirds: Download the dataset from here. Unzipping this should result in a folder waterbird_complete95_forest2water2, which should be moved to ./datasets/data/Waterbirds/.

CelebA: Download dataset files from this Kaggle link. Then move files to ./datasets/data/CelebA/ such that we have the following structure:

# In `./datasets/data/CelebA/`:
|-- list_attr_celeba.csv
|-- list_eval_partition.csv
|-- img_align_celeba/
    |-- image1.png
    |-- ...
    |-- imageN.png

CivilComments-WILDS: Loading this dataset requires the transformers package. One can download the source csv from here. Then, move all_data_with_identities.csv to ./datasets/data/CivilComments/all_data_with_identities.csv.

Training and Evaluation

For all datasets except Colored MNIST, running the below commands requires loading an initial trained ERM model, which we provide through download links below in the Pre-trained Models section. The training for these models is as described in the paper. We recommend downloading as training the initial ERM model can take a fair amount of time, e.g. ~1.5 hours for Waterbirds on a machine with 8 CPUs and 1 NVIDIA V100 GPU, and ~3 hours for CelebA on a machine with 32 CPUs and 4 NVIDIA V100 GPUs. Training these ERM models also just involves standard mini-batch SGD with the dataset labels.

Colored MNIST

Train:

python train_supervised_contrast.py --dataset colored_mnist --train_encoder --arch cnn --data_cmap hsv --test_shift random -tc 0 1 -tc 2 3 -tc 4 5 -tc 6 7 -tc 8 9 --p_correlation 0.995 -tcr 1.0 -tcr 1.0 -tcr 1.0 -tcr 1.0 -tcr 1.0 --max_epoch_s 5 --bs_trn_s 32 --num_anchor 32 --num_positive 32 --num_negative 32 --num_negative_easy 32 --batch_factor 32 --optim sgd --lr 1e-3 --momentum 0.9 --weight_decay 1e-4 --weight_decay_c 1e-4 --target_sample_ratio 1 --temperature 0.05 --max_epoch 3 --no_projection_head --contrastive_weight 0.75 --bs_trn 32 --bs_val 32 --num_workers 0 --no_projection_head --log_loss_interval 10 --checkpoint_interval 10000 --log_visual_interval 40000 --verbose --replicate 42 --seed 42

Evaluate:

python train_supervised_contrast.py --dataset colored_mnist --arch cnn --evaluate --load_encoder [model_file_name.pth.tar] --data_cmap hsv --test_shift random -tc 0 1 -tc 2 3 -tc 4 5 -tc 6 7 -tc 8 9 --p_correlation 0.995 -tcr 1.0 -tcr 1.0 -tcr 1.0 -tcr 1.0 -tcr 1.0

Waterbirds

Train:

python train_supervised_contrast.py --dataset waterbirds --arch resnet50_pt --train_encoder --pretrained_spurious_path "./model/waterbirds/waterbirds_erm_regularized.pt" --num_anchor 17 --num_positive 17 --num_negative 17 --num_negative_easy 17 --batch_factor 32 --optim sgd --lr 1e-4 --momentum 0.9 --weight_decay 1e-3 --weight_decay_c 1e-3 --target_sample_ratio 1 --temperature 0.1 --max_epoch 5 --no_projection_head --contrastive_weight 0.75 --log_visual_interval 10000 --checkpoint_interval 10000 --verbose --log_loss_interval 10 --replicate 42 --seed 42

Evaluate:

python train_supervised_contrast.py --dataset waterbirds --arch resnet50_pt --evaluate --load_encoder waterbirds_cnc_pretrained.pth.tar 

CelebA

Train:

python train_supervised_contrast.py --dataset celebA --arch resnet50_pt --train_encoder --pretrained_spurious_path "./model/celebA/celeba_erm_regularized.pt" --num_anchor 64 --num_positive 64 --num_negative 64 --num_negative_easy 64 --batch_factor 32 --optim sgd --lr 1e-5 --momentum 0.9 --weight_decay 1e-1 --weight_decay_c 1e-1 --target_sample_ratio 0.1 --temperature 0.05 --max_epoch 15 --no_projection_head --contrastive_weight 0.75 --log_visual_interval 10000 --checkpoint_interval 10000 --verbose --log_loss_interval 10 --replicate 42 --seed 42

Evaluate:

python train_supervised_contrast.py --dataset celebA --arch resnet50_pt --evaluate --load_encoder celebA_cnc_pretrained.pth.tar 

CivilComments-WILDS

Train:

python -W ignore train_supervised_contrast.py --dataset civilcomments --arch bert-base-uncased_pt --train_encoder --pretrained_spurious_path ./model/civilcomments/civilcomments_erm_early.pth.tar --num_anchor 16 --num_positive 16 --num_negative 16 --num_negative_easy 16 --batch_factor 128 --bs_trn 16 --clip_grad_norm --optim AdamW --lr 1e-4 --weight_decay 1e-2 --target_sample_ratio 0.1 --temperature 0.1 --max_epoch 10 --no_projection_head --contrastive_weight 0.75 --log_loss_interval 10 --checkpoint_interval 10000 --verbose --log_visual_interval 400000 --verbose --replicate 42 --seed 42

Evaluate:

python -W ignore train_supervised_contrast.py --dataset civilcomments --arch bert-base-uncased_pt --evaluate --load_encoder civilcomments_cnc_pretrained.pth.tar 

Pre-trained Models

Both pretrained initial ERM models and the trained Correct-N-Contrast models are available to download here.

Once downloaded, models should be moved to the following file paths:

Waterbirds

  • ERM model: ./model/waterbirds/./model/waterbirds/waterbirds_erm_regularized.pt
  • CNC model: ./model/waterbirds/config-tn=waterbird_complete95-cn=['forest2water2']/waterbirds_cnc_pretrained.pth.tar

CelebA

  • ERM model: ./model/celebA/celeba_erm_regularized.pt
  • CNC model: ./model/celebA/config/celebA_cnc_pretrained.pth.tar

CivilComments-WILDs

  • ERM model: ./model/civilcomments/civilcomments_erm_early.pth.tar
  • CNC model: ./model/civilcomments/config/civilcomments_cnc_pretrained.pth.tar

About

Official code repository for Correct-N-Contrast

License:Apache License 2.0


Languages

Language:Python 100.0%