amirvhd / Uncertainty_aware_SSL

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

This repository is for "Diversified Ensemble of Independent Sub-Networks for Robust Self-Supervised Representation Learning" paper. It contains the example codes for different task descriptions.

Table of contents

Badges

MIT license

Dependency

pytorch_lightning torch

Installation

Install requirments:

pip install -r requirements.txt

Usage/Examples

Pretraining

You need to specify the path for dataset and also saved_models for the following codes. the You can run the pretraining of model for CIFAR-10 for UA-SSL with following code.

python main_pretrain.py --cosine --nh 10 --dataset cifar10 --lamda1 1 --lamda2 0.08 --epoch 800

Linear evaluation

You can run the linear-evaluation of the model for CIFAR-10 for UA-SSL with the following code. To get the results for semi-supervised, you need to use --semi and also specify what percccentage of data to be used for linear evaaluation for example --semi_percent 10 means it uses 10 percent of data for linear evaluation.

python main_linear.py --nh 10 --dataset cifar10 --lamda1 1 --lamda2 0.08

Uncertainty evaluation

You can calculate different metrics (e.g. accuracy, NLL, ECE, OE, ...) for the model with the following code. You should specify paths for the pretrained and linear evaluation models.

python uncertainty_metric.py --nh 10 --dataset cifar10 --lamda1 1 --lamda2 0.08

Out of distribtuion detection

You can calculate results of out of distribution detection for CIFAR-10 with the followingg code. You should give the path for datasets and also paths for pretrained and linear evaluation models.

python main_execute_method.py --nh 10 --dataset cifar10 --lamda1 1 --lamda2 0.08

To get the results for covariant shift (corrupted datasets) you should add --c to the code.

Acknowledgements

Base Simclr adapted from following repository:

Out of distribution code adapted from following repository:

Metrics for uncertainty analysis are taken from following repository:

About


Languages

Language:Python 100.0%