Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug][MetaSchedule] Failed Post Processor During Bert Tuning For vm_mod_fused_nn_softmax #10

Closed
zxybazh opened this issue Jan 13, 2022 · 5 comments

Comments

@zxybazh
Copy link
Collaborator

zxybazh commented Jan 13, 2022

During the post processing the design space generated for vm_mod_fused_nn_softmax on bert base tuning, I encountered the following error. It seems to be a failing case of the rewrite_cooperative_fetch post processor.

Will also add the TIR script below.

[16:25:09] /home/zxybazh/tvm-tensorir/src/meta_schedule/task_scheduler/task_scheduler.cc:127: Scheduler picks Task #10: vm_mod_fused_nn_softmax
Traceback (most recent call last):
  File "/home/zxybazh/tvm-tensorir/tests/python/unittest/test_meta_schedule_tune_relay.py", line 88, in <module>
    test_meta_schedule_tune_relay("bert_base", 1, "nvidia/geforce-rtx-3070")
  File "/home/zxybazh/tvm-tensorir/tests/python/unittest/test_meta_schedule_tune_relay.py", line 66, in test_meta_schedule_tune_relay
    schs: List[Schedule] = tune_relay(
  File "/home/zxybazh/tvm-tensorir/python/tvm/meta_schedule/tune.py", line 706, in tune_relay
    task_scheduler.tune()
  File "/home/zxybazh/tvm-tensorir/python/tvm/meta_schedule/task_scheduler/task_scheduler.py", line 61, in tune
    _ffi_api.TaskSchedulerTune(self)  # type: ignore # pylint: disable=no-member
  File "tvm/_ffi/_cython/./packed_func.pxi", line 323, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 257, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 246, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 163, in tvm._ffi._cy3.core.CALL
tvm._ffi.base.TVMError: Traceback (most recent call last):
  14: TVMFuncCall
        at /home/zxybazh/tvm-tensorir/src/runtime/c_runtime_api.cc:475
  13: tvm::runtime::PackedFunc::CallPacked(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
        at /home/zxybazh/tvm-tensorir/include/tvm/runtime/packed_func.h:1151
  12: std::function<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
        at /usr/include/c++/10/bits/std_function.h:622
  11: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<void (tvm::meta_schedule::TaskScheduler)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
        at /usr/include/c++/10/bits/std_function.h:291
  10: std::enable_if<std::__and_<std::is_void<void>, std::__is_invocable<tvm::runtime::TypedPackedFunc<void (tvm::meta_schedule::TaskScheduler)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}&, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*> >::value, void>::type std::__invoke_r<void, tvm::runtime::TypedPackedFunc<void (tvm::meta_schedule::TaskScheduler)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}&, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*>(tvm::runtime::TypedPackedFunc<void (tvm::meta_schedule::TaskScheduler)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
        at /usr/include/c++/10/bits/invoke.h:153
  9: void std::__invoke_impl<void, tvm::runtime::TypedPackedFunc<void (tvm::meta_schedule::TaskScheduler)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}&, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*>(std::__invoke_other, tvm::runtime::TypedPackedFunc<void (tvm::meta_schedule::TaskScheduler)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
        at /usr/include/c++/10/bits/invoke.h:60
  8: tvm::runtime::TypedPackedFunc<void (tvm::meta_schedule::TaskScheduler)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
        at /home/zxybazh/tvm-tensorir/include/tvm/runtime/packed_func.h:1480
  7: void tvm::runtime::detail::unpack_call<void, 1, tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const*, tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1} const&, tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)
        at /home/zxybazh/tvm-tensorir/include/tvm/runtime/packed_func.h:1421
  6: void tvm::runtime::detail::unpack_call_dispatcher<void, 1, 0, tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>::run<>(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const*, tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1} const&, tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)
        at /home/zxybazh/tvm-tensorir/include/tvm/runtime/packed_func.h:1382
  5: void tvm::runtime::detail::unpack_call_dispatcher<void, 0, 1, tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>::run<tvm::runtime::TVMMovableArgValueWithContext_>(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const*, tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1} const&, tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*, tvm::runtime::TVMMovableArgValueWithContext_&&)
        at /home/zxybazh/tvm-tensorir/include/tvm/runtime/packed_func.h:1410
  4: tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}::operator()(tvm::meta_schedule::TaskScheduler) const
        at /home/zxybazh/tvm-tensorir/include/tvm/runtime/registry.h:239
  3: tvm::meta_schedule::TaskSchedulerNode::Tune()
        at /home/zxybazh/tvm-tensorir/src/meta_schedule/task_scheduler/task_scheduler.cc:132
  2: tvm::meta_schedule::ReplayTraceNode::GenerateMeasureCandidates()
        at /home/zxybazh/tvm-tensorir/src/meta_schedule/search_strategy/replay_trace.cc:111
  1: tvm::meta_schedule::ReplayTraceNode::State::GenerateMeasureCandidates()
        at /home/zxybazh/tvm-tensorir/src/meta_schedule/search_strategy/replay_trace.cc:145
  0: tvm::support::parallel_for_dynamic(int, int, int, std::function<void (int, int)> const&)
        at /home/zxybazh/tvm-tensorir/src/support/parallel_for.cc:128
  22: TVMFuncCall
        at /home/zxybazh/tvm-tensorir/src/runtime/c_runtime_api.cc:475
  21: tvm::runtime::PackedFunc::CallPacked(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
        at /home/zxybazh/tvm-tensorir/include/tvm/runtime/packed_func.h:1151
  20: std::function<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
        at /usr/include/c++/10/bits/std_function.h:622
  19: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<void (tvm::meta_schedule::TaskScheduler)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
        at /usr/include/c++/10/bits/std_function.h:291
  18: std::enable_if<std::__and_<std::is_void<void>, std::__is_invocable<tvm::runtime::TypedPackedFunc<void (tvm::meta_schedule::TaskScheduler)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}&, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*> >::value, void>::type std::__invoke_r<void, tvm::runtime::TypedPackedFunc<void (tvm::meta_schedule::TaskScheduler)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}&, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*>(tvm::runtime::TypedPackedFunc<void (tvm::meta_schedule::TaskScheduler)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
        at /usr/include/c++/10/bits/invoke.h:153
  17: void std::__invoke_impl<void, tvm::runtime::TypedPackedFunc<void (tvm::meta_schedule::TaskScheduler)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}&, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*>(std::__invoke_other, tvm::runtime::TypedPackedFunc<void (tvm::meta_schedule::TaskScheduler)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
        at /usr/include/c++/10/bits/invoke.h:60
  16: tvm::runtime::TypedPackedFunc<void (tvm::meta_schedule::TaskScheduler)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
        at /home/zxybazh/tvm-tensorir/include/tvm/runtime/packed_func.h:1480
  15: void tvm::runtime::detail::unpack_call<void, 1, tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const*, tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1} const&, tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)
        at /home/zxybazh/tvm-tensorir/include/tvm/runtime/packed_func.h:1421
  14: void tvm::runtime::detail::unpack_call_dispatcher<void, 1, 0, tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>::run<>(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const*, tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1} const&, tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)
        at /home/zxybazh/tvm-tensorir/include/tvm/runtime/packed_func.h:1382
  13: void tvm::runtime::detail::unpack_call_dispatcher<void, 0, 1, tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}>::run<tvm::runtime::TVMMovableArgValueWithContext_>(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const*, tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1} const&, tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*, tvm::runtime::TVMMovableArgValueWithContext_&&)
        at /home/zxybazh/tvm-tensorir/include/tvm/runtime/packed_func.h:1410
  12: tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, , void>(void (tvm::meta_schedule::TaskSchedulerNode::*)())::{lambda(tvm::meta_schedule::TaskScheduler)#1}::operator()(tvm::meta_schedule::TaskScheduler) const
        at /home/zxybazh/tvm-tensorir/include/tvm/runtime/registry.h:239
  11: tvm::meta_schedule::TaskSchedulerNode::Tune()
        at /home/zxybazh/tvm-tensorir/src/meta_schedule/task_scheduler/task_scheduler.cc:132
  10: tvm::meta_schedule::ReplayTraceNode::GenerateMeasureCandidates()
        at /home/zxybazh/tvm-tensorir/src/meta_schedule/search_strategy/replay_trace.cc:111
  9: tvm::meta_schedule::ReplayTraceNode::State::GenerateMeasureCandidates()
        at /home/zxybazh/tvm-tensorir/src/meta_schedule/search_strategy/replay_trace.cc:145
  8: tvm::support::parallel_for_dynamic(int, int, int, std::function<void (int, int)> const&)
        at /home/zxybazh/tvm-tensorir/src/support/parallel_for.cc:123
  7: operator()
        at /home/zxybazh/tvm-tensorir/src/support/parallel_for.cc:113
  6: std::function<void (int, int)>::operator()(int, int) const
        at /usr/include/c++/10/bits/std_function.h:622
  5: tvm::meta_schedule::ReplayTraceNode::State::GenerateMeasureCandidates()::{lambda(int, int)#1}::operator()(int, int) const
        at /home/zxybazh/tvm-tensorir/src/meta_schedule/search_strategy/replay_trace.cc:139
  4: tvm::meta_schedule::ThreadedTraceApply::Apply(tvm::IRModule const&, tvm::tir::Trace const&, long*)
        at /home/zxybazh/tvm-tensorir/src/meta_schedule/search_strategy/../utils.h:316
  3: tvm::meta_schedule::RewriteCooperativeFetchNode::Apply(tvm::tir::Schedule const&)
        at /home/zxybazh/tvm-tensorir/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc:122
  2: tvm::tir::ParseThreadBinding(tvm::tir::Schedule const&, tvm::tir::Instruction const&, tvm::runtime::String)
        at /home/zxybazh/tvm-tensorir/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc:42
  1: tvm::tir::ConcreteScheduleNode::Get(tvm::tir::LoopRV const&) const
        at /home/zxybazh/tvm-tensorir/src/tir/schedule/./concrete_schedule.h:204
  0: tvm::tir::ConcreteScheduleNode::GetSRef(tvm::tir::LoopRV const&) const
        at /home/zxybazh/tvm-tensorir/src/tir/schedule/././concrete_schedule.h:272
  File "/home/zxybazh/tvm-tensorir/src/support/parallel_for.cc", line 128
RuntimeError: parallel_for_dynamic error with [16:25:09] /home/zxybazh/tvm-tensorir/src/tir/schedule/./concrete_schedule.h:272: ValueError: The loop no longer exists in the IRModule
@zxybazh
Copy link
Collaborator Author

zxybazh commented Jan 13, 2022

The tir script is as follows:

# from tvm.script import tir as T
@tvm.script.ir_module
class Module:
    @T.prim_func
    def vm_mod_fused_nn_softmax(placeholder: T.Buffer[(1, 12, 128, 128), "float32"], T_softmax_norm: T.Buffer[(1, 12, 128, 128), "float32"]) -> None:
        # function attr dict
        T.func_attr({"tir.noalias": True, "global_symbol": "main"})
        # body
        with T.block("root"):
            T.reads()
            T.writes()
            T.block_attr({"meta_schedule.unroll_explicit":512})
            T_softmax_maxelem_shared = T.alloc_buffer([1, 12, 128], dtype="float32", scope="shared")
            T_softmax_expsum_shared = T.alloc_buffer([1, 12, 128], dtype="float32", scope="shared")
            for i0, i1, i2, ax0, ax1, ax2, ax3_0 in T.grid(1, 12, 128, 1, 1, 1, 2):
                for ax3_1 in T.thread_binding(64, thread="threadIdx.x"):
                    with T.block("T_softmax_maxelem"):
                        i0_1 = T.axis.spatial(1, 0)
                        i1_1, i2_1 = T.axis.remap("SS", [i1, i2])
                        k = T.axis.reduce(128, ax3_0 * 64 + ax3_1)
                        T.reads(T_softmax_maxelem_shared[i0_1, i1_1, i2_1], placeholder[i0_1, i1_1, i2_1, k])
                        T.writes(T_softmax_maxelem_shared[i0_1, i1_1, i2_1])
                        with T.init():
                            T_softmax_maxelem_shared[i0_1, i1_1, i2_1] = T.float32(-3.4028234663852886e+38)
                        T_softmax_maxelem_shared[i0_1, i1_1, i2_1] = T.max(T_softmax_maxelem_shared[i0_1, i1_1, i2_1], placeholder[i0_1, i1_1, i2_1, k])
            for i0, i1, i2 in T.grid(1, 12, 128):
                for ax0, ax1, ax2, ax3_0 in T.grid(1, 1, 1, 8):
                    for ax3_1 in T.thread_binding(16, thread="threadIdx.x"):
                        with T.block("T_softmax_expsum"):
                            i0_2 = T.axis.spatial(1, 0)
                            i1_2, i2_2 = T.axis.remap("SS", [i1, i2])
                            k = T.axis.reduce(128, ax3_0 * 16 + ax3_1)
                            T.reads(T_softmax_expsum_shared[i0_2, i1_2, i2_2], placeholder[i0_2, i1_2, i2_2, k], T_softmax_maxelem_shared[i0_2, i1_2, i2_2])
                            T.writes(T_softmax_expsum_shared[i0_2, i1_2, i2_2])
                            with T.init():
                                T_softmax_expsum_shared[i0_2, i1_2, i2_2] = T.float32(0)
                            T_softmax_expsum_shared[i0_2, i1_2, i2_2] = T_softmax_expsum_shared[i0_2, i1_2, i2_2] + T.exp(placeholder[i0_2, i1_2, i2_2, k] - T_softmax_maxelem_shared[i0_2, i1_2, i2_2], dtype="float32")
                for i3_0 in T.serial(8):
                    for i3_1 in T.thread_binding(16, thread="threadIdx.x"):
                        with T.block("T_softmax_norm"):
                            i0_3 = T.axis.spatial(1, 0)
                            i1_3, i2_3 = T.axis.remap("SS", [i1, i2])
                            i3 = T.axis.spatial(128, i3_0 * 16 + i3_1)
                            T.reads(placeholder[i0_3, i1_3, i2_3, i3], T_softmax_maxelem_shared[i0_3, i1_3, i2_3], T_softmax_expsum_shared[i0_3, i1_3, i2_3])
                            T.writes(T_softmax_norm[i0_3, i1_3, i2_3, i3])
                            T.block_attr({"axis":3})
                            T_softmax_norm[i0_3, i1_3, i2_3, i3] = T.exp(placeholder[i0_3, i1_3, i2_3, i3] - T_softmax_maxelem_shared[i0_3, i1_3, i2_3], dtype="float32") / T_softmax_expsum_shared[i0_3, i1_3, i2_3]

@junrushao
Copy link
Owner

Looks like it's very relevant to @MasterJH5574's ongoing effort to make softmax work on CUDA.

@MasterJH5574
Copy link
Collaborator

Yes I think so. Will make my best to fix it as quick as possible!

@junrushao
Copy link
Owner

This should be fix now. @zxybazh would you like to confirm?

@zxybazh
Copy link
Collaborator Author

zxybazh commented Jan 21, 2022

Confirmed this is fixed now, thanks @MasterJH5574 !

@zxybazh zxybazh closed this as completed Jan 21, 2022
junrushao pushed a commit that referenced this issue Oct 18, 2022
* ExprVisitor/ExprMutator for relax nodes.

* Update Visitor & Mutator.

* Update Mutator.

* DataflowMutator interface.

* EwiseFMARewriter.

* Update fma rewrite and add test.

* Update test.

* Fix dataflow block dispatching.

* Construct new dataflow block with IRBuilder.

* VisitBinding return void and mutate internal IRBuilder.

* Simplify.

* Update emit dataflow output.

* Explicit memeory allocation rewrite.

* LazyIRBuilder.

* Update ExplicitMemMutator.

* Overload IRBuilder::Emit to have 3 styles.

* Update IRBuilder/IRMutator interfaces and passes.

* Add MatchShape binding to IRBuilder.

* Improve IRMutator interface; add Normalize and CanProveShapeEqual to IRBuilder

* Update EmitMatchShape.

Co-authored-by: ZihengJiang <ziheng@apache.org>
junrushao pushed a commit that referenced this issue Feb 8, 2023
* ExprVisitor/ExprMutator for relax nodes.

* Update Visitor & Mutator.

* Update Mutator.

* DataflowMutator interface.

* EwiseFMARewriter.

* Update fma rewrite and add test.

* Update test.

* Fix dataflow block dispatching.

* Construct new dataflow block with IRBuilder.

* VisitBinding return void and mutate internal IRBuilder.

* Simplify.

* Update emit dataflow output.

* Explicit memeory allocation rewrite.

* LazyIRBuilder.

* Update ExplicitMemMutator.

* Overload IRBuilder::Emit to have 3 styles.

* Update IRBuilder/IRMutator interfaces and passes.

* Add MatchShape binding to IRBuilder.

* Improve IRMutator interface; add Normalize and CanProveShapeEqual to IRBuilder

* Update EmitMatchShape.

Co-authored-by: ZihengJiang <ziheng@apache.org>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants