flash-attention gradient calculation fail due to numerical error
hgl71964 opened this issue · comments
Guoliang He commented
adding causal=False to the pytest case in python/tutorial/06-fused-attention.py causes numerical error when computing backward gradient
RTX4090, CUDA 12.4, triton 3.0.0 (107fed4), llvm (765206e050453018e861637a08a4520f29238074)
Guoliang He commented
or is it because no one would actually back-prop when causal=False?