Transformer decoder target mask wrong shape error
kashif opened this issue · comments
Hello, so I have an encoder-decoder setup with a tgt_mask
in the decoder as follows:
# enc input [B, C, E]
encoder_association = HopfieldLayer(input_size=E, num_heads=num_heads)
encoder_layer = HopfieldEncoderLayer(
encoder_association,
dim_feedforward=E*2,
dropout=dropout_rate,
activation=act_type,
)
transformer_encoder = nn.TransformerEncoder(
encoder_layer, num_layers=num_encoder_layers
)
# dec input [B, P, E]
decoder_association_self = HopfieldLayer(
input_size=E, num_heads=num_heads
)
decoder_association_cross = HopfieldLayer(
input_size=P, num_heads=num_heads
)
decoder_layer = HopfieldDecoderLayer(
hopfield_association_self=decoder_association_self,
hopfield_association_cross=decoder_association_cross,
dim_feedforward=E*2,
dropout=dropout_rate,
activation=act_type
)
transformer_decoder = nn.TransformerDecoder(
decoder_layer, num_layers=num_decoder_layers
)
# Transformer
transformer = nn.Transformer(
d_model=E,
nhead=num_heads,
custom_encoder=transformer_encoder,
custom_decoder=transformer_decoder,
batch_first=True,
)
I create the mask via:
tgt_mask = transformer.generate_square_subsequent_mask(P)
And when I run it I get:
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
--> 314 raise RuntimeError('The size of the 2D attn_mask is not correct.')
So for example for P=28
I have:
attn_mask.shape
torch.Size([1, 28, 28])
and query has shape:
query.shape
torch.Size([28, B, E])
and key is:
key.shape
torch.Size([1, B, E])
for some reason even though the input to the decoder has tensor shapes:
dec_output = transformer.decoder(
dec_input, # [B, P, E]
enc_out, # [B, C, E]
tgt_mask=tgt_mask, # [P, P]
)
Would you know what I am missing? Thanks!
Hi @kashif,
for the default encoder-decoder Transformer setting, one has to use Hopfield
instead of HopfieldLayer
, as the latter uses learnable parameters as the inputs for the key and query. See
hopfield-layers/hflayers/__init__.py
Lines 12 to 15 in 1497a4d
and
hopfield-layers/hflayers/__init__.py
Lines 619 to 623 in 1497a4d
for more information. Moreover, the input size of decoder_association_cross
needs to be equal to the number of features of a single instance/token, which is E
in your case:
decoder_association_cross = Hopfield(input_size=E, num_heads=num_heads)
Please let me know, if the issue is resolved.
thanks @bschaefl let me check and get back to you!
@bschaefl yes sorry it works now after replacing all the HopfieldLayer
s by Hopfield
and the fix thanks!