A matmul kernel with block-sparsity based on OpenAI-Triton's Matmul) that works in Pytorch.
Currently, it's required to build triton-2.1.0 from source to use the newest block pointer.
Skip the pruned blocks along K-axis in B
when loading. To avoid conditionals inside kernel, we precompute the skiping strides from the B-mask outside the kernel, which can be reused for different input A
s.
This implementation is faster than pytorch's native cublas-based matmul on >50% block-sparsity on A100, when BLOCK_SIZE_M=128
, BLOCK_SIZE_K=32
and BLOCK_SIZE_N=64
,
- HuggingFace implemented a blocksparse gemm kernel earlier based on CUTLASS, but unfortunately the speedup isn't satisfactory yet for 50% sparsity.
- OpenAI also implemented one for tensorflow, Pytorch support is unfortunately not available.
At the moment (Jun 23) triton gemm kernels seems still suffering from high CPU overhead, although GPU time shows good results, the real performance during real batched inference application may still be sub-optimal compared to cuBlas.