triton-lang / triton

Development repository for the Triton language and compiler

Home Page:https://triton-lang.org/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

question about swizzle2d traversal order

scturtle opened this issue · comments

Should this line:

new_i = off_i + (ij % size_g)

be new_i = off_i + ((ij % size_gj) % size_g)?

For the following testing code:

def test(m, n, g):

    def swizzle2d(i, j, size_i, size_j, size_g):
        ij = i * size_j + j
        size_gj = size_g * size_j
        group_id = ij // size_gj
        off_i = group_id * size_g
        size_g = min(size_i - off_i, size_g)
        # new_i = off_i + (ij % size_g)
        new_i = off_i + ((ij % size_gj) % size_g)
        new_j = (ij % size_gj) // size_g
        return new_i, new_j

    order = [[None] * n for _ in range(m)]
    for i in range(m):
        for j in range(n):
            new_i, new_j = swizzle2d(i, j, m, n, g)
            order[new_i][new_j] = i * n + j

    for i in range(m):
        for j in range(n):
            print(order[i][j], end=' ')
        print()

test(5, 7, 3)

The original code gives the following outputs. The order in last two lines is a bit weird:

0 3 6 9 12 15 18 
1 4 7 10 13 16 19 
2 5 8 11 14 17 20 
22 24 26 28 30 32 34 
21 23 25 27 29 31 33

After modification:

0 3 6 9 12 15 18 
1 4 7 10 13 16 19 
2 5 8 11 14 17 20 
21 23 25 27 29 31 33 
22 24 26 28 30 32 34