triton-lang / triton

Development repository for the Triton language and compiler

Home Page:https://triton-lang.org/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Print statements inside kernel print incorrect value of int64 tensors

georg-wolflein opened this issue · comments

I came across a bug using int64 tensors. Here's a minimal reproduction.

MWE:

import torch
import triton
import triton.language as tl


@triton.jit
def ndscore_kernel(ptr):
    value = tl.load(ptr)
    print("value in kernel", value)
    tl.store(ptr, value + 1)


ptr = torch.tensor(42, dtype=torch.int64).cuda()
print("value before kernel", ptr.item())
ndscore_kernel[(1,)](ptr)
print("value after kernel", ptr.item())

Output:

value before kernel 42
pid (0, 0, 0) idx () value in kernel: 0
[...]
pid (0, 0, 0) idx () value in kernel: 0
value after kernel 43

Why does the kernel print 0 instead of 42?

Observations:

  • Changing the dtype of ptr to torch.int32 correctly prints pid (0, 0, 0) idx () value in kernel: 42

Thank you for the bug report, this looks real, I will see if I can have a look.

I am able to repro this and will work on fixing this.

My initial analysis:
The issue seems to be with the alignment for the store that sets up 'value' for the vprintf.

The llir for this store has an alignment of 4
%2 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09@$2 ld.global.b64 { $0 }, [ $1 + 0 ];", "=l,l,b"(ptr addrspace(1) %0, i1 true) #2, !dbg !12
...
store i64 %2, ptr %9, align 4

This results in the value being split across two ptx stores
@%p1 ld.global.b64 { %rd1 }, [ %rd2 + 0 ];
...
st.local.u32 [%r4+12], %rd1;
shr.u64 %rd6, %rd1, 32;
st.local.u32 [%r4+16], %rd6;

I am still not clear about how this split results in 0 being printed.
I see that the alignment is 8 with NVCC
https://godbolt.org/z/Y81rafYPo

llir/ptx files
4060_i64.ptx.txt
4060_i64.llir.txt

Thanks for the detailed analysis! I suspect the alignment of the struct variable passed into vprintf might be wrong.

In your PTX code, the print format (variable printfFormat_0) has the value of

"pid (%u, %u, %u) idx () value in kernel: %llu\n"

which is passed in as the first parameter of vprintf. The second parameter is the address of the struct object, where the last field corresponds to the int64 value. But that field has an alignment of 4 and I think vprint expects 8, so the lower half of the value is actually considered as padding and only the higher part is printed.

The kernel prints the right value if I run with DISABLE_LLVM_OPT=1.

I see that a GEP rewrite optimization(https://github.com/llvm/llvm-project/blob/90ba33099cbb17e7c159e9ebc5a512037db99d6d/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp#L2456 ) is converting the GEP to 'value', with an index 3 into the arg struct passed to vprintf; to a GEP with offset of 12 into that struct. When I checked to see if the DataLayout passed to the optimization is incorrect, I saw that the DataLayout used while translating to LLVMIR seems to have the wrong alignment for Int64:
If I print llvmModule.getDataLayout().getABIIntegerTypeAlignment(64) inside mlir::translateModuleToLLVMIR(...), I see (llvm::Align) $2 = (ShiftValue = '\x02'). This means int64 alignment specified in the the DataLayout is 4.
I am tracking down the cause for this.