abdelrahman-gaber / Classification-AutoEncoder

The aim of this project is to train autoencoder, and use the trained weights as initialization to improve classification accuracy with cifar10 dataset.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Classification-AutoEncoder

The aim of this project is to train an autoencoder network, then use its trained weights as initialization to improve classification accuracy with cifar10 dataset. This is a kind of transfer learning where we have pretrained models using the unsupervised learning approach of auto-encoders. Final classification model achieved accuracy of 87.33%.

How to run

Download and prepare cifar10 dataset:

  cd data/
  ./gen_cifar10.sh

This script will download cifar10 dataset, then split it to folders as needed for Keras flow_from_directory() function. The Autoencoder and Classification scripts here can work with any dataset, just remember to set input image size, number of classes, .. etc to match your data.

Train Auto-encoder network

To train the autoencoder, just run python Train_Autoencoder.py, or you can download the trained model from here

The autoencoder network in this project is trained to be an initialization for the classification network. If you want to get better output images, consider removing the fully connected layer.

Alt text

Train Classification network

To start training the classifier, run python Train_Classifier.py, or you can download the trained model from here

Classification model evaluation

Finally, run python Evaluate_Classifier_cifar10.py to evaluate the trained classification model.

The model accuracy is 87.33%, and the confusion matrix is shown below.

Alt text

The evaluation script here is done for cifar10 dataset, but it can be easily modified to work with any dataset you have.

Requirements

  1. Python 3.5 or higher
  2. Keras with tensorflow version 1.9 or higher (tested with 1.9, 1.12, and 1.15)

Have fun ^_^

About

The aim of this project is to train autoencoder, and use the trained weights as initialization to improve classification accuracy with cifar10 dataset.


Languages

Language:Python 98.0%Language:Shell 2.0%