tlc-pack / tvm-tensorir

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[BUG] rfactor a splited axis resulting incorrect schedule

vinx13 opened this issue · comments

When a reduction axis is split by its length (the outer extent is 1) and then rfactoring the inner axis, the result is wrong.

import tvm
from tvm import tir
from tvm.script import ty
@tvm.script.tir
def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128])
    B = tir.match_buffer(b, [128, 128])
    C = tir.match_buffer(c, [128, 128])
    with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]:
        with tir.init():
            C[vi, vj] = 0.0
        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
s = tir.create_schedule(matmul)
C = s.get_block("update")
i, j, k = s.get_axes(C)
ko, ki = s.split(k, 128)
s.rfactor(ki, 2)
print(tvm.script.asscript(s.func))

## OUTPUT
@tvm.script.tir
def func(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    C = tir.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1)
    A = tir.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1)
    B = tir.match_buffer(b, [128, 128], elem_offset=0, align=128, offset_factor=1)
    # body
    with tir.block([], "root") as []:
        tir.reads([])
        tir.writes([])
        C_rf = tir.buffer_allocate([128, 128, 128], elem_offset=0, align=128, offset_factor=1)
        for i2_inner, i0, i1, i2_outer in tir.grid(128, 128, 128, 1):
            with tir.block([128, 128, 128], "update") as [vi, vj, vi2_inner]:
                tir.bind(vi, i0)
                tir.bind(vj, i1)
                tir.bind(vi2_inner, i2_inner)
                tir.reads([C_rf[vi:(vi + 1), vj:(vj + 1), vi2_inner:(vi2_inner + 1)], A[vi:(vi + 1), 0:1], B[vj:(vj + 1), 0:1]])
                tir.writes([C_rf[vi:(vi + 1), vj:(vj + 1), vi2_inner:(vi2_inner + 1)]])
                with tir.init():
                    C_rf[vi, vj, vi2_inner] = tir.float32(0)
                C_rf[vi, vj, vi2_inner] = (C_rf[vi, vj, vi2_inner] + (A[vi, 0]*B[vj, 0]))  # SHOULD BE A[vi, vi2_inner] * B[vj, vi2_inner]
        for i0_1, i1_1, i2_inner_1 in tir.grid(128, 128, 128):
            with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi_1, vj_1, vi2_inner_1]:
                tir.bind(vi_1, i0_1)
                tir.bind(vj_1, i1_1)
                tir.bind(vi2_inner_1, i2_inner_1)
                tir.reads([C[vi_1:(vi_1 + 1), vj_1:(vj_1 + 1)], C_rf[vi_1:(vi_1 + 1), vj_1:(vj_1 + 1), vi2_inner_1:(vi2_inner_1 + 1)]])
                tir.writes([C[vi_1:(vi_1 + 1), vj_1:(vj_1 + 1)]])
                with tir.init():
                    C[vi_1, vj_1] = tir.float32(0)
                C[vi_1, vj_1] = (C[vi_1, vj_1] + C_rf[vi_1, vj_1, vi2_inner_1])

These lines might be related: https://github.com/Hzfengsy/tvm-tensorir/blob/master/src/tir/schedule/schedule_reduction.cc#L421-L423

@junrushao1994 @spectrometerHBH

I'm not quite familiar with the rfactor primitive... But here I found another issue...
In the first "update" block of the above output, there is

with tir.block([128, 128, 128], "update") as [vi, vj, vi2_inner]:
    tir.bind(vi, i0)
    tir.bind(vj, i1)
    tir.bind(vi2_inner, i2_inner)
    tir.reads([C_rf[vi:(vi + 1), vj:(vj + 1), vi2_inner:(vi2_inner + 1)], A[vi:(vi + 1), 0:1], B[vj:(vj + 1), 0:1]])
    tir.writes([C_rf[vi:(vi + 1), vj:(vj + 1), vi2_inner:(vi2_inner + 1)]])
    with tir.init():
        C_rf[vi, vj, vi2_inner] = tir.float32(0)
    C_rf[vi, vj, vi2_inner] = (C_rf[vi, vj, vi2_inner] + (A[vi, 0]*B[vj, 0]))  # SHOULD BE A[vi, vi2_inner] * B[vj, vi2_inner]

This block is not a reduction block (the its body is not a reduction, and it has no kCommReduce block var), but there's a init in this block, which is not expected.

Moreover, the issue mentioned by @vinx13 also exists even if there's no split of the reduction axis.
See the following code.

import tvm
from tvm import tir
from tvm.script import ty
@tvm.script.tir
def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128])
    B = tir.match_buffer(b, [128, 128])
    C = tir.match_buffer(c, [128, 128])
    with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]:
        with tir.init():
            C[vi, vj] = 0.0
        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
s = tir.create_schedule(matmul)
C = s.get_block("update")
i, j, k = s.get_axes(C)
s.rfactor(k, 2)
print(tvm.script.asscript(s.func))


# OUTPUT
@tvm.script.tir
def func(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    C = tir.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1)
    B = tir.match_buffer(b, [128, 128], elem_offset=0, align=128, offset_factor=1)
    A = tir.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1)
    # body
    with tir.block([], "root") as []:
        tir.reads([])
        tir.writes([])
        C_rf = tir.buffer_allocate([128, 128, 128], elem_offset=0, align=128, offset_factor=1)
        for i2, i0, i1 in tir.grid(128, 128, 128):
            with tir.block([128, 128, 128], "update") as [vi, vj, vi2]:
                tir.bind(vi, i0)
                tir.bind(vj, i1)
                tir.bind(vi2, i2)
                tir.reads([C_rf[vi:(vi + 1), vj:(vj + 1), vi2:(vi2 + 1)], A[vi:(vi + 1), 0:1], B[vj:(vj + 1), 0:1]])
                tir.writes([C_rf[vi:(vi + 1), vj:(vj + 1), vi2:(vi2 + 1)]])
                with tir.init():
                    C_rf[vi, vj, vi2] = tir.float32(0)
                C_rf[vi, vj, vi2] = (C_rf[vi, vj, vi2] + (A[vi, 0]*B[vj, 0]))  # <=== Here is the issue
        for i0_1, i1_1, i2_1 in tir.grid(128, 128, 128):
            with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi_1, vj_1, vi2_1]:
                tir.bind(vi_1, i0_1)
                tir.bind(vj_1, i1_1)
                tir.bind(vi2_1, i2_1)
                tir.reads([C[vi_1:(vi_1 + 1), vj_1:(vj_1 + 1)], C_rf[vi_1:(vi_1 + 1), vj_1:(vj_1 + 1), vi2_1:(vi2_1 + 1)]])
                tir.writes([C[vi_1:(vi_1 + 1), vj_1:(vj_1 + 1)]])
                with tir.init():
                    C[vi_1, vj_1] = tir.float32(0)
                C[vi_1, vj_1] = (C[vi_1, vj_1] + C_rf[vi_1, vj_1, vi2_1])

BTW since I don't know what the integer argument in the rfactor primitive means (that is, what does 2 in s.rfactor(ki, 2) mean), I don't know whether s.rfactor(ki, 2) is a valid usage of rfactor.

It's a known bug that rfactor inner axis will cause error, I will take care of it.

Fixed by #322.