feifeibear / long-context-attention

USP: Unified (a.k.a. Hybrid, 2D) Sequence Parallel Attention for Long Context Transformers Model Training and Inference

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Hybrid中zigzag ring attention 的数据分片问题

YouYouCoding opened this issue · comments

你好,关于Hybrid Long Attention的地方有个疑惑,
按照zigzag的方法,数据按照gput切分序列的时候,为了各个GPU计算均衡,并不是按照顺序切分的,而是折叠的方法来切分数据:

image

以及对应的代码:https://github.com/zhuzilin/ring-flash-attention/blob/main/test/test_zigzag_ring_flash_attn_func.py#L43

但是貌似在Hybrid中假设use_ulysess_low的情况下,在QKV过完Ulysses Process Group的all to all之后,到ring attention时,这里的QKV似乎还是按照顺序来计算的,也就是实现的并非zigzag ring attentioin?

比如GPU=8卡,数据分成了8个part; ulysses=4(groups=[0,1,2,3], [4,5,6,7]),ring=2,(groups=[0,4],[1,5], [2,6], [3,7]), 在计算ring attetnion的时候,变成了part0和part4的数据在一个group里面算ring attention。

这里是否有问题? 还是我哪里理解的不对,辛苦看下,多谢

https://github.com/feifeibear/long-context-attention/blob/main/test/test_hybrid_qkvpacked_attn.py#L76
具体实现在这
https://github.com/feifeibear/long-context-attention/blob/main/yunchang/comm/extract_local.py#L61

你需要的是这个吧,需要在attention计算前把输入序列按照ulysses和ring degree设置reorder一下。

哦哦,我看的是非qkvpack的版本:
https://github.com/feifeibear/long-context-attention/blob/main/test/test_hybrid_attn.py
这个里面的输入没有做reorder ,所以才有此疑问,原来是已经实现了

哦哦,我看的是非qkvpack的版本: https://github.com/feifeibear/long-context-attention/blob/main/test/test_hybrid_attn.py 这个里面的输入没有做reorder ,所以才有此疑问,原来是已经实现了

这个偷懒了,你如果方便的话,可以改一下,让这个单测可以通过。