[bug] Unspecified Launch Failure of Normal GEMM when n=16
Chtholly-Boss opened this issue · comments
The refactoring of the repo seems to introduce a bug when n=16
, here is the minimal reproducible steps:
- in
tests/generators.py
, modify the normal generator as:
def enumerate_normal() -> Generator:
for kernel_type in get_kernel_types():
# for m in (128, 4096):
# for n, k in [(2112, 7168), (24576, 1536), (32768, 512), (7168, 16384), (4096, 7168), (7168, 2048)]:
for m in (576, ):
for n, k in [(16, 7168), ]:
for major_a, major_b in get_major_ab(False):
for out_dtype in get_out_dtype():
for accumulate in (False, ) if out_dtype == torch.bfloat16 or kernel_type.is_1d2d() else (False, True):
yield kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype
- in
tests/test_core.py
, only enable thetest_gemm
:
if __name__ == '__main__':
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.manual_seed(0)
random.seed(0)
print('Library path:')
print(f' > {deep_gemm.__path__}\n')
test_gemm()
# test_m_grouped_gemm_contiguous()
# test_m_grouped_gemm_masked()
# test_k_grouped_gemm_contiguous()
Then in the project root, run PYTHONPATH=. python3 tests/test_core.py
, and we get the output as:
Library path:
> ['/root/DeepGEMM/deep_gemm']
Testing GEMM:
Traceback (most recent call last):
File "/root/DeepGEMM/tests/test_core.py", line 171, in <module>
test_gemm()
File "/root/DeepGEMM/tests/test_core.py", line 38, in test_gemm
diff = calc_diff(d, ref_d)
^^^^^^^^^^^^^^^^^^^
File "/root/DeepGEMM/deep_gemm/testing/numeric.py", line 7, in calc_diff
denominator = (x * x + y * y).sum()
^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: unspecified launch failure
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
DG_JIT_DEBUG
messages like follows:
Testing GEMM:
Gemm type: 0, kernel type: 1, M: 576, N: 16, K: 7168, groups: 1, A major: 0, B major: 0, AB dtype: Float8_e4m3fn, CD dtype: BFloat16, accumulation: 0, SM limit: 114 -> block M: 64, block N: 16, block K: 128, stages: 12, last stages: 8, SMs: 10, multicast: 2, multicast on A: 1, shared memory: 128432 bytes, swizzle A: 128, swizzle B: 128, swizzle CD: 32, threads: 256
Making TMA desc: global memory: 7168 576, shared memory: 128 64, outer stride: 7168, swizzle: 128, elem size: 1
Making TMA desc: global memory: 7168 16, shared memory: 128 16, outer stride: 7168, swizzle: 128, elem size: 1
Making TMA desc: global memory: 16 576, shared memory: 16 64, outer stride: 16, swizzle: 32, elem size: 2
Making TMA desc: global memory: 576 56, shared memory: 64 1, outer stride: 576, swizzle: 0, elem size: 4
Generated kernel code:
#ifdef __CUDACC_RTC__
#include <deep_gemm/nvrtc_std.cuh>
#else
#include <cuda.h>
#include <string>
#endif
#include <deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {
auto ptr = reinterpret_cast<void*>(&sm90_fp8_gemm_1d2d_impl<
0, 16, 7168,
1,
64, 16, 128,
32,
12, 8,
128, 128,
2, true,
GemmType::Normal
>);
};
Loading CUBIN: /root/.deep_gemm/cache/kernel.sm90_fp8_gemm_1d2d.d15b11aa7f94a7fd2939a8265515d60e/kernel.cubin
Symbol names: _ZN9deep_gemm23sm90_fp8_gemm_1d2d_implILj0ELj16ELj7168ELj1ELj64ELj16ELj128ELj32ELj12ELj8ELj128ELj128ELj2ELb1ELNS_8GemmTypeE0EEEvPfPijjj14CUtensorMap_stS4_S4_S4_,
Launch kernel with {10, 1} x 256, shared memory: 128432 bytes, cluster: 2, stream: 0
I doubted there might be a bug related to TMA multicast logic in the impl. In this case, we launch 10 SMs but only 9 have works to do. However, when TMA loading done, B8 will wait the empty barrier which need B9 to arrive like:
// To safely deconstruct distributed shared barriers, we need another round of empty waits
if constexpr (kNumTMAMulticast > 1) {
#pragma unroll
for (uint32_t s = 0; s < kNumStages; ++ s)
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + 1) & 1);
}
But B9 has no work to do, it will never "arrive" the empty barrier.
However, when I try to reproduce this in the repo before refactoring, I failed.
I am confused about what happened...
BTW, when m=2112, this also happens.
Thanks! @yukuai26 please follow this up.
@Chtholly-Boss Hello, thank you for your bug report. We have fixed this issue in our latest PR #149. Please give it another try. Since this bug has been resolved, I will close this issue.