d-ailin / SSProbing

Repo for the paper "Trust, but Verify: Using Self-Supervised Probing to Improve Trustworthiness" (ECCV 2022)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

SSProbing

Code for the paper: Trust, but Verify: Using Self-Supervised Probing to Improve Trustworthiness (ECCV'22)

Demo Visualization

We have provided a notebook vis_demo.ipynb(link) for a demo visualization of self-supervised probing confidence scores.

Environment Setup

Setup environment and install with the packages:

  • Python==3.6.13
  • torch==1.10.1
  • torchvision==0.11.2

Our environment is partly based on the ConfidNet.

$ git clone https://github.com/valeoai/ConfidNet
$ pip install -e ConfidNet

After installing above packages, install packages in the requirements.txt

    pip install -r requirements.txt

The Juypter Lab is not installed in the requirements. If you would like to use it, you could install it manually.

Data

All the datasets should be placed under data/ directory.

    mkdir data/

OOD related datasets:

Pre-trained models

The pre-trained model weights we used in the paper can be accessed via this link.

Training the base model

If you perfer to train a new base model and test it out, you could follow the following guidance.

As we use MCDropout as one of our baselines, the base models require Dropout components. In detail, the model implementation can refer to VGG16 and our model file under the train_base/models/.

For the VGG16 model training, please refer to ConfidNet.

For the resnet18 model training, you could follow the below example:

    $ cd train_base
    $ python train_cifar10.py -e 300

After the training, the models should be saved in train_base/snapshots/xxx

Applying SSProbing on the pre-trained models

The following commands or examples are applicable if you have downloaded the pre-trained models and unzipped it in the repo directory. Otherwise, you could slightly modify the commands according your actual demand, e.g. the config path argument to your actual saved or trained model path.

    # create the res output directory
    mkdir res_dir/

Misclassification Detection

Example:

    python -u mis_detect.py -c snapshots/cifar10_resnet18/cifar10_resnet18_dp_baseline_epoch_299.pt -m mcp -t ./task_configs/rot4_trans5.yaml -sf cifar10_mis_res.txt -se 5
  • -c: a trained model path
  • -m: the baseline method, options:['mcp', 'mcdropout', 'trustscore', 'tcp']
  • -t: task config path, set the probing task and details
  • -sf: the result output filename, the file will be under res_dir after the successful execution
  • -se: self-supervised probing task training epoch number

OOD Detection

Example:

    python -u ood_detect.py -c snapshots/cifar10_resnet18/cifar10_resnet18_dp_baseline_epoch_299.pt -m mcp -se 5 -sf cifar10_ood_res.text
  • -c: a trained model path
  • -m: the baseline method, options:['mcp', 'entropy']
  • -sf: the result output filename, the file will be under res_dir after the successful execution
  • -se: self-supervised probing task training epoch number

Calibration

Example:

    python -u cal.py -c snapshots/cifar10_resnet18/cifar10_resnet18_dp_baseline_epoch_299.pt -sf cifar10_cal_res.txt
  • -c: a trained model path
  • -sf: the result output filename, the file will be under res_dir after the successful execution

For more details, you could also refer to run_mis.sh / run_ood.sh / run_cal.sh.

Citation

If you find this repo or our work useful for your research, please consider citing the paper

@inproceedings{deng2022trust,
  title={Trust, but Verify: Using Self-supervised Probing to Improve Trustworthiness},
  author={Deng, Ailin and Li, Shen and Xiong, Miao and Chen, Zhirui and Hooi, Bryan},
  booktitle={European Conference on Computer Vision},
  pages={361--377},
  year={2022},
  organization={Springer}
}

Acknowledgement

About

Repo for the paper "Trust, but Verify: Using Self-Supervised Probing to Improve Trustworthiness" (ECCV 2022)

License:MIT License


Languages

Language:Python 82.5%Language:Jupyter Notebook 16.9%Language:Shell 0.6%