OhadRubin / Lightweighting_Cookbook

This project attempts to build neural network training and lightweighting cookbook including three kinds of lightweighting solutions, i.e., knowledge distillation, filter pruning, and quantization.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Training and Lightweighting Cookbook in JAX/FLAX

Introduction

  • This project attempts to build neural network training and lightweighting cookbook including three kinds of lightweighting solutions, i.e., knowledge distillation, filter pruning, and quantization.
  • It will be a quite long term project, so please get patient and keep watching this repository 🤗.

Requirements

  • jax
  • flax
  • tensorflow ( to download CIFAR dataset )

Key features

Knowledge distillation | Filter pruning

Basic training framework in JAX/FLAX

How to use

  1. Move to the codebase.
  2. Train and evaluate our model by the below command.
  # ResNet-56 on CIFAR10
  python train.py --gpu_id 0 --arch ResNet-56 --dataset CIFAR10 --train_path ~/test
  python test.py --gpu_id 0 --arch ResNet-56 --dataset CIFAR10 --trained_param pretrained/res56_c10

Experimental comparison with other common deep learning libraries, i.e., Tensorflow2 and Pytorch

  • Hardware: GTX 1080

  • Tensorflow implementation [link]

  • Pytorch implementation [link]

  • In order to check only training time except for model and data preparation, training time is calculated from the second to the last epoch.

  • Note that Accuracy on CIFAR dataset has a quite large variance
    so that you should focus on another metrics, i.e., training time.

  • As you can notice, JAX and TF are much faster than Pytorch because of JIT compiling.

Library Accuracy Time (m)
JAX 93.98 54
TF 93.91 53
Pytorch 93.80 69


TO DO

  • Basic training and test framework

    • Dataprovider in JAX
    • Naive training framework
    • Monitoring by Tensorboard
    • Profiling addons
    • Enlarge model zoo including HuggingFace pre-trained models
  • Knowledge distillation framework

    • Basic framework
    • Off-line distillation
    • On-line distillation
    • Self distillation
    • Enlarge the distillation algorithm zoo
  • Filter pruning framework

    • Basic framework
    • Criterion-based pruning
    • Search-based pruning
    • Enlarge filter pruning algorithm zoo
  • Quantization framework

    • Basic framework
    • Quantization aware training
    • Post Training Quantization
    • Enlarge quantization algorithm zoo
  • Tools for handy usage.

About

This project attempts to build neural network training and lightweighting cookbook including three kinds of lightweighting solutions, i.e., knowledge distillation, filter pruning, and quantization.


Languages

Language:Python 100.0%