gavinuhma / neuralsort

Code for "Stochastic Optimization of Sorting Networks using Continuous Relaxations", ICLR 2019.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Stochastic Optimization of Sorting Networks via Continuous Relaxations

This repository provides a reference implementation for learning NeuralSort-based models as described in the paper:

Stochastic Optimization of Sorting Networks via Continuous Relaxations
Aditya Grover, Eric Wang, Aaron Zweig and Stefano Ermon.
International Conference on Learning Representations (ICLR), 2019.
Paper: https://openreview.net/pdf?id=H1eSS3CcKX

Requirements

The codebase is implemented in Python 3.7. To install the necessary requirements, run the following commands:

pip install -r requirements.txt

Datasets

The scripts for downloading and loading the MNIST and CIFAR10 datasets are included in the datasets_loader folder. These scripts will be called automatically the first time the main.py script is run.

Options

Learning and inference of differentiable kNN models is handled by the pytorch/main.py script which provides the following command-line arguments:

  --k INT                 number of nearest neighbors
  --tau FLOAT             temperature of sorting operator
  --nloglr FLOAT          negative log10 of learning rate
  --method STRING         one of 'deterministic', 'stochastic'
  --dataset STRING        one of 'mnist', 'fashion-mnist', 'cifar10'
  --num_train_queries INT number of queries to evaluate during training.
  --num_train_neighbors INT number of neighbors to consider during training.
  --num_samples INT       number of samples for stochastic methods
  --num_epochs INT        number of epochs to train
  -resume                 start a new model, instead of loading an older one

Learning and inference of quantile-regression models is handled by the tf/median.py script, which provides the following command-line arguments:

  --M INT                 minibatch size
  --n INT                 number of elements to compare at a time
  --l INT                 number of digits in each multi-mnist dataset element
  --tau FLOAT             temperature (either of sinkhorn or neuralsort relaxation)
  --method STRING         one of 'vanilla', 'sinkhorn', 'gumbel_sinkhorn', 'deterministic_neuralsort', 'stochastic_neuralsort'
  --n_s INT               number of samples for stochastic methods
  --num_epochs INT        number of epochs to train
  --lr FLOAT              initial learning rate

Learning and inference of quantile-regression models is handled by the tf/sort.py script, which provides the following command-line arguments:

  --M INT                 minibatch size
  --n INT                 number of elements to compare at a time
  --l INT                 number of digits in each multi-mnist dataset element
  --tau FLOAT             temperature (either of sinkhorn or neuralsort relaxation)
  --method STRING         one of 'vanilla', 'sinkhorn', 'gumbel_sinkhorn', 'deterministic_neuralsort', 'stochastic_neuralsort'
  --n_s INT               number of samples for stochastic methods
  --num_epochs INT        number of epochs to train
  --lr FLOAT              initial learning rate

Examples

Training dKNN model to classify CIFAR10 digits

cd pytorch
python run_dknn.py --k=9 --tau=64 --nloglr=3 --method=deterministic --dataset=cifar10 -startnew

Training quantile regression model to predict the median of sets of nine 5-digit numbers

cd tf
python run_median.py --M=100 --n=9 --l=5 --method=deterministic_neuralsort

Training sorting model to sort sets of five 4-digit numbers

cd tf
python run_sort.py --M=100 --n=5 --l=4 --method=deterministic_neuralsort

Citing

If you find NeuralSort useful in your research, please consider citing the following paper:

@inproceedings{
grover2018stochastic,
title={Stochastic Optimization of Sorting Networks via Continuous Relaxations},
author={Aditya Grover and Eric Wang and Aaron Zweig and Stefano Ermon},
booktitle={International Conference on Learning Representations},
year={2019},
url={https://openreview.net/forum?id=H1eSS3CcKX},
}

About

Code for "Stochastic Optimization of Sorting Networks using Continuous Relaxations", ICLR 2019.


Languages

Language:Python 100.0%