Triton Matrix Multiplication example invalid results (return zeros) on Volta
RobertCsordas opened this issue · comments
The Triton Matrix Multiplication example kernel (https://triton-lang.org/main/_downloads/d5fee5b55a64e47f1b5724ec39adf171/03-matrix-multiplication.py, with the non-existent HIP specific primitives removed: https://pastebin.com/GxeharnC) returns zeros on Volta GPUs starting from Triton 2.2. The same kernel works perfectly on A series GPUs on the same machine. It also works on Volta with Triton 2.1 (after replacing accumulator = tl.dot(..., accumulator) with accumulator += tl.dot(...)).
The bug only happens with float16. Float32 works well.
Hi @RobertCsordas thank you for the update regarding dtype, this is very helpful, let me check what is going on with float16 on Volta.
Regarding Triton version PyTorch indeed pins the version for the compatibility reasons, but if there is a problem that was fixed in upstream Triton it's very easy to update the pin
also find the same problem,.
my env is cuda11.6, triton 2.3, torch2.3.1, V100GPU
I simplified the matmul code a bit (removed leaky relu) and left just one config to guarantee equivalence, and dumped the PTX with the working 2.1 and the broken 2.2 Triton. Maybe this can help with debugging. Code: https://pastebin.com/FAL22dH1
Tirton 2.1 ptx (working): https://pastebin.com/6E0wiVbb
Triton 2.2 ptx (broken): https://pastebin.com/XMNJgZYB
I don't speak PTX, but to me it look like the Triton 2.2 PTX is completely missing the code that should call the tensor cores (the 2.1 PTX has a bunch of mma.sync.aligned.m8n8k4 instructions, while the 2.2 one has 0). The invalid 2.2 code is also significantly shorter.
The Triton 2.3.1 PTX is identical to that of 2.2.
EDIT: updated code to dump the TTIR and TTGIR as well: https://pastebin.com/3TtEEPiG
2.1 TTIR: https://pastebin.com/Lz6r02Ft
2.1 TTGIR: https://pastebin.com/0ukha35H
2.1 LLIR: https://pastebin.com/xzZSDs06
2.2 TTIR: https://pastebin.com/06mP1j1j
2.2 TTGIR: https://pastebin.com/k5FfpT7K
2.2 LLIR: https://pastebin.com/FAL22dH1
The TTIRs seems identical, except the register numbers, and 2 instructions:
2.2 has
%9 = arith.cmpi slt, %8, %c8_i32 : i32
%10 = arith.select %9, %8, %c8_i32 : i32
while 2.1:
%9 = arith.minsi %8, %c8_i32 : i32
The ordering and reg numbers in the TTGIR are different, but the general gist seems to be similar. The thing that forgets to do the tt.dot seems to come after these IRs.
The 2.2 LLIRs don't have mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32's, while the 2.1 LLIR has them.
EDIT 2:
I ran the unit tests on (branch release/2.2.x from github), and they fail after some time with segmentation fault, but the results in the mean time are here: https://pastebin.com/PSHDbUs0
GEMM tends to fail.
The results of lit test: https://pastebin.com/MrNRwyqT
I can't run the Ninja test because it searches cmake in /tmp and it can't find it, and I have not yet figured out how to fix it.
EDIT 3: added missing LLIRs
also find the same problem,. my env is cuda11.6, triton 2.3, torch2.3.1, V100GPU
Download the source from https://github.com/triton-lang/triton, check out the right version (e.g git checkout 'release/2.3.x'), find the file called MMAv1.cpp (the location varies with the version), find function convertMMA884, find the line
unsigned numN = mmaLayout.getMMAv1NumOuter(BShape, ALayout.getOpIdx());
and change it to
unsigned numN = mmaLayout.getMMAv1NumOuter(BShape, BLayout.getOpIdx());
Then build the whole thing with pip3 wheel -e python
, install with ```pip3 install``, and it will work just fine!