lucidrains / x-transformers

A simple but complete full-attention transformer with a set of promising experimental features from various papers

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Feature request] support num_memory_tokens in ContinuousTransformerWrapper

pfeatherstone opened this issue · comments

Can we support either num_memory_tokens or null key/value in ContinuousTransformerWrapper please?

@pfeatherstone @hugofloresgarcia you can already use null key values by setting attn_num_mem_kv = {num null k/v} on either the Encoder or Decoder

yup i can add it

wow, the continuous wrapper is very popular! had no idea

i think there is a bug. I'll knock up a quick repro

lm = ContinuousTransformerWrapper(
    dim_in              = 4,
    dim_out             = 256+3,
    max_seq_len         = 0,
    num_memory_tokens   = 20,
    attn_layers = Decoder(
        dim = 512,
        depth = 4,
        heads = 4,
        rotary_pos_emb  = True,
        attn_flash      = True,
        use_scalenorm   = True,
        attn_onnxable   = True,
        shift_tokens    = 1
    )
)

x = torch.randn(2, 1024, 4)
l = torch.randint(100, x.shape[1], size=(x.shape[0],))
m = torch.arange(x.shape[1]).unsqueeze(0) < l.unsqueeze(-1)
x = lm(x, mask=m)

I'll file a new bug

@pfeatherstone oh oops, yup, should be fixed in 1.23.4