tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow

Home Page:https://www.tensorflow.org/probability/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

tfp.stats.histogram cannot be compiled by XLA?

yellowdolphin opened this issue · comments

Summary of problem

Applying tfp.stats.histogram on the data in a tf keras model breaks XLA compilation. The example code (see below) works on CPU/GPU but with TPU strategy raises:

InvalidArgumentError: 9 root error(s) found.
  (0) INVALID_ARGUMENT: {{function_node __inference_train_function_3630}} Input 1 to node `sequential/lambda/histogram/count_integers/map/while/bincount/Bincount` with op Bincount must be a compile-time constant.

XLA compilation requires that operator arguments that represent shapes or dimensions be evaluated to concrete values at compile time. This error means that a shape or dimension argument could not be evaluated at compile time, usually because the value of the argument depends on a parameter to the computation, on a variable, or on a stateful operation such as a random number generator.

	 [[{{node sequential/lambda/histogram/count_integers/map/while/bincount/Bincount}}]]

	 [[sequential/lambda/histogram/count_integers/map/while]]
	 [[TPUReplicate/_compile/_10395939635067805685/_4]]

Reproducible example

https://colab.research.google.com/drive/1g9yHihhmcAcwEeE80wWPwyI8W6D6BGfx?usp=sharing

Hi @yellowdolphin , did you try to dig in a bit? This would be very helpful! As the error suggests, it looks to me that count_integers and inside that bincount creates some troubles, can you compile these separately?