deepseek-ai / DeepGEMM

DeepGEMM: clean and efficient FP8 GEMM kernels with fine-grained scaling

Repository from Github https://github.comdeepseek-ai/DeepGEMMRepository from Github https://github.comdeepseek-ai/DeepGEMM

[bug] Unspecified Launch Failure of Normal GEMM when n=16

Chtholly-Boss opened this issue · comments

The refactoring of the repo seems to introduce a bug when n=16, here is the minimal reproducible steps:

  1. in tests/generators.py, modify the normal generator as:
def enumerate_normal() -> Generator:
    for kernel_type in get_kernel_types():
        # for m in (128, 4096):
        #     for n, k in [(2112, 7168), (24576, 1536), (32768, 512), (7168, 16384), (4096, 7168), (7168, 2048)]:
        for m in (576, ):
            for n, k in [(16, 7168), ]:
                for major_a, major_b in get_major_ab(False):
                    for out_dtype in get_out_dtype():
                        for accumulate in (False, ) if out_dtype == torch.bfloat16 or kernel_type.is_1d2d() else (False, True):
                            yield kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype
  1. in tests/test_core.py, only enable the test_gemm:
if __name__ == '__main__':
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.manual_seed(0)
    random.seed(0)

    print('Library path:')
    print(f' > {deep_gemm.__path__}\n')

    test_gemm()
    # test_m_grouped_gemm_contiguous()
    # test_m_grouped_gemm_masked()
    # test_k_grouped_gemm_contiguous()

Then in the project root, run PYTHONPATH=. python3 tests/test_core.py, and we get the output as:

Library path:
 > ['/root/DeepGEMM/deep_gemm']

Testing GEMM:
Traceback (most recent call last):
  File "/root/DeepGEMM/tests/test_core.py", line 171, in <module>
    test_gemm()
  File "/root/DeepGEMM/tests/test_core.py", line 38, in test_gemm
    diff = calc_diff(d, ref_d)
           ^^^^^^^^^^^^^^^^^^^
  File "/root/DeepGEMM/deep_gemm/testing/numeric.py", line 7, in calc_diff
    denominator = (x * x + y * y).sum()
                  ^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: unspecified launch failure
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

DG_JIT_DEBUG messages like follows:

Testing GEMM:
Gemm type: 0, kernel type: 1, M: 576, N: 16, K: 7168, groups: 1, A major: 0, B major: 0, AB dtype: Float8_e4m3fn, CD dtype: BFloat16, accumulation: 0, SM limit: 114 -> block M: 64, block N: 16, block K: 128, stages: 12, last stages: 8, SMs: 10, multicast: 2, multicast on A: 1, shared memory: 128432 bytes, swizzle A: 128, swizzle B: 128, swizzle CD: 32, threads: 256
Making TMA desc: global memory: 7168 576, shared memory: 128 64, outer stride: 7168, swizzle: 128, elem size: 1
Making TMA desc: global memory: 7168 16, shared memory: 128 16, outer stride: 7168, swizzle: 128, elem size: 1
Making TMA desc: global memory: 16 576, shared memory: 16 64, outer stride: 16, swizzle: 32, elem size: 2
Making TMA desc: global memory: 576 56, shared memory: 64 1, outer stride: 576, swizzle: 0, elem size: 4
Generated kernel code:
#ifdef __CUDACC_RTC__
#include <deep_gemm/nvrtc_std.cuh>
#else
#include <cuda.h>
#include <string>
#endif

#include <deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh>

using namespace deep_gemm;

static void __instantiate_kernel() {
    auto ptr = reinterpret_cast<void*>(&sm90_fp8_gemm_1d2d_impl<
        0, 16, 7168,
        1,
        64, 16, 128,
        32,
        12, 8,
        128, 128,
        2, true,
        GemmType::Normal
    >);
};

Loading CUBIN: /root/.deep_gemm/cache/kernel.sm90_fp8_gemm_1d2d.d15b11aa7f94a7fd2939a8265515d60e/kernel.cubin
Symbol names: _ZN9deep_gemm23sm90_fp8_gemm_1d2d_implILj0ELj16ELj7168ELj1ELj64ELj16ELj128ELj32ELj12ELj8ELj128ELj128ELj2ELb1ELNS_8GemmTypeE0EEEvPfPijjj14CUtensorMap_stS4_S4_S4_,
Launch kernel with {10, 1} x 256, shared memory: 128432 bytes, cluster: 2, stream: 0

I doubted there might be a bug related to TMA multicast logic in the impl. In this case, we launch 10 SMs but only 9 have works to do. However, when TMA loading done, B8 will wait the empty barrier which need B9 to arrive like:

            // To safely deconstruct distributed shared barriers, we need another round of empty waits
            if constexpr (kNumTMAMulticast > 1) {
                #pragma unroll
                for (uint32_t s = 0; s < kNumStages; ++ s)
                    empty_barriers[s]->wait((scheduler.current_iter * num_iterations + 1) & 1);
            }

But B9 has no work to do, it will never "arrive" the empty barrier.

However, when I try to reproduce this in the repo before refactoring, I failed.
I am confused about what happened...
BTW, when m=2112, this also happens.

Thanks! @yukuai26 please follow this up.

@Chtholly-Boss Hello, thank you for your bug report. We have fixed this issue in our latest PR #149. Please give it another try. Since this bug has been resolved, I will close this issue.