Masking implementation
tomasmenezes opened this issue · comments
Hi @CyberZHG, I'm using self-attention over an RNN for a classification problem, however I'm a bit confused with the masking implementation and their differences among the provided attention types. I apologize for the size of the post in advance.
To test the masking, I created a placeholder tensor to represent the output hidden states from an RNN with T=6 timesteps [t0,...,t5] and D=3 units, where timesteps t2, t4 and t5 are masked:
h_states = tf.convert_to_tensor(np.array([[[0.5,0.2,0.1],[0.4,0.9,0.3],[-1,-1,-1],[0.1,0.2,0.1], [-1,-1,-1], [-1,-1,-1]]]), dtype='float32')
masked_states = Masking(mask_value=-1)(h_states)
SeqSelfAttention
SeqSelfAttention(return_attention=True)(masked_states)
When calling the additive or dot attention, I was surprised to find that only a_{i,j}, with i,j = [2,4,5] in the [TxT] attention matrix were masked:
SeqSelfAttention(return_attention=True)(masked_states)
[<tf.Tensor: shape=(1, 6, 3), dtype=float32, numpy=
array([[[0.13978598, 0.16256869, 0.06499083],
[0.1412907 , 0.16534805, 0.06593135],
[0.32337117, 0.3761466 , 0.15032762],
[0.14026345, 0.163282 , 0.06523413],
[0.32337117, 0.3761466 , 0.15032762],
[0.32337117, 0.3761466 , 0.15032762]]], dtype=float32)>,
<tf.Tensor: shape=(1, 6, 6), dtype=float32, numpy=
array([[[0.159832 , 0.10862342, 0.18911283, 0.16420609, 0.18911283, 0.18911283],
[0.16049388, 0.11161783, 0.18797404, 0.16396616, 0.18797404, 0.18797404],
[0.36969936, 0.25163805, 0. , 0.37866256, 0. , 0. ],
[0.16022852, 0.10937916, 0.18880568, 0.16397531, 0.18880568, 0.18880568],
[0.36969936, 0.25163805, 0. , 0.37866256, 0. , 0. ],
[0.36969936, 0.25163805, 0. , 0.37866256, 0. , 0. ]]], dtype=float32)>]
- Q1: Shouldn't the [2,4,5] rows and columns be masked entirely instead since the values result from alignments with masked timesteps?
SeqWeightedAttention
SeqWeightedAttention
seems to mask the padding timesteps completely:
SeqWeightedAttention(return_attention=True)(masked_states)
[<tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[0.33272028, 0.43565503, 0.16733001]], dtype=float32)>,
<tf.Tensor: shape=(1, 6), dtype=float32, numpy=array([[0.32931313, 0.33665004, 0. , 0.33403683, 0. , 0. ]], dtype=float32)>]
ScaledDotProductAttention
ScaledDotProductAttention
expectedly returned similar values to Keras' implementation tf.keras.layers.Attention(use_scale=True)
, except for the existing masked timestep values:
ScaledDotProductAttention(return_attention=True)(masked_states)
[<tf.Tensor: shape=(1, 6, 3), dtype=float32, numpy=
array([[[0.34341848, 0.4522895 , 0.17208272],
[0.3484643 , 0.5025628 , 0.18644652],
[0.33333334, 0.43333334, 0.16666667],
[0.33703578, 0.4488316 , 0.17109475],
[0.33333334, 0.43333334, 0.16666667],
[0.33333334, 0.43333334, 0.16666667]]], dtype=float32)>,
<tf.Tensor: shape=(1, 6, 6), dtype=float32, numpy=
array([[[0.33823597, 0.3604136 , 0. , 0.3013504 , 0. , 0. ],
[0.29698637, 0.43223262, 0. , 0.27078095, 0. , 0. ],
[0.33333334, 0.33333334, 0. , 0.33333334, 0. , 0. ],
[0.32598415, 0.3554737 , 0. , 0.31854212, 0. , 0. ],
[0.33333334, 0.33333334, 0. , 0.33333334, 0. , 0. ],
[0.33333334, 0.33333334, 0. , 0.33333334, 0. , 0. ]]], dtype=float32)>]
Here the mask propagates over the columns but not the rows.
Keras Dot Attention
Finally, even though its implementation is supposedly not supported for RNN (as per code documentation), the final result is more aligned with my expected behavior, where the values for the masked timesteps are removed:
Attention(use_scale=True)([masked_states,masked_states])
<tf.Tensor: shape=(1, 6, 3), dtype=float32, numpy=
array([[[0.35038543, 0.46623248, 0.17606643],
[0.35869 , 0.5558893 , 0.20168266],
[0. , 0. , 0. ],
[0.3397184 , 0.46044892, 0.17441398],
[0. , 0. , 0. ],
[0. , 0. , 0. ]]], dtype=float32)>
- Q2: Is there a need to multiply the output of
SeqSelfAttention
orScaledDotAttention
by the initial mask before summing over the timestep dimension to obtain a final vector?
[edit: question wording, example removed]
Is this still relevant? If so, what is blocking it? Is there anything you can do to help move it forward?
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs.