OrangeSodahub / Openclip-flashattn

Integrate FlashAttention into CLIP Model

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Openclip-flashattn

Test on left: Tesla T4 right RTX3080, Times = 100, using causal mask


Model: `ViT-L-14::laion2b-s32b-b82k`

|       shape       | baseline(s) | flash_attn(s) | speed up (x) | mean diff |      | baseline(s) | flash_attn(s)   | speed up (x) | mean diff |
|:-----------------:|:-----------:|:-------------:|:------------:|:---------:|      |:-----------:|:---------------:|:------------:|----------:|
| (1, 77)           |     0.81193 |       0.66437 |      1.22211 |   0.00066 |      | 0.66602     | 0.55968         | 1.1899       | 0.000599  |
| (2, 77)           |      0.8264 |       0.70035 |      1.17998 |    0.0007 |      | 0.69619     | 0.60348         | 1.1536       | 0.000570  |
| (4, 77)           |     0.81998 |       0.69887 |       1.1733 |   0.00064 |      | 0.75353     | 0.65664         | 1.1225       | 0.000573  |
| (8, 77)           |     1.05975 |       0.85742 |      1.23597 |   0.00054 |      | 0.73511     | 0.64203         | 1.1449       | 0.000530  |
| (16, 77)          |     2.12992 |       1.68367 |      1.26504 |   0.00057 |      | 0.78381     | 0.72705         | 1.0780       | 0.000545  |
| (1, 3, 224, 224)  |     2.00593 |       1.34507 |      1.49131 |   0.00155 |      | 1.25425     | 1.10564         | 1.1344       | 0.001713  |
| (2, 3, 224, 224)  |     3.74493 |       2.29818 |      1.62952 |   0.00122 |      | 1.32202     | 1.22698         | 1.0775       | 0.001739  |
| (4, 3, 224, 224)  |     7.35365 |       4.44447 |      1.65456 |   0.00159 |      | 2.36880     | 2.11119         | 1.0926       | 0.001227  |
| (8, 3, 224, 224)  |    14.63006 |       9.05604 |       1.6155 |   0.00135 |      | 3.88597     | 3.76666         | 1.0316       | 0.001441  |
| (16, 3, 224, 224) |    29.63732 |      18.29142 |      1.62028 |   0.00155 |      | 7.84954     | 7.20817         | 1.0889       | 0.001533  |

Model: `ViT-B-32::laion2b-s34b-b79k`

|       shape       | baseline(s) | flash_attn(s) | speed up (x) | mean diff |      | baseline(s) | flash_attn(s)   | speed up (x) | mean diff |
|:-----------------:|:-----------:|:-------------:|:------------:|:---------:|      |:-----------:|:---------------:|:------------:|----------:|
| (1, 77)           |     0.42692 |       0.37867 |       1.1274 |   0.00076 |      | 0.73374     | 0.63919         | 1.0914       | 0.000599  |
| (2, 77)           |     0.46548 |       0.42958 |      1.08358 |   0.00077 |      | 0.72493     | 0.66131         | 1.0967       | 0.000570  |
| (4, 77)           |     0.49818 |       0.46378 |      1.07417 |   0.00079 |      | 0.70449     | 0.64658         | 1.0946       | 0.000573  |
| (8, 77)           |     0.48738 |       0.45324 |       1.0753 |   0.00068 |      | 0.71474     | 0.68031         | 1.0902       | 0.000530  |
| (16, 77)          |      0.4764 |       0.44315 |      1.07502 |   0.00067 |      | 0.71016     | 0.70077         | 1.0960       | 0.000545  |
| (1, 3, 224, 224)  |      0.4349 |       0.40392 |       1.0767 |   0.00054 |      | 0.65047     | 0.59596         | 1.0914       | 0.000601  |
| (2, 3, 224, 224)  |      0.4445 |       0.42513 |      1.04555 |   0.00058 |      | 0.67756     | 0.63876         | 1.0607       | 0.000556  |
| (4, 3, 224, 224)  |     0.43605 |       0.41524 |      1.05009 |   0.00054 |      | 0.71534     | 0.67424         | 1.0609       | 0.000552  |
| (8, 3, 224, 224)  |     0.47367 |       0.45316 |      1.04527 |   0.00064 |      | 0.66634     | 0.62738         | 1.0620       | 0.000597  |
| (16, 3, 224, 224) |     0.51586 |       0.50555 |       1.0204 |   0.00058 |      | 0.6718      | 0.63712         | 1.0545       | 0.000578  |

Model: `ViT-B-16::laion400m_e31`

|       shape       | baseline(s) | flash_attn(s) | speed up (x) | mean diff |      | baseline(s) | flash_attn(s)   | speed up (x) | mean diff |
|:-----------------:|:-----------:|:-------------:|:------------:|:---------:|      |:-----------:|:---------------:|:------------:|----------:|
| (1, 77)           |     0.42736 |       0.38143 |      1.12042 |   0.00051 |      | 0.69159     | 0.61381         | 1.1267       | 0.000541  |`
| (2, 77)           |     0.46499 |         0.432 |      1.07635 |   0.00025 |      | 0.77649     | 0.70859         | 1.0958       | 0.000248  |
| (4, 77)           |     0.50052 |       0.46798 |      1.06953 |   0.00045 |      | 0.78205     | 0.71427         | 1.0948       | 0.000425  |
| (8, 77)           |     0.48717 |       0.45675 |      1.06659 |   0.00043 |      | 0.72512     | 0.66471         | 1.0908       | 0.000419  |
| (16, 77)          |     0.47855 |       0.44717 |      1.07017 |    0.0005 |      | 0.72748     | 0.66784         | 1.0893       | 0.000457  |
| (1, 3, 224, 224)  |     0.45754 |       0.38839 |      1.17803 |   0.00044 |      | 0.70322     | 0.59588         | 1.1801       | 0.000410  |
| (2, 3, 224, 224)  |     0.50152 |       0.44523 |      1.12641 |   0.00042 |      | 0.73814     | 0.64270         | 1.1484       | 0.000412  |
| (4, 3, 224, 224)  |     0.56444 |       0.51215 |       1.1021 |   0.00044 |      | 0.75904     | 0.66596         | 1.1397       | 0.000395  |
| (8, 3, 224, 224)  |     0.93341 |       0.82501 |      1.13139 |   0.00045 |      | 1.18015     | 1.03530         | 1.1399       | 0.000427  |
| (16, 3, 224, 224) |     1.77167 |       1.58981 |      1.11439 |   0.00044 |      | 2.22424     | 1.97961         | 1.1235       | 0.000429  |

Model: `ViT-g-14::laion2b-s12b-b42k`

|       shape       | baseline(s) | flash_attn(s) | speed up (x) | mean diff |
|:-----------------:|:-----------:|:-------------:|:------------:|:---------:|
| (1, 77)           |      0.8894 |        0.7984 |      1.11397 |   0.00093 |
| (2, 77)           |     0.89632 |       0.83137 |      1.07812 |   0.00079 |
| (4, 77)           |     0.93494 |       0.87282 |      1.07117 |    0.0008 |
| (8, 77)           |     1.15918 |       1.08828 |      1.06514 |   0.00073 |
| (16, 77)          |     2.05242 |       1.92306 |      1.06727 |   0.00058 |
| (1, 3, 224, 224)  |     1.94179 |       1.85599 |      1.04623 |   0.00115 |
| (2, 3, 224, 224)  |     3.29228 |       3.10451 |      1.06048 |   0.00109 |
| (4, 3, 224, 224)  |     5.87704 |       5.55928 |      1.05715 |   0.00115 |
| (8, 3, 224, 224)  |    10.37121 |        9.8581 |      1.05204 |   0.00101 |
| (16, 3, 224, 224) |    20.85575 |      19.86552 |      1.04984 |   0.00112 |

Model: `ViT-H-14::laion2b-s32b-b79k`

|       shape       | baseline(s) | flash_attn(s) | speed up (x) | mean diff |
|:-----------------:|:-----------:|:-------------:|:------------:|:---------:|
| (1, 77)           |     0.90747 |       0.81617 |      1.11186 |    0.0008 |
| (2, 77)           |     0.90999 |       0.84415 |      1.07799 |   0.00077 |
| (4, 77)           |     0.97304 |        0.9068 |      1.07305 |   0.00053 |
| (8, 77)           |     2.55404 |       2.48595 |      1.02738 |   0.00067 |
| (16, 77)          |     4.36827 |       4.11261 |      1.06216 |   0.00067 |
| (1, 3, 224, 224)  |     1.33663 |       1.26552 |      1.05619 |    0.0009 |
| (2, 3, 224, 224)  |       2.273 |       2.12032 |        1.072 |    0.0008 |
| (4, 3, 224, 224)  |     4.04465 |       3.78626 |      1.06824 |   0.00082 |
| (8, 3, 224, 224)  |     7.23102 |       6.80665 |      1.06234 |    0.0008 |
| (16, 3, 224, 224) |    13.64971 |       12.8397 |      1.06308 |   0.00087 |

About

Integrate FlashAttention into CLIP Model


Languages

Language:Python 100.0%