tfp.stats.histogram cannot be compiled by XLA?
yellowdolphin opened this issue · comments
Summary of problem
Applying tfp.stats.histogram on the data in a tf keras model breaks XLA compilation. The example code (see below) works on CPU/GPU but with TPU strategy raises:
InvalidArgumentError: 9 root error(s) found.
(0) INVALID_ARGUMENT: {{function_node __inference_train_function_3630}} Input 1 to node `sequential/lambda/histogram/count_integers/map/while/bincount/Bincount` with op Bincount must be a compile-time constant.
XLA compilation requires that operator arguments that represent shapes or dimensions be evaluated to concrete values at compile time. This error means that a shape or dimension argument could not be evaluated at compile time, usually because the value of the argument depends on a parameter to the computation, on a variable, or on a stateful operation such as a random number generator.
[[{{node sequential/lambda/histogram/count_integers/map/while/bincount/Bincount}}]]
[[sequential/lambda/histogram/count_integers/map/while]]
[[TPUReplicate/_compile/_10395939635067805685/_4]]
Reproducible example
https://colab.research.google.com/drive/1g9yHihhmcAcwEeE80wWPwyI8W6D6BGfx?usp=sharing
Hi @yellowdolphin , did you try to dig in a bit? This would be very helpful! As the error suggests, it looks to me that count_integers
and inside that bincount
creates some troubles, can you compile these separately?