KKallidromitis / Contrastive-Neural-Processes

Implementation of Contrastive Neural Processes in PyTorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Contrastive Neural Processes

PyTorch implementation of "Contrastive Neural Processes for Self-Supervised Learning" accepted as a Long Oral at ACML2021

Implementation Details

This folder includes the code for Contrastive Neural Processes and Baselines. Code has been modified accordingly to the needs of the project. Original sources are cited here:

Folders:

  • Baselines : Includes all baselines, hyperparameters and evaluation metrics used for base experiments
  • ContrNP : Includes code for ContrNP method and resources
  • Results : Location where weights and results are saved
  • Data : Location where datasets are located. Use data_name_load.py to download and extract data. (Please Note: Some commands for data extraction are Ubuntu specific)

Baselines includes implementations for Tloss [1], CPC [2], TNC [2] and SimCLR [3].

Folders: npf, utils are for the implementation of Neural Processes [4].

  • Main Implementation of contrastive convolutional cnp: Contrastive-ConvCNP-SSL.ipynb
  • Implementation of Self supervised convolutional cnp: ConvCNP-SSL.ipynb
  • Implementation of Self supervised cnp: CNP-SSL.ipynb

Citing this work

[arXiv] [PMLR] [ACML2021]

@misc{kallidromitis2021contrastive,
      title={Contrastive Neural Processes for Self-Supervised Learning}, 
      author={Konstantinos Kallidromitis and Denis Gudovskiy and Kazuki Kozuka and Ohama Iku and Luca Rigazio},
      journal={arXiv preprint arXiv:2110.13623},
      year={2021}
}

Reproduced Results

AFDB IMS Bearing Urban8K
Method Accuracy AUPRC Sil↑ DBI↓ Accuracy AUPRC Sil↑ DBI↓ Accuracy AUPRC Sil↑ DBI↓
CPC 71.6 62.6 0.22 1.74 72.4 84.4 0.12 2.20 83.3 94.5 0.24 1.64
Tloss 74.8 59.8 0.14 2.04 73.2 87.6 0.17 1.79 81.5 93.8 0.26 1.30
TNC 74.5 56.3 0.24 1.44 70.3 86.3 0.31 0.94 80.7 93.9 0.36 0.72
SimCLR 82.3 71.5 0.34 1.49 41.5 70.7 0.24 1.47 82.8 94.1 0.35 1.13
ContrNP (ours) 94.2 89.1 0.36 1.35 73.6 89.3 0.38 0.91 84.2 95.4 0.42 0.89
Fully supervised 98.4 81.6 0.43 0.83 86.3 94.8 0.47 0.77 99.9 99.9 0.49 0.80

Requirements [txt]

python>=3.6.9
skorch==0.8
pytorch>=1.3.1
scikit-image
wfdb

About

Implementation of Contrastive Neural Processes in PyTorch

License:MIT License


Languages

Language:Python 66.6%Language:Jupyter Notebook 33.4%