ZlatanWilliams / StochasticDisturbanceLearning

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Certifiable Out-of-Distribution Generalization

This is the code repository of the paper Certifiable Out-of-Distribution Generalization. This repository contains the code for getting OoD-Bench results on diversity and correlation shift datasets (Modified from the PyTorch suite DomainBed).

Environment preparation

Recommended PyTorch environment

Environment:
    Python: 3.9.12
    PyTorch: 1.10.1+cu111
    Torchvision: 0.11.2+cu111
    CUDA: 11.1
    CUDNN: 8005
    NumPy: 1.21.2
    PIL: 9.1.0

The code can theoretically work on arbitrary PyTorch environments, but we do not recommend too old version of PyTorch to avoid the version conflict of the wilds package and some errors when installing torch-scatter. The experiment results may have a slight fluctuation when changing the environment. This is because of several factors such as the PyTorch version, the CUDA version or the GPU version, and hardware. For reproducing the results exactly, we recommend using our trained weights here.

Pre-trained weights preparation

Download ImageNet pre-trained weights of ResNet-18 from https://download.pytorch.org/models/resnet18-5c106cde.pth, and place it under pretrained_weights. The directory structure should be:

StochasticDisturbanceLearning
├── datasets
├── DomainBed
├── pretrained_weights
├── ...

Data preparation

Most of the datasets (except for CelebA and NICO) can be downloaded by running the script DomainBed/domainbed/scripts/download.py, and for NICO we provide a download link for there are some weird extension names in the original dataset. After the download, place the datasets under datasets and make sure the directory structures are as follows:

PACS
└── kfold
    ├── art_painting
    ├── cartoon
    ├── photo
    └── sketch
office_home
├── Art
├── Clipart
├── Product
├── Real World
├── ImageInfo.csv
└── imagelist.txt
terra_incognita
├── location_38
├── location_43
├── location_46
└── location_100
WILDS
└── camelyon17_v1.0
    ├── patches
    └── metadata.csv
MNIST
└── processed
    ├── training.pt
    └── test.pt
celeba
├── img_align_celeba
└── blond_split
    ├── tr_env1_df.pickle
    ├── tr_env2_df.pickle
    └── te_env_df.pickle
NICO
├── animal
├── vehicle
└── mixed_split
    ├── env_train1.csv
    ├── env_train2.csv
    ├── env_val.csv
    └── env_test.csv

Note: the data split files of CelebA and NICO are already provided under datasets.

Implement the experiments

To simply run experiments for a certain dataset under a certain algorithm, see DomainBed/run.py.
Example usage:

# Launch
python run.py launch --dataset PACS --algorithm SDL_Gaussian
# If not complete
python run.py delete_incomplete --dataset PACS --algorithm SDL_Gaussian
# List the running status
python run.py list --dataset PACS --algorithm SDL_Gaussian

If you want to try for a group of hyper-parameters, firstly edit DomainBed/sweep/${dataset}/hparams.json to lock the hyperparameters, then simply run DomainBed/tuning.py as the following example:

python tuning.py launch --dataset PACS --algorithm SDL_Gaussian --lr 6e-5 --worst_case_p 0.1

To show the results of running or adjusting, run DomainBed/collect_run_results.py or DomainBed/collect_adjust_results.py as the following examples:

python collect_run_results.py --dataset PACS --algorithm SDL_Gaussian
python collect_tuning_results.py --dataset PACS --algorithm SDL_Gaussian --lr 6e-5 --worst_case_p 0.1

You can check the args in DomainBed/collect_adjust_results.py for all the hyper-parameters (the default version is only for SDL algorithms, for others you can write it on your own). The model weights and results.txt will be stored in DomainBed/sweep/${dataset}/outputs/. We also provide a Python script to list the hyperparameters searching details by connecting DomainBed/domainbed/scripts/list_top_hparams.py, the example usage is as follows:

python list.py --dataset PACS --algorithm SDL_Gaussian --test_env 0

Citation

We would appreciate any suggestions and comments. If you find our work helpful, please cite our papers. Thanks!

@inproceedings{ye2023certifiable,
  title={Certifiable out-of-distribution generalization},
  author={Ye, Nanyang and Zhu, Lin and Wang, Jia and Zeng, Zhaoyu and Shao, Jiayao and Peng, Chensheng and Pan, Bikang and Li, Kaican and Zhu, Jun},
  booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
  volume={37},
  number={9},
  pages={10927--10935},
  year={2023}
}

About


Languages

Language:Python 97.9%Language:Shell 2.1%