question about swizzle2d traversal order
scturtle opened this issue · comments
scturtle commented
Should this line:
triton/python/triton/language/standard.py
Line 110 in 9d9ec14
be
new_i = off_i + ((ij % size_gj) % size_g)
?
For the following testing code:
def test(m, n, g):
def swizzle2d(i, j, size_i, size_j, size_g):
ij = i * size_j + j
size_gj = size_g * size_j
group_id = ij // size_gj
off_i = group_id * size_g
size_g = min(size_i - off_i, size_g)
# new_i = off_i + (ij % size_g)
new_i = off_i + ((ij % size_gj) % size_g)
new_j = (ij % size_gj) // size_g
return new_i, new_j
order = [[None] * n for _ in range(m)]
for i in range(m):
for j in range(n):
new_i, new_j = swizzle2d(i, j, m, n, g)
order[new_i][new_j] = i * n + j
for i in range(m):
for j in range(n):
print(order[i][j], end=' ')
print()
test(5, 7, 3)
The original code gives the following outputs. The order in last two lines is a bit weird:
0 3 6 9 12 15 18
1 4 7 10 13 16 19
2 5 8 11 14 17 20
22 24 26 28 30 32 34
21 23 25 27 29 31 33
After modification:
0 3 6 9 12 15 18
1 4 7 10 13 16 19
2 5 8 11 14 17 20
21 23 25 27 29 31 33
22 24 26 28 30 32 34