Akimoto-Cris / TritonBlockSparseMatmul

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

TritonBlockSparseMatmul

A matmul kernel with block-sparsity based on OpenAI-Triton's Matmul) that works in Pytorch.

Currently, it's required to build triton-2.1.0 from source to use the newest block pointer.

Basic Idea

Skip the pruned blocks along K-axis in B when loading. To avoid conditionals inside kernel, we precompute the skiping strides from the B-mask outside the kernel, which can be reused for different input As.

Benchmarking

This implementation is faster than pytorch's native cublas-based matmul on >50% block-sparsity on A100, when BLOCK_SIZE_M=128, BLOCK_SIZE_K=32 and BLOCK_SIZE_N=64,

On 70% sparsity the speedup is almost 2× than cublas under large dimensions.

Related Work

Limitations

At the moment (Jun 23) triton gemm kernels seems still suffering from high CPU overhead, although GPU time shows good results, the real performance during real batched inference application may still be sub-optimal compared to cuBlas.

About

License:MIT License


Languages

Language:Python 100.0%