Skip to content

[Bug] [MetaSchedule] Failed Verification When Applying Schedule Rules #9

@zxybazh

Description

@zxybazh

During design space generation, specifically in PostOrderApply, the following error is encountered when applying certain schedule rules. Error Trace and ir are also as follows. To reproduce, simple running test_meta_schedule_tune_relay("resnet18", 1, "nvidia/geforce-rtx-3070") in tests/python/unittest/test_meta_schedule_tune_relay.py would work.

Traceback (most recent call last):
  File "/home/zxybazh/tvm-tensorir/tests/python/unittest/test_meta_schedule_tune_relay.py", line 84, in <module>
    test_meta_schedule_tune_relay("resnet18", 1, "nvidia/geforce-rtx-3070")
  File "/home/zxybazh/tvm-tensorir/tests/python/unittest/test_meta_schedule_tune_relay.py", line 63, in test_meta_schedule_tune_relay
    schs: List[Schedule] = tune_relay(
  File "/home/zxybazh/tvm-tensorir/python/tvm/meta_schedule/tune.py", line 706, in tune_relay
    task_scheduler.tune()
  File "/home/zxybazh/tvm-tensorir/python/tvm/meta_schedule/task_scheduler/task_scheduler.py", line 61, in tune
    _ffi_api.TaskSchedulerTune(self)  # type: ignore # pylint: disable=no-member
  File "tvm/_ffi/_cython/./packed_func.pxi", line 323, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 257, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 246, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 163, in tvm._ffi._cy3.core.CALL
tvm._ffi.base.TVMError: Traceback (most recent call last):
  20: TVMFuncCall
        at /home/zxybazh/tvm-tensorir/src/runtime/c_runtime_api.cc:475
  19: tvm::runtime::PackedFunc::CallPacked(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
        at /home/zxybazh/tvm-tensorir/include/tvm/runtime/packed_func.h:1151
  18: std::function<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
        at /usr/include/c++/10/bits/std_function.h:622
  17: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<void (tvm::meta_schedule::TaskScheduler)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
        at /usr/include/c++/10/bits/std_function.h:291
  16: std::enable_if<std::__and_<std::is_void<void>, std::__is_invocable<tvm::runtime::TypedPackedFunc<void (tvm::meta_schedule::TaskScheduler)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}&, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*> >::value, void>::type std::__invoke_r<void, tvm::runtime::TypedPackedFunc<void (tvm::meta_schedule::TaskScheduler)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}&, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*>(tvm::runtime::TypedPackedFunc<void (tvm::meta_schedule::TaskScheduler)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
        at /usr/include/c++/10/bits/invoke.h:153
  15: void std::__invoke_impl<void, tvm::runtime::TypedPackedFunc<void (tvm::meta_schedule::TaskScheduler)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}&, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*>(std::__invoke_other, tvm::runtime::TypedPackedFunc<void (tvm::meta_schedule::TaskScheduler)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
        at /usr/include/c++/10/bits/invoke.h:60
  14: tvm::runtime::TypedPackedFunc<void (tvm::meta_schedule::TaskScheduler)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
        at /home/zxybazh/tvm-tensorir/include/tvm/runtime/packed_func.h:1480
  13: void tvm::runtime::detail::unpack_call<void, 1, tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const*, tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1} const&, tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)
        at /home/zxybazh/tvm-tensorir/include/tvm/runtime/packed_func.h:1421
  12: void tvm::runtime::detail::unpack_call_dispatcher<void, 1, 0, tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>::run<>(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const*, tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1} const&, tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)
        at /home/zxybazh/tvm-tensorir/include/tvm/runtime/packed_func.h:1382
  11: void tvm::runtime::detail::unpack_call_dispatcher<void, 0, 1, tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>::run<tvm::runtime::TVMMovableArgValueWithContext_>(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const*, tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1} const&, tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*, tvm::runtime::TVMMovableArgValueWithContext_&&)
        at /home/zxybazh/tvm-tensorir/include/tvm/runtime/packed_func.h:1410
  10: tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}::operator()(tvm::meta_schedule::TaskScheduler) const
        at /home/zxybazh/tvm-tensorir/include/tvm/runtime/registry.h:239
  9: tvm::meta_schedule::TaskSchedulerNode::Tune()
        at /home/zxybazh/tvm-tensorir/src/meta_schedule/task_scheduler/task_scheduler.cc:112
  8: tvm::meta_schedule::PostOrderApplyNode::GenerateDesignSpace(tvm::IRModule const&)
        at /home/zxybazh/tvm-tensorir/src/meta_schedule/space_generator/post_order_apply.cc:130
  7: tvm::meta_schedule::MultiLevelTilingNode::Apply(tvm::tir::Schedule const&, tvm::tir::BlockRV const&)
        at /home/zxybazh/tvm-tensorir/src/meta_schedule/schedule_rule/multi_level_tiling.cc:302
  6: std::vector<tvm::meta_schedule::State, std::allocator<tvm::meta_schedule::State> > tvm::meta_schedule::SubRule<tvm::meta_schedule::MultiLevelTilingNode::Apply(tvm::tir::Schedule const&, tvm::tir::BlockRV const&)::{lambda(tvm::meta_schedule::State)#4}>(std::vector<tvm::meta_schedule::State, std::allocator<tvm::meta_schedule::State> >, tvm::meta_schedule::MultiLevelTilingNode::Apply(tvm::tir::Schedule const&, tvm::tir::BlockRV const&)::{lambda(tvm::meta_schedule::State)#4})
        at /home/zxybazh/tvm-tensorir/src/meta_schedule/schedule_rule/multi_level_tiling.cc:214
  5: tvm::meta_schedule::MultiLevelTilingNode::Apply(tvm::tir::Schedule const&, tvm::tir::BlockRV const&)::{lambda(tvm::meta_schedule::State)#4}::operator()(tvm::meta_schedule::State) const
        at /home/zxybazh/tvm-tensorir/src/meta_schedule/schedule_rule/multi_level_tiling.cc:302
  4: tvm::meta_schedule::MultiLevelTilingNode::AddReadReuse(tvm::meta_schedule::State) const
        at /home/zxybazh/tvm-tensorir/src/meta_schedule/schedule_rule/multi_level_tiling.cc:492
  3: tvm::tir::TracedScheduleNode::Fuse(tvm::runtime::Array<tvm::tir::LoopRV, void> const&)
        at /home/zxybazh/tvm-tensorir/src/tir/schedule/traced_schedule.cc:161
  2: tvm::tir::ConcreteScheduleNode::Fuse(tvm::runtime::Array<tvm::tir::LoopRV, void> const&)
        at /home/zxybazh/tvm-tensorir/src/tir/schedule/concrete_schedule.cc:343
  1: tvm::tir::ScheduleStateNode::DebugVerify() const
        at /home/zxybazh/tvm-tensorir/src/tir/schedule/state.cc:1077
  0: tvm::tir::VerifyCachedFlags(tvm::tir::ScheduleState const&)
        at /home/zxybazh/tvm-tensorir/src/tir/schedule/analysis/verify.cc:236
  File "/home/zxybazh/tvm-tensorir/src/tir/schedule/analysis/verify.cc", line 236
TVMError: Schedule verification failed. The IR is:
# from tvm.script import tir as T
@tvm.script.ir_module
class Module:
    @T.prim_func
    def vm_mod_fused_nn_conv2d_add_1(placeholder: T.Buffer[(1, 256, 1, 1), "float32"], placeholder_1: T.Buffer[(256, 128, 1, 1), "float32"], placeholder_2: T.Buffer[(1, 128, 38, 38), "float32"], T_add: T.Buffer[(1, 256, 19, 19), "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        # body
        # with T.block("root")
        pad_temp = T.alloc_buffer([1, 128, 38, 38], dtype="float32")
        compute = T.alloc_buffer([1, 256, 19, 19], dtype="float32")
        compute_local = T.alloc_buffer([1, 256, 19, 19], dtype="float32", scope="local")
        pad_temp_shared = T.alloc_buffer([1, 128, 38, 38], dtype="float32", scope="shared")
        for i0, i1, i2, i3 in T.grid(1, 128, 38, 38):
            with T.block("pad_temp"):
                i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
                T.reads(placeholder_2[i0_1, i1_1, i2_1, i3_1])
                T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1])
                pad_temp[i0_1, i1_1, i2_1, i3_1] = placeholder_2[i0_1, i1_1, i2_1, i3_1]
        for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(76, thread="blockIdx.x"):
            for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(2, thread="vthread.x"):
                for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(16, thread="threadIdx.x"):
                    for i4_0, i5_0, i6_0 in T.grid(1, 1, 1):
                        for ax0_ax1_ax2_ax3_fused in T.serial(4736):
                            with T.block("pad_temp_shared"):
                                v0 = T.axis.spatial(1, 0)
                                v1 = T.axis.spatial(128, ax0_ax1_ax2_ax3_fused // 37)
                                v2 = T.axis.spatial(38, ax0_ax1_ax2_ax3_fused % 37)
                                v3 = T.axis.spatial(38, i0_0_i1_0_i2_0_i3_0_fused % 19 * 2 * 2)
                                T.reads(pad_temp[v0, v1, v2, v3])
                                T.writes(pad_temp_shared[v0, v1, v2, v3])
                                T.block_attr({"meta_schedule.cache_type":0})
                                pad_temp_shared[v0, v1, v2, v3] = pad_temp[v0, v1, v2, v3]
                        for i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i6_2, i0_4, i1_4, i2_4, i3_4 in T.grid(32, 1, 1, 1, 1, 19, 1, 4, 1, 1, 1, 2, 1, 1):
                            with T.block("compute"):
                                nn = T.axis.spatial(1, 0)
                                ff = T.axis.spatial(256, i0_0_i1_0_i2_0_i3_0_fused // 19 * 64 + i0_1_i1_1_i2_1_i3_1_fused * 32 + i0_2_i1_2_i2_2_i3_2_fused * 2 + i1_4)
                                yy = T.axis.spatial(19, i2_3)
                                xx = T.axis.spatial(19, i0_0_i1_0_i2_0_i3_0_fused % 19)
                                rc = T.axis.reduce(128, i4_1 * 4 + i4_2)
                                ry = T.axis.reduce(1, 0)
                                rx = T.axis.reduce(1, 0)
                                T.reads(compute_local[nn, ff, yy, xx], pad_temp_shared[nn, rc, yy * 2 + ry, xx * 2 + rx], placeholder_1[ff, rc, ry, rx])
                                T.writes(compute_local[nn, ff, yy, xx])
                                T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS", "workload":["conv2d_nchw.cuda", ["TENSOR", [1, 128, 38, 38], "float32"], ["TENSOR", [256, 128, 1, 1], "float32"], [2, 2], [0, 0, 0, 0], [1, 1], "float32"]})
                                with T.init():
                                    compute_local[nn, ff, yy, xx] = T.float32(0)
                                compute_local[nn, ff, yy, xx] = compute_local[nn, ff, yy, xx] + pad_temp_shared[nn, rc, yy * 2 + ry, xx * 2 + rx] * placeholder_1[ff, rc, ry, rx]
        for ax0, ax1, ax2, ax3 in T.grid(1, 256, 19, 19):
            with T.block("compute_local"):
                v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(compute_local[v0, v1, v2, v3])
                T.writes(compute[v0, v1, v2, v3])
                T.block_attr({"meta_schedule.cache_type":1})
                compute[v0, v1, v2, v3] = compute_local[v0, v1, v2, v3]
        for i0, i1, i2, i3 in T.grid(1, 256, 19, 19):
            with T.block("T_add"):
                ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
                T.reads(compute[ax0, ax1, ax2, ax3], placeholder[ax0, ax1, 0, 0])
                T.writes(T_add[ax0, ax1, ax2, ax3])
                T_add[ax0, ax1, ax2, ax3] = compute[ax0, ax1, ax2, ax3] + placeholder[ax0, ax1, 0, 0]
    

The errors are:
- Wrong region_cover:  (compute, expected=0, actual=1)
- Wrong stage_pipeline:  (root, expected=0, actual=1)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions