zuenko / neural-ode

Experiments with Neural Ordinary Differential Equations on image and text classification tasks

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Experiments with Neural Ordinary Differential Equations on image and text classification tasks

For image classification we use ResNet model and MNIST and CIFAR-10 datasets, while for text classifiacation we use VdCNN model and Ag-News dataset.

Requirments

  • PyTorch >= 1.0
  • NumPy

Spiral experiment

Run ODE:

PYTHONPATH=. python ./experiments/spiral-torch.py

Result

spiral

MNIST classification

Run ResNet with 6 blocks:

PYTHONPATH=. python ./experiments/train.py  --data mnist --save ./log_resnet6_mnist --save_every 50 \
--log_every 1 --optimizer sgd --lr 0.1 --num_res 6

Run ResNet with 1 block:

PYTHONPATH=. python ./experiments/train.py  --data mnist --save ./log_resnet1_mnist --save_every 50 \
--log_every 1 --optimizer sgd --lr 0.1 --num_res 1

Run OdeNet with explicit Runge-Kutta solver and tolerance 1e-2:

PYTHONPATH=. python ./experiments/train.py  --data mnist --save ./log_odenet_mnist --save_every 50 \
--log_every 1 --optimizer sgd --lr 0.1 --solver runge_kutta  --tol 1e-2 --use_ode

* another possible option is explicit Euler solver: --solver euler

Results

Test Accuracy Loss
mnist_score mnist_loss
Model Test Error, % # parameters Time (s/epoch)
ResNet(6) 0.34 577 K 13.18
ResNet(1) 0.37 207 K 11.21
OdeNet (Runge-Kutta) 0.45 207 K 254.42

CIFAR-10 classification

Run ResNet with 6 blocks:

PYTHONPATH=. python ./experiments/train.py  --data cifar --save ./log_resnet6_cifar --save_every 50 \
--log_every 1 --optimizer sgd --lr 0.1 --num_res 6

Run ResNet with 1 block:

PYTHONPATH=. python ./experiments/train.py  --data cifar --save ./log_resnet1_cifar --save_every 50 \
--log_every 1 --optimizer sgd --lr 0.1 --num_res 1

Run OdeNet with explicit Runge-Kutta solver and tolerance 1e-2 (may take a lot of time):

PYTHONPATH=. python ./experiments/train.py  --data cifar --save ./log_odenet_cifar --save_every 50 \
--log_every 1 --optimizer sgd --lr 0.1 --use_ode --solver runge_kutta  --tol 1e-2 

Results

Test Accuracy Loss
cifar_score cifar_loss
Model Accuracy, % # parameters Time (s/epoch)
ResNet(6) 86.7 577 K 12.25
ResNet(1) 84.19 207 K 9.84
OdeNet (Runge-Kutta) 84.85 207 K 1860.31
OdeNet (Euler) 84.62 207 K 159.02

Text classification

Download and create Ag-News data:

mkdir .data
mkdir .data/ag_news
cd .data/ag_news
wget https://raw.githubusercontent.com/tothanhtung0205/VDCNN/master/ag_news_csv/test.csv
wget https://raw.githubusercontent.com/tothanhtung0205/VDCNN/master/ag_news_csv/train.csv
echo -e 'World\nSports\nBusiness\nSci/Tech' > classes.txt

Run VdCNN with 6 blocks:

PYTHONPATH='.' python ./experiments/texts/vdcnn.py --batch_size 256 --max_epo 20 --save vdcnn6  

Run VdCNN with 1 block:

PYTHONPATH='.' python ./experiments/texts/vdcnn.py --batch_size 256 --max_epo 20 --save vdcnn1 \
--num_blocks 1

Run OdeNet with explicit Euler solver and tolerance 1e-2 (may take a lot of time):

PYTHONPATH='.' python ./experiments/texts/vdcnn.py --batch_size 256 --max_epo 20 --save vdcnn_ode \
--use_ode --solver euler  --tol 1e-2 

Results

Test Accuracy Loss
text_score text_loss
Model Accuracy, % # parameters Time (s/epoch)
VdCNN(6) 88.46 287 K 311
VdCNN(1) 87.75 162 K 172
OdeNet (Euler) 84.21 162 K 4874

References

Original implementation

About

Experiments with Neural Ordinary Differential Equations on image and text classification tasks

License:MIT License


Languages

Language:Python 100.0%