nnstreamer / nntrainer

NNtrainer is Software Framework for Training Neural Network Models on Devices.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Select ops are not supported in TFlite interpreter

KirillP2323 opened this issue · comments

I'm running a SimpleShot app with our custom ViT-based tflite backbone. I get the following error:

INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
ERROR: Select TensorFlow op(s), included in the given model, is(are) not supported by this interpreter. Make sure you apply/link the Flex delegate before inference. For the Android, it can be resolved by adding "org.tensorflow:tensorflow-lite-select-tf-ops" dependency. See instructions: https://www.tensorflow.org/lite/guide/ops_select
ERROR: Node number 1108 (FlexErf) failed to prepare.
ERROR: Select TensorFlow op(s), included in the given model, is(are) not supported by this interpreter. Make sure you apply/link the Flex delegate before inference. For the Android, it can be resolved by adding "org.tensorflow:tensorflow-lite-select-tf-ops" dependency. See instructions: https://www.tensorflow.org/lite/guide/ops_select
ERROR: Node number 1108 (FlexErf) failed to prepare.
terminate called after throwing an instance of 'std::runtime_error'
  what():  Failed to allocate tensors!
Aborted (core dumped)

Here is our code where we transform our model to tflite with select ops:

converter = tf.lite.TFLiteConverter.from_saved_model('onnx_model_tf/')
converter.target_spec.supported_ops = [
  tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
  tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
tflite_model = converter.convert()

:octocat: cibot: Thank you for posting issue #2133. The person in charge will reply soon.

@KirillP2323

above

ERROR: Select TensorFlow op(s), included in the given model, is(are) not supported by this interpreter. Make sure you apply/link the Flex delegate before inference. For the Android, it can be resolved by adding "org.tensorflow:tensorflow-lite-select-tf-ops" dependency. See instructions: https://www.tensorflow.org/lite/guide/ops_select
ERROR: Node number 1108 (FlexErf) failed to prepare.
ERROR: Select TensorFlow op(s), included in the given model, is(are) not supported by this interpreter. Make sure you apply/link the Flex delegate before inference. For the Android, it can be resolved by adding "org.tensorflow:tensorflow-lite-select-tf-ops" dependency. See instructions: https://www.tensorflow.org/lite/guide/ops_select
ERROR: Node number 1108 (FlexErf) failed to prepare.

seems Tensorflow Error, Not an NNTrainer Error
You can follow bellow build and inference method
https://www.tensorflow.org/lite/guide/ops_select

Does it work normally on your local computer?
(Run Inference in native C++ with tensorflow lite custom ops, only tensorflow except nntrainer)
--> For using custom ops you need to build with custom ops

@DonghakPark Thanks for the pointers.
Currently I use my .tflite model as it is, without rebuilding on C++ side, can you let me know how to do the rebuild so it is interpreted properly by NNtrainer?
Also, from the guide https://www.tensorflow.org/lite/guide/ops_select, it looks like I need to link Flex library to enable custom ops for the interpreter, which I assume is build on NNtrainer side. Is there a way to pass some parameters to the interpreter build?

@DonghakPark Thanks for the pointers.
Currently I use my .tflite model as it is, without rebuilding on C++ side, can you let me know how to do the rebuild so it is interpreted properly by NNtrainer?
Also, from the guide https://www.tensorflow.org/lite/guide/ops_select, it looks like I need to link Flex library to enable custom ops for the interpreter, which I assume is build on NNtrainer side. Is there a way to pass some parameters to the interpreter build?

I will test on my local environment ASAP and i will share with you!

Are you using tf_funtion(with signature) on custom ops?

@DonghakPark
We are using an architecture based on this: https://github.com/hushell/pmf_cvpr22/blob/main/models/utils.py#L136
The python TF version already has all custom ops, so other than

converter.target_spec.supported_ops = [
  tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
  tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]

we don't do anything.

@DonghakPark
Actually, we might be able to use the model without erf function. Let me check and I'll come back to you.

@KirillP2323
Ok, if you have any change Please Let me know

We changed our backbone so it doesn't use the Erf function, and no custom ops are required. However, we encountered another problem while running SimpleShot with our backbone:

$ ./simple_shot vit_dino_no_erf UN dog_5way_5shot_train.dat dog_5way_5shot_test.dat
...
terminate called after throwing an instance of 'std::invalid_argument'
  what():  Input dimensions mismatch
Aborted (core dumped)

After some digging, I found that the error is thrown during this step of model initialization: https://github.com/nnstreamer/nntrainer/blob/main/nntrainer/graph/network_graph.cpp#L878 ->
https://github.com/nnstreamer/nntrainer/blob/main/nntrainer/graph/network_graph.cpp#L711 ->
https://github.com/nnstreamer/nntrainer/blob/main/nntrainer/layers/layer_node.cpp#L570

So it's getting stuck on layer->finalize() function. I saw that there are many checks on input dimension before that line, but they all seem to pass. Can you suggest why do I get a dimension mismatch error there?

We changed our backbone so it doesn't use the Erf function, and no custom ops are required. However, we encountered another problem while running SimpleShot with our backbone:

$ ./simple_shot vit_dino_no_erf UN dog_5way_5shot_train.dat dog_5way_5shot_test.dat
...
terminate called after throwing an instance of 'std::invalid_argument'
  what():  Input dimensions mismatch
Aborted (core dumped)

After some digging, I found that the error is thrown during this step of model initialization: https://github.com/nnstreamer/nntrainer/blob/main/nntrainer/graph/network_graph.cpp#L878 -> https://github.com/nnstreamer/nntrainer/blob/main/nntrainer/graph/network_graph.cpp#L711 -> https://github.com/nnstreamer/nntrainer/blob/main/nntrainer/layers/layer_node.cpp#L570

So it's getting stuck on layer->finalize() function. I saw that there are many checks on input dimension before that line, but they all seem to pass. Can you suggest why do I get a dimension mismatch error there?

@KirillP2323

As Error Message say, it look like dimension issue
so i will suggest some check points

  1. NNTrainer Native Application has NCHW dimensions
    -> please check your .dat file
  2. our Dataloader consist below format
    --> for example in MNIST DATA SET with 28x28 input image size
    --> then MNIST dimension will be 1X1X28X28 (NCHW)
    --> and in .dat file that will be [28x28 image pixels with NCHW format][label] in sequencially
    --> [0~256 data in size (28x28)][label data in size (10)]

please check this and if you still have problem with example please let me know

@DonghakPark Thanks for the suggestions!
My data is in 3x228x228 format, I double-checked that it is in the right format you described above.
The issue still persists, and I was thinking that's because of the mismatched input shape or batch size here: https://github.com/nnstreamer/nntrainer/blob/main/Applications/SimpleShot/task_runner.cpp#L131, I changed it and that didn't help.
Let me know what else I can try.

yes. if you have to change the input to 3x228x228, then the input_shape needs to be 228:228:3. It needs to be NHWC because tflite uses NHWC format. If your backbone model was tested NCHW, then you can define input_shape as 3:228:228 in NCHW format. NNTrainer compare the input dimension provided by tflite API, interpter->inputs().

void TfLiteLayer::setDimensions(const std::vector<int> &tensor_idx_list,

The error message is due the mismatch between tflite dimension and nntrainer input dimension.

what is the error message? could you share it with us? does it produce the same error?

The error message is the same, unfortunately:

$ ./simple_shot vit_dino_no_erf UN dog_5way_5shot_train.dat dog_5way_5shot_test.dat
================================================================================
          Layer name          Layer type     Input dimension         Input layer
================================================================================
            backbone     backbone_tflite                                        
--------------------------------------------------------------------------------
                 knn        centroid_knn                                        
================================================================================
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
terminate called after throwing an instance of 'std::invalid_argument'
  what():  Input dimensions mismatch
Aborted

There is a log from log_nntrainer_***.out, in case it's useful:

...
[NNTRAINER DEBUG 2023-03-20 10:52:14] (/home/CORP/k.paramonov/projects/nnbuilder/nntrainer/nntrainer/app_context.cpp:registerFactory:609) factory has registered with key: time_dist, int_key: 503
[NNTRAINER DEBUG 2023-03-20 10:52:14] (/home/CORP/k.paramonov/projects/nnbuilder/nntrainer/nntrainer/app_context.cpp:registerFactory:609) factory has registered with key: unknown, int_key: 999
[NNTRAINER DEBUG 2023-03-20 10:52:14] (/home/CORP/k.paramonov/projects/nnbuilder/nntrainer/nntrainer/app_context.cpp:getConfig:112) [AppContext]  conf path: /usr/local/etc/nntrainer.ini
[NNTRAINER DEBUG 2023-03-20 10:52:14] (/home/CORP/k.paramonov/projects/nnbuilder/nntrainer/nntrainer/app_context.cpp:getPluginPaths:175) DEFAULT CONF PATH, path: /usr/local/lib/x86_64-linux-gnu/nntrainer/layers
[NNTRAINER WARN  2023-03-20 10:52:14] (/home/CORP/k.paramonov/projects/nnbuilder/nntrainer/nntrainer/app_context.cpp:add_extension_object:360) tried to register extension from /usr/local/lib/x86_64-linux-gnu/nntrainer/layers but failed, reason: [AppContext] failed to open the directory: /usr/local/lib/x86_64-linux-gnu/nntrainer/layers
[NNTRAINER DEBUG 2023-03-20 10:52:14] (/home/CORP/k.paramonov/projects/nnbuilder/nntrainer/nntrainer/app_context.cpp:registerFactory:609) factory has registered with key: centering, int_key: 42
[NNTRAINER INFO  2023-03-20 10:52:14] (/home/CORP/k.paramonov/projects/nnbuilder/nntrainer/nntrainer/compiler/previous_input_realizer.cpp:realize:63) knn is identified as a non-input node and default input layer(backbone) is being set 
[NNTRAINER DEBUG 2023-03-20 10:52:14] (/home/CORP/k.paramonov/projects/nnbuilder/nntrainer/nntrainer/models/neuralnet.cpp:initialize:208) initializing neural network, layer size: 2
[NNTRAINER DEBUG 2023-03-20 10:52:14] (/home/CORP/k.paramonov/projects/nnbuilder/nntrainer/nntrainer/graph/network_graph.cpp:initialize:858) layer name : backbone

we add code to produce more information about the Input dimension mismatch as in #2151 , Could you run your application once again to see which dimension does not match.

@jijoongmoon Great, that helped, thank you! I changed the input_shape to be 3:224:224, and changed the batch size of my tflite model to 1, and it successfully ran.
One problem I noticed while debugging is that changing batch_size on https://github.com/nnstreamer/nntrainer/blob/main/Applications/SimpleShot/task_runner.cpp#L127 doesn't fix the batch size mismatch issue, since I'm getting an error:

terminate called after throwing an instance of 'std::invalid_argument'
  what():  Input dimensions mismatch -> 0:Shape: 16:3:224:224
 Shape: 1:3:224:224

Aborted

So the only option to match the dimension is to change the batch size of my tflite model from 16 to 1.

Now I'm having another problem: the accuracy of the model is 20%, which is much lower than expected. I'll work on debugging the issue. Closing the bug, since initial issue is resolved.

have you tried to change the batch size in                                       {"batch_size=1", "epochs=1"}); @task_runner.cpp ?

@jijoongmoon Yes, I tried that, and it didn't change the 1:3:224:224 dimensions for the input.

Hello, I'm opening the issue because our model needs the tf.Erf function to perform with good accuracy.
Could you add that function in TFlite interpreter for NNtrainer?

For NNtrainer logs, you can check comment 1, and here are logs from my tflite conversion:

2023-04-13 10:42:54.753849: W tensorflow/compiler/mlir/lite/flatbuffer_export.cc:1901] TFLite interpreter needs to link Flex delegate in order to run the model since it contains the following Select TFop(s):
Flex ops: FlexErf
Details:
        tf.Erf(tensor<1x197x1536xf32>) -> (tensor<1x197x1536xf32>) : {device = ""}
See instructions: https://www.tensorflow.org/lite/guide/ops_select

@KirillP2323 OK then I have some question

  1. Are you Using X86_64 Computer?
  2. Are you Using Linux based OS?

@DonghakPark Yes, x86_64 and Ubuntu 20.04

@KirillP2323
If you have any question -> Please Check #2193

Great, the application runs with Gelu model after following the instructions in #2193. Is it possible to add support for this to the installation process on the main branch?

Great, the application runs with Gelu model after following the instructions in #2193. Is it possible to add support for this to the installation process on the main branch?

Yes. I will make some scripts with meson option

I will Close this Issue