MPolaris / onnx2tflite

Tool for onnx->keras or onnx->tflite. If tool is useful for you, please star it.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

splitted group convolution

lkdci opened this issue · comments

commented

Hi, thanks for the great work, really cool repository.
I have a question about the grouped convolution implementation. I've noticed in TFGroupConv class that when group != 1 and out_channels, the groups are explicitly splitted and later concatenated which might be less efficient in runtime. What is the reason to this approach? Is there something problematic with group-conv conversion otherwise?

For why "explicitly splitted", because this work is mainly established for tflite, only normal conv and depthwiseconv are surpported in tflite.
For "problematic ", the group-conv conversion is correct, the output of each operator is checked.
For why "this approach", this is how group convolution works, you can find the paper/blog about it.

commented

Hi @MPolaris you are right, the current implementation is indeed equivalent to the torch convolution. Although the same operator can be converted differently in the final tflite graph.
tf.keras.layers.Conv2D does support grouped convs that are not depthwise simply by passing the argument group, see doc.
I've tried to build a model by passing the group argument as follows:

self.conv = keras.layers.Conv2D(
                out_channel_num, kernel_size, strides, "VALID", use_bias=False if bias is None else True,
                kernel_initializer=keras.initializers.Constant(weights),
                bias_initializer='zeros' if bias is None else keras.initializers.Constant(bias),
                dilation_rate=dilations,
                groups=group)

The resulted block looks different when observing the tflite graph. i.e for in_channels=64, group size=16

with split and concat (original):
image

with passing the group arg to Conv2d:
image

I've compared the outputs of the torch and tflite model, the outputs are aligned.

Although the output is the same the num of FLOPs is the same, still runtime compiler can optimize a given operator or layers by many way, such as, fused kernels, num of threads, efficient memory access and more, and the way the graph is defined can affect the performance on the deployed hardware.

To check if there is a difference between the models, i converted regnet_x_400mf from torchvision models, this model use grouped convolution. I benchmarked both tflite models on a Samsung S22 phone, using the oficial tflite binary benchmark file.

adb shell /data/local/tmp/android_aarch64_benchmark_model --graph=/data/local/tmp/merged.tflite --num_threads=4 --use_gpu=true

I found that the new simplified version is a bit faster:

Split & Concat original -> Inference (avg): 7729.22
Conv2d with group -> Inference (avg): 6445.25

What do you think? I can push a PR if you are open to contributor :)

@lkdci Thanks for your work! Maybe I was make a mistake for group conv.
I will check your proposal as soon, and PR is welcome!

@lkdci Hi, I have checked your method, group version is load error and split version is load success on my device.
The group vesion throw a error "Cannot create interpreter: Didn't find op for builtin opcode 'CONV_2D' version '6'".
My runtime environment is special, so I think user can decide which way(split/group) in your PR, and the split version should be default(more robust).
And more, can you provide the environment of your runtime?

commented

@MPolaris , I was able to reproduce your error with an older tflite version.
My compilation env is:

  • ubuntu 18.04
  • torch==1.11.0
  • tensorflow==2.10.0
  • keras==2.10.0
  • onnx==1.12.0

My benchmark env:

  • Galaxy S22, ANdroid OS 12
  • nightly version of the native binary benchmark from here, so tflite is nightly.

As you can see my env is pretty up to date. Not sure when they added support for this convolution version, I'm looking for older version of the benchmark binaries and i've asked tf team for support here.
Once we'll know the version where this op is supported I'll add a clear a documentation from which version this op is supported. I agree that the default should be the split version of grouped convolution.

commented

The merged version is supported only for tflite version >= 2.9.
Tried to PR but looks like I dont have permissions.

remote: Permission to MPolaris/onnx2tflite.git denied to lkdci. unable to access 'https://github.com/MPolaris/onnx2tflite.git/': The requested URL returned error: 403

@lkdci HI, Thanks your issue and work again.
The pipline of PR should be fork this repository to your github at first, then change code and push to the forked repository on your github, and last create pull request, not push code to master branch directly.
Maybe you can reference this blog.
Embarrassing,I'm not familiar with pipline of PR either, and we can learn it together.

@MPolaris , I was able to reproduce your error with an older tflite version. My compilation env is:

  • ubuntu 18.04
  • torch==1.11.0
  • tensorflow==2.10.0
  • keras==2.10.0
  • onnx==1.12.0

My benchmark env:

  • Galaxy S22, ANdroid OS 12
  • nightly version of the native binary benchmark from here, so tflite is nightly.

As you can see my env is pretty up to date. Not sure when they added support for this convolution version, I'm looking for older version of the benchmark binaries and i've asked tf team for support here. Once we'll know the version where this op is supported I'll add a clear a documentation from which version this op is supported. I agree that the default should be the split version of grouped convolution.

Thanks again, I will check it carefully, it's unbelievable for group convolution is not supported on tflite, especially after version 2.0.

commented

Yeah it's a basic layer that was supposed to be supported way before. I'm happy they did eventually.
Thanks for the instructions, i'm not an expert as well. opened a PR #21 , and again thanks for your contribution. I mostly develop with pytorch and deploying models to tflite was always a mess. I can tell from my experience with other libraries like pytorch2keras and onnx_tf, that this repo as far i know solve those conversions problems the best.

Yeah it's a basic layer that was supposed to be supported way before. I'm happy they did eventually. Thanks for the instructions, i'm not an expert as well. opened a PR #21 , and again thanks for your contribution. I mostly develop with pytorch and deploying models to tflite was always a mess. I can tell from my experience with other libraries like pytorch2keras and onnx_tf, that this repo as far i know solve those conversions problems the best.

Thanks for your nice try, PR is merged.
This repo is build for solving some effectiveness problem of onnx_tf, mainly aim for 2D-CNN model, model from onnx_tf is slower and weird(Input of conv layer is NCHW and then transpose to NHWC).
I'm not an expert too, the code of repo is also 'simple'(immature?), and last very happy to see that repo can help people like you!