lucidrains / x-transformers

A simple but complete full-attention transformer with a set of promising experimental features from various papers

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Question: How to implement rel_pos_bias in cross_attention?

alexdemartos opened this issue · comments

Hi,

I am planning to implement relative positional encoding for the 'c' (cross-attention) AttentionLayer.

In my case, the target and context sequences are of the same length and synchronous, so hopefully the relative positional encoding will help the attention to focus on the corresponding context part.

I tried passing rel_pos

out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), return_intermediates = True)
to the 'c' block, as:

          elif layer_type == 'c':
                out, inter = block(x, context = context, mask = mask, context_mask = context_mask, rel_pos = self.rel_pos, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), return_intermediates = True)

While this works well for training, inference results are garbage. Maybe some caching issues? Any help is very much appreciated.

Thanks in advance!

@alexdemartos oh that's interesting, you are the second researcher within the last week that's asked about positions for cross attention

could you expand on what you are doing? i assume this is audio related, given we last interacted when you were using soundstream over at audiolm-pytorch!

is the source and target sequence aligned token by token? also, what type of relative positions are you using?

@alexdemartos you can rule out caching issues by disabling it with this flag

Hi @lucidrains , excellent memory that of yours! :)

Thanks for your quick response. Let me be more precise:

I am using a Transformer Decoder to auto-regressively predict real-valued phoneme-level duration, pitch and energy. I found your ContinuousAutoregressiveWrapper pretty handy for the task, with just minor mods.

The context vector C (encoded phonemes) and the targets (dur/pitch/energy) are of the same length L. First I tried just setting rel_pos=True, but as I got noisy inference predictions, I realized no positional information is added to context, so I thought this might be the issue and tried adding rel_pos to the cross-attention block. I thought this would be a good option given C and targets are time-synchronous.

I am however experiencing the same behaviour as before adding rel_pos to the cross-attention block: training works well, however the inference process seems broken.

I will try disabling caching as you suggested. Thanks for your time!

PD: Training vs inference pitch contours

https://ibb.co/7jJr8gr
https://ibb.co/vzyX9H3

@alexdemartos that's an interesting use case for ContinuousAutoregressiveWrapper! yes, do let me know if disabling cache fixes it or not, and i'll throw some brain cycles into this issue. no guarantees though, as your use case is a bit off the beaten path

@alexdemartos how did turning off the caching go? i thought of a way to generalize relative positions within the attention blocks, so just let me know

Hi @lucidrains . Thanks for chasing this. Actually I didn't manage to turn off the flag you mentioned as it looks this is not available for the ContinuousAutoregressiveWrapper:

https://github.com/lucidrains/x-transformers/blob/b2979195ba496532eb0b7f52616eed178848d8af/x_transformers/continuous.py#L153C28-L153C28

I still didn't manage to get inference working. Tried implementing mask_prob with large dropout (0.5) to prevent the exposure bias from teacher forcing, but this didn't seem to help significantly.

@alexdemartos oh that's right, continuous doesn't have kv cache just yet

ok, so the issue must be unrelated then

@alexdemartos what happens if you remove the relative positions altogether? perhaps give the source and target weight tied absolute positional embedding?

@alexdemartos what happens if you remove the relative positions altogether? perhaps give the source and target weight tied absolute positional embedding?

Testing this next :)

Update: Unfortunately no luck disabling rel_pos_bias either. The results look slightly different, but still garbage.

Training: https://ibb.co/h74wysQ
Inference: https://ibb.co/hKjDLcs

@alexdemartos oh, so it is unrelated to positioning then

Hi. It's been a long time, but I finally found the root of the issue. This doesn't relate to any issue regarding the current library implementation, but an issue on my own implementation of rel_pos_bias in the cross-attention layer of the Transformer Decoder. Anyway, I just wanted to post it here for completion, just in case anyone might be interested in some similar implementation.

I was passing self.rel_pos both to the self-attention and cross-attention layers. However, the RelativePositionBias of the self-attention layer gets causal=True from the decoder parameters, while causal should be false for the cross-attention layer. I solved the issue by just creating a separate self.rel_pos_cross = RelativePositionBias(...causal = False) and passing that one to the 'c' layer.

def __init__(
        self,
        ...
    ):
        ...
        self.rel_pos_cross = RelativePositionBias(scale = dim_head ** 0.5, causal = False, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)

def forward(
        ...
    ):
    ...
    elif layer_type == 'c':
        out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, rel_pos = self.rel_pos_cross, cache = next(iter_attn_cache, None), return_intermediates = True)
    ...

@alexdemartos nice! hope you trained a cool model in the end 😄