ranjiewwen / TF_cifar10

The cifar10 classification project completed by tensorflow, including complete training, prediction, visualization, independent of each module of the project, and convenient expansion.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Cifar10 Tensorflow Project

Get Started

  • environment: tensorflow-gpu1.8+cude9.0
  • datasets from kaggle : CIFAR-10 - Object Recognition in Images, first you can download the train and test dataset.
  • then use the utils/get_data_list.py and utils/get_dataset_mean.py scripts to generate train.txt and val.txt.

How to Learn this project

  • one step: you can modify trian parameters in config/cifar10_config.json.
  • two step: you can learn how load datasets before training from src/datasets/cifar10_dataloader.py.
  • three step: you can learn how to write network from src/models/layers and src/models/simple_model.py, you can easily create you own model.
  • four step: you should finish trian scripts tools/train_cifar10.py, in this process you will finish loss function and metric funtion:src/loss/cross_entropy.py and src/metrics/acc_metric.py; in this scripts tools/train_cifar10.py, we will first create graph and then run session. at the same time, we will record train models and use tensorboard to visual loss and accuracy in experiments/chekpoint and experiment/summary folder.
  • five step: you can run train scripts:tools/train_cifar10.py.
  • six step: when you get train model, you can predict image and get class name in demo/prdict.py.
  • seven step: you can also get some extra information from demo/visual.py, such as weights or visual feature map.
  • other: you can fimilar how to use some tool function in tools/utils.py.

The optimization process

  • The detailed information you can get from there.
  • run scripts tools/trian_cifar10.py include adjust lr , add data augmentation ,add dropout ,weight decay,stack 3*3 conv training tricks. you can learn how train model acc from 70%+ to 91+%, while add model depth through conv4_1 and conv4_2 it can not imporve val acc.
  • run scripts tools/trian_cifar10_v2.py include add batch_norm, we can see it make the training more unstable, maybe it not imporve val acc, while stack 3*3 conv it can improve val acc remarkable.
  • run scripts tools/fintune_cifar10.py. it frist load imagenet pretrain weights and then finetune resnet50.

Reference

finetune

tiny-imagenet

mnist

other

About

The cifar10 classification project completed by tensorflow, including complete training, prediction, visualization, independent of each module of the project, and convenient expansion.


Languages

Language:Python 100.0%