XL-recurrence with RotaryEmbedding and mems not working correctly.
pfeatherstone opened this issue · comments
Note, this follows on from #216
I am trying to do XL-recurrence with:
RotaryEmbedding
attn_num_mem_kv > 0
mems
andreturn_mems
I'm doing a test which checks that the outputs when passing mems=None
and mems=torch.zeros(...)
are the same. They are not.
I'm using the code below:
lm = ContinuousTransformerWrapper(
dim_in = 2,
dim_out = 36,
max_seq_len = 0,
max_mem_len = 100,
attn_layers = Encoder(
dim = 512,
depth = 4,
heads = 4,
rotary_pos_emb = True,
attn_flash = True,
attn_num_mem_kv = 20
)
)
B, M, D, depth = 1, lm.max_mem_len, lm.attn_layers.dim, lm.attn_layers.depth
x = torch.randn(B, 1024, 2)
length = torch.randint(100, x.shape[1], size=(x.shape[0],))
mask = torch.arange(x.shape[1]).unsqueeze(0) < length.unsqueeze(-1)
mems = [torch.zeros(x.shape[0], M, D) for _ in range(depth)]
out1, mems1 = lm(x, mask=mask, return_mems=True)
out2, mems2 = lm(x, mask=mask, mems=mems, return_mems=True)
torch.testing.assert_close(out1, out2)
for m1, m2 in zip(mems1, mems2):
torch.testing.assert_close(m1, m2)
I also tried changing
x-transformers/x_transformers/x_transformers.py
Lines 882 to 884 in 583c19d
to
if exists(input_mask) and exists(mem):
attend = torch.any(mem)
input_mask = pad_at_dim(input_mask, (mem.shape[-2], 0), dim = -1, value = attend)
but that doesn't help.
any ideas?
I also tried changing:
to
freqs = freqs[:seq_len, :]
That made more sense to me. I think this makes the results match a bit better but not perfectly.
If i set:
use_abs_pos_emb=True,
rotary_pos_emb=False
And keep the suggested change
x-transformers/x_transformers/x_transformers.py
Lines 882 to 884 in 583c19d
to
if exists(input_mask) and exists(mem):
attend = torch.any(mem)
input_mask = pad_at_dim(input_mask, (mem.shape[-2], 0), dim = -1, value = attend)
Then it works.
My understanding was that RotaryEmbedding should work in this case. Maybe not.
@lucidrains can you confirm?
@pfeatherstone hey, does the equality work if you turn off rotary embeddings?
If I use
use_abs_pos_emb=True,
rotary_pos_emb=False
with the suggested change it works.
If I use:
rotary_pos_emb=False
it attempts to use AbsolutePositionalEmbedding
which i don't really want.
nice yea, i think i may know what's up. will look into it when i find a stretch of free time
Can you give me a hint? I can try figure out the details
@pfeatherstone i think the memories should be kept at negative positions, so say you have 2 memory tokens and 5 main tokens, the positions should be [-1, -2, 0, 1, 2, 3, 4]
instead of [0..7)
. could be wrong, need to reread my code
I will give it a go
@pfeatherstone what is the magnitude of the error?
The absolute error is around 0.008 on average
@pfeatherstone ok, it is likely what i said then, if you meant 'max' instead of 'average'
@pfeatherstone ok, it is likely what i said then, if you meant 'max' instead of 'average'
sorry, it's actually larger. More like 0.4 max absolute difference.
i'll still try what you suggested
@lucidrains Yes it worked!
So the total changes are:
if exists(input_mask) and exists(mem):
attend = torch.any(mem)
input_mask = pad_at_dim(input_mask, (mem.shape[-2], 0), dim = -1, value = attend)
at line 882 of x_transformers.py
and
if not exists(rotary_pos_emb) and exists(self.rotary_pos_emb):
M = max(list(map(lambda m: m.shape[1] if exists(m) else 0, mems)))
T = x.shape[1]
t = torch.arange(-M, T)
rotary_pos_emb = self.rotary_pos_emb.forward(t)
at line 1257 of x_transformers.py
@pfeatherstone 👏 💯 you mvp
want to try submitting a PR?
@pfeatherstone 👏 💯 you mvp
want to try submitting a PR?
I can do. Without unit tests, PRs are easy ;)
Only thing is that some of the changes aren't ONNX-export friendly...
Also, the line:
attend = torch.any(mem)
doesn't work if any of the batch items is non-zero. So you would need to pad differently for each batch item. I'm looking into a fix
ok, at the very least you got it working for your case
this isn't really that big of a deal
i'll make the correction for rotary when i find some time
thanks for taking the initiative and working it out
So I've fixed the issue of zero mems is the same as not attending to mems at all, and correct rotary embeddings.
The second issue i've come across is that mems are recorded before the pre-norm layer normalization. Yet, on the next iteration, they are prepended after.
I tested it, and i was getting gibberish. I've fixed the issue by recording new mems exactly where old mems are prepended. Now, i get sensible results. FYI, i'm using sandwich norm which uses pre-LN.
@pfeatherstone ah interesting, you find sandwich norms helpful? i tried it for a contracting project and saw worse results
@pfeatherstone yea i saw your PR, but i think it may need to be broken up. i think the zero mems is better dealt with with a mem_mask
input
@pfeatherstone ah interesting, you find sandwich norms helpful? i tried it for a contracting project and saw worse results
To be honest i don't know anymore. The first time I tried it I thought it helped. But it could have been a coincidence.
@pfeatherstone yea i saw your PR, but i think it may need to be broken up. i think the zero mems is better dealt with with a
mem_mask
input
Ok cool. Though the code does create an appropriate mask. it assumes that all zeros shouldn't be attended to. I think that's a sensible default. Would someone want to explicitly attend to zeros ?
@pfeatherstone i don't think there would be any issue, just that a mem_mask
would lead to more flexibility, and solve your problem with needing an initial zero mems, which i assume is onnx related
@pfeatherstone ah interesting, you find sandwich norms helpful? i tried it for a contracting project and saw worse results
To be honest i don't know anymore. The first time I tried it I thought it helped. But it could have been a coincidence.
just rerun twice with a change of a boolean and you'll have your answer
Yeah, It takes a couple days for my models to train. There is a lot of augmentation and therefore randomness. Every time i run an experiment, without changing any parameters, convergence happens at different times and of course i get wildely different results.
So when i'm looking at convergence, it's hard to know if an improvement was sheer luck or a model enhancement. Stability on the other hand is pretty tied to the architecture. In my case, with or without sandwich norm, stability is the same.
could you try the latest version? below runs fine for me now
import torch
from x_transformers import ContinuousTransformerWrapper, Encoder
from x_transformers import ContinuousAutoregressiveWrapper
lm = ContinuousTransformerWrapper(
dim_in = 2,
dim_out = 36,
max_seq_len = 0,
max_mem_len = 100,
attn_layers = Encoder(
dim = 512,
depth = 4,
heads = 4,
rotary_pos_emb = True,
attn_flash = True,
attn_num_mem_kv = 20
)
)
B, M, D, depth = 1, lm.max_mem_len, lm.attn_layers.dim, lm.attn_layers.depth
x = torch.randn(B, 1024, 2)
length = torch.randint(100, x.shape[1], size=(x.shape[0],))
mask = torch.arange(x.shape[1]).unsqueeze(0) < length.unsqueeze(-1)
mems = [torch.zeros(x.shape[0], M, D) for _ in range(depth)]
mem_masks = [torch.zeros(x.shape[0], M, dtype = torch.bool) for _ in range(depth)] # memory mask
out1, mems1 = lm(x, mask=mask, return_mems=True)
out2, mems2 = lm(x, mask=mask, mems=mems, mem_masks = mem_masks, return_mems=True)
torch.testing.assert_close(out1, out2)
for m1, m2 in zip(mems1, mems2):
torch.testing.assert_close(m1, m2)
@pfeatherstone you let me know what you see when you rerun the sandwich norm experiments. thinking about removing it
ok, i'm going to close this issue, i think it is good now
@pfeatherstone noticed you are using an
Encoder
instead of aDecoder
in your example code. you have a working model based on this idea?
I'm actually using a Decoder. I used Encoder for the repro to make things simpler
@pfeatherstone ahh got it, you are using it correctly then, just checking
@pfeatherstone ahh got it, you are using it correctly then, just checking
Out of interest, why would it not be ok to use this with Encoder
. The only difference between Encoder
and Decoder
is whether the mask is causal (triangular) or not. I use Decoder
mainly because I don't want to attend to "future" tokens. which is desirable in a streaming architecture.
@pfeatherstone depends on how you are sampling it