[Bug] Tensorization Failure During Multilevel Tiling with Tensor Intrin
zxybazh opened this issue · comments
Xiyou Zhou commented
Expected behavior
MetaSchedule Tuning Works for the given Conv2d workload
Actual behavior
Triggers an error ValueError: The block no longer exists in the IRModule
during application of schedule rule Multi-level tiling with tensor intrin. I notcied that state->tensor_core_reindex_store
would point to a block that is already merged into another block via ComputeInline during application of TileWithTensorIntrin
.
Environment
Latest TVM Main
Steps to reproduce
import tvm
from tvm.script import tir as T
from tvm import meta_schedule as ms
@T.prim_func(private=True)
def fused_conv2d_add1(reshape3: T.Buffer((T.int64(50), T.int64(8), T.int64(72), T.int64(128)), "float16"), conv_in_weight: T.Buffer((T.int64(320), T.int64(8), T.int64(3), T.int64(3)), "float16"), lv23: T.Buffer((T.int64(1), T.int64(320), T.int64(1), T.int64(1)), "float16"), T_add_intermediate: T.Buffer((T.int64(50), T.int64(320), T.int64(72), T.int64(128)), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
pad_temp = T.alloc_buffer((T.int64(50), T.int64(8), T.int64(74), T.int64(130)), "float16")
conv2d_nchw_intermediate = T.alloc_buffer((T.int64(50), T.int64(320), T.int64(72), T.int64(128)), "float16")
for i0, i1, i2, i3 in T.grid(T.int64(50), T.int64(8), T.int64(74), T.int64(130)):
with T.block("pad_temp"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(reshape3[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1)])
T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(T.int64(1) <= v_i2 and v_i2 < T.int64(73) and T.int64(1) <= v_i3 and v_i3 < T.int64(129), reshape3[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1)], T.float16(0))
for nn, ff, yy, xx, rc, ry, rx in T.grid(T.int64(50), T.int64(320), T.int64(72), T.int64(128), T.int64(8), T.int64(3), T.int64(3)):
with T.block("conv2d_nchw"):
v_nn, v_ff, v_yy, v_xx, v_rc, v_ry, v_rx = T.axis.remap("SSSSRRR", [nn, ff, yy, xx, rc, ry, rx])
T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx], conv_in_weight[v_ff, v_rc, v_ry, v_rx])
T.writes(conv2d_nchw_intermediate[v_nn, v_ff, v_yy, v_xx])
with T.init():
conv2d_nchw_intermediate[v_nn, v_ff, v_yy, v_xx] = T.float16(0)
conv2d_nchw_intermediate[v_nn, v_ff, v_yy, v_xx] = conv2d_nchw_intermediate[v_nn, v_ff, v_yy, v_xx] + pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx] * conv_in_weight[v_ff, v_rc, v_ry, v_rx]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(50), T.int64(320), T.int64(72), T.int64(128)):
with T.block("T_add"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(conv2d_nchw_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv23[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
T.writes(T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = conv2d_nchw_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] + lv23[T.int64(0), v_ax1, T.int64(0), T.int64(0)]
target=tvm.target.Target("nvidia/nvidia-a10g")
func = fused_conv2d_add1
ms.tune_tir(func, target=target, max_trials_global=100, work_dir="./temp")