Print statements inside kernel print incorrect value of int64 tensors
georg-wolflein opened this issue · comments
I came across a bug using int64 tensors. Here's a minimal reproduction.
MWE:
import torch
import triton
import triton.language as tl
@triton.jit
def ndscore_kernel(ptr):
value = tl.load(ptr)
print("value in kernel", value)
tl.store(ptr, value + 1)
ptr = torch.tensor(42, dtype=torch.int64).cuda()
print("value before kernel", ptr.item())
ndscore_kernel[(1,)](ptr)
print("value after kernel", ptr.item())
Output:
value before kernel 42
pid (0, 0, 0) idx () value in kernel: 0
[...]
pid (0, 0, 0) idx () value in kernel: 0
value after kernel 43
Why does the kernel print 0
instead of 42
?
Observations:
- Changing the
dtype
ofptr
totorch.int32
correctly printspid (0, 0, 0) idx () value in kernel: 42
Thank you for the bug report, this looks real, I will see if I can have a look.
I am able to repro this and will work on fixing this.
My initial analysis:
The issue seems to be with the alignment for the store that sets up 'value' for the vprintf.
The llir for this store has an alignment of 4
%2 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09@$2 ld.global.b64 { $0 }, [ $1 + 0 ];", "=l,l,b"(ptr addrspace(1) %0, i1 true) #2, !dbg !12
...
store i64 %2, ptr %9, align 4
This results in the value being split across two ptx stores
@%p1 ld.global.b64 { %rd1 }, [ %rd2 + 0 ];
...
st.local.u32 [%r4+12], %rd1;
shr.u64 %rd6, %rd1, 32;
st.local.u32 [%r4+16], %rd6;
I am still not clear about how this split results in 0 being printed.
I see that the alignment is 8 with NVCC
https://godbolt.org/z/Y81rafYPo
llir/ptx files
4060_i64.ptx.txt
4060_i64.llir.txt
Thanks for the detailed analysis! I suspect the alignment of the struct variable passed into vprintf might be wrong.
In your PTX code, the print format (variable printfFormat_0
) has the value of
"pid (%u, %u, %u) idx () value in kernel: %llu\n"
which is passed in as the first parameter of vprintf
. The second parameter is the address of the struct object, where the last field corresponds to the int64 value. But that field has an alignment of 4 and I think vprint
expects 8, so the lower half of the value is actually considered as padding and only the higher part is printed.
The kernel prints the right value if I run with DISABLE_LLVM_OPT=1.
I see that a GEP rewrite optimization(https://github.com/llvm/llvm-project/blob/90ba33099cbb17e7c159e9ebc5a512037db99d6d/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp#L2456 ) is converting the GEP to 'value', with an index 3 into the arg struct passed to vprintf; to a GEP with offset of 12 into that struct. When I checked to see if the DataLayout passed to the optimization is incorrect, I saw that the DataLayout used while translating to LLVMIR seems to have the wrong alignment for Int64:
If I print llvmModule.getDataLayout().getABIIntegerTypeAlignment(64) inside mlir::translateModuleToLLVMIR(...), I see (llvm::Align) $2 = (ShiftValue = '\x02'). This means int64 alignment specified in the the DataLayout is 4.
I am tracking down the cause for this.