CyberZHG / keras-self-attention

Attention mechanism for processing sequential data that considers the context for each timestamp.

Home Page:https://pypi.org/project/keras-self-attention/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

"Tuple index out of range" when using SeqWeightedAttention

Hellisotherpeople opened this issue · comments

elif keras_mode == "RNN":
            model.add(Reshape((1, list_of_embeddings[1].size), input_shape = Emb_train.shape[1:])) 
            model.add(Bidirectional(GRU(list_of_embeddings[1].size, activation = 'relu'))) ##this works too - seems to be better for smaller datasets too!
            model.add(SeqWeightedAttention())
            model.add(Dense(len(np.unique(Y_val)),activation='softmax',kernel_initializer=kernel_initializer, use_bias = False))
Traceback (most recent call last):
  File "classification.py", line 182, in <module>
    pipe.fit(X_train, Y_train)
  File "/usr/lib/python3.7/site-packages/sklearn/pipeline.py", line 267, in fit
    self._final_estimator.fit(Xt, y, **fit_params)
  File "/usr/lib/python3.7/site-packages/keras/wrappers/scikit_learn.py", line 210, in fit
    return super(KerasClassifier, self).fit(x, y, **kwargs)
  File "/usr/lib/python3.7/site-packages/keras/wrappers/scikit_learn.py", line 141, in fit
    self.model = self.build_fn(**self.filter_sk_params(self.build_fn))
  File "classification.py", line 144, in create_model
    model.add(SeqWeightedAttention())
  File "/usr/lib/python3.7/site-packages/keras/engine/sequential.py", line 181, in add
    output_tensor = layer(self.outputs[0])
  File "/usr/lib/python3.7/site-packages/keras/engine/base_layer.py", line 431, in __call__
    self.build(unpack_singleton(input_shapes))
  File "/usr/lib/python3.7/site-packages/keras_self_attention/seq_weighted_attention.py", line 27, in build
    self.W = self.add_weight(shape=(int(input_shape[2]), 1),
IndexError: tuple index out of range

I had the same problem with SeqSelfAttention and I tried this instead per your issue tracker and it wasn't fixed

See #22.

return_sequence=True

See #22.

return_sequence=True

@CyberZHG This solution does not work as well, what might be the problem?

@Hellisotherpeople what was the solution?