GuillaumeMougeot / tensorflow_resnet_cifar10

Proper tensorflow implementation of ResNet-s for CIFAR10 dataset corresponding to the original paper.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

tensorflow_resnet_cifar10

This repositories contains an implementation of the original ResNet paper with tensorflow 2 and keras on the CIFAR10 dataset.

Requirements

Before running the training be sure the following python libraries are installed:

  • tensorflow 2.3.1 or higher
  • matplotlib 3.3.2 or higher
  • numpy 1.18.5 or higher
  • scikit-image 0.16.2 or higher

Run

Before training, the CIFAR10 dataset needs to be converted into tfrecord files. To do so, please use the following command by replacing the path/to/cifar10 with the appropriate location:

python prepare_data.py --data_path='/path/to/cifar10'

The training of all the resnets can be run with:

python train.py

If you want to train only a particular ResNet or change the training hyperparameters, please edit the global variables defined in the beginning of train.py.

Logs

During training, this implementation will store regularly:

  • the keras model
  • the tensorboard logs
  • images of the model predictions on a batch of test samples

Performance

The performances below were obtained by doing only one run on all the model and taking the best test error during training. With model selection, the test errors should undoubtedly improve.

Name # layers # params Test err(paper) Test err(this impl.)
ResNet20 20 0.27M 8.75% 8.68%
ResNet32 32 0.46M 7.51% 7.69%
ResNet44 44 0.66M 7.17% 7.31%
ResNet56 56 0.85M 6.97% 7.04%
ResNet110 110 1.7M 6.43% 6.75%
ResNet120 1202 19.4M 7.93% 7.33%

Acknowledgement

This code is inspired by the two following repositories:

About

Proper tensorflow implementation of ResNet-s for CIFAR10 dataset corresponding to the original paper.

License:MIT License


Languages

Language:Python 100.0%