spetryk / GALS

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

GALS: Guiding Visual Attention with Language Specification

GALS

This is the official implementation for the CVPR 2022 paper On Guiding Visual Attention with Language Specification by Suzanne Petryk*, Lisa Dunlap*, Keyan Nasseri, Joseph Gonzalez, Trevor Darrell, and Anna Rohrbach.

If you find our code or paper useful, please cite:

@article{petryk2022gals,
  title={On Guiding Visual Attention with Language Specification},
  author={Petryk, Suzanne  and Dunlap, Lisa and Nasseri, Keyan and Gonzalez, Joseph and Darrell, Trevor and Rohrbach, Anna},
  journal={Conference on Computer Vision and Pattern Recognition (CVPR)},
  year={2022}
}

Setting up Environment

Conda

conda env create -f env.yaml
conda activate gals

Pip

pip install -r requirements.txt

Download Datasets

Please see the original dataset pages for further detail:

  • Waterbirds
    • Waterbirds-95%: The original dataset page includes instructions on how to generate different biased splits. Please download Waterbirds-95% from the link for waterbird_complete95_forest2water2 on their page.
    • Waterbirds-100%: We use the script from Sagawa & Koh et al. to generate the 100% biased split. For convenience, we supply this split of the dataset here: Waterbirds-100%.
  • Food101: Original dataset page. Please download the images from the original dataset page. We construct the 5-class subset for Red Meat with images from the page.
  • MSCOCO-ApparentGender: Original dataset page. Please use the original dataset page to download the COCO 2014 train & validation images and annotations. We base MSCOCO-ApparentGender on the dataset used in Women Also Snowboard (by Burns & Hendricks et al.). We modify the training IDs slightly, yet keep the same evaluation set. Please download the files about the splits here: MSCOCO-ApparentGender.

The data is expected to be under the folder ./data. More specifically, here is the suggested data file structure:

  • ./data
    • waterbird_complete95_forest2water2/ (Waterbirds-95%)
    • waterbird_1.0_forest2water2/ (Waterbirds-100%)
    • food-101/ (Red Meat)
    • COCO/
      • annotations/ (COCO annotations from original dataset page)
      • train2014/ (COCO images from original dataset page)
      • val2014/ ( COCO images from original dataset page)
      • COCO_gender/ (ApparentGender files we provided)

Repo Organization

  • main.py is the point of entry for model training.
  • configs/ contains .yaml configuration files for each dataset and model type.
  • extract_attention.py is the script to precompute attention with CLIP ResNet50 GradCAM and CLIP ViT transformer attention.
  • approaches/ contains training code. approaches/base.py is for general training, and is extended by model-specific approaches such as approaches/abn.py, or datasets requiring extra evaluation (such as approaches/coco_gender.py).
  • datasets/ contains PyTorch dataset creation files.
  • models/ contains architectures for both vanilla and ABN ResNet50 classification models.
  • utils/ contains helper functions for general training, loss and attention computation.

This repo also expects the following additional folders:

  • ./data: contains the dataset folders
  • ./weights: contains pretrained ImageNet ResNet50 weights for the ABN model, named resnet50_abn_imagenet.pth.tar. These weights are provided by Hiroshi Fukui & Tsubasa Hirakawa from their codebase. For convenience, you may also find the weights with the correct naming here.

We use Weights & Biases to log experiments. This requires the user to be logged in to a (free) W&B account. Details to set up an account here.

Training & Evaluation

Training models using GALS is a 2 stage process:

  1. Generate and store attention per image
  2. Train model using attention

Example commands training networks with GALS as well as the baselines within the paper are below.

NOTE: To change .yaml configuration values on the command line, add text of the form ATTRIBUTE.NESTED=new_value to the end of the command. For example:

CUDA_VISIBLE_DEVICES=0 python main.py --config configs/waterbirds_100_gals.yaml DATA.BATCH_SIZE=96

Stage 1: Generate Attention

Important files:

Sample command:

CUDA_VISIBLE_DEVICES=0 python extract_attention.py --config configs/coco_attention.yaml

Stage 2: Train model

Important files:

The model configs include the hyperparameters and attention settings used to reproduce results in our paper.

An example command to train a model with GALS on Waterbirds-100%:

CUDA_VISIBLE_DEVICES=0,1,2 python main.py --name waterbirds100_gals --config configs/waterbirds_100_gals.yaml

The --name flag is used for Weights & Biases logging. You can add --dryrun to the command to run locally without uploading to the W&B server. This can be useful for debugging.

Model evaluation

To evaluate a model on the test split for a given dataset, simply use the --test_checkpoint flag and provide a path to a trained checkpoint. For example, to evaluate a Waterbirds-95% GALS model with weights under a trained_weights directory

CUDA_VISIBLE_DEVICES=0 python main.py --config configs/waterbirds_95_gals.yaml --test_checkpoint trained_weights/waterbirds_95_gals.ckpt

Note: For MSCOCO-ApparentGender, the Ratio Delta in our paper is 1-test_ratio in the output results.

Checkpoints/Results

In our paper, we report the mean and standard deviation over 10 trials. Below, we include a checkpoint from a single trial per experiment.

Waterbirds 100%

Method Per Group Acc (%) Worst Group Acc (%)
GALS 80.67 57.00
Vanilla 72.36 32.20
UpWeight 72.22 37.29
ABN 71.96 44.39

Waterbirds 95%

Method Per Group Acc (%) Worst Group Acc (%)
GALS 89.03 79.91
Vanilla 86.91 73.21
UpWeight 87.51 76.48
ABN 86.85 69.31

Red Meat (Food101)

Method Acc (%) Worst Group Acc (%)
GALS 72.24 58.00
Vanilla 69.20 48.80
ABN 69.28 52.80

MSCOCO-ApparentGender

Method Ratio Delta Outcome Divergence
GALS 0.160 0.022
Vanilla 0.349 0.071
UpWeight 0.272 0.040
ABN 0.334 0.068

Acknowledgements

We are very grateful to the following people, from which we have used code throughout this repository that is taken or based off of their work:

About


Languages

Language:Jupyter Notebook 95.2%Language:Python 4.8%