apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators

Home Page:https://tvm.apache.org/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Bug] Tensorization Failure During Multilevel Tiling with Tensor Intrin

zxybazh opened this issue · comments

Expected behavior

MetaSchedule Tuning Works for the given Conv2d workload

Actual behavior

Triggers an error ValueError: The block no longer exists in the IRModule during application of schedule rule Multi-level tiling with tensor intrin. I notcied that state->tensor_core_reindex_store would point to a block that is already merged into another block via ComputeInline during application of TileWithTensorIntrin.

Environment

Latest TVM Main

Steps to reproduce

import tvm
from tvm.script import tir as T
from tvm import meta_schedule as ms

@T.prim_func(private=True)
def fused_conv2d_add1(reshape3: T.Buffer((T.int64(50), T.int64(8), T.int64(72), T.int64(128)), "float16"), conv_in_weight: T.Buffer((T.int64(320), T.int64(8), T.int64(3), T.int64(3)), "float16"), lv23: T.Buffer((T.int64(1), T.int64(320), T.int64(1), T.int64(1)), "float16"), T_add_intermediate: T.Buffer((T.int64(50), T.int64(320), T.int64(72), T.int64(128)), "float16")):
    T.func_attr({"tir.noalias": T.bool(True)})
    # with T.block("root"):
    pad_temp = T.alloc_buffer((T.int64(50), T.int64(8), T.int64(74), T.int64(130)), "float16")
    conv2d_nchw_intermediate = T.alloc_buffer((T.int64(50), T.int64(320), T.int64(72), T.int64(128)), "float16")
    for i0, i1, i2, i3 in T.grid(T.int64(50), T.int64(8), T.int64(74), T.int64(130)):
        with T.block("pad_temp"):
            v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
            T.reads(reshape3[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1)])
            T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
            pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(T.int64(1) <= v_i2 and v_i2 < T.int64(73) and T.int64(1) <= v_i3 and v_i3 < T.int64(129), reshape3[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1)], T.float16(0))
    for nn, ff, yy, xx, rc, ry, rx in T.grid(T.int64(50), T.int64(320), T.int64(72), T.int64(128), T.int64(8), T.int64(3), T.int64(3)):
        with T.block("conv2d_nchw"):
            v_nn, v_ff, v_yy, v_xx, v_rc, v_ry, v_rx = T.axis.remap("SSSSRRR", [nn, ff, yy, xx, rc, ry, rx])
            T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx], conv_in_weight[v_ff, v_rc, v_ry, v_rx])
            T.writes(conv2d_nchw_intermediate[v_nn, v_ff, v_yy, v_xx])
            with T.init():
                conv2d_nchw_intermediate[v_nn, v_ff, v_yy, v_xx] = T.float16(0)
            conv2d_nchw_intermediate[v_nn, v_ff, v_yy, v_xx] = conv2d_nchw_intermediate[v_nn, v_ff, v_yy, v_xx] + pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx] * conv_in_weight[v_ff, v_rc, v_ry, v_rx]
    for ax0, ax1, ax2, ax3 in T.grid(T.int64(50), T.int64(320), T.int64(72), T.int64(128)):
        with T.block("T_add"):
            v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
            T.reads(conv2d_nchw_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv23[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
            T.writes(T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
            T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = conv2d_nchw_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] + lv23[T.int64(0), v_ax1, T.int64(0), T.int64(0)]

target=tvm.target.Target("nvidia/nvidia-a10g")
func = fused_conv2d_add1
ms.tune_tir(func, target=target, max_trials_global=100, work_dir="./temp")

Hi @zxybazh , I am facing the same issue, when I try to metaschedule a resnet-50 relay workload.