triton-lang / triton

Development repository for the Triton language and compiler

Home Page:https://triton-lang.org/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

flash-attention gradient calculation fail due to numerical error

hgl71964 opened this issue · comments

adding causal=False to the pytest case in python/tutorial/06-fused-attention.py causes numerical error when computing backward gradient

RTX4090, CUDA 12.4, triton 3.0.0 (107fed4), llvm (765206e050453018e861637a08a4520f29238074)

or is it because no one would actually back-prop when causal=False?