triton-lang / triton

Development repository for the Triton language and compiler

Home Page:https://triton-lang.org/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Question about memory coalescing

Kitsunetic opened this issue · comments

Hi, I have a question about memory coalescing while calling tl.load.

Here is a simple example code to calculate x[:, 0] * x[:, 1] + x[:, 2] * x[:, 3] from input (N, 4) matrix:

import torch as th
import triton
import triton.language as tl


@triton.autotune(
    configs=[
        triton.Config({"BLOCK_SIZE": 256}),
    ],
    key=["N"],
)
@triton.jit
def column_operation_kernel(x_ptr, y_ptr, N, BLOCK_SIZE: tl.constexpr):
    nid = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = nid < N

    x0 = tl.load(x_ptr + nid * 4 + 0, mask=mask)
    x1 = tl.load(x_ptr + nid * 4 + 1, mask=mask)
    x2 = tl.load(x_ptr + nid * 4 + 2, mask=mask)
    x3 = tl.load(x_ptr + nid * 4 + 3, mask=mask)

    y = x0 * x1 + x2 * x3

    tl.store(y_ptr + nid, y, mask=mask)


def column_operation(x: th.Tensor) -> th.Tensor:
    assert x.ndim == 2 and x.size(1) == 4
    N = x.size(0)
    y = x.new_empty(N)
    grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),)
    column_operation_kernel[grid](x, y, N)
    return y


if __name__ == "__main__":
    x = th.rand(64, 4, device="cuda")
    y_triton = column_operation(x)
    y_torch = x[:, 0] * x[:, 1] + x[:, 2] * x[:, 3]
    assert th.allclose(y_triton, y_torch), y_triton - y_torch

Since tensor indexing is not supported directly in Triton, I use sequential tl.load calls to process multiple elements in a thread. My concern is whether memory coalescing is occurring with this approach. If not, what changes can I make to ensure memory coalescing and improve the efficiency of memory access in my Triton kernel?

For this example you may get better performance if you split your tensor into 4 pieces using tl.split. On the other hand LLVM may vectorize this load itself, outside of Triton.

In general the best way to debug Triton performance is to look at the LLVM IR, PTX, and SASS. The developers are not generally available to assist.