CyberZHG / keras-self-attention

Attention mechanism for processing sequential data that considers the context for each timestamp.

Home Page:https://pypi.org/project/keras-self-attention/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Compatibility with `tf.keras`

nshaud opened this issue · comments

I have been looking into self-attention using TensorFlow. More specifically I use the Keras API which is integrated the tf.keras module.

I have tried both the Sequential and Functional API to no avail:

text_inputs = tf.keras.layers.Input(shape=(None,))
embd_layer = tf.keras.layers.Embedding(input_dim=VOCAB_SIZE,
                                output_dim=EMBEDDING_DIM,
                                mask_zero=True,
                                weights=None,
                                trainable=None is None,
                                name='Embedding')(text_inputs)
lstm_layer = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(units=512,
                                                      recurrent_dropout=0.4,
                                                      return_sequences=True),
                                                      name='Bi-LSTM')(embd_layer)
attention_layer = SeqSelfAttention(attention_activation='sigmoid',
                               attention_width=9,
                               return_attention=False,
                               name='Attention')(lstm_layer)

returns TypeError: The added layer must be an instance of class Layer. Found: <keras_self_attention.seq_self_attention.SeqSelfAttention object at 0x7f87ee16bd30> (I think because TensorFlow expects a tf.keras.Layer object).

And using the Functional API:

text_inputs = tf.keras.layers.Input(shape=(SEQ_LENGTH,))
x = tf.keras.layers.Embedding(VOCAB_SIZE, EMBEDDING_DIM, input_length=SEQ_LENGTH)(text_inputs)
x = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(512, return_sequences=True))(x)
x = SeqSelfAttention(attention_activation='sigmoid')(x)

returns ValueError: Layer Attention was called with an input that isn't a symbolic tensor. Received type: <class 'tensorflow.python.keras.engine.base_layer.DeferredTensor'>. Full input: [<DeferredTensor 'None' shape=(?, ?, 1024) dtype=float32>]. All inputs to the layer should be tensors

Any clue? Is it because I am not using Keras but tf.keras instead?

Add TF_KERAS=1 to your environment variables.

This does the trick, thanks.