From 83bba000f3ab069f775b3d603e925a889a6b5ba5 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 6 Nov 2021 09:07:07 -0700 Subject: [PATCH] [TensorIR] GetProducer, GetConsumer (#506) (#9464) Co-authored-by: Siyuan Feng Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Ruihang Lai Co-authored-by: Wuwei Lin Co-authored-by: Siyuan Feng Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Ruihang Lai Co-authored-by: Wuwei Lin --- include/tvm/tir/schedule/schedule.h | 12 +++ python/tvm/tir/schedule/schedule.py | 30 +++++++ src/tir/schedule/concrete_schedule.cc | 14 ++++ src/tir/schedule/concrete_schedule.h | 2 + src/tir/schedule/primitive.h | 14 ++++ src/tir/schedule/primitive/get_block_loop.cc | 78 +++++++++++++++++++ src/tir/schedule/schedule.cc | 4 + src/tir/schedule/traced_schedule.cc | 22 ++++++ src/tir/schedule/traced_schedule.h | 2 + .../unittest/test_tir_schedule_utilities.py | 42 +++++++++- 10 files changed, 219 insertions(+), 1 deletion(-) diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 7bfe605623a2..ffd860d84cf3 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -232,6 +232,18 @@ class ScheduleNode : public runtime::Object { * \return A list of child blocks */ virtual Array GetChildBlocks(const LoopRV& loop_rv) = 0; + /*! + * \brief Get the producer of a specific block + * \param block_rv The block in the query + * \return A list of blocks, the producers of the given block + */ + virtual Array GetProducers(const BlockRV& block_rv) = 0; + /*! + * \brief Get the consumers of a specific block + * \param block_rv The block to be queried + * \return A list of blocks, the consumers of the given block + */ + virtual Array GetConsumers(const BlockRV& block_rv) = 0; /******** Schedule: Transform loops ********/ /*! * \brief Fuse a list of consecutive loops into one. It requires: diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 0790e4fd37b3..884eeb7c612c 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -415,6 +415,36 @@ def get_child_blocks(self, block_or_loop: Union[BlockRV, LoopRV]) -> List[BlockR """ return _ffi_api.ScheduleGetChildBlocks(self, block_or_loop) # type: ignore # pylint: disable=no-member + def get_producers(self, block: BlockRV) -> List[BlockRV]: + """Get the producers of a specific block + + Parameters + ---------- + block : BlockRV + The block in the query + + Returns + ------- + producers : List[BlockRV] + A list of producers of the given block + """ + return _ffi_api.ScheduleGetProducers(self, block) # type: ignore # pylint: disable=no-member + + def get_consumers(self, block: BlockRV) -> List[BlockRV]: + """Get the consumers of a specific block + + Parameters + ---------- + block : BlockRV + The block in the query + + Returns + ------- + consumers : List[BlockRV] + A list of consumers of the given block + """ + return _ffi_api.ScheduleGetConsumers(self, block) # type: ignore # pylint: disable=no-member + ########## Schedule: Transform loops ########## def fuse(self, *loops: List[LoopRV]) -> LoopRV: """Fuse a list of consecutive loops into one. It requires: diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 1c741fb22a76..4db4cd4ba1c8 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -310,6 +310,20 @@ Array ConcreteScheduleNode::GetChildBlocks(const LoopRV& loop_rv) { return result; } +Array ConcreteScheduleNode::GetProducers(const BlockRV& block_rv) { + TVM_TIR_SCHEDULE_BEGIN(); + return CreateRV(tir::GetProducers(state_, this->GetSRef(block_rv))); + TVM_TIR_SCHEDULE_END("get-producers", this->error_render_level_); + throw; +} + +Array ConcreteScheduleNode::GetConsumers(const BlockRV& block_rv) { + TVM_TIR_SCHEDULE_BEGIN(); + return CreateRV(tir::GetConsumers(state_, this->GetSRef(block_rv))); + TVM_TIR_SCHEDULE_END("get-consumers", this->error_render_level_); + throw; +} + /******** Schedule: Transform loops ********/ LoopRV ConcreteScheduleNode::Fuse(const Array& loop_rvs) { diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 199faf8afc23..035c16f506cf 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -90,6 +90,8 @@ class ConcreteScheduleNode : public ScheduleNode { Array GetLoops(const BlockRV& block_rv) override; Array GetChildBlocks(const BlockRV& block_rv) override; Array GetChildBlocks(const LoopRV& loop_rv) override; + Array GetProducers(const BlockRV& block_rv) override; + Array GetConsumers(const BlockRV& block_rv) override; /******** Schedule: Transform loops ********/ LoopRV Fuse(const Array& loop_rvs) override; Array Split(const LoopRV& loop_rv, const Array>& factors) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 4e9d00f75e8a..cc7e44d4df9e 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -105,6 +105,20 @@ Array GetLoops(const StmtSRef& block_sref); * \return A list of leaf blocks inside a specific block/loop */ Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref); +/*! + * \brief Get the producers of a specific block + * \param self The schedule state + * \param block_sref The block in the query + * \return A list of blocks, the producers of the given block + */ +Array GetProducers(const ScheduleState& self, const StmtSRef& block_sref); +/*! + * \brief Get the consumers of a specific block + * \param self The schedule state + * \param block_rv The block in the query + * \return A list of blocks, the consumers of the given block + */ +Array GetConsumers(const ScheduleState& self, const StmtSRef& block_sref); /******** Schedule: Transform loops ********/ /*! * Split a loop into a list of consecutive loops. It requires: diff --git a/src/tir/schedule/primitive/get_block_loop.cc b/src/tir/schedule/primitive/get_block_loop.cc index 2c4e23dadbbf..c044de3bc644 100644 --- a/src/tir/schedule/primitive/get_block_loop.cc +++ b/src/tir/schedule/primitive/get_block_loop.cc @@ -77,6 +77,34 @@ Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent return std::move(collector.result); } +Array GetProducers(const ScheduleState& self, const StmtSRef& block_sref) { + StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false, + /*require_stage_pipeline=*/false); + Array edges = self->GetBlockScope(scope_root)->GetDepsByDst(block_sref); + Array results; + results.reserve(edges.size()); + for (const Dependency& edge : edges) { + if (edge->kind == DepKind::kRAW || edge->kind == DepKind::kWAW) { + results.push_back(edge->src); + } + } + return results; +} + +Array GetConsumers(const ScheduleState& self, const StmtSRef& block_sref) { + StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false, + /*require_stage_pipeline=*/false); + Array edges = self->GetBlockScope(scope_root)->GetDepsBySrc(block_sref); + Array results; + results.reserve(edges.size()); + for (const Dependency& edge : edges) { + if (edge->kind == DepKind::kRAW || edge->kind == DepKind::kWAW) { + results.push_back(edge->dst); + } + } + return results; +} + /******** InstructionKind Registration ********/ struct GetBlockTraits : public UnpackedInstTraits { @@ -159,9 +187,59 @@ struct GetChildBlocksTraits : public UnpackedInstTraits { friend struct ::tvm::tir::UnpackedInstTraits; }; +struct GetProducersTraits : public UnpackedInstTraits { + static constexpr const char* kName = "GetProducers"; + static constexpr bool kIsPure = true; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static Array UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { + return sch->GetProducers(block_rv); + } + + static String UnpackedAsPython(Array outputs, String block_rv) { + PythonAPICall py("get_producers"); + py.Input("block", block_rv); + py.OutputList(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +struct GetConsumersTraits : public UnpackedInstTraits { + static constexpr const char* kName = "GetConsumers"; + static constexpr bool kIsPure = true; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static Array UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { + return sch->GetConsumers(block_rv); + } + + static String UnpackedAsPython(Array outputs, String block_rv) { + PythonAPICall py("get_consumers"); + py.Input("block", block_rv); + py.OutputList(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + TVM_REGISTER_INST_KIND_TRAITS(GetBlockTraits); TVM_REGISTER_INST_KIND_TRAITS(GetLoopsTraits); TVM_REGISTER_INST_KIND_TRAITS(GetChildBlocksTraits); +TVM_REGISTER_INST_KIND_TRAITS(GetProducersTraits); +TVM_REGISTER_INST_KIND_TRAITS(GetConsumersTraits); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index a1b582dbd787..a411e40b13b6 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -142,6 +142,10 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetChildBlocks") << ". Its value is: " << rv; throw; }); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetProducers") + .set_body_method(&ScheduleNode::GetProducers); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetConsumers") + .set_body_method(&ScheduleNode::GetConsumers); /******** (FFI) Transform loops ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method(&ScheduleNode::Fuse); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method(&ScheduleNode::Split); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index e05d187dccfd..4a028d1dad5c 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -119,6 +119,28 @@ Array TracedScheduleNode::GetChildBlocks(const LoopRV& loop_rv) { return results; } +Array TracedScheduleNode::GetProducers(const BlockRV& block_rv) { + Array results = ConcreteScheduleNode::GetProducers(block_rv); + + static const InstructionKind& kind = InstructionKind::Get("GetProducers"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // + /*inputs=*/{block_rv}, + /*attrs=*/{}, + /*outputs=*/{results.begin(), results.end()})); + return results; +} + +Array TracedScheduleNode::GetConsumers(const BlockRV& block_rv) { + Array results = ConcreteScheduleNode::GetConsumers(block_rv); + + static const InstructionKind& kind = InstructionKind::Get("GetConsumers"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // + /*inputs=*/{block_rv}, + /*attrs=*/{}, + /*outputs=*/{results.begin(), results.end()})); + return results; +} + /******** Schedule: Transform loops ********/ LoopRV TracedScheduleNode::Fuse(const Array& loop_rvs) { diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index ae726ad594e0..ac36b9ca06a9 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -56,6 +56,8 @@ class TracedScheduleNode : public ConcreteScheduleNode { Array GetLoops(const BlockRV& block_rv) final; Array GetChildBlocks(const BlockRV& block_rv) final; Array GetChildBlocks(const LoopRV& loop_rv) final; + Array GetProducers(const BlockRV& block_rv) final; + Array GetConsumers(const BlockRV& block_rv) final; /******** Schedule: Transform loops ********/ LoopRV Fuse(const Array& loop_rvs) final; Array Split(const LoopRV& loop_rv, const Array>& factor_rvs) final; diff --git a/tests/python/unittest/test_tir_schedule_utilities.py b/tests/python/unittest/test_tir_schedule_utilities.py index 1596d08a1fb4..d75bc1461c5e 100644 --- a/tests/python/unittest/test_tir_schedule_utilities.py +++ b/tests/python/unittest/test_tir_schedule_utilities.py @@ -36,13 +36,31 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: for i, j in T.grid(128, 128): with T.block("init"): vi, vj = T.axis.remap("SS", [i, j]) - C[vi, vj] = T.float32(0) + C[vi, vj] = 0.0 for k in range(0, 128): with T.block("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] +@T.prim_func +def matmul_relu(a: T.handle, b: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, (1024, 1024)) + B = T.match_buffer(b, (1024, 1024)) + C = T.alloc_buffer((1024, 1024)) + D = T.match_buffer(d, (1024, 1024)) + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j in T.grid(1024, 1024): + with T.block("relu"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = T.max(C[vi, vj], 0.0) + + # pylint: enable=no-member,invalid-name,unused-variable @@ -159,5 +177,27 @@ def test_get_child_blocks(): assert s.get(update) == s.get(blocks[1]) +def test_get_producers(): + sch = tir.Schedule(mod=matmul_relu, debug_mask="all") + block = sch.get_block("relu") + (producer,) = sch.get_producers(block) + assert tvm.ir.structural_equal( + sch.get_sref(producer).stmt, + sch.get_sref(sch.get_block("matmul")).stmt, + ) + verify_trace_roundtrip(sch, mod=matmul_relu) + + +def test_get_consumers(): + sch = tir.Schedule(mod=matmul_relu, debug_mask="all") + block = sch.get_block("matmul") + (consumer,) = sch.get_consumers(block) + assert tvm.ir.structural_equal( + sch.get_sref(consumer).stmt, + sch.get_sref(sch.get_block("relu")).stmt, + ) + verify_trace_roundtrip(sch, mod=matmul_relu) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))