ml-jku / hopfield-layers

Hopfield Networks is All You Need

Home Page:https://ml-jku.github.io/hopfield-layers/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Transformer decoder target mask wrong shape error

kashif opened this issue · comments

Hello, so I have an encoder-decoder setup with a tgt_mask in the decoder as follows:

# enc input [B, C, E]
encoder_association = HopfieldLayer(input_size=E, num_heads=num_heads)
encoder_layer = HopfieldEncoderLayer(
            encoder_association,
            dim_feedforward=E*2,
            dropout=dropout_rate,
            activation=act_type,
        )
transformer_encoder = nn.TransformerEncoder(
            encoder_layer, num_layers=num_encoder_layers
        )

# dec input [B, P, E]
decoder_association_self = HopfieldLayer(
            input_size=E, num_heads=num_heads
    )
decoder_association_cross = HopfieldLayer(
            input_size=P, num_heads=num_heads
        )
decoder_layer = HopfieldDecoderLayer(
            hopfield_association_self=decoder_association_self,
            hopfield_association_cross=decoder_association_cross,
            dim_feedforward=E*2,
            dropout=dropout_rate,
            activation=act_type
        )
transformer_decoder = nn.TransformerDecoder(
            decoder_layer, num_layers=num_decoder_layers
        )

# Transformer
transformer = nn.Transformer(
            d_model=E,
            nhead=num_heads,
            custom_encoder=transformer_encoder,
            custom_decoder=transformer_decoder,
            batch_first=True,
        )

I create the mask via:

tgt_mask = transformer.generate_square_subsequent_mask(P)

And when I run it I get:

 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
--> 314                         raise RuntimeError('The size of the 2D attn_mask is not correct.')

So for example for P=28 I have:

attn_mask.shape
torch.Size([1, 28, 28])

and query has shape:

query.shape
torch.Size([28, B, E])

and key is:

key.shape
torch.Size([1, B, E])

for some reason even though the input to the decoder has tensor shapes:

dec_output = transformer.decoder(
            dec_input, # [B, P, E]
            enc_out,  # [B, C, E]
            tgt_mask=tgt_mask, # [P, P]
        )

Would you know what I am missing? Thanks!

Hi @kashif,

for the default encoder-decoder Transformer setting, one has to use Hopfield instead of HopfieldLayer, as the latter uses learnable parameters as the inputs for the key and query. See

class Hopfield(Module):
"""
Module with underlying Hopfield association.
"""

and

class HopfieldLayer(Module):
"""
Wrapper class encapsulating a trainable but fixed stored pattern, pattern projection and "Hopfield" in
one combined module to be used as a Hopfield-based pooling layer.
"""

for more information. Moreover, the input size of decoder_association_cross needs to be equal to the number of features of a single instance/token, which is E in your case:

decoder_association_cross = Hopfield(input_size=E, num_heads=num_heads)

Please let me know, if the issue is resolved.

thanks @bschaefl let me check and get back to you!

@bschaefl yes sorry it works now after replacing all the HopfieldLayers by Hopfield and the fix thanks!