wnov / nn-compression-simple

a simple NN compression tool using ADMM

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

nn-compression-simple

A simple NN compression tool using ADMM.

Support weight pruning, weight quantization and custom compression operator.

Usage

Just

from admm import YOUR_COMPRESSION_TOOL

Weight Pruning Example

Import:

from admm import ADMM_pruning

Instantiating the class:

admm = ADMM_pruning(model, update_interval=args.admm_update_interval, l1=args.admm_l1)

After loss.backward() , you should:

admm.loss_update(loss)

If you want to mask gradient while finetuning, use:

admm.grad_mask()

Use admm.apply_projW() and admm.restoreW() at the beginning and the end of each model evaluation to get evaluation of the pruned model. Like:

admm.apply_projW()
# Evaluate your model here
admm.restoreW()

Want to finished pruning iteration or want to start finetuning, use:

admm.apply_projW()

to prune model thoroughly.

Custom compression operator

You need to implement a class that inherits from class ADMM. Use your own update() function to define your compression operator. In brief you need to project the weight parameters (or other parameters you want to compress) into your constraint space.

For example, if you want to do pruning and quantization at the same time, you can simply call both update function one after other, which can project the weights to the intersection space of their constraint space.

How it works

About

a simple NN compression tool using ADMM


Languages

Language:Python 100.0%