deepseek-ai / DeepGEMM

DeepGEMM: clean and efficient FP8 GEMM kernels with fine-grained scaling

Repository from Github https://github.comdeepseek-ai/DeepGEMMRepository from Github https://github.comdeepseek-ai/DeepGEMM

TMA selection of get_best_configs in gemm.py

joyeamd opened this issue · comments

for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'):
        if m >= 512 and is_multicast_legal[i]:
            best_tma_multicast_config = (2, i == 'A')
            break

when best_block_m <= best_block_n, I think need to use n instead of m to choose. Need to change to below:

    for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'):
        dim_size = m if best_block_m > best_block_n else n
        if dim_size >= 512 and is_multicast_legal[i]:
            best_tma_multicast_config = (2, i == 'A')
            break

Sorry for my later reply. m >= 512 is just a heuristic (tested but not theoretical). For now, I am not sure whether your change can bring a speedup. If you test and verify its better performance, a PR is welcome!