[BUG] rfactor a splited axis resulting incorrect schedule
vinx13 opened this issue · comments
When a reduction axis is split by its length (the outer extent is 1) and then rfactoring the inner axis, the result is wrong.
import tvm
from tvm import tir
from tvm.script import ty
@tvm.script.tir
def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, [128, 128])
B = tir.match_buffer(b, [128, 128])
C = tir.match_buffer(c, [128, 128])
with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]:
with tir.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
s = tir.create_schedule(matmul)
C = s.get_block("update")
i, j, k = s.get_axes(C)
ko, ki = s.split(k, 128)
s.rfactor(ki, 2)
print(tvm.script.asscript(s.func))
## OUTPUT
@tvm.script.tir
def func(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
C = tir.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1)
A = tir.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1)
B = tir.match_buffer(b, [128, 128], elem_offset=0, align=128, offset_factor=1)
# body
with tir.block([], "root") as []:
tir.reads([])
tir.writes([])
C_rf = tir.buffer_allocate([128, 128, 128], elem_offset=0, align=128, offset_factor=1)
for i2_inner, i0, i1, i2_outer in tir.grid(128, 128, 128, 1):
with tir.block([128, 128, 128], "update") as [vi, vj, vi2_inner]:
tir.bind(vi, i0)
tir.bind(vj, i1)
tir.bind(vi2_inner, i2_inner)
tir.reads([C_rf[vi:(vi + 1), vj:(vj + 1), vi2_inner:(vi2_inner + 1)], A[vi:(vi + 1), 0:1], B[vj:(vj + 1), 0:1]])
tir.writes([C_rf[vi:(vi + 1), vj:(vj + 1), vi2_inner:(vi2_inner + 1)]])
with tir.init():
C_rf[vi, vj, vi2_inner] = tir.float32(0)
C_rf[vi, vj, vi2_inner] = (C_rf[vi, vj, vi2_inner] + (A[vi, 0]*B[vj, 0])) # SHOULD BE A[vi, vi2_inner] * B[vj, vi2_inner]
for i0_1, i1_1, i2_inner_1 in tir.grid(128, 128, 128):
with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi_1, vj_1, vi2_inner_1]:
tir.bind(vi_1, i0_1)
tir.bind(vj_1, i1_1)
tir.bind(vi2_inner_1, i2_inner_1)
tir.reads([C[vi_1:(vi_1 + 1), vj_1:(vj_1 + 1)], C_rf[vi_1:(vi_1 + 1), vj_1:(vj_1 + 1), vi2_inner_1:(vi2_inner_1 + 1)]])
tir.writes([C[vi_1:(vi_1 + 1), vj_1:(vj_1 + 1)]])
with tir.init():
C[vi_1, vj_1] = tir.float32(0)
C[vi_1, vj_1] = (C[vi_1, vj_1] + C_rf[vi_1, vj_1, vi2_inner_1])
These lines might be related: https://github.com/Hzfengsy/tvm-tensorir/blob/master/src/tir/schedule/schedule_reduction.cc#L421-L423
@junrushao1994 @spectrometerHBH
I'm not quite familiar with the rfactor primitive... But here I found another issue...
In the first "update" block of the above output, there is
with tir.block([128, 128, 128], "update") as [vi, vj, vi2_inner]:
tir.bind(vi, i0)
tir.bind(vj, i1)
tir.bind(vi2_inner, i2_inner)
tir.reads([C_rf[vi:(vi + 1), vj:(vj + 1), vi2_inner:(vi2_inner + 1)], A[vi:(vi + 1), 0:1], B[vj:(vj + 1), 0:1]])
tir.writes([C_rf[vi:(vi + 1), vj:(vj + 1), vi2_inner:(vi2_inner + 1)]])
with tir.init():
C_rf[vi, vj, vi2_inner] = tir.float32(0)
C_rf[vi, vj, vi2_inner] = (C_rf[vi, vj, vi2_inner] + (A[vi, 0]*B[vj, 0])) # SHOULD BE A[vi, vi2_inner] * B[vj, vi2_inner]
This block is not a reduction block (the its body is not a reduction, and it has no kCommReduce
block var), but there's a init
in this block, which is not expected.
Moreover, the issue mentioned by @vinx13 also exists even if there's no split of the reduction axis.
See the following code.
import tvm
from tvm import tir
from tvm.script import ty
@tvm.script.tir
def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, [128, 128])
B = tir.match_buffer(b, [128, 128])
C = tir.match_buffer(c, [128, 128])
with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]:
with tir.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
s = tir.create_schedule(matmul)
C = s.get_block("update")
i, j, k = s.get_axes(C)
s.rfactor(k, 2)
print(tvm.script.asscript(s.func))
# OUTPUT
@tvm.script.tir
def func(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
C = tir.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1)
B = tir.match_buffer(b, [128, 128], elem_offset=0, align=128, offset_factor=1)
A = tir.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1)
# body
with tir.block([], "root") as []:
tir.reads([])
tir.writes([])
C_rf = tir.buffer_allocate([128, 128, 128], elem_offset=0, align=128, offset_factor=1)
for i2, i0, i1 in tir.grid(128, 128, 128):
with tir.block([128, 128, 128], "update") as [vi, vj, vi2]:
tir.bind(vi, i0)
tir.bind(vj, i1)
tir.bind(vi2, i2)
tir.reads([C_rf[vi:(vi + 1), vj:(vj + 1), vi2:(vi2 + 1)], A[vi:(vi + 1), 0:1], B[vj:(vj + 1), 0:1]])
tir.writes([C_rf[vi:(vi + 1), vj:(vj + 1), vi2:(vi2 + 1)]])
with tir.init():
C_rf[vi, vj, vi2] = tir.float32(0)
C_rf[vi, vj, vi2] = (C_rf[vi, vj, vi2] + (A[vi, 0]*B[vj, 0])) # <=== Here is the issue
for i0_1, i1_1, i2_1 in tir.grid(128, 128, 128):
with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi_1, vj_1, vi2_1]:
tir.bind(vi_1, i0_1)
tir.bind(vj_1, i1_1)
tir.bind(vi2_1, i2_1)
tir.reads([C[vi_1:(vi_1 + 1), vj_1:(vj_1 + 1)], C_rf[vi_1:(vi_1 + 1), vj_1:(vj_1 + 1), vi2_1:(vi2_1 + 1)]])
tir.writes([C[vi_1:(vi_1 + 1), vj_1:(vj_1 + 1)]])
with tir.init():
C[vi_1, vj_1] = tir.float32(0)
C[vi_1, vj_1] = (C[vi_1, vj_1] + C_rf[vi_1, vj_1, vi2_1])
BTW since I don't know what the integer argument in the rfactor primitive means (that is, what does 2
in s.rfactor(ki, 2)
mean), I don't know whether s.rfactor(ki, 2)
is a valid usage of rfactor.
It's a known bug that rfactor inner axis will cause error, I will take care of it.
Fixed by #322.