tensorflow / tensorflow

An Open Source Machine Learning Framework for Everyone

Home Page:https://tensorflow.org

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Android: NoOpKernel was registered to support Op 'TensorflowArray' DynamicRNN

lingz opened this issue · comments

I'm trying to load a graph I trained in python, which includes only using dynamicRNN, and the most basic RNN cell.

When I try execute the graph on the Android, I get the following error message:

Invalid argument: No OpKernel was registered to support Op 'TensorArray' with these attrs
     [[Node: c96ce52cccd5408aad2fb356a8246023/RNN/TensorArray_1 = TensorArray[clear_after_read=true, dtype=DT_FLOAT, dynamic_size=false, tensor_array_name="c96ce52cccd5408aad2fb356a8246023/RNN/dynamic_rnn/input"](c96ce52cccd5408aad2fb356a8246023/RNN/unpack)]]

Environment info

Operating System:
Ubuntu x64 14.04
Android 6.0

Installed version of CUDA and cuDNN:
(please attach the output of ls -l /path/to/cuda/lib/libcud*):

If installed from binary pip package, provide:

  1. Which pip package you installed.
  2. The output from python -c "import tensorflow; print(tensorflow.__version__)".

If installed from source, provide

  1. The commit hash (git rev-parse HEAD)
    tag v0.9.0
  2. The output of bazel version
    0.3.0

Steps to reproduce

  1. Create RNN model in python
  2. Compile it with tensorflow_android_lib targeting armv7
  3. Try run the model on android

What have you tried?

Logs or other output that would be helpful

(If logs are large, please upload as attachment).

It seems it that the RNN ops are default not included in the android_extended_ops_group under /tensorflow/cc/kernel/BUILD.

a hack is adding the operations manually but the more correct way seems to use selective_registration.
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/selective_registration.h

EDIT: Easy regex to automatically list all operations in your graph, as long as you have it dumped in text format:

grep "op: " PATH/TO/mygraph.txt | sort | uniq | sed -E 's/^.+"(.+)".?$/\1/g'

And on the same note, is there an example of how to use selective_registration, in particular how to differentiate between SHOULD_REGISTER_OP_KERNEL SHOULD_REGISTER_OP and SHOULD_REGISTER_OP_GRADIENT

So long as the op builds and runs correctly, it's fine to add it to the default ops file group (we've just been adding them on an as-needed basis now that we have a way to manage binary size).

You don't need to use selective registration yourself unless you are prioritizing binary-size -- it will cut the final .so size essentially in half to result in < 1mb compressed for a typical graph.

@cwhipkey do you know if there is a public example of this yet?

Okay, but it would be nice to be able to do selective registration eventually, so I can get the minimal size for my own graph right.

I tried to do it on my own, and replaced the android_cc_library, to depend on android_all_ops instead. Like this:

cc_library(
    name = "android_tensorflow_kernels",
    srcs = select({
        "//tensorflow:android": [
            "//tensorflow/core/kernels:android_all_ops",
        ],
        "//conditions:default": [],
    }),
    copts = tf_copts(),
    tags = [
        "manual",
        "notap",
    ],
    visibility = ["//visibility:public"],
    deps = [
        "//tensorflow/core:android_tensorflow_lib_lite",
        "//tensorflow/core:protos_cc",
        "//third_party/eigen3",
    ],
    alwayslink = 1,
)

But then I got the following error during compilation:

external/eigen_archive/eigen-eigen-d02e6a705c30/unsupported/Eigen/CXX11/../../../Eigen/src/Core/util/StaticAssert.h:32:40: error: static assertion failed: THIS_TYPE_IS_NOT_SUPPORTED
     #define EIGEN_STATIC_ASSERT(X,MSG) static_assert(X,#MSG);
                                        ^
external/eigen_archive/eigen-eigen-d02e6a705c30/unsupported/Eigen/CXX11/../../../Eigen/src/Core/SpecialFunctions.h:976:9: note: in expansion of macro 'EIGEN_STATIC_ASSERT'
         EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),

This is my main build file:

cc_binary(
    name = "myapp",
    srcs = [
        "main.cc",
        "ops_to_register.h"
    ],
    copts = [
        "-fno-exceptions",
        "-DEIGEN_AVOID_STL_ARRAY",
        "-DSELECTIVE_REGISTRATION",
        "-mfpu=neon", # This one for ARM only
        "-std=c++11",
        "-DMIN_LOG_LEVEL=0",
        "-DTF_LEAN_BINARY",
        "-O2",
        "-fPIE",
    ],
    linkopts = [
        "-landroid",
        "-llog",
        "-lm",
        "-z defs",
        "-s",
        "-Wl,--icf=all",  # Identical Code Folding
        "-Wl,--exclude-libs,ALL",  # Exclude syms in all libs from auto export
        "-pie",
    ],
    deps = [
        "@org_tensorflow//tensorflow/core:android_tensorflow_lib",
    ],
)

And finally my ops_to_register.h looks like this:

#ifndef METALANG_OPS_TO_REGISTER
#define METALANG_OPS_TO_REGISTER

#include "tensorflow/core/framework/selective_registration.h"

SHOULD_REGISTER_OP("Add")
SHOULD_REGISTER_OP("AddN")
SHOULD_REGISTER_OP("All")
SHOULD_REGISTER_OP("ApplyAdam")
SHOULD_REGISTER_OP("ArgMax")
....
....

#endif

Any ideas as to why it might be failing to compile with EIGEN?

commented

In my case op 'Div' is already included in tensorflow/contrib/makefiletf_op_files.txt but I getting error https://github.com/tensorflow/tensorflow/issues/3546. I think it is probably depending on context.

Adding links to #3543 and #3546 in case they all have a related root cause.

Regarding the question: "@cwhipkey do you know if there is a public example of this yet?", no there is no public example yet of creating a selective-registration header.

commented

Re: @lingz : I was able to compile with selective registration, but it didn't help reduce the file size for me. Instead, I ended up using the makefile from tensorflow/contrib/makefile, removing all the unused ops classes from tf_op_files.txt.

For what it's worth, this is what I had in my ops_to_register.h:

const bool kRequiresSymbolicGradients = false;
const char* kNecessaryOpKernelClasses = 
    ",Add,"
    ",Const,"
    ...
    ",Softmax,";


constexpr bool ShouldRegisterOp(const char* name) {
    return strcmp(name,"Add") ? true :
        strcmp(name,"Const") ? true :
        ...
        strcmp(name,"Softmax") ? true :
        false;
}

The SHOULD_REGISTER_OP macro from tensorflow/core/framework/selective_registration.h calls ShouldRegisterOp, which returns true if the op should be registered.
The SHOULD_REGISTER_OP_KERNEL macro does a string search in kNecessaryOpKernelClasses

For why the header didn't make the size smaller, I think there are a few
possibilities to check:

  1. selective registration can only make code deleted if the expressions can
    be evaluated as constexprs. Using const char* instead of const char[] on
    the definition of kNecessaryOpKernelClasses may be enough to prevent this.
    But even with that, I would have expected some drop by the const bool
    kRequiresSymbolicGradients = false; line. Was it no change at all in
    size, or a small change?
  2. doublecheck that SELECTIVE_REGISTRATION macro is defined for calls to
    compile the ops and kernels.
  3. what compiler are you using? Maybe compiler differences could explain
    it as well.

Here's an example ops_to_register.h that gets generated by the tool that
will be opensourced soon:

#ifndef OPS_TO_REGISTER
#define OPS_TO_REGISTER
constexpr inline bool ShouldRegisterOp(const char op[]) {
return false
|| (strcmp(op, "AvgPool") == 0)
|| (strcmp(op, "BiasAdd") == 0)
|| (strcmp(op, "Concat") == 0)
|| (strcmp(op, "Const") == 0)
|| (strcmp(op, "Conv2D") == 0)
|| (strcmp(op, "Identity") == 0)
|| (strcmp(op, "LRN") == 0)
|| (strcmp(op, "MatMul") == 0)
|| (strcmp(op, "MaxPool") == 0)
|| (strcmp(op, "NoOp") == 0)
|| (strcmp(op, "Placeholder") == 0)
|| (strcmp(op, "Relu") == 0)
|| (strcmp(op, "Reshape") == 0)
|| (strcmp(op, "Softmax") == 0)
|| (strcmp(op, "_Recv") == 0)
|| (strcmp(op, "_Send") == 0)
;
}
const char kNecessaryOpKernelClasses[] = ","
"AvgPoolingOp<CPUDevice, float>,"
"BiasOp<CPUDevice, float>,"
"ConcatOp<CPUDevice, float>,"
"ConstantOp,"
"Conv2DOp<CPUDevice, float>,"
"IdentityOp,"
"LRNOp<CPUDevice, float>,"
"MatMulOp<CPUDevice, float, false >,"
"MaxPoolingOp<CPUDevice, float>,"
"NoOp,"
"PlaceholderOp,"
"ReluOp<CPUDevice, float>,"
"ReshapeOp,"
"SoftmaxOp<CPUDevice, float>,"
"RecvOp,"
"SendOp,"
;
const bool kRequiresSymbolicGradients = false;
#endif

On Thu, Sep 29, 2016 at 11:39 AM, Sam notifications@github.com wrote:

Re: @lingz https://github.com/lingz : I was able to compile with
selective registration, but it didn't help reduce the file size for me.
Instead, I ended up using the makefile from tensorflow/contrib/makefile,
removing all the unused ops classes from tf_op_files.txt.
https://github.com/tensorflow/tensorflow/blob/f71cc62282bf2e066f9ebd08cf3f605fc98c6e41/tensorflow/contrib/makefile/tf_op_files.txt

For what it's worth, this is what I had in my ops_to_register.h:

const bool kRequiresSymbolicGradients = false;const char* kNecessaryOpKernelClasses =
",Add,"
",Const,"
...
",Softmax,";

constexpr bool ShouldRegisterOp(const char* name) {
return strcmp(name,"Add") ? true :
strcmp(name,"Const") ? true :
...
strcmp(name,"Softmax") ? true :
false;
}

The SHOULD_REGISTER_OP macro from tensorflow/core/framework/
selective_registration.h
https://github.com/tensorflow/tensorflow/blob/f71cc62282bf2e066f9ebd08cf3f605fc98c6e41/tensorflow/core/framework/selective_registration.h
calls ShouldRegisterOp, which returns true if the op should be registered.
The SHOULD_REGISTER_OP_KERNEL macro does a string search in
kNecessaryOpKernelClasses


You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
#3549 (comment),
or mute the thread
https://github.com/notifications/unsubscribe-auth/AQw4wQ7SVEzQUTsU3512180ZgGITJBZTks5qvAXugaJpZM4JXOz-
.

commented

hi @andrewharp

I tried to use print_selective_registration_header.py to write ops to ops_to_register.h. I follow instructions and run:

bazel build tensorflow/python/tools:print_selective_registration_header && bazel-bin/tensorflow/python/tools/print_selective_registration_header --graphs=path/to/graph.pb > ops_to_register.h

Then

bazel build -c opt --copt="-DSELECTIVE_REGISTRATION" \
    --copt="-DSUPPORT_SELECTIVE_REGISTRATION" \
    //tensorflow/contrib/android:libtensorflow_inference.so \
    --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
    --crosstool_top=//external:android/crosstool --cpu=armeabi-v7a

Everything went well, so I use the libtensorflow_inference.so in my Android project. Still when I run the model resnet_101_coco from the model zoo, I get the error:
java.lang.IllegalArgumentException: No OpKernel was registered to support Op 'Round' with these attrs. Registered devices: [CPU], Registered kernels: <no registered kernels>

How to overcome this? Thanks.

@bobeo: What does your ops_to_register.h look like?

commented

hi @andrewharp

I attached the ops_to_register.h. I can see it has 'Round" ops ( || isequal(op, "Round") ) . The final libtensorflow_inference.so is 4MB.

The model I use is faster_rcnn_resnet101_coco from
http://download.tensorflow.org/models/object_detection/faster_rcnn_resnet101_coco_11_06_2017.tar.gz

Thanks.

ops_to_register.h.zip

commented

@andrewharp Have you had time to check on this? Thanks

There are at least couple reasons this could happen:

  1. the kernel for round isn't being compiled on android. you could add a compilation error in tensorflow/core/kernels/cwise_op_round.cc and then run the android build again, to see if it actually compiles that file.

  2. the kernel is being compiled, but for some reason the ops_to_register.h is causing it to not use the Round op. ops_to_register.h works by comparing the string name of the class to the string in the ops_to_register.h. It's possible the compiler for the device is using a different name than the compiler for the host (used to make ops_to_register.h). It can be tricky to get the names out of the device compilation -- one way could be to change the cwise_op_round.cc kernel file and

Add:
#define CLAZ_NAME(x) #x
static_assert(false, CLAZ_NAME(UnaryOp));

then build and look for the class name printed by the assertion failure.

thanks @bobeo ,
I did this process with a few pb graphs including the ones from the example app
and in in time I'm the print_selective_registration_header process i see that the ops_to_register.h file is updated according to the graph but the resulting libtensorflow_inference.so is always in the same size.
i tried all the bazel clean / bazel dump options that i have found, but nothing helped
so it is either a bazel caching issue. or that the build from ops_to_register.h does not work properly.

OK,
just realized that i had 2 ops_to_register.h files ...
one in the /core/framework folder
and one in the project root folder where i ran the print_selective_registration_header build
i think that this is worth clearing in the print_selective_registration_header.h comments...
or changing it in a way that the ops_to_register.h file will be created only in one place no matter where the process is called from

commented

@eli99999 any luck? Can you try this model?
I'm trying all over again, if success I will let you know. Cheers!

commented

@cwhipkey @andrewharp hi guys, do you know if I should compile libtensorflow_inference.so with GPU enabled? Or it doesnt make any difference? Currently I compiled it without GPU because else I got this error

When I run it an Android, it takes like 1 minute to process each frame. Is that normal?