[Feature request] support num_memory_tokens in ContinuousTransformerWrapper
pfeatherstone opened this issue · comments
pfeatherstone commented
Can we support either num_memory_tokens
or null key/value in ContinuousTransformerWrapper
please?
Hugo Flores García commented
+1!
Phil Wang commented
@pfeatherstone @hugofloresgarcia you can already use null key values by setting attn_num_mem_kv = {num null k/v}
on either the Encoder
or Decoder
Phil Wang commented
yup i can add it
wow, the continuous wrapper is very popular! had no idea
pfeatherstone commented
i think there is a bug. I'll knock up a quick repro
pfeatherstone commented
lm = ContinuousTransformerWrapper(
dim_in = 4,
dim_out = 256+3,
max_seq_len = 0,
num_memory_tokens = 20,
attn_layers = Decoder(
dim = 512,
depth = 4,
heads = 4,
rotary_pos_emb = True,
attn_flash = True,
use_scalenorm = True,
attn_onnxable = True,
shift_tokens = 1
)
)
x = torch.randn(2, 1024, 4)
l = torch.randint(100, x.shape[1], size=(x.shape[0],))
m = torch.arange(x.shape[1]).unsqueeze(0) < l.unsqueeze(-1)
x = lm(x, mask=m)
pfeatherstone commented
I'll file a new bug
Phil Wang commented
@pfeatherstone oh oops, yup, should be fixed in 1.23.4