Question about memory coalescing
Kitsunetic opened this issue · comments
Hi, I have a question about memory coalescing while calling tl.load
.
Here is a simple example code to calculate x[:, 0] * x[:, 1] + x[:, 2] * x[:, 3]
from input (N, 4)
matrix:
import torch as th
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 256}),
],
key=["N"],
)
@triton.jit
def column_operation_kernel(x_ptr, y_ptr, N, BLOCK_SIZE: tl.constexpr):
nid = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = nid < N
x0 = tl.load(x_ptr + nid * 4 + 0, mask=mask)
x1 = tl.load(x_ptr + nid * 4 + 1, mask=mask)
x2 = tl.load(x_ptr + nid * 4 + 2, mask=mask)
x3 = tl.load(x_ptr + nid * 4 + 3, mask=mask)
y = x0 * x1 + x2 * x3
tl.store(y_ptr + nid, y, mask=mask)
def column_operation(x: th.Tensor) -> th.Tensor:
assert x.ndim == 2 and x.size(1) == 4
N = x.size(0)
y = x.new_empty(N)
grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),)
column_operation_kernel[grid](x, y, N)
return y
if __name__ == "__main__":
x = th.rand(64, 4, device="cuda")
y_triton = column_operation(x)
y_torch = x[:, 0] * x[:, 1] + x[:, 2] * x[:, 3]
assert th.allclose(y_triton, y_torch), y_triton - y_torch
Since tensor indexing is not supported directly in Triton, I use sequential tl.load calls to process multiple elements in a thread. My concern is whether memory coalescing is occurring with this approach. If not, what changes can I make to ensure memory coalescing and improve the efficiency of memory access in my Triton kernel?
For this example you may get better performance if you split your tensor into 4 pieces using tl.split. On the other hand LLVM may vectorize this load itself, outside of Triton.
In general the best way to debug Triton performance is to look at the LLVM IR, PTX, and SASS. The developers are not generally available to assist.