tensorflow / tensorflow

An Open Source Machine Learning Framework for Everyone

Home Page:https://tensorflow.org

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Broken ExtensionType interoperability with Keras in 2.16

karelhorak-gen opened this issue · comments

Issue type

Bug

Have you reproduced the bug with TensorFlow Nightly?

No

Source

binary

TensorFlow version

v2.16.1-0-g5bc9d26649c 2.16.1

Custom code

No

OS platform and distribution

MacOS 14.4.1 (Arm)

Mobile device

No response

Python version

No response

Bazel version

No response

GCC/compiler version

No response

CUDA/cuDNN version

No response

GPU model and memory

No response

Current behavior?

The ExtensionType guide suggests that it is possible to create tf.keras.layers.Input for ExtensionTypes using the type_spec argument. This argument, however, seems to be missing in Keras 3 / TF 2.16. It is not clear how the integration with Keras should work in 2.16 (if it is even possible due to the absence of the critical argument).

Standalone code to reproduce the issue

# Taken from the guide for reference (https://www.tensorflow.org/guide/extension_type#keras):

class Network(tf.experimental.BatchableExtensionType):
  shape: tf.TensorShape  # batch shape. A single network has shape=[].
  work: tf.Tensor        # work[*shape, n] = work left to do at node n
  bandwidth: tf.Tensor   # bandwidth[*shape, n1, n2] = bandwidth from n1->n2

  def __init__(self, work, bandwidth):
    self.work = tf.convert_to_tensor(work)
    self.bandwidth = tf.convert_to_tensor(bandwidth)
    work_batch_shape = self.work.shape[:-1]
    bandwidth_batch_shape = self.bandwidth.shape[:-2]
    self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)

  def __repr__(self):
    return network_repr(self)


input_spec = Network.Spec(shape=None,
                          work=tf.TensorSpec(None, tf.float32),
                          bandwidth=tf.TensorSpec(None, tf.float32))
model = tf.keras.Sequential([
    tf.keras.layers.Input(type_spec=input_spec),
    # BalanceNetworkLayer(),
    ])

Relevant log output

TypeError: Input() got an unexpected keyword argument 'type_spec'