Broken ExtensionType interoperability with Keras in 2.16
karelhorak-gen opened this issue · comments
Karel Horak commented
Issue type
Bug
Have you reproduced the bug with TensorFlow Nightly?
No
Source
binary
TensorFlow version
v2.16.1-0-g5bc9d26649c 2.16.1
Custom code
No
OS platform and distribution
MacOS 14.4.1 (Arm)
Mobile device
No response
Python version
No response
Bazel version
No response
GCC/compiler version
No response
CUDA/cuDNN version
No response
GPU model and memory
No response
Current behavior?
The ExtensionType guide suggests that it is possible to create tf.keras.layers.Input
for ExtensionTypes using the type_spec
argument. This argument, however, seems to be missing in Keras 3 / TF 2.16. It is not clear how the integration with Keras should work in 2.16 (if it is even possible due to the absence of the critical argument).
Standalone code to reproduce the issue
# Taken from the guide for reference (https://www.tensorflow.org/guide/extension_type#keras):
class Network(tf.experimental.BatchableExtensionType):
shape: tf.TensorShape # batch shape. A single network has shape=[].
work: tf.Tensor # work[*shape, n] = work left to do at node n
bandwidth: tf.Tensor # bandwidth[*shape, n1, n2] = bandwidth from n1->n2
def __init__(self, work, bandwidth):
self.work = tf.convert_to_tensor(work)
self.bandwidth = tf.convert_to_tensor(bandwidth)
work_batch_shape = self.work.shape[:-1]
bandwidth_batch_shape = self.bandwidth.shape[:-2]
self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)
def __repr__(self):
return network_repr(self)
input_spec = Network.Spec(shape=None,
work=tf.TensorSpec(None, tf.float32),
bandwidth=tf.TensorSpec(None, tf.float32))
model = tf.keras.Sequential([
tf.keras.layers.Input(type_spec=input_spec),
# BalanceNetworkLayer(),
])
Relevant log output
TypeError: Input() got an unexpected keyword argument 'type_spec'