Modify FLOPs in MFU calculation for casual mask when using FlashAttention.
Yuxin-CV opened this issue · comments
Hi, I suggest we modify the FLOPs calculation in the MFU according to the FlashAttention benchmark script.
Specifically, the current calculation for the casual mask can exceed 100% MFU for seq_len = 16k (189 * 2 / 312 = 1.21), which is inaccurate. The FLOPs for the casual mask setting should be divided by 2 when using FlashAttention.
![flash2_a100_fwd_bwd_benchmark](https://private-user-images.githubusercontent.com/57927171/331480542-56465aed-eafd-4618-943d-0ee39baac294.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MjEzOTQ4NzQsIm5iZiI6MTcyMTM5NDU3NCwicGF0aCI6Ii81NzkyNzE3MS8zMzE0ODA1NDItNTY0NjVhZWQtZWFmZC00NjE4LTk0M2QtMGVlMzliYWFjMjk0LnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNDA3MTklMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjQwNzE5VDEzMDkzNFomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWVmYjQxZTEwMDc2ZDkwMzc5OGY2M2JjYzBjOWJhMDIyNmM0ZTJhN2I3NzNiNDU0OTkwNDA2YjdjOWU2NGJmNDgmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JmFjdG9yX2lkPTAma2V5X2lkPTAmcmVwb19pZD0wIn0.gPjWYEse4Nbhe9Eq8y_djzHMP2e6yFOKrkQWEzC68RY)