walsvid / VGG-TensorFlow

VGGNet with TensorFlow

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

VGG-TensorFlow

VGGNet re-implement with TensorFlow.

The purpose of this project

Focus on the structure of the project, The current code can be properly trained. Next I will continue to improve the readability of the code and the structure of the project will also improve the accuracy of the VGG network training.

Current situation

This repository will use object-oriented programming as much as possible to make the machine learning code structure clearer. So far I have implemented data loader and processing configuration classes and implemented VGG16 network class. In addition, the current code can be visualized using tensorboard.

  • data loader
  • config file
  • base network class
  • vgg16 network class
  • tensorboard
  • test code

Usage

Dataset

This repository will support data in a variety of formats. Please inherit Dataloader class, write the import module of the dataset you need.

CIFAR10 Binary Dataset

Up to now it supports CIFAR10 data read in binary format.

In order to ensure the correct training, please organize the structure of the data as follows.

data
├── datasets
│   └── cifar-10-batches-bin
└── imagenet_models
    └── vgg16.npy

Training

All configuration of data and training can be modified in the file. Use the more readable yaml file format as the configuration file.

For more details, please refer to experiments/configs/vgg16.yaml

Simply run this script:

python ./training.py

Recommend using python virtual environment to train.

Test

Simply run this script to view the detail of arguments:

python ./test.py --help

To test image please run

python ./test.py --im <your_image_path>

You can modify the checkpoint path im experiments/configs/vgg16.yaml, please reference TEST.MODEL_PATH for more detail.

Images

loss and accuracy

screenshot

Test image, please resize image to the same size as CIFAR-10.

The test result of image above, the classification result is automobile, and the probability is 99.96%.

graph

Please notice that the current Tensorboard shows that Graph is not beautified. Due to the low version of TensorFlow, using variable_scope when loading pre-trained weights will cause the creation of duplicate variable_scope. If you need to beautify Graph, please select a version of tensorflow >=1.6.0 or use a temporary hack as below:

def load_with_skip(self, data_path, session, skip_layer):
    data_dict = np.load(data_path, encoding='latin1').item()  # type: dict
    for key in data_dict.keys():
        if key not in skip_layer:
            with tf.variable_scope(self.scope[key], reuse=True) as scope:
                with tf.name_scope(scope.original_name_scope):
                    for subkey, data in zip(('weights', 'biases'), data_dict[key]):
                        session.run(tf.get_variable(subkey).assign(data))

If your tensorflow >=1.6.0, please notice the solution in the current vgg16.py file, that is, use auxiliary_name_scope=False as the default parameter.

The netscope of VGG-16, visualized by tensorboard.

Release

v0.0.5

About

VGGNet with TensorFlow

License:MIT License


Languages

Language:Python 100.0%