lucidrains / x-transformers

A simple but complete full-attention transformer with a set of promising experimental features from various papers

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How can I add custom attention masks to a Decoder?

DerEchteFeuerpfeil opened this issue · comments

Hey there, I am in the position where I need to mask out certain positions of my input manually, so the usual causal mask does not suffice. I am using the Decoder class and no Transformer wrapper as I need to feed in the embeddings manually.
self.transformer= Decoder( dim = self.dim, depth = self.depth, heads = self.heads, attn_flash = True )

I tried to manually create an attention mask following the usual procedure, but when I pass that into the forward function, my outputs all become NaN:
image

There is also the mask besides the attention mask, can I just use this instead?
e.g. have a torch.ones(1, 256).bool().to(device) mask for my (1, 256, 512) embedding input that would then internally be converted to the attention mask?

I would greatly appreciate if you could provide me with code or a shape that I need to create for the default causal attention that I can then modify. I am also not quite sure if it needs to be boolean and if True means attend and False means not attend.

@DerEchteFeuerpfeil hey Moritz! so i had the same confusion as you starting out in the field with how the mask is represented

so i decided i would not perpetuate this confusion. for all my repos, without exception, masking is always True for attend and False for not. this makes sense to me because if one were to cast it to a float, one can also multiply it with your input for effectively masking in non-attention scenarios

in your example, it should work if you just invert your mask with a ~, but let me know if it does not

Hey @lucidrains , thanks for the quick reply!

Looks like I was so close to getting it right 😉 This implementation also makes the most sense to me.
Works like a charm now 👍
image

@DerEchteFeuerpfeil nice! go train something amazing :)