MPolaris / onnx2tflite

Tool for onnx->keras or onnx->tflite. If tool is useful for you, please star it.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

ONNX->Keras and ONNX->TFLite tools

Welcome

If you have some good ideas, welcome to discuss or give project PRs.

How to use

pip install -r requirements.txt
# base
python converter.py --weights "./your_model.onnx"

# give save path
python converter.py --weights "./your_model.onnx" --outpath "./save_path"

# save tflite model
python converter.py --weights "./your_model.onnx" --outpath "./save_path" --formats "tflite"

# save keras and tflite model
python converter.py --weights "./your_model.onnx" --outpath "./save_path" --formats "tflite" "keras"

# cutoff model, redefine inputs and outputs, support middle layers
python converter.py --weights "./your_model.onnx" --outpath "./save_path" --formats "tflite" --input-node-names "layer_inputname" --output-node-names "layer_outname1" "layer_outname2"

# quantify model weight, only weight
python converter.py --weights "./your_model.onnx" --formats "tflite" --weigthquant

# quantify model weight, include input and output
## recommend
python converter.py --weights "./your_model.onnx" --formats "tflite" --int8 --imgroot "./dataset_path" --int8mean 0 0 0 --int8std 255 255 255
## generate random data, instead of read from image file
python converter.py --weights "./your_model.onnx" --formats "tflite" --int8

Features

  • High Consistency. Compare to ONNX outputs, average error less than 1e-5 per elements.
  • More Faster. Output tensorflow-lite model 30% faster than onnx_tf.
  • Auto Channel Align. Auto convert pytorch format(NCWH) to tensorflow format(NWHC).
  • Deployment Support. Support output quantitative model, include fp16 quantization and uint8 quantization.
  • Code Friendly. I've been trying to keep the code structure simple and clear.

Pytorch -> ONNX -> Tensorflow-Keras -> Tensorflow-Lite

  • From torchvision to tensorflow-lite

import torch
import torchvision
_input = torch.randn(1, 3, 224, 224)
model = torchvision.models.mobilenet_v2(True)
# use default settings is ok
torch.onnx.export(model, _input, './mobilenetV2.onnx', opset_version=11)# or opset_version=13

from converter import onnx_converter
onnx_converter(
    onnx_model_path = "./mobilenetV2.onnx",
    need_simplify = True,
    output_path = "./",
    target_formats = ['tflite'], # or ['keras'], ['keras', 'tflite']
    weight_quant = False,
    int8_model = False,
    int8_mean = None,
    int8_std = None,
    image_root = None
)
  • From custom pytorch model to tensorflow-lite-int8

import torch
import torch.nn as nn
import torch.nn.functional as F

class MyModel(nn.Module):
    def __init__(self):
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )
    
    def forward(self, x):
        return self.conv(x)

model = MyModel()
model.load_state_dict(torch.load("model_checkpoint.pth", map_location="cpu"))

_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, _input, './mymodel.onnx', opset_version=11)# or opset_version=13

from converter import onnx_converter
onnx_converter(
    onnx_model_path = "./mymodel.onnx",
    need_simplify = True,
    output_path = "./",
    target_formats = ['tflite'], #or ['keras'], ['keras', 'tflite']
    weight_quant = False,
    int8_model = True, # do quantification
    int8_mean = [123.675, 116.28, 103.53], # give mean of image preprocessing 
    int8_std = [58.395, 57.12, 57.375], # give std of image preprocessing 
    image_root = "./dataset/train" # give image folder of train
)

Validated models


Add operator by yourself

When you counter unspported operator, you can choose to add it by yourself or make an issue.
It's very simple to implement a new operator parser by following these steps below.
Step 0: Select a corresponding layer code file in layers folder, such as activations_layers.py for 'HardSigmoid'.
Step 1: Open it, and edit it:

# all operators regist through OPERATOR register.
# regist operator's name is onnx operator name. 
@OPERATOR.register_operator("HardSigmoid")
class TFHardSigmoid():
    def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs) -> None:
        '''
        :param tensor_grap: dict, key is node name, value is tensorflow-keras node output tensor.
        :param node_weights: dict, key is node name, value is static data, such as weight/bias/constant, weight should be transfom by dimension_utils.tensor_NCD_to_NDC_format at most time.
        :param node_inputs: List[str], stored node input names, indicates which nodes the input comes from, tensor_grap and node_weights are possible.
        :param node_attribute: dict, key is attribute name, such as 'axis' or 'perm'. value type is indeterminate, such as List[int] or int or float. notice that type of 'axis' value should be adjusted form NCHW to NHWC by dimension_utils.channel_to_last_dimension or dimension_utils.shape_NCD_to_NDC_format.
        '''
        super().__init__()
        self.alpha = node_attribute.get("alpha", 0.2)
        self.beta = node_attribute.get("beta", 0.5)

    def __call__(self, inputs):
        return tf.clip_by_value(self.alpha*inputs+self.beta, 0, 1)

Step 2: Make it work without error.
Step 3: Convert model to tflite without any quantification.

TODO

  • support Transofomer, VIT\Swin Trasnformer etc...
  • support cutoff onnx model and specify output layer
  • optimize comfirm_acc.py(removed, The output checker will run automatically.)

Limitation

  • The number of operators can not cover all models.
  • Friendly to 1D/2D vision CNN, and not support 3D CNN.
  • Bad support for some math or channel change operators(such as Squeeze\MatMul).

Emmmmmmm

It's too disgusting for first(batch) or second(channel) axis change. There are always circumstances that have not been taken into account.

License

This software is covered by Apache-2.0 license.

About

Tool for onnx->keras or onnx->tflite. If tool is useful for you, please star it.

License:Apache License 2.0


Languages

Language:Python 100.0%