triton-lang / triton

Development repository for the Triton language and compiler

Home Page:https://triton-lang.org/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Triton Matrix Multiplication example invalid results (return zeros) on Volta

RobertCsordas opened this issue · comments

The Triton Matrix Multiplication example kernel (https://triton-lang.org/main/_downloads/d5fee5b55a64e47f1b5724ec39adf171/03-matrix-multiplication.py, with the non-existent HIP specific primitives removed: https://pastebin.com/GxeharnC) returns zeros on Volta GPUs starting from Triton 2.2. The same kernel works perfectly on A series GPUs on the same machine. It also works on Volta with Triton 2.1 (after replacing accumulator = tl.dot(..., accumulator) with accumulator += tl.dot(...)).

The bug only happens with float16. Float32 works well.

Hi @RobertCsordas thank you for the update regarding dtype, this is very helpful, let me check what is going on with float16 on Volta.
Regarding Triton version PyTorch indeed pins the version for the compatibility reasons, but if there is a problem that was fixed in upstream Triton it's very easy to update the pin

also find the same problem,.
my env is cuda11.6, triton 2.3, torch2.3.1, V100GPU

I simplified the matmul code a bit (removed leaky relu) and left just one config to guarantee equivalence, and dumped the PTX with the working 2.1 and the broken 2.2 Triton. Maybe this can help with debugging. Code: https://pastebin.com/FAL22dH1

Tirton 2.1 ptx (working): https://pastebin.com/6E0wiVbb
Triton 2.2 ptx (broken): https://pastebin.com/XMNJgZYB

I don't speak PTX, but to me it look like the Triton 2.2 PTX is completely missing the code that should call the tensor cores (the 2.1 PTX has a bunch of mma.sync.aligned.m8n8k4 instructions, while the 2.2 one has 0). The invalid 2.2 code is also significantly shorter.

The Triton 2.3.1 PTX is identical to that of 2.2.

EDIT: updated code to dump the TTIR and TTGIR as well: https://pastebin.com/3TtEEPiG
2.1 TTIR: https://pastebin.com/Lz6r02Ft
2.1 TTGIR: https://pastebin.com/0ukha35H
2.1 LLIR: https://pastebin.com/xzZSDs06
2.2 TTIR: https://pastebin.com/06mP1j1j
2.2 TTGIR: https://pastebin.com/k5FfpT7K
2.2 LLIR: https://pastebin.com/FAL22dH1

The TTIRs seems identical, except the register numbers, and 2 instructions:
2.2 has
%9 = arith.cmpi slt, %8, %c8_i32 : i32
%10 = arith.select %9, %8, %c8_i32 : i32
while 2.1:
%9 = arith.minsi %8, %c8_i32 : i32

The ordering and reg numbers in the TTGIR are different, but the general gist seems to be similar. The thing that forgets to do the tt.dot seems to come after these IRs.

The 2.2 LLIRs don't have mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32's, while the 2.1 LLIR has them.

EDIT 2:

I ran the unit tests on (branch release/2.2.x from github), and they fail after some time with segmentation fault, but the results in the mean time are here: https://pastebin.com/PSHDbUs0
GEMM tends to fail.

The results of lit test: https://pastebin.com/MrNRwyqT

I can't run the Ninja test because it searches cmake in /tmp and it can't find it, and I have not yet figured out how to fix it.

EDIT 3: added missing LLIRs

also find the same problem,. my env is cuda11.6, triton 2.3, torch2.3.1, V100GPU

Download the source from https://github.com/triton-lang/triton, check out the right version (e.g git checkout 'release/2.3.x'), find the file called MMAv1.cpp (the location varies with the version), find function convertMMA884, find the line
unsigned numN = mmaLayout.getMMAv1NumOuter(BShape, ALayout.getOpIdx());
and change it to
unsigned numN = mmaLayout.getMMAv1NumOuter(BShape, BLayout.getOpIdx());
Then build the whole thing with pip3 wheel -e python, install with ```pip3 install``, and it will work just fine!