From 08c911b98fa83aba286ef7ae0a89e3c6d6104876 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 9 Sep 2021 10:40:45 -0700 Subject: [PATCH] [TensorIR][M2a] Compute-At (#8943) This PR is part of the TensorIR upstreaming effort (#7527), which adds the following schedule primitives: * `compute-at` * `reverse-compute-at` Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin Co-authored-by: Siyuan Feng --- include/tvm/arith/int_set.h | 11 +- include/tvm/tir/schedule/schedule.h | 35 + include/tvm/tir/schedule/state.h | 5 - python/tvm/tir/schedule/schedule.py | 186 +++- src/arith/int_set.cc | 30 +- src/relay/transforms/fold_scale_axis.cc | 12 +- src/support/nd_int_set.h | 150 ++++ src/tir/schedule/analysis.h | 104 ++- src/tir/schedule/analysis/analysis.cc | 331 +++++-- src/tir/schedule/concrete_schedule.cc | 38 + src/tir/schedule/concrete_schedule.h | 3 + src/tir/schedule/primitive.h | 46 +- src/tir/schedule/primitive/block_annotate.cc | 2 +- .../schedule/primitive/cache_read_write.cc | 27 +- src/tir/schedule/primitive/compute_at.cc | 589 +++++++++++++ src/tir/schedule/primitive/compute_inline.cc | 115 +-- src/tir/schedule/primitive/for_kind.cc | 10 +- src/tir/schedule/primitive/get_block_loop.cc | 2 +- .../schedule/primitive/loop_transformation.cc | 2 +- src/tir/schedule/primitive/reduction.cc | 8 +- src/tir/schedule/primitive/sampling.cc | 6 +- src/tir/schedule/schedule.cc | 4 + src/tir/schedule/state.cc | 98 +-- src/tir/schedule/traced_schedule.cc | 22 + src/tir/schedule/traced_schedule.h | 3 + src/tir/schedule/transform.cc | 69 +- src/tir/schedule/transform.h | 40 + src/tir/schedule/utils.h | 17 + src/tir/transforms/compact_buffer_region.cc | 72 +- .../unittest/test_tir_schedule_compute_at.py | 832 ++++++++++++++++++ 30 files changed, 2526 insertions(+), 343 deletions(-) create mode 100644 src/support/nd_int_set.h create mode 100644 src/tir/schedule/primitive/compute_at.cc create mode 100644 tests/python/unittest/test_tir_schedule_compute_at.py diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h index b9e81c0a5533..6b350e25e167 100644 --- a/include/tvm/arith/int_set.h +++ b/include/tvm/arith/int_set.h @@ -121,17 +121,24 @@ class IntSet : public ObjectRef { * \return The result set containing the indices in the vector. */ static IntSet Vector(PrimExpr vec); + /*! + * \brief Construct a set representing a range [min, min + extent). + * \param min The minimum of the range range + * \param extent The extent of the range. + * \return The constructed set. + */ + static IntSet FromMinExtent(PrimExpr min, PrimExpr extent); /*! * \brief Construct a set representing a range. * \param r The range - * \return constructed set. + * \return The constructed set. */ static IntSet FromRange(tvm::Range r); /*! * \brief Construct a set representing a interval. * \param min The minimum value of the interval. * \param max The maximum value of the interval. - * \return constructed set. + * \return The constructed set. */ static IntSet Interval(PrimExpr min, PrimExpr max); diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 33776cbe1985..66dd5375eaf9 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -305,6 +305,41 @@ class ScheduleNode : public runtime::Object { virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope) = 0; /******** Schedule: Compute location ********/ + /*! + * \brief Move a producer block under the specific loop, and regenerate the + * loops induced by the block so that the buffer region produced by the producer block could + * cover those regions consumed by its consumer blocks under the given loop. It requires: + * 1) `block` and `loop` are under the same scope, `loop` is not the ancestor of `block` + * 2) The scope block has stage-pipeline property + * 3) The subtree of the scope block, where the given block is in, satisfies the compact dataflow + * condition. i.e. all the blocks in the scope block's subtree must be either complete block or + * reduction block + * 4) The block is not an output block with regard to the scope block, i.e. the buffers written by + * the block are allocated under the scope block + * 5) All the consumers of the block are under the given loop + * \param block_rv The block to be moved + * \param loop_rv The loop where the block to be moved under + * \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1 + */ + virtual void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, + bool preserve_unit_loops) = 0; + /*! + * \brief Move a consumer block under the specific loop, and regenerate the + * loops induced by the block so that the buffer region consumed by the consumer block could + * cover those regions produced by its producer blocks under the given loop. It requires: + * 1) `block` and `loop` are under the same scope, `loop` is not the ancestor of `block` + * 2) The scope block has stage-pipeline property + * 3) The subtree of the scope block, where the given block is in, satisfies the compact dataflow + * condition. i.e. all the blocks in the scope block's subtree must be either complete block or + * reduction block + * 4) All the producers of the block are under the given loop + * + * \param block_rv The block to be moved + * \param loop_rv The loop where the block to be moved under + * \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1 + */ + virtual void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, + bool preserve_unit_loops) = 0; /*! * \brief Inline a block into its consumer(s). It requires: * 1) The block is a complete non-root block, which only produces one buffer diff --git a/include/tvm/tir/schedule/state.h b/include/tvm/tir/schedule/state.h index 35299a3fa84b..7cd1b00c15ef 100644 --- a/include/tvm/tir/schedule/state.h +++ b/include/tvm/tir/schedule/state.h @@ -128,11 +128,6 @@ class ScheduleStateNode : public Object { */ TVM_DLL void Replace(const tir::StmtSRef& src_sref, const Stmt& tgt_stmt, const Map& block_sref_reuse); - /*! - * \brief Recalculate the `affine_binding` flag of the scope block info. - * \param scope_sref The sref to the interested scope block. - */ - TVM_DLL void UpdateAffineFlag(const StmtSRef& scope_sref); /*! * \brief Trigger the verification according to the `debug_mask` bitmask. * 1) If the bitmask `kVerifySRefTree` is on, verify the correctness of the sref tree. diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index ac09bdbb264d..7545c09b020d 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -927,6 +927,183 @@ def after_cache_write(a: ty.handle, b: ty.handle) -> None: ########## Schedule: Compute location ########## + def compute_at( + self, + block: BlockRV, + loop: LoopRV, + preserve_unit_loops: bool = False, + ) -> None: + """Compute-At. Move a producer block under the specific loop, and regenerate the + loops induced by the block so that the buffer region produced by the producer block could + cover those regions consumed by its consumer blocks under the given loop. It requires: + + 1) `block` and `loop` are under the same scope, `loop` is not the ancestor of `block` + + 2) The scope block has stage-pipeline property + + 3) The subtree of the scope block, where the given block is in, satisfies the compact + dataflow condition. i.e. all the blocks in the scope block's subtree must be either + complete block or reduction block + + 4) The block is not an output block with regard to the scope block, i.e. the buffers written + by the block are allocated under the scope block + + 5) All the consumers of the block are under the given loop + + Parameters + ---------- + block : BlockRV + The block to be moved + + loop: LoopRV + The loop where the block to be moved under + + preserve_unit_loops: bool + Whether to keep the trivial loops whose extents are 1 + + Examples + -------- + + Before compute-at, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_compute_at(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128), "float32") + B = tir.alloc_buffer((128, 128), "float32") + C = tir.match_buffer(c, (128, 128), "float32") + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + Create the schedule and do compute-at: + + .. code-block:: python + + sch = tir.Schedule(before_compute_at) + block = sch.get_block("B") + loop, _ = sch.get_loops(sch.get_block("C")) + sch.compute_at(block, loop, preserve_unit_loops=False) + print(tvm.script.asscript(sch.mod["main"])) + + After applying compute-at, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_compute_at(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128), "float32") + B = tir.alloc_buffer((128, 128), "float32") + C = tir.match_buffer(c, (128, 128), "float32") + for i in tir.serial(0, 128): + for j in tir.serial(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + for j in tir.serial(0, 128): + with tir.block([128, 128], "C") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + C[vi, vj] = B[vi, vj] + 1.0 + + """ + _ffi_api.ScheduleComputeAt( # type: ignore # pylint: disable=no-member + self, + block, + loop, + preserve_unit_loops, + ) + + def reverse_compute_at( + self, + block: BlockRV, + loop: LoopRV, + preserve_unit_loops: bool = False, + ) -> None: + """Reverse-Compute-At. Move a consumer block under the specific loop, and regenerate the + loops induced by the block so that the buffer region consumed by the consumer block could + cover those regions produced by its producer blocks under the given loop. It requires: + + 1) `block` and `loop` are under the same scope, `loop` is not the ancestor of `block` + + 2) The scope block has stage-pipeline property + + 3) The subtree of the scope block, where the given block is in, satisfies the compact + dataflow condition. i.e. all the blocks in the scope block's subtree must be either + complete block or reduction block + + 4) All the producers of the block are under the given loop + + Parameters + ---------- + block : BlockRV + The block to be moved + + loop: LoopRV + The loop where the block to be moved under + + preserve_unit_loops: bool + Whether to keep the trivial loops whose extents are 1 + + Examples + -------- + + Before reverse-compute-at, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_reverse_compute_at(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128), "float32") + B = tir.alloc_buffer((128, 128), "float32") + C = tir.match_buffer(c, (128, 128), "float32") + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + Create the schedule and do reverse-compute-at: + + .. code-block:: python + + sch = tir.Schedule(before_reverse_compute_at) + block = sch.get_block("C") + loop, _ = sch.get_loops(sch.get_block("B")) + sch.reverse_compute_at(block, loop, preserve_unit_loops=False) + print(tvm.script.asscript(sch.mod["main"])) + + After applying reverse-compute-at, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_reverse_compute_at(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128), "float32") + B = tir.alloc_buffer((128, 128), "float32") + C = tir.match_buffer(c, (128, 128), "float32") + for i in tir.serial(0, 128): + for j in tir.serial(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + for j in tir.serial(0, 128): + with tir.block([128, 128], "C") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + C[vi, vj] = B[vi, vj] + 1.0 + + """ + _ffi_api.ScheduleReverseComputeAt( # type: ignore # pylint: disable=no-member + self, + block, + loop, + preserve_unit_loops, + ) + def compute_inline(self, block: BlockRV) -> None: """Inline a block into its consumer(s). It requires: @@ -1189,10 +1366,15 @@ def after_rfactor(a: ty.handle, b: ty.handle) -> None: """ return _ffi_api.ScheduleRFactor(self, loop, factor_axis) # type: ignore # pylint: disable=no-member - ######## Schedule: Block annotatoin ######## + ######## Schedule: Block annotation ######## def storage_align( # pylint: disable=too-many-arguments - self, block: BlockRV, buffer_index: int, axis: int, factor: int, offset: int + self, + block: BlockRV, + buffer_index: int, + axis: int, + factor: int, + offset: int, ) -> None: """Set alignment requirement for specific dimension such that stride[axis] == k * factor + offset for some k. This is useful to set memory layout for more diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 7000de96dc99..a402212cf4ea 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -607,6 +607,13 @@ inline bool ProveEqual(Analyzer* analyzer, PrimExpr lhs, PrimExpr rhs) { return is_zero(analyzer->Simplify(lhs - rhs)); } +IntSet IntSet::FromMinExtent(PrimExpr min, PrimExpr extent) { + if (is_one(extent)) { + return IntSet::SinglePoint(min); + } + return IntervalSet(min, extent + min - 1); +} + IntSet IntSet::FromRange(Range r) { // must make sure it can be matched back by MatchRange. if (is_one(r->extent)) { @@ -815,19 +822,18 @@ IntSet EvalSet(Range r, const Map& dom_map) { return EvalSet(r, ConvertDomMap(dom_map)); } -Optional> EstimateRegionLowerBound(const Array& region, - const Map& var_dom, - const PrimExpr& predicate, - arith::Analyzer* analyzer) { +Optional> EstimateRegionLowerBound(const Array& region, + const Map& var_dom, + const PrimExpr& predicate, Analyzer* analyzer) { int ndim = region.size(); - Array iter_sum_exprs{nullptr}; + Array iter_sum_exprs{nullptr}; { Array affine_indices; affine_indices.reserve(ndim); for (const Range& range : region) { affine_indices.push_back(range->min); } - iter_sum_exprs = arith::DetectIterMap( + iter_sum_exprs = DetectIterMap( /*indices=*/affine_indices, /*input_iters=*/var_dom, /*predicate=*/predicate, /*require_bijective=*/false, analyzer); } @@ -835,17 +841,17 @@ Optional> EstimateRegionLowerBound(const Array& regi return NullOpt; } ICHECK_EQ(iter_sum_exprs.size(), ndim); - Array result; + Array result; result.reserve(ndim); for (int i = 0; i < ndim; ++i) { - const arith::IterSumExpr& sum_expr = iter_sum_exprs[i]; + const IterSumExpr& sum_expr = iter_sum_exprs[i]; const Range& range = region[i]; if (sum_expr->args.empty()) { - result.push_back(arith::IntSet::Interval(sum_expr->base, sum_expr->base + range->extent)); + result.push_back(IntSet::FromMinExtent(sum_expr->base, range->extent)); continue; } ICHECK_EQ(sum_expr->args.size(), 1); - const arith::IterSplitExpr& split = sum_expr->args[0]; + const IterSplitExpr& split = sum_expr->args[0]; if (!analyzer->CanProve(range->extent >= split->scale)) { return NullOpt; } @@ -853,8 +859,8 @@ Optional> EstimateRegionLowerBound(const Array& regi // IterSplitExpr: (source // lower_factor) % extent * scale // where `(source // lower_factor) % extent` is within [0, extent - 1] // Therefore, the range of `region[i]->min` is `base + [0, (extent - 1) * scale]` - result.push_back(arith::IntSet::Interval( - base, split->extent * split->scale + base + (range->extent - split->scale) - 1)); + result.push_back( + IntSet::FromMinExtent(base, split->extent * split->scale + (range->extent - split->scale))); } return result; } diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc index 7056dfe79fee..7b3f2da716aa 100644 --- a/src/relay/transforms/fold_scale_axis.cc +++ b/src/relay/transforms/fold_scale_axis.cc @@ -243,7 +243,9 @@ class ForwardPrep : private MixedModeVisitor { } } // Visitor pattern override. - void VisitExpr_(const LetNode* op) { + void VisitExpr_(const TupleGetItemNode* op) final { MixedModeVisitor::VisitExpr_(op); } + + void VisitExpr_(const LetNode* op) final { ExprVisitor::VisitExpr_(op); // do pass through condition // by assigning NullValue @@ -256,13 +258,13 @@ class ForwardPrep : private MixedModeVisitor { flist_.push_back(flazy); } - void VisitExpr_(const FunctionNode* op) { + void VisitExpr_(const FunctionNode* op) final { ExprVisitor::VisitExpr_(op); auto flazy = [this, op] { this->Update(op->body, NullValue()); }; flist_.push_back(flazy); } - void VisitExpr_(const CallNode* call) { + void VisitExpr_(const CallNode* call) final { ExprVisitor::VisitExpr_(call); // function to be lazily invoked auto flazy = [this, call]() { @@ -292,7 +294,7 @@ class ForwardPrep : private MixedModeVisitor { flist_.push_back(flazy); } - void VisitExpr_(const TupleNode* op) { + void VisitExpr_(const TupleNode* op) final { ExprVisitor::VisitExpr_(op); // do not support pass scale through tuple for now. auto flazy = [this, op]() { @@ -303,7 +305,7 @@ class ForwardPrep : private MixedModeVisitor { flist_.push_back(flazy); } - void VisitExpr_(const IfNode* op) { + void VisitExpr_(const IfNode* op) final { ExprVisitor::VisitExpr_(op); // do pass through condition // by assigning NullValue diff --git a/src/support/nd_int_set.h b/src/support/nd_int_set.h new file mode 100644 index 000000000000..ae4a0386d404 --- /dev/null +++ b/src/support/nd_int_set.h @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SUPPORT_ND_INT_SET_H_ +#define TVM_SUPPORT_ND_INT_SET_H_ + +#include +#include + +#include +#include + +namespace tvm { +namespace support { + +/*! \brief An N-dimensional integer set representing a rectangle region */ +using NDIntSet = std::vector; + +/*! + * \brief Construct an N-dimensional integer set representing a region. + * \param region The region. + * \return The constructed set. + */ +inline NDIntSet NDIntSetFromRegion(const tir::Region& region) { + NDIntSet result; + result.reserve(region.size()); + for (const Range& range : region) { + result.push_back(arith::IntSet::FromRange(range)); + } + return result; +} + +/*! + * \brief Construct an N-dimensional integer set representing a shape. + * \param shape The shape which is an array of the length of each dimension. + * \return The constructed set. + */ +inline NDIntSet NDIntSetFromShape(const Array& shape) { + PrimExpr zero = Integer(0); + NDIntSet result; + result.reserve(shape.size()); + for (const PrimExpr& extent : shape) { + result.push_back(arith::IntSet::FromMinExtent(zero, extent)); + } + return result; +} + +/*! + * \brief Construct an N-dimensional integer set representing a point. + * \param indices The N-dimensional indices representing the point. + * \return The constructed set. + */ +inline NDIntSet NDIntSetFromPoint(const Array& indices) { + NDIntSet result; + result.reserve(indices.size()); + for (const PrimExpr& index : indices) { + result.push_back(arith::IntSet::SinglePoint(index)); + } + return result; +} + +/*! + * \brief Create a union set of two sets, possibly relaxed. The RHS set will be combined into the + * LHS set. + * \param lhs The first N-dimensional integer set + * \param rhs The second N-dimensional integer set + */ +inline void NDIntSetUnionWith(NDIntSet* lhs, const NDIntSet& rhs) { + ICHECK_EQ(lhs->size(), rhs.size()); + int ndim = rhs.size(); + for (int i = 0; i < ndim; ++i) { + arith::IntSet& int_set = lhs->at(i); + int_set = arith::Union({int_set, rhs.at(i)}); + } +} + +/*! + * \brief Union a list of N-dimensional integer sets + * \param nd_int_sets The N-dimensional integer sets to be merged. + * \return The result of the union + */ +inline NDIntSet NDIntSetUnion(const std::vector& nd_int_sets) { + ICHECK(!nd_int_sets.empty()); + int n = nd_int_sets.size(); + if (n == 1) { + return nd_int_sets[0]; + } + int ndim = nd_int_sets[0].size(); + for (int i = 1; i < n; ++i) { + ICHECK_EQ(nd_int_sets[i].size(), ndim); + } + NDIntSet result; + result.reserve(ndim); + Array int_sets(n, arith::IntSet{nullptr}); + for (int dim = 0; dim < ndim; ++dim) { + for (int i = 0; i < n; ++i) { + int_sets.Set(i, nd_int_sets[i][dim]); + } + result.push_back(arith::Union(int_sets)); + } + return result; +} + +/*! + * \brief Create an empty N-dimensional integer set. + * \param ndim The number of dimensions. + * \return The constructed set. + */ +inline NDIntSet NDIntSetEmpty(int ndim) { + return std::vector(ndim, arith::IntSet::Nothing()); +} + +/*! + * \brief The N-dimensional version of EvalSet. + * \param nd_int_set The N-dimensional integer set to be evaluated. + * \param dom_map The domain of each variable. + * \return An N-dimensional integer set that can cover all the possible values of the N-dimensional + * integer set. + * \sa EvalSet + */ +inline NDIntSet NDIntSetEval( + const NDIntSet& nd_int_set, + const std::unordered_map& dom_map) { + NDIntSet ret; + ret.reserve(nd_int_set.size()); + for (const arith::IntSet& s : nd_int_set) { + ret.push_back(EvalSet(s, dom_map)); + } + return ret; +} + +} // namespace support +} // namespace tvm + +#endif // TVM_SUPPORT_ND_INT_SET_H_ diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index d4e4728abfe0..5a2f46c910b4 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -21,6 +21,7 @@ #include +#include #include #include @@ -69,11 +70,20 @@ StmtSRef GetSRefTreeRoot(const StmtSRef& sref); * \param self The schedule state * \param sref The sref whose scope is to be checked * \param require_stage_pipeline A boolean indicating whether to check stage pipeline - * \throw ScheduleError if the sref has been the root of the AST (so it has no scope root), or its - * scope root is not a stage pipeline + * \param require_subtree_compact_dataflow A boolean indicating whether to check + * subtree compact dataflow property. The scope root may have one or more subtrees rooted at + * its direct children, and this property requires all the blocks of the subtree + * that the specified sref is in to be complete block or reduction block. + * \throw ScheduleError if + * 1) the sref has been the root of the AST (so it has no scope root), or + * 2) require_stage_pipeline = true, but its scope root is not a stage pipeline + * 3) require_subtree_compact_dataflow = true, but the subtree that the sref is in doesn't satisfy + * the compact dataflow condition, i.e. a block in the subtree is neither complete block nor + * reduction block * \return The block sref to the scope root */ -StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, bool require_stage_pipeline); +StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, bool require_stage_pipeline, + bool require_subtree_compact_dataflow); /*! * \brief Checks whether the block is a complete block under the scope @@ -128,18 +138,36 @@ void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref); /*! - * \brief Check whether a subtree on SRef tree has compact data flow, and throw an exception if the - * subtree does not have compact data flow - * \details For a given StmtSRef, We say the subtree rooted from the StmtSRef has "compact data - * flow" property if: - * - the scope root of the input subtree root has stage-pipeline property, and - * - all its child blocks on SRef tree are complete blocks or reduction blocks. + * \brief Check if the block is a complete block or a reduction block under the scope * \param self The schedule state - * \param subtree_root_sref The root of the subtree to be checked in the SRef tree - * \throw ScheduleError If the subtree does not have compact data flow - * \sa IsCompleteBlock, IsReductionBlock + * \param block_sref The sref of the block to be checked + * \param scope_root_sref The scope root of the block + * \throw ScheduleError If the block is neither a complete block nor a reduction block + */ +void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root_sref); + +/*! + * \brief Check if the block is an output block, i.e. the block writes to at least a buffer that is + * not allocated under the current scope + * \param self The schedule state + * \param block_sref The block to be checked + * \param scope_root_sref The scope root of the block + * \return A boolean flag indicating if the block is an output block + */ +bool IsOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root_sref); + +/*! + * \brief Check if the block is not an output block, i.e. all the buffers the block writes to + * are allocated under the current scope + * \param self The schedule state + * \param block_sref The block to be checked + * \param scope_root_sref The scope root of the block + * \throw ScheduleError if the block is an output block */ -void CheckSRefSubtreeCompactDataFlow(const ScheduleState& self, const StmtSRef& subtree_root_sref); +void CheckNotOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root_sref); /******** Binding ********/ /*! @@ -224,6 +252,7 @@ Array GetChildBlockRealizeOnSRefTree(const StmtSRef& parent_sref); */ BlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self, const StmtSRef& parent_sref); + /*! * \brief Get the BlockRealize of the input block * \param self The schedule state @@ -232,6 +261,55 @@ BlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self */ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sref); +/******** Producer-consumer relation ********/ + +/*! + * \brief Get the producer blocks to the given block under the given scope + * \param block_sref The block whose producers are to be retrieved + * \param scope The block scope where the given block is in + * \return The producer blocks of the specified block + */ +Array GetProducers(const StmtSRef& block_sref, const BlockScope& scope); + +/*! + * \brief Get the consumer blocks to the given block under the given scope + * \param block_sref The block whose consumers are to be retrieved + * \param scope The block scope where the given block is in + * \return The consumer blocks of the specified block + */ +Array GetConsumers(const StmtSRef& block_sref, const BlockScope& scope); + +/*! + * \brief A solution to split a ordered list of subtrees into two parts, + * where producers are on the LHS and consumers are on the RHS. + * For example, subtree[0, 3) are on the LHS, and subtree[3, 6) are on the RHS. + */ +struct ProducerConsumerSplit { + /*! \brief Indicates that all producers fall into `subtrees[0, last_producer_position]` */ + int last_producer_position; + /*! \brief Indicates that all consumers fall into `subtrees[first_consumer_position, ...)` */ + int first_consumer_position; + /*! \brief The number of given producers visited in `subtrees` */ + int n_producers_visited; + /*! \brief The number of given consumers visited in `subtrees` */ + int n_consumers_visited; + /*! + * \brief Find a split among the given `subtree` + * \param state The schedule state + * \param subtrees The ordered list of subtrees to be split + * \param producer_block_srefs The producers + * \param consumer_block_srefs The consumers + * \param block2realize If not null, the corresponding BlockRealize to each block in the scope + * will be saved in this map + * \return The valid split points are (last_producer_position, first_consumer_position] + * \throw ScheduleError is not valid split is found + */ + static ProducerConsumerSplit Find( + const ScheduleState& state, const Array& subtrees, + const Array& producer_block_srefs, const Array& consumer_block_srefs, + std::unordered_map* block2realize); +}; + /******** Block-buffer relation ********/ /*! diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 3865781c5870..d14d64a4c787 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -47,22 +47,9 @@ const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_bl /******** Scope ********/ -/*! - * \brief Gets the sref to the scope root block, exclusive - * \param sref The block or loop sref to be retrieved - * \return The sref to the scope root block. NullOpt if `sref` is the root block of the IR - */ -Optional GetScopeRoot(const StmtSRef& sref) { - for (const StmtSRefNode* p = sref->parent; p != nullptr; p = p->parent) { - if (p->stmt->IsInstance()) { - return GetRef(p); - } - } - return NullOpt; -} - -StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, - bool require_stage_pipeline) { +StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, // + bool require_stage_pipeline, // + bool require_subtree_compact_dataflow) { class RootBlockError : public ScheduleError { public: explicit RootBlockError(IRModule mod) : mod_(mod) {} @@ -98,16 +85,67 @@ Definition of a scope that is a stage pipeline: Block block_; }; + class NotCompactDataFlowError : public ScheduleError { + public: + explicit NotCompactDataFlowError(IRModule mod, Stmt subtree_root, Block violate_block) + : mod_(std::move(mod)), + subtree_root_(std::move(subtree_root)), + violate_block_(std::move(violate_block)) { + ICHECK(subtree_root_->IsInstance() || subtree_root_->IsInstance()); + } + String FastErrorString() const final { + return "ScheduleError: The queried subtree root in SRef tree does not have compact dataflow, " + "because some of its child block on SRef tree is neither a complete block nor a " + "reduction block"; + } + String DetailRenderTemplate() const final { + return "The queried subtree root {0} in SRef tree does not have compact dataflow, because " + "its child block {1} on SRef tree is neither a complete block nor a reduction block"; + } + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {subtree_root_, violate_block_}; } + + IRModule mod_; + Stmt subtree_root_; + Block violate_block_; + }; + StmtSRef scope_root_sref{nullptr}; - if (Optional opt_scope_root_sref = GetScopeRoot(sref)) { - scope_root_sref = opt_scope_root_sref.value(); - } else { - throw RootBlockError(self->mod); + StmtSRef scope_root_subtree{nullptr}; + // Step 1. Find the scope root and the subtree that the given sref is in + { + const StmtSRefNode* p = sref->parent; + const StmtSRefNode* subtree = sref.get(); + for (; p != nullptr; subtree = p, p = p->parent) { + if (p->stmt->IsInstance()) { + scope_root_sref = GetRef(p); + scope_root_subtree = GetRef(subtree); + break; + } + } + if (p == nullptr) { + throw RootBlockError(self->mod); + } + } + // Step 2. Handle `require_stage_pipeline` + if (require_stage_pipeline) { + bool stage_pipeline = self->GetBlockInfo(scope_root_sref).scope->stage_pipeline; + if (stage_pipeline == false) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, scope_root_sref); + throw NotStagePipelineError(self->mod, GetRef(block)); + } } - bool stage_pipeline = self->GetBlockInfo(scope_root_sref).scope->stage_pipeline; - if (require_stage_pipeline && stage_pipeline == false) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, scope_root_sref); - throw NotStagePipelineError(self->mod, GetRef(block)); + // Step 3. Handle `require_subtree_compact_dataflow` + if (require_subtree_compact_dataflow) { + Array child_block_srefs = GetChildBlockSRefOnSRefTree(self, scope_root_sref); + for (const StmtSRef& block_sref : child_block_srefs) { + if (!IsCompleteBlock(self, block_sref, scope_root_sref) && + !IsReductionBlock(self, block_sref, scope_root_sref)) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + throw NotCompactDataFlowError(self->mod, GetRef(scope_root_subtree->stmt), + GetRef(block)); + } + } } return scope_root_sref; } @@ -174,6 +212,18 @@ int CheckCompleteBlockErrorCode(const ScheduleState& self, const StmtSRef& block return 0; } +static const char* kCompleteBlockDefinition = R"(Definition of a complete block: +1) All block vars are data parallel +2) Dominant: the block is the only writer of its output, dominating the reader of its output buffers +3) No overlap between the buffers the block reads and writes)"; + +static const char* kReductionBlockDefinition = R"(Definition of a reduction block: +1) The block has the `init` statement +2) All the block bindings are quasi-affine expressions +3) All block vars are either data parallel block vars or reduction block vars +4) Dominant: the block is the only writer of its output, dominating the reader of its output buffers +5) The reduction block vars are not used to index the output buffers)"; + bool IsCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref) { return CheckCompleteBlockErrorCode(self, block_sref, scope_root_sref) == 0; @@ -188,12 +238,8 @@ void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, String FastErrorString() const final { return "ScheduleError: Incomplete block"; } String DetailRenderTemplate() const final { std::ostringstream os; - os << "The block {0} is not a complete block - it violates condition #" << violated_cond_ - << ".\n" - << R"(Definition of a complete block: -1) All block vars are data parallel -2) Dominant: the block is the only writer of its output, dominating the reader of its output buffers -3) No overlap between the buffers the block reads and writes)"; + os << "The block {0} is not a complete block - it violates condition #" << violated_cond_; + os << ".\n" << kCompleteBlockDefinition; return os.str(); } IRModule mod() const final { return mod_; } @@ -291,14 +337,8 @@ void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, String FastErrorString() const final { return "ScheduleError: Not a reduction block"; } String DetailRenderTemplate() const final { std::ostringstream os; - os << "The block {0} is not a reduction block - it violates condition #" << violated_cond_ - << ".\n" - << R"(Definition of a reduction block: -1) The block has the `init` statement -2) All the block bindings are quasi-affine expressions -3) All block vars are either data parallel block vars or reduction block vars -4) Dominant: the block is the only writer of its output, dominating the reader of its output buffers -5) The reduction block vars are not used to index the output buffers)"; + os << "The block {0} is not a reduction block - it violates condition #" << violated_cond_; + os << ".\n" << kReductionBlockDefinition; return os.str(); } IRModule mod() const final { return mod_; } @@ -315,41 +355,89 @@ void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, } } -void CheckSRefSubtreeCompactDataFlow(const ScheduleState& self, const StmtSRef& subtree_root_sref) { - class NotCompactDataFlowError : public ScheduleError { +void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root_sref) { + class NotCompleteOrReductionBlockError : public ScheduleError { public: - explicit NotCompactDataFlowError(IRModule mod, Stmt subtree_root, Block violate_block) - : mod_(std::move(mod)), - subtree_root_(std::move(subtree_root)), - violate_block_(std::move(violate_block)) { - ICHECK(subtree_root_->IsInstance() || subtree_root_->IsInstance()); - } + explicit NotCompleteOrReductionBlockError(IRModule mod, Block block, + int complete_block_error_code, + int reduction_block_error_code) + : mod_(mod), + block_(block), + complete_block_error_code_(complete_block_error_code), + reduction_block_error_code_(reduction_block_error_code) {} + String FastErrorString() const final { - return "ScheduleError: The queried subtree root in SRef tree does not have compact data " - "flow, because some of its child block on SRef tree is neither a complete block nor a " - "reduction block"; + return "ScheduleError: Not a complete or reduction block"; } String DetailRenderTemplate() const final { - return "The queried subtree root {0} in SRef tree does not have compact data flow, because " - "its child block {1} on SRef tree is neither a complete block nor a reduction block"; + std::ostringstream os; + os << "The block {0} is not a complete block - it violates condition #" + << complete_block_error_code_; + os << ".\n" << kCompleteBlockDefinition; + os << "\nThe block is not a reduction block either - it violates condition #" + << reduction_block_error_code_; + os << ".\n" << kReductionBlockDefinition; + return os.str(); } IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {subtree_root_, violate_block_}; } + Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; - Stmt subtree_root_; - Block violate_block_; + Block block_; + int complete_block_error_code_; + int reduction_block_error_code_; }; - StmtSRef scope_root = GetScopeRoot(self, subtree_root_sref, /*require_stage_pipeline=*/true); - Array child_blocks = GetChildBlockSRefOnSRefTree(self, scope_root); - for (const StmtSRef& block : child_blocks) { - if (!IsCompleteBlock(self, block, scope_root) && !IsReductionBlock(self, block, scope_root)) { - const BlockNode* violate_block = TVM_SREF_TO_BLOCK(violate_block, block); - throw NotCompactDataFlowError(self->mod, GetRef(subtree_root_sref->stmt), - GetRef(violate_block)); + int complete_block_error_code = CheckCompleteBlockErrorCode(self, block_sref, scope_root_sref); + if (complete_block_error_code == 0) { + return; + } + int reduction_block_error_code = CheckReductionBlockErrorCode(self, block_sref, scope_root_sref); + if (reduction_block_error_code == 0) { + return; + } + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + throw NotCompleteOrReductionBlockError(self->mod, GetRef(block), complete_block_error_code, + reduction_block_error_code); +} + +bool IsOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root_sref) { + const BlockNode* scope_root = TVM_SREF_TO_BLOCK(scope_root, scope_root_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + std::unordered_set scope_allocated; + scope_allocated.reserve(scope_root->alloc_buffers.size()); + for (const Buffer& buffer : scope_root->alloc_buffers) { + scope_allocated.insert(buffer.get()); + } + for (const BufferRegion& buffer_region : block->writes) { + if (!scope_allocated.count(buffer_region->buffer.get())) { + return true; } } + return false; +} + +void CheckNotOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root_sref) { + class OutputBlockError : public ScheduleError { + public: + explicit OutputBlockError(IRModule mod, Block block) : mod_(mod), block_(block) {} + String FastErrorString() const final { + return "ScheduleError: Cannot operate on an output block"; + } + String DetailRenderTemplate() const final { return "The block {0} is an output block"; } + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + IRModule mod_; + Block block_; + }; + if (IsOutputBlock(self, block_sref, scope_root_sref)) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + throw OutputBlockError(self->mod, GetRef(block)); + } } /******** Binding ********/ @@ -586,6 +674,125 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr } } +/******** Producer-consumer relation ********/ + +Array GetProducers(const StmtSRef& block_sref, const BlockScope& scope) { + Array deps = scope->GetDepsByDst(block_sref); + Array result; + result.reserve(deps.size()); + for (const Dependency& dep : deps) { + result.push_back(dep->src); + } + return result; +} + +Array GetConsumers(const StmtSRef& block_sref, const BlockScope& scope) { + Array deps = scope->GetDepsBySrc(block_sref); + Array result; + result.reserve(deps.size()); + for (const Dependency& dep : deps) { + result.push_back(dep->dst); + } + return result; +} + +ProducerConsumerSplit ProducerConsumerSplit::Find( + const ScheduleState& self, const Array& subtrees, + const Array& producer_block_srefs, const Array& consumer_block_srefs, + std::unordered_map* block2realize) { + class InsertionPointNotFoundError : public ScheduleError { + public: + explicit InsertionPointNotFoundError(IRModule mod, int last_producer_position, + int first_consumer_position) + : mod_(mod), + last_producer_position_(last_producer_position), + first_consumer_position_(first_consumer_position) {} + + String FastErrorString() const final { + return "ScheduleError: Cannot find the insertion point that satisfies the producer-consumer " + "constraint"; + } + + String DetailRenderTemplate() const final { + return "Cannot find the insertion point that satisfies the producer-consumer constraint. In " + "0-based indexing, the last producer appears in subtree " + + std::to_string(last_producer_position_) + + ", and the first consumer appears in subtree " + + std::to_string(first_consumer_position_); + } + + IRModule mod() const final { return mod_; } + + Array LocationsOfInterest() const final { return {}; } + + private: + IRModule mod_; + int last_producer_position_; + int first_consumer_position_; + }; + + class Finder : public StmtVisitor { + public: + void VisitStmt_(const BlockRealizeNode* realize) final { + const BlockNode* block = realize->block.get(); + if (block2realize_) { + block2realize_->emplace(block, realize); + } + if (producer_blocks_.count(block)) { + ++this->n_producers_visited_; + } + if (consumer_blocks_.count(block)) { + ++this->n_consumers_visited_; + } + } + + std::unordered_map* block2realize_; + std::unordered_set producer_blocks_; + std::unordered_set consumer_blocks_; + int n_producers_visited_ = 0; + int n_consumers_visited_ = 0; + }; + + Finder finder; + finder.block2realize_ = block2realize; + // Set up the lookup table for producers + finder.producer_blocks_.reserve(producer_block_srefs.size()); + for (const StmtSRef& block_sref : producer_block_srefs) { + finder.producer_blocks_.insert(block_sref->stmt); + } + // Set up the lookup table for consumers + finder.consumer_blocks_.reserve(consumer_block_srefs.size()); + for (const StmtSRef& block_sref : consumer_block_srefs) { + finder.consumer_blocks_.insert(block_sref->stmt); + } + // Visit the subtrees + int n = subtrees.size(); + int last_producer_position = -1; + int first_consumer_position = n; + for (int i = 0; i < n; ++i) { + int n_producers_visited_before = finder.n_producers_visited_; + int n_consumers_visited_before = finder.n_consumers_visited_; + finder(subtrees[i]); + // Check if the subtree contains at least a producer + if (finder.n_producers_visited_ != n_producers_visited_before) { + last_producer_position = i; + } + // Check if the subtree contains at least a consumer + if (finder.n_consumers_visited_ != n_consumers_visited_before) { + if (first_consumer_position == n) { + first_consumer_position = i; + } + } + } + if (last_producer_position >= first_consumer_position) { + throw InsertionPointNotFoundError(self->mod, last_producer_position, first_consumer_position); + } + return ProducerConsumerSplit{last_producer_position, // + first_consumer_position, // + finder.n_producers_visited_, // + finder.n_consumers_visited_}; +} + /******** Block-buffer relation ********/ Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, bool is_write) { @@ -957,11 +1164,13 @@ bool FromIdentityCombiner(const PrimExpr& identity, const BufferStore& combiner, } /******** SRef Tree Related ********/ + StmtSRef GetSRefTreeRoot(const StmtSRef& sref) { const StmtSRefNode* p = sref.get(); for (; p->parent != nullptr; p = p->parent) { } return GetRef(p); } + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 86223e11c196..07af73ebabb6 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -439,6 +439,44 @@ BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buff /******** Schedule: Compute location ********/ +void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, + bool preserve_unit_loops) { + static StmtSRef inline_mark = StmtSRef::InlineMark(); + static StmtSRef root_mark = StmtSRef::RootMark(); + StmtSRef loop_sref = this->GetSRef(loop_rv); + if (loop_sref.same_as(root_mark)) { + // do nothing + } else if (loop_sref.same_as(inline_mark)) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::ComputeInline(state_, this->GetSRef(block_rv)); + TVM_TIR_SCHEDULE_END("compute-at", this->error_render_level_); + } else { + TVM_TIR_SCHEDULE_BEGIN(); + tir::ComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops); + TVM_TIR_SCHEDULE_END("compute-at", this->error_render_level_); + } + this->state_->DebugVerify(); +} + +void ConcreteScheduleNode::ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, + bool preserve_unit_loops) { + static StmtSRef inline_mark = StmtSRef::InlineMark(); + static StmtSRef root_mark = StmtSRef::RootMark(); + StmtSRef loop_sref = this->GetSRef(loop_rv); + if (loop_sref.same_as(root_mark)) { + // do nothing + } else if (loop_sref.same_as(inline_mark)) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::ReverseComputeInline(state_, this->GetSRef(block_rv)); + TVM_TIR_SCHEDULE_END("reverse-compute-at", this->error_render_level_); + } else { + TVM_TIR_SCHEDULE_BEGIN(); + tir::ReverseComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops); + TVM_TIR_SCHEDULE_END("reverse-compute-at", this->error_render_level_); + } + this->state_->DebugVerify(); +} + void ConcreteScheduleNode::ComputeInline(const BlockRV& block_rv) { TVM_TIR_SCHEDULE_BEGIN(); tir::ComputeInline(state_, this->GetSRef(block_rv)); diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index e756f9da41b2..c9a9402832f2 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -108,6 +108,9 @@ class ConcreteScheduleNode : public ScheduleNode { BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope) override; /******** Schedule: Compute location ********/ + void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) override; + void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, + bool preserve_unit_loops) override; void ComputeInline(const BlockRV& block) override; void ReverseComputeInline(const BlockRV& block) override; /******** Schedule: Reduction ********/ diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 412611adf76d..05eefaca8a11 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -160,6 +160,44 @@ TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int r TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, const String& storage_scope); /******** Schedule: Compute location ********/ +/*! + * \brief Move a producer block under the specific loop, and regenerate the + * loops induced by the block so that the buffer region produced by the producer block could + * cover those regions consumed by its consumer blocks under the given loop. It requires: + * 1) `block` and `loop` are under the same scope, `loop` is not the ancestor of `block` + * 2) The scope block has stage-pipeline property + * 3) The subtree of the scope block, where the given block is in, satisfies the compact dataflow + * condition. i.e. all the blocks in the scope block's subtree must be either complete block or + * reduction block + * 4) The block is not an output block with regard to the scope block, i.e. the buffers written by + * the block are allocated under the scope block + * 5) All the consumers of the block are under the given loop + * + * \param self The schedule state + * \param block_sref The block to be moved + * \param loop_sref The loop where the block to be moved to + * \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1 + */ +TVM_DLL void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref, + bool preserve_unit_loops); +/*! + * \brief Move a consumer block under the specific loop, and regenerate the + * loops induced by the block so that the buffer region consumed by the consumer block could + * cover those regions produced by its producer blocks under the given loop. It requires: + * 1) `block` and `loop` are under the same scope, `loop` is not the ancestor of `block` + * 2) The scope block has stage-pipeline property + * 3) The subtree of the scope block, where the given block is in, satisfies the compact dataflow + * condition. i.e. all the blocks in the scope block's subtree must be either complete block or + * reduction block + * 4) All the producers of the block are under the given loop + * + * \param self The schedule state + * \param block_sref The block to be moved + * \param loop_sref The loop where the block to be moved to + * \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1 + */ +TVM_DLL void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref, + const StmtSRef& loop_sref, bool preserve_unit_loops); /*! * \brief Inline a block into its consumer(s). It requires: * 1) The block is a complete non-root block, which only produces one buffer @@ -199,6 +237,10 @@ TVM_DLL void ReverseComputeInline(ScheduleState self, const StmtSRef& block_sref */ TVM_DLL StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int factor_axis); /******** Schedule: Block annotation ********/ +/*! \brief The quad used by StorageAlign for (buffer_idx, axis, factor, offset) */ +using StorageAlignTuple = Array; +/*! \brief A list of StorageAlignTuple, used by StorageAlign */ +using StorageAlignAnnotation = Array; /*! * \brief Set alignment requirement for specific dimension such that * stride[axis] == k * factor + offset for some k. This is useful to set memory layout for @@ -214,10 +256,6 @@ TVM_DLL StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int fact TVM_DLL void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_index, int axis, int factor, int offset); -/******** Annotation types for StorageAlign ********/ -using StorageAlignTuple = Array; // (buffer_idx, axis, factor, offset) -using StorageAlignAnnotation = Array; // unordered array of StorageAlignTuple - /******** Schedule: Blockize & Tensorize ********/ /******** Schedule: Annotation ********/ /******** Schedule: Misc ********/ diff --git a/src/tir/schedule/primitive/block_annotate.cc b/src/tir/schedule/primitive/block_annotate.cc index 06f7ac3c1bc2..a96c8ca09f32 100644 --- a/src/tir/schedule/primitive/block_annotate.cc +++ b/src/tir/schedule/primitive/block_annotate.cc @@ -270,7 +270,7 @@ void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_ind self->Replace(block_sref, new_block, {{GetRef(block_ptr), new_block}}); } -/******** Instruction Registration ********/ +/******** InstructionKind Registration ********/ struct StorageAlignTraits : public UnpackedInstTraits { static constexpr const char* kName = "StorageAlign"; diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index df54c9652ece..8628cc3c0791 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -146,7 +146,7 @@ Block MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info, /*annotations=*/{}); // Create the block realize node Stmt body = BlockRealize(/*values=*/iter_values, - /*predicate=*/Bool(true), + /*predicate=*/const_true(), /*block=*/block); // Create surrounding loops for (size_t i = loop_vars.size(); i >= 1; --i) { @@ -160,6 +160,21 @@ Block MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info, return block; } +/*! + * \brief Recalculate the `affine_binding` flag of a specifc block + * \param block_sref The sref to the specific block + */ +bool CalculateAffineFlag(const ScheduleState& self, const StmtSRef& block_sref) { + if (block_sref->parent == nullptr) { + return true; + } + arith::Analyzer analyzer; + StmtSRef parent_sref = GetRef(block_sref->parent); + return IsAffineBinding(/*realize=*/GetBlockRealize(self, block_sref), + /*loop_var_ranges=*/LoopDomainOfSRefTreePath(parent_sref), + /*analyzer=*/&analyzer); +} + /*! * \brief Insert the cache_read/cache_write stage into the specific position * \param stmt A sequence of statements or a single statement that the new stage is inserted in @@ -613,7 +628,8 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); Buffer read_buffer = GetNthAccessBuffer(self, GetRef(block), read_buffer_index, /*is_write=*/false); - StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); + StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true, + /*require_subtree_compact_dataflow=*/false); const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref); // Step 2. Creat CacheStageInfo @@ -657,8 +673,8 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff // Step 5. Replacing and updating flags. self->Replace(scope_sref, new_scope, info.block_reuse); StmtSRef result_block_sref = self->stmt2ref.at(cache_read_stage.get()); - self->UpdateAffineFlag(result_block_sref); BlockInfo& block_info = self->block_info[result_block_sref]; + block_info.affine_binding = CalculateAffineFlag(self, result_block_sref); block_info.region_cover = true; block_info.scope->stage_pipeline = true; return result_block_sref; @@ -680,7 +696,8 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); Buffer write_buffer = GetNthAccessBuffer(self, GetRef(block), write_buffer_index, /*is_write=*/true); - StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); + StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true, + /*require_subtree_compact_dataflow=*/false); // Step 2. Creating CacheStageInfo CacheStageInfo info; @@ -710,8 +727,8 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu // Step 6. Replacing and updating flags. self->Replace(scope_sref, new_scope, info.block_reuse); StmtSRef result_block_sref = self->stmt2ref.at(cache_write_stage.get()); - self->UpdateAffineFlag(result_block_sref); BlockInfo& block_info = self->block_info[result_block_sref]; + block_info.affine_binding = CalculateAffineFlag(self, result_block_sref); block_info.region_cover = true; block_info.scope->stage_pipeline = true; return result_block_sref; diff --git a/src/tir/schedule/primitive/compute_at.cc b/src/tir/schedule/primitive/compute_at.cc new file mode 100644 index 000000000000..0dae50abc05e --- /dev/null +++ b/src/tir/schedule/primitive/compute_at.cc @@ -0,0 +1,589 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +using support::NDIntSet; + +/******** Error Classes ********/ + +/*! + * \brief An error raised when not all required blocks are under the given loop. + * \tparam is_consumer Indicates if all the required blocks are consumers or producers + */ +template +class NotAllRequiredBlocksAreVisitedError : public ScheduleError { + public: + explicit NotAllRequiredBlocksAreVisitedError(IRModule mod, int num_not_visited, + const Array& required) + : mod_(mod), num_not_visited_(num_not_visited) { + required_.reserve(required.size()); + for (const StmtSRef& block_sref : required) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + required_.push_back(GetRef(block)); + } + } + + String FastErrorString() const final { + return "ScheduleError: Not all required blocks are under the loop scope"; + } + + String DetailRenderTemplate() const final { + String relation = is_consumer ? "consumer(s)" : "producer(s)"; + std::ostringstream os; + os << "The primitive requires all the " << relation + << " of the given block to be present under the target loop. However, there are " + << num_not_visited_ << " " << relation << " not satisfying the constraint. List of the " + << relation << ":"; + for (int i = 0, n = required_.size(); i < n; ++i) { + os << "{" << i << "}"; + } + return os.str(); + } + + IRModule mod() const final { return mod_; } + + Array LocationsOfInterest() const final { + return {required_.begin(), required_.end()}; + } + + private: + IRModule mod_; + int num_not_visited_; + Array required_; +}; + +/*! + * \brief An error raised when the given block is not in the same block scope as the given loop, + * or the given loop is the ancestor of the given block. + */ +class NotInSameScopeError : public ScheduleError { + public: + static void CheckAndBindLoopDomain(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& loop_sref, const StmtSRef& scope_root_sref, + arith::Analyzer* analyzer) { + for (const StmtSRefNode* p = loop_sref.get();; p = p->parent) { + if (const ForNode* loop = p->StmtAs()) { + analyzer->Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + } else if (p != scope_root_sref.get()) { + throw NotInSameScopeError(self->mod, block_sref, loop_sref); + } else { + break; + } + } + for (const StmtSRefNode* p = block_sref->parent; p != scope_root_sref.get(); p = p->parent) { + if (p == loop_sref.get()) { + throw NotInSameScopeError(self->mod, block_sref, loop_sref); + } + } + } + + String FastErrorString() const final { + return "ScheduleError: Expected the block and loop to be under the same block scope, and loop " + "not to be the ancestor of block"; + } + String DetailRenderTemplate() const final { + return "ScheduleError: Expected the block {0} and loop {1} to be under the same block scope, " + "and loop not to be the ancestor of block"; + } + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_, loop_}; } + + private: + explicit NotInSameScopeError(IRModule mod, const StmtSRef& block_sref, const StmtSRef& loop_sref) + : mod_(mod), + block_(GetRef(block_sref->StmtAs())), + loop_(GetRef(loop_sref->StmtAs())) {} + + IRModule mod_; + Block block_; + For loop_; +}; + +/******** Helper Functions/Classes ********/ + +/*! + * \brief Find a point where the block can be inserted under the loop + * \tparam require_all_producers_visited Requires all producer blocks to be present under the loop + * \tparam require_all_consumers_visited Requires all consumer blocks to be present under the loop + * \param self The schedule state + * \param subtrees The subtrees under the loop, among which the insertion points are sought + * \param producer_srefs The producer blocks + * \param consumer_srefs The consumer blocks + * \param block2realize A cache that maps a block to its realize + * \return The last position the new block can be inserted onto, and the + * producer-consumer-relationship is still satisfied. + * \throws ScheduleError if there is no such insertion point found + */ +template +int FindInsertionPoint( + const ScheduleState& self, const Array& subtrees, const Array& producer_srefs, + const Array& consumer_srefs, + std::unordered_map* block2realize) { + ProducerConsumerSplit split = + ProducerConsumerSplit::Find(self, subtrees, producer_srefs, consumer_srefs, block2realize); + // Step 1. Check if all the producers are visited in the subtrees, if required to + if (require_all_producers_visited) { + int num_producers = producer_srefs.size(); + if (split.n_producers_visited < num_producers) { + throw NotAllRequiredBlocksAreVisitedError( + self->mod, num_producers - split.n_producers_visited, producer_srefs); + } + } + // Step 2. Check if all the consumers are visited in the subtrees, if required to + if (require_all_consumers_visited) { + int num_consumers = consumer_srefs.size(); + if (split.n_consumers_visited < num_consumers) { + throw NotAllRequiredBlocksAreVisitedError( + self->mod, num_consumers - split.n_consumers_visited, consumer_srefs); + } + } + // Step 3. Check if there is at least one index of the position can be inserted into + // The valid indices are: (last_producer_position, first_consumer_position] + ICHECK(split.last_producer_position < split.first_consumer_position); + // Step 4. Return the last valid insertion point + return split.first_consumer_position; +} + +/*! + * \brief A helper to reconstruct the block scope where the given block is moved under the given + * loop, and the given block's induced loop nest is regenerated to satisfy the required region. + */ +class ScopeReconstructor : private StmtMutator { + public: + explicit ScopeReconstructor(Block scope_root, Block block, For loop) + : scope_root_(scope_root), block_(block), loop_(loop) {} + + using StmtMutator::operator(); + + /*! + * \brief Create the loop nest on top of the block, induced by the given block var's domain + * \param insert_position The position among the subtrees where the block and its induced loop + * nest is inserted + * \param iter_doms The domain of each block var + * \param preserve_unit_loops Whether to generate unit loops where the loop extent is 1 + */ + void MakeNewLoop(int insert_position, std::vector iter_doms, bool preserve_unit_loops) { + int n_iters = iter_doms.size(); + Array loop_vars; + Array loop_extents; + Array iter_values; + loop_vars.reserve(n_iters); + loop_extents.reserve(n_iters); + iter_values.reserve(n_iters); + for (int i = 0; i < n_iters; ++i) { + const Range& iter_dom = iter_doms[i]; + if (preserve_unit_loops || !is_one(iter_dom->extent)) { + Var var("ax" + std::to_string(loop_vars.size()), DataType::Int(32)); + loop_vars.push_back(var); + loop_extents.push_back(iter_dom->extent); + iter_values.push_back(iter_dom->min + var); + } else { + iter_values.push_back(iter_dom->min); + } + } + this->new_block_realize_ = + BlockRealize(std::move(iter_values), const_true(), std::move(block_)); + Stmt new_subtree = this->new_block_realize_; + for (int i = static_cast(loop_vars.size()) - 1; i >= 0; --i) { + const Var& loop_var = loop_vars[i]; + const PrimExpr& loop_extent = loop_extents[i]; + new_subtree = For(/*loop_var=*/loop_var, + /*min=*/Integer(0), + /*extent=*/loop_extent, + /*ForKind=*/ForKind::kSerial, + /*body=*/std::move(new_subtree)); + } + Array subtrees = AsArray(loop_->body); + subtrees.insert(subtrees.begin() + insert_position, std::move(new_subtree)); + ObjectPtr new_loop = make_object(*loop_.get()); + new_loop->body = SeqStmt(std::move(subtrees)); + this->new_loop_ = For(std::move(new_loop)); + } + + private: + Stmt VisitStmt_(const BlockNode* block) final { + if (block != scope_root_.get()) { + return GetRef(block); + } + if (block == rm_src_stmt_.get()) { + block = TVM_TYPE_AS(block, rm_tgt_stmt_, BlockNode); + } + return StmtMutator::VisitStmt_(block); + } + + Stmt VisitStmt_(const ForNode* loop) final { + if (loop == rm_src_stmt_.get()) { + loop = TVM_TYPE_AS(loop, rm_tgt_stmt_, ForNode); + } + if (loop == loop_.get()) { + return new_loop_; + } + return StmtMutator::VisitStmt_(loop); + } + + public: + /*! \brief The root block of the block scope */ + Block scope_root_; + /*! \brief The given block to be moved */ + Block block_; + /*! \brief The given loop the block and its loop nest to be put under */ + For loop_; + /*! \brief The new loop to replace the original loop */ + For new_loop_{nullptr}; + /*! \brief The new block realize to the moved block */ + BlockRealize new_block_realize_{nullptr}; + /*! \brief The plan to remove the given block by replacing this loop/block in the AST */ + Stmt rm_src_stmt_{nullptr}; + /*! \brief The plan to remove the given block by replacing to this loop/block in the AST */ + Stmt rm_tgt_stmt_{nullptr}; +}; + +/*! + * \brief Calculate a list of accessed buffer regions under a path of loops + * \tparam relax_storage_scope Whether to relax beyond the path according to the storage and + * execution scope + * \param binding The block binding, used to unbind the buffer regions + * \param buffer_regions The buffer regions to be calculated + * \param relax_path_low_inclusive The lowest point in the loop path, inclusive + * \param relax_path_high_exclusive The highest point in the loop path, exclusive + * \param relaxed Where the calculation result is stored + */ +template +void RelaxBufferRegions(const Map& binding, + const Array& buffer_regions, + const StmtSRef& relax_path_low_inclusive, + const StmtSRef& relax_path_high_exclusive, + std::unordered_map>* relaxed) { + runtime::StorageScope global_scope{runtime::StorageRank::kGlobal, ""}; + // We cache the variable domains + runtime::StorageRank previous_rank = runtime::StorageRank::kGlobal; + Optional> var_dom = NullOpt; + // Enumerate every buffer region + for (const BufferRegion& buffer_region : buffer_regions) { + const Buffer& buffer = buffer_region->buffer; + const Array& region = buffer_region->region; + // Skip the buffer regions we are not interested in + auto it = relaxed->find(buffer.get()); + if (it == relaxed->end()) { + continue; + } + std::vector& relaxed_regions = it->second; + // Check and update the cached `var_dom` + runtime::StorageScope scope = + relax_storage_scope ? runtime::StorageScope::Create(buffer.scope()) : global_scope; + runtime::StorageRank rank = scope.rank; + if (rank != previous_rank || !var_dom.defined()) { + previous_rank = rank; + var_dom = AsIntSet(LoopDomainOfSRefTreePath( + /*low_inclusive=*/relax_path_low_inclusive, + /*high_exclusive=*/relax_path_high_exclusive, + /*extra_relax_scope=*/scope)); + } + // Relax the region + Array relaxed_region = + arith::EvalSet(Substitute(region, binding), var_dom.value()); + relaxed_regions.push_back({relaxed_region.begin(), relaxed_region.end()}); + } +} + +/*! + * \brief Calculate the iteration domain of a provided integer set to fully cover the required + * domain + * \param provided The provided integer set to cover the required domain + * \param required The required domain to be covered + * \param iter_doms The result iteration domains to be updated + * \param analyzer The arithmetic analyzer + */ +void UpdateBlockVarDomain(const arith::IntSet& provided, const arith::IntSet& required, + std::unordered_map>* iter_doms, + arith::Analyzer* analyzer) { + PrimExpr provided_min = analyzer->Simplify(provided.min()); + PrimExpr provided_extent = analyzer->Simplify(provided.max() - provided_min + 1); + PrimExpr required_min = analyzer->Simplify(required.min()); + PrimExpr required_extent = analyzer->Simplify(required.max() - required_min + 1); + PrimExpr dom_min{nullptr}, dom_extent{nullptr}; + Var dom_var{ObjectPtr{nullptr}}; + arith::PVar p_v; + arith::PVar p_e; + if ((p_v * p_e).Match(provided_min) || (p_e * p_v).Match(provided_min)) { + PrimExpr e = p_e.Eval(); + dom_var = p_v.Eval(); + dom_min = floordiv(required_min, e); + dom_extent = analyzer->Simplify((required_extent + e - 1) / e); + } else if (analyzer->CanProveEqual(provided_extent, 1) && p_v.Match(provided_min)) { + dom_var = p_v.Eval(); + dom_min = required_min; + dom_extent = required_extent; + } else { + ICHECK(false) << "ValueError: BufferRegion pattern match failed"; + } + auto it = iter_doms->find(dom_var.get()); + if (it != iter_doms->end()) { + std::vector& doms = it->second; + doms.push_back(arith::IntSet::FromMinExtent(dom_min, dom_extent)); + } else { + ICHECK(analyzer->CanProveEqual(provided_min, required_min)); + ICHECK(analyzer->CanProveEqual(provided_extent, required_extent)); + } +} + +/*! + * \brief Calculate the domain of block vars to cover the required region + * \param iter_vars The list of block vars to cover the required region + * \param provided_regions The region provided by one iteration instance of the block vars + * \param required_regions The region required to be covered + * \param analyzer The arithmetic analyzer + * \return A list of iteration domain corresponding to the given list of block vars + */ +std::vector CalculateBlockVarDomain( + const Array& iter_vars, + std::unordered_map> provided_regions, + std::unordered_map> required_regions, + arith::Analyzer* analyzer) { + int n_iters = iter_vars.size(); + // Step 1. Construct the mapping from block var to their iteration domain (initialized to empty) + std::unordered_map> iter_doms; + iter_doms.reserve(n_iters); + for (const IterVar& iter_var : iter_vars) { + iter_doms[iter_var->var.get()] = {}; + } + // Step 2. For each buffer, update the domain according to the provided and required regions + for (const auto& kv : provided_regions) { + const BufferNode* buffer = kv.first; + const std::vector& many_provided_regions = kv.second; + // Calculate `provided_region` and `required_region` + auto it = required_regions.find(buffer); + if (it == required_regions.end() || it->second.empty()) { + continue; + } + NDIntSet required_region = support::NDIntSetUnion(it->second); + NDIntSet provided_region = support::NDIntSetUnion(many_provided_regions); + ICHECK_EQ(provided_region.size(), buffer->shape.size()); + ICHECK_EQ(required_region.size(), buffer->shape.size()); + // For each dimension, update the iteration domain + int ndim = buffer->shape.size(); + for (int i = 0; i < ndim; ++i) { + arith::IntSet provided = provided_region[i]; + arith::IntSet required = required_region[i]; + required = arith::Intersect( + {std::move(required), arith::IntSet::FromMinExtent(Integer(0), buffer->shape[i])}); + UpdateBlockVarDomain(provided, required, &iter_doms, analyzer); + } + } + // Union the iter var domains, put them in the same order of block vars, and return + std::vector result; + result.reserve(n_iters); + for (const IterVar& iter_var : iter_vars) { + const std::vector& doms = iter_doms.at(iter_var->var.get()); + arith::IntSet dom = arith::IntSet::FromRange(iter_var->dom); + if (!doms.empty()) { + dom = arith::Intersect({std::move(dom), arith::Union(doms)}); + } + PrimExpr min = analyzer->Simplify(dom.min()); + PrimExpr extent = analyzer->Simplify(dom.max() - min + 1); + result.push_back(Range::FromMinExtent(min, extent)); + } + return result; +} + +/*! + * \brief Calculate the provided region of the given block by one single of its execution instance, + * as well as the required buffer regions relaxed to the given loop + * \tparam is_compute_at Indicates if the operation is compute-at or reverse-compute-at + * \param block The given block that provides buffer regions + * \param loop_sref The given loop under which the block is going to be moved to + * \param block2realize Maps a block to its corresponding BlockRealize + * \param producer_srefs The producers of the given block + * \param consumer_srefs The consumers of the given block + * \param provided_regions The calculated regions provided by the block + * \param required_regions The calculated regions required by its consumers (in compute-at) or + * producers (in reverse-compute-at) + */ +template +void CalculateProvidedRequiredRegions( + const BlockNode* block, const StmtSRef& loop_sref, + std::unordered_map block2realize, + Array producer_srefs, Array consumer_srefs, + std::unordered_map>* provided_regions, + std::unordered_map>* required_regions) { + // Step 1. Calculate the region provided by a single execution instance of `block` + const Array& provided_buffers = is_compute_at ? block->writes : block->reads; + provided_regions->reserve(provided_buffers.size()); + required_regions->reserve(provided_buffers.size()); + for (const BufferRegion& provided_buffer_region : provided_buffers) { + const BufferNode* buffer = provided_buffer_region->buffer.get(); + const Array& region = provided_buffer_region->region; + (*provided_regions)[buffer].push_back(support::NDIntSetFromRegion(region)); + (*required_regions)[buffer].clear(); + } + // Step 2. Calculate the region required by dependent blocks under `loop` + for (const StmtSRef& required_block_sref : is_compute_at ? consumer_srefs : producer_srefs) { + const BlockNode* required_block = TVM_SREF_TO_BLOCK(required_block, required_block_sref); + ICHECK(block2realize.count(required_block)); + RelaxBufferRegions( + /*binding=*/GetBindings(GetRef(block2realize.at(required_block))), + /*buffer_regions=*/is_compute_at ? required_block->reads : required_block->writes, + /*relax_path_low_inclusive=*/GetRef(required_block_sref->parent), + /*relax_path_high_exclusive=*/loop_sref, /*relaxed=*/required_regions); + } +} + +/******** Main Implementation ********/ + +template +void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_sref, + const StmtSRef& loop_sref, bool preserve_unit_loops) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + // Step 1. Bunch of checks + // Check condition 1) and 2): stage pipeline and subtree compact dataflow + StmtSRef scope_root_sref = GetScopeRoot(self, block_sref, + /*require_stage_pipeline=*/true, + /*require_subtree_compact_dataflow=*/true); + Block scope_root = GetRef(scope_root_sref->StmtAs()); + BlockScope scope = self->GetBlockScope(scope_root_sref); + Array producer_srefs = GetProducers(block_sref, scope); + Array consumer_srefs = GetConsumers(block_sref, scope); + arith::Analyzer analyzer; + // Check condition 3): `block` and `loop` are under the same scope, + // and `loop` is not the ancestor of `block` + NotInSameScopeError::CheckAndBindLoopDomain(self, block_sref, loop_sref, scope_root_sref, + &analyzer); + // Check condition 4): `block` is not an output block + if (is_compute_at) { + CheckNotOutputBlock(self, block_sref, scope_root_sref); + } + // Step 2. Plan for the removal of `block` + ScopeReconstructor reconstructor(scope_root, GetRef(block), GetRef(loop)); + LeafBlockRemovalPlan(self, block_sref, &reconstructor.rm_src_stmt_, &reconstructor.rm_tgt_stmt_); + // Step 3. Find the insertion point under `loop` + // Check condition 5): all the required block are under the given loop + std::unordered_map block2realize; + block2realize.reserve(self->block_info.size()); + int insert_position = FindInsertionPoint( + /*self=*/self, + /*subtrees=*/AsArray(loop->body), + /*producer_srefs=*/producer_srefs, + /*consumer_srefs=*/consumer_srefs, /*block2realize=*/&block2realize); + // Step 4. Calculate the region provided by a single execution instance of `block`, + // as well as the region required by dependent blocks under `loop`. + // Here is the definition of `provide` and `require`: + // - In compute-at, `provide` means `produce`, and `require` means `consume` + // - In reverse-compute-at, `provide` means `consume`, and `require` means `produce` + std::unordered_map> provided_regions; + std::unordered_map> required_regions; + CalculateProvidedRequiredRegions( + /*block=*/block, /*loop_sref=*/loop_sref, /*block2realize=*/std::move(block2realize), + /*producer_srefs=*/std::move(producer_srefs), + /*consumer_srefs=*/std::move(consumer_srefs), + /*provided_regions=*/&provided_regions, /*required_regions=*/&required_regions); + // Step 5. Calculate the iteration domain for each block var + std::vector iter_doms = + CalculateBlockVarDomain(/*iter_vars=*/block->iter_vars, + /*provided_regions=*/std::move(provided_regions), + /*required_regions=*/std::move(required_regions), + /*analyzer=*/&analyzer); + // Step 6. Create the new scope according to the iteration domain + reconstructor.MakeNewLoop(/*insert_position=*/insert_position, /*iter_doms=*/std::move(iter_doms), + /*preserve_unit_loops=*/preserve_unit_loops); + Block new_scope_root = Downcast(reconstructor(scope_root)); + // Step 7. Do the actual replacement + self->Replace(scope_root_sref, new_scope_root, {{scope_root, new_scope_root}}); + // Step 8. Update the cached flags + BlockInfo& block_info = self->block_info[block_sref]; + block_info.affine_binding = IsAffineBinding( + /*realize=*/reconstructor.new_block_realize_, + /*loop_var_ranges=*/LoopDomainOfSRefTreePath(GetRef(block_sref->parent)), + /*analyzer=*/&analyzer); +} + +void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref, + bool preserve_unit_loops) { + ComputeAtOrReverseComputeAtImpl(self, block_sref, loop_sref, preserve_unit_loops); +} + +void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref, + bool preserve_unit_loops) { + ComputeAtOrReverseComputeAtImpl(self, block_sref, loop_sref, preserve_unit_loops); +} + +/******** InstructionKind Registration ********/ + +struct ComputeAtTraits : public UnpackedInstTraits { + static constexpr const char* kName = "ComputeAt"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 2; + static constexpr size_t kNumAttrs = 1; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, LoopRV loop_rv, + Bool preserve_unit_loops) { + return sch->ComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool()); + } + + static String UnpackedAsPython(Array outputs, String block_rv, String loop_rv, + Bool preserve_unit_loops) { + PythonAPICall py("compute_at"); + py.Input("block", block_rv); + py.Input("loop", loop_rv); + py.Input("preserve_unit_loops", preserve_unit_loops.operator bool()); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +struct ReverseComputeAtTraits : public UnpackedInstTraits { + static constexpr const char* kName = "ReverseComputeAt"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 2; + static constexpr size_t kNumAttrs = 1; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, LoopRV loop_rv, + Bool preserve_unit_loops) { + return sch->ReverseComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool()); + } + + static String UnpackedAsPython(Array outputs, String block_rv, String loop_rv, + Bool preserve_unit_loops) { + PythonAPICall py("reverse_compute_at"); + py.Input("block", block_rv); + py.Input("loop", loop_rv); + py.Input("preserve_unit_loops", preserve_unit_loops.operator bool()); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(ComputeAtTraits); +TVM_REGISTER_INST_KIND_TRAITS(ReverseComputeAtTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index 9c88cc1e787a..c2de78863d79 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -97,31 +97,6 @@ class BodyAnalysisError : public ScheduleError { Block block_; }; -class OnlyLeafError : public ScheduleError { - public: - explicit OnlyLeafError(IRModule mod, Block leaf_block, StmtSRef scope_root_sref) - : mod_(mod), leaf_block_(std::move(leaf_block)), scope_root_(nullptr) { - const BlockNode* scope_root = TVM_SREF_TO_BLOCK(scope_root, scope_root_sref); - this->scope_root_ = GetRef(scope_root); - } - - String FastErrorString() const final { - return "ScheduleError: Cannot remove the only leaf in the scope"; - } - - String DetailRenderTemplate() const final { - return "Block {0} is the only leaf in the scope {1}, which cannot be removed; Otherwise the " - "scope will be empty."; - } - - IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {leaf_block_, scope_root_}; } - - IRModule mod_; - Block leaf_block_; - Block scope_root_; -}; - class NonSingleProducerError : public ScheduleError { public: explicit NonSingleProducerError(IRModule mod, Block block) @@ -188,76 +163,6 @@ class OpaqueAccessError : public ScheduleError { Block scope_root_; }; -/*! - * \brief Construct a new AST, with a specific sref tree leaf removed. - * The leaf's ancestors who have only a single child will be removed too. - * \param leaf_block_sref The block/loop sref to the sref tree leaf to be removed - * \param src_stmt The root of the subtree where the replacement begins - * \param tgt_stmt The root of the subtree after the replacement - * \return A boolean indicating if the leaf can be removed successfully - * \note Removal is not conducted beyond scope-level. - * - * An example of the removal plan, say we are removing the leaf block "B" from the AST. - * - * \code - * with block([], "scope_root"): - * ... - * with block([128, 128], "B") as [vi, vj]: - * B[vi, vj] = A[vi, vj] + 1.0 - * with block([128, 128], "C") as [vi, vj]: - * C[vi, vj] = B[vi, vj] * 2.0 - * \endcode - * - * Ths method does not mutate the AST, instead it returns the a `(src_stmt, tgt_stmt)` pair as a - * plan to substitute certain pieces of the IR. - * - * In our example, it returns block "scope_root" as `src_stmt`, and the result `tgt_stmt` is: - * - * \code - * with block([], "scope_root"): - * ... - * with block([128, 128], "C") as [vi, vj]: - * C[vi, vj] = B[vi, vj] * 2.0 - * \endcode - */ -bool LeafBlockRemovalPlan(const StmtSRef& leaf_block_sref, Stmt* src_stmt, Stmt* tgt_stmt) { - // Go upwards until find an ancestor with more than one child - const StmtNode* last_stmt = leaf_block_sref->stmt; - StmtSRefNode* sref = leaf_block_sref->parent; - for (;; last_stmt = sref->stmt, sref = sref->parent) { - if (const auto* loop = sref->StmtAs()) { - if (const auto* seq = loop->body.as()) { - if (seq->size() > 1) { - break; - } - } - } else { - // Removal is not done beyond scope-level. - // When encountering a block, i.e. the scope root, we simply stop - break; - } - } - if (const auto* block = sref->StmtAs()) { - if (const auto* seq = block->body.as()) { - ObjectPtr n = make_object(*block); - n->body = RemoveFromSeqStmt(GetRef(seq), GetRef(last_stmt)); - *src_stmt = GetRef(block); - *tgt_stmt = Stmt(std::move(n)); - return true; - } - } - if (const auto* loop = sref->StmtAs()) { - if (const auto* seq = loop->body.as()) { - ObjectPtr n = make_object(*loop); - n->body = RemoveFromSeqStmt(GetRef(seq), GetRef(last_stmt)); - *src_stmt = GetRef(loop); - *tgt_stmt = Stmt(std::move(n)); - return true; - } - } - return false; -} - /*! * \brief The base class of the inliner, which handles: * 1) Substitute a subtree with the specific block being inlined @@ -622,8 +527,9 @@ void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) { Block producer_block = GetRef(_producer_block); Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleWrite(self, producer_block); // Step 1. Get the scope block - StmtSRef scope_root_sref = - GetScopeRoot(self, producer_block_sref, /*require_stage_pipeline=*/true); + StmtSRef scope_root_sref = GetScopeRoot(self, producer_block_sref, // + /*require_stage_pipeline=*/true, + /*require_subtree_compact_dataflow=*/false); // Step 2. Check completeness CheckCompleteBlock(self, producer_block_sref, scope_root_sref); // Step 3. Analyze the block body @@ -632,9 +538,7 @@ void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) { throw BodyAnalysisError(false, self->mod, producer_block); } // Step 4. Create a plan that removes the leaf block to be inlined - if (!LeafBlockRemovalPlan(producer_block_sref, &inliner.src_stmt, &inliner.tgt_stmt)) { - throw OnlyLeafError(self->mod, producer_block, scope_root_sref); - } + LeafBlockRemovalPlan(self, producer_block_sref, &inliner.src_stmt, &inliner.tgt_stmt); // Step 5. Create an AST where the leaf `producer_block_sref` points to is removed, // and update other blocks who read from the removed block Stmt tgt_stmt = inliner(GetRef(scope_root_sref->stmt)); @@ -650,8 +554,9 @@ void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sre Block consumer_block = GetRef(_consumer_block); Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleRead(self, consumer_block); // Step 1. Get the scope block - StmtSRef scope_root_sref = - GetScopeRoot(self, consumer_block_sref, /*require_stage_pipeline=*/true); + StmtSRef scope_root_sref = GetScopeRoot(self, consumer_block_sref, // + /*require_stage_pipeline=*/true, + /*require_subtree_compact_dataflow=*/false); // Step 2. Check completeness CheckCompleteBlock(self, consumer_block_sref, scope_root_sref); // Step 3. Check if the consumer has a single complete producer @@ -662,9 +567,7 @@ void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sre throw BodyAnalysisError(true, self->mod, consumer_block); } // Step 5. Create a plan that removes the leaf block to be inlined - if (!LeafBlockRemovalPlan(consumer_block_sref, &inliner.src_stmt, &inliner.tgt_stmt)) { - throw OnlyLeafError(self->mod, consumer_block, scope_root_sref); - } + LeafBlockRemovalPlan(self, consumer_block_sref, &inliner.src_stmt, &inliner.tgt_stmt); // Step 6. Create an AST where the leaf `consumer_block_sref` points to is removed, // and update other blocks who read from the removed block Stmt tgt_stmt = inliner(GetRef(scope_root_sref->stmt)); @@ -675,7 +578,7 @@ void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sre self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse); } -/******** Instruction Registration ********/ +/******** InstructionKind Registration ********/ struct ComputeInlineTraits : public UnpackedInstTraits { static constexpr const char* kName = "ComputeInline"; diff --git a/src/tir/schedule/primitive/for_kind.cc b/src/tir/schedule/primitive/for_kind.cc index a6056d607042..008d47792f69 100644 --- a/src/tir/schedule/primitive/for_kind.cc +++ b/src/tir/schedule/primitive/for_kind.cc @@ -27,7 +27,7 @@ class WrongBlockIterTypeError : public ScheduleError { : mod_(std::move(mod)), loop_var_(std::move(loop_var)), block_(std::move(block)) { op_str_ = for_kind == ForKind::kParallel ? "parallel" - : for_kind == ForKind::kVectorized ? "vectorize" : "bind"; + : (for_kind == ForKind::kVectorized ? "vectorize" : "bind"); } String FastErrorString() const final { std::ostringstream os; @@ -151,7 +151,9 @@ void ParallelizeComputation(const ScheduleState& self, const StmtSRef& loop_sref * parallelized/vectorized/bound. */ // Step 1. Check whether the subtree rooted from the `loop` in sref tree has compact data flow. - CheckSRefSubtreeCompactDataFlow(self, loop_sref); + GetScopeRoot(self, loop_sref, // + /*require_stage_pipeline=*/true, + /*require_subtree_compact_dataflow=*/true); // Step 2. Check whether the loop can be parallelized/vectorized/bound with regard to each // underlying block. @@ -187,7 +189,7 @@ void Unroll(ScheduleState self, const StmtSRef& loop_sref) { self->Replace(loop_sref, For(new_loop), {}); } -/******** Instruction Registration ********/ +/******** InstructionKind Registration ********/ struct ParallelTraits : public UnpackedInstTraits { static constexpr const char* kName = "Parallel"; @@ -251,7 +253,7 @@ struct BindTraits : public UnpackedInstTraits { static String UnpackedAsPython(Array outputs, String loop_rv, String thread) { PythonAPICall py("bind"); py.Input("loop", loop_rv); - py.Input("thread", thread); + py.Input("thread_axis", thread); return py.Str(); } diff --git a/src/tir/schedule/primitive/get_block_loop.cc b/src/tir/schedule/primitive/get_block_loop.cc index a8d9c5a69dc9..8b32a9c14f58 100644 --- a/src/tir/schedule/primitive/get_block_loop.cc +++ b/src/tir/schedule/primitive/get_block_loop.cc @@ -55,7 +55,7 @@ Array GetLoops(const StmtSRef& block_sref) { return {result.rbegin(), result.rend()}; } -/******** Instruction Registration ********/ +/******** InstructionKind Registration ********/ struct GetBlockTraits : public UnpackedInstTraits { static constexpr const char* kName = "GetBlock"; diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index 7c2b61344427..95c92aa0a322 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -687,7 +687,7 @@ void Reorder(ScheduleState self, const Array& ordered_loop_srefs) { self->Replace(GetRef(top), new_loop, {}); } -/******** Instruction Registration ********/ +/******** InstructionKind Registration ********/ struct SplitTraits : public UnpackedInstTraits { static constexpr const char* kName = "Split"; diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index af77e51e4d83..677b64311855 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -427,7 +427,7 @@ class BaseBlockCreator { CreateReadWriteRegions(); String new_block_name = old_block_realize_->block->name_hint; - PrimExpr predicate = Bool(true); + PrimExpr predicate = const_true(); if (is_rf_block_) { new_block_name = new_block_name + "_rf"; predicate = old_block_realize_->predicate; @@ -860,7 +860,9 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax BlockRealize block_realize = CheckGetSingleChildBlockRealizeOnSRefTree(self, rf_loop_sref); const StmtSRef& block_sref = self->stmt2ref.at(block_realize->block.get()); const Block& block = block_realize->block; - StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); + StmtSRef scope_root = GetScopeRoot(self, block_sref, // + /*require_stage_pipeline=*/true, + /*require_subtree_compact_dataflow=*/false); CheckReductionBlock(self, block_sref, scope_root); const ForNode* rf_loop = TVM_SREF_TO_FOR(rf_loop, rf_loop_sref); if (rf_loop->kind != ForKind::kSerial) { @@ -954,7 +956,7 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax return new_block_srefs[0]; } -/******** Instruction Registration ********/ +/******** InstructionKind Registration ********/ struct RFactorTraits : public UnpackedInstTraits { static constexpr const char* kName = "RFactor"; diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index ac40d27c4bf3..8843ac613179 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -19,7 +19,6 @@ #include -#include "../primitive.h" #include "../utils.h" namespace tvm { @@ -51,6 +50,8 @@ int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_st return candidates[i]; } +/******** InstructionKind Registration ********/ + struct SampleCategoricalTraits : public UnpackedInstTraits { static constexpr const char* kName = "SampleCategorical"; static constexpr bool kIsPure = true; @@ -79,7 +80,8 @@ struct SampleCategoricalTraits : public UnpackedInstTraits; + template + friend struct ::tvm::tir::UnpackedInstTraits; }; TVM_REGISTER_INST_KIND_TRAITS(SampleCategoricalTraits); diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index fd30b02fc9dd..4262a099b59d 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -146,6 +146,10 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheRead") TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheWrite") .set_body_method(&ScheduleNode::CacheWrite); /******** (FFI) Compute location ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeAt") + .set_body_method(&ScheduleNode::ComputeAt); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReverseComputeAt") + .set_body_method(&ScheduleNode::ReverseComputeAt); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeInline") .set_body_method(&ScheduleNode::ComputeInline); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReverseComputeInline") diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index 799806bef7b5..4604add3bdb4 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -35,15 +35,22 @@ using SMap = std::unordered_map; * \param dom_high_exclusive The highest node in the sref tree path * \return An n-dimensional integer set */ -Array AnalyzeRegionUpperBound(const BufferRegion& region, - const StmtSRef& dom_low_inclusive, - const StmtSRef& dom_high_exclusive) { - return arith::EvalSet( - region->region, - AsIntSet(LoopDomainOfSRefTreePath( - /*low_inclusive=*/dom_low_inclusive, - /*high_exclusive=*/dom_high_exclusive, - /*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer.scope())))); +Array AnalyzeRegionUpperBound(const BufferRegion& region, // + const PrimExpr& predicate, // + const StmtSRef& dom_low_inclusive, // + const StmtSRef& dom_high_exclusive, // + arith::Analyzer* analyzer) { + Map var_dom = LoopDomainOfSRefTreePath( + /*low_inclusive=*/dom_low_inclusive, + /*high_exclusive=*/dom_high_exclusive, + /*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer.scope())); + if (Optional> result = EstimateRegionLowerBound( + /*region=*/region->region, + /*var_dom=*/var_dom, + /*predicate=*/predicate, /*analyzer=*/analyzer)) { + return result.value(); + } + return arith::EvalSet(region->region, AsIntSet(var_dom)); } /*! @@ -56,19 +63,19 @@ Array AnalyzeRegionUpperBound(const BufferRegion& region, * \param analyzer The analyzer * \return An n-dimensional integer set */ -Array AnalyzeRegionLowerBound(const BlockRealize& realize, - const BufferRegion& region, - const StmtSRef& dom_low_inclusive, - const StmtSRef& dom_high_exclusive, +Array AnalyzeRegionLowerBound(const BufferRegion& region, // + const PrimExpr& predicate, // + const StmtSRef& dom_low_inclusive, // + const StmtSRef& dom_high_exclusive, // arith::Analyzer* analyzer) { + Map var_dom = LoopDomainOfSRefTreePath( + /*low_inclusive=*/dom_low_inclusive, + /*high_exclusive=*/dom_high_exclusive, + /*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer.scope())); if (Optional> result = EstimateRegionLowerBound( /*region=*/region->region, - /*var_dom=*/ - LoopDomainOfSRefTreePath( - /*low_inclusive=*/dom_low_inclusive, - /*high_exclusive=*/dom_high_exclusive, - /*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer.scope())), - /*predicate=*/realize->predicate, /*analyzer=*/analyzer)) { + /*var_dom=*/var_dom, + /*predicate=*/predicate, /*analyzer=*/analyzer)) { return result.value(); } return Array(region->buffer->shape.size(), arith::IntSet::Nothing()); @@ -90,16 +97,16 @@ bool ProducerCoversConsumer(const Array& buffer_shape, ICHECK_EQ(produced_region.size(), consumed_region.size()); int ndim = produced_region.size(); for (int i = 0; i < ndim; ++i) { - Range buffer_size = Range::FromMinExtent(0, buffer_shape[i]); + arith::IntSet buffer_size = arith::IntSet::FromMinExtent(0, buffer_shape[i]); if (produced_region[i].IsNothing()) { return false; } - Range produced = produced_region[i].CoverRange(buffer_size); - Range consumed = consumed_region[i].CoverRange(buffer_size); - PrimExpr produced_min = produced->min; - PrimExpr produced_max = produced->min + produced->extent; - PrimExpr consumed_min = consumed->min; - PrimExpr consumed_max = consumed->min + consumed->extent; + arith::IntSet produced = arith::Intersect({produced_region[i], buffer_size}); + arith::IntSet consumed = arith::Intersect({consumed_region[i], buffer_size}); + PrimExpr produced_min = analyzer->Simplify(produced.min()); + PrimExpr produced_max = analyzer->Simplify(produced.max() - produced_min + 1); + PrimExpr consumed_min = analyzer->Simplify(consumed.min()); + PrimExpr consumed_max = analyzer->Simplify(consumed.max() - consumed_min + 1); if (!analyzer->CanProve((produced_min <= consumed_min) && (consumed_max <= produced_max))) { return false; } @@ -276,6 +283,8 @@ class StateCreator : private StmtVisitor { for (const auto& kv : info.scope->dst2deps) { const StmtSRef& consumer_block_sref = kv.first; const Array& deps = kv.second; + const BlockNode* consumer_block = TVM_SREF_TO_BLOCK(consumer_block, consumer_block_sref); + const BlockRealize& consumer_realize = block2realize_.at(consumer_block); bool& region_cover = self_->block_info.at(consumer_block_sref).region_cover = true; // Step 2.1. Extract the path to the scope root std::unordered_map> lca_loc; @@ -334,11 +343,12 @@ class StateCreator : private StmtVisitor { // and to make sure region cover property must be satisfied once the flag is on // Therefore, we use lower-bound analysis for producers and upper-bound analysis for // consumer, and require that the produced region can cover the consumed region - touched_region.push_back(AnalyzeRegionLowerBound(/*realize=*/producer_realize, - /*region=*/region, - /*dom_low_inclusive=*/parent_sref, - /*dom_high_exclusive=*/lca, - /*analyzer=*/&analyzer_)); + touched_region.push_back(AnalyzeRegionLowerBound( + /*region=*/region, + /*predicate=*/producer_realize->predicate, + /*dom_low_inclusive=*/parent_sref, + /*dom_high_exclusive=*/lca, + /*analyzer=*/&analyzer_)); } } } @@ -353,8 +363,10 @@ class StateCreator : private StmtVisitor { arith::UnionRegionLowerBound({touched_region.begin(), touched_region.end()}); Array consumed_region = AnalyzeRegionUpperBound( /*region=*/region, + /*predicate=*/consumer_realize->predicate, /*dom_low_inclusive=*/parent_sref, - /*dom_high_exclusive=*/lca); + /*dom_high_exclusive=*/lca, + /*analyzer=*/&analyzer_); if (!ProducerCoversConsumer(buffer->shape, produced_region, consumed_region, &analyzer_)) { region_cover = false; @@ -920,8 +932,8 @@ void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_ // Before step `i`: // 1) `child_sref` is `src_sref` going up by `i` steps // 2) `child_tgt_stmt` is the subtree that `child_sref` should correspond to after replacement - // 3) except for the subtree root, srefs that point to the subtree of `child_tgt_stmt` are - // correct 4) for the subtree root of `child_tgt_stmt`, `child_sref` has not pointed to it yet + // 3) except for the subtree root, srefs that point to the subtree of `child_tgt_stmt` are correct + // 4) for the subtree root of `child_tgt_stmt`, `child_sref` has not pointed to it yet // 5) `tgt_stmt` is of type Loop, Block or BlockRealize // // During step `i`: @@ -1029,24 +1041,6 @@ TVM_DLL Array GetCachedFlags(const ScheduleState& self, const StmtSRef& bl Bool(info.scope->stage_pipeline)}; } -TVM_DLL void ScheduleStateNode::UpdateAffineFlag(const StmtSRef& scope_sref) { - auto it = this->block_info.find(scope_sref); - ICHECK(it != this->block_info.end()) << "Cannot find the block info of the given block."; - BlockInfo& info = it->second; - - bool is_root_block = scope_sref->parent == nullptr; - if (is_root_block) { - info.affine_binding = true; - } else { - BlockRealize realize = GetBlockRealize(GetRef(this), scope_sref); - arith::Analyzer analyzer; - StmtSRef parent_sref = GetRef(scope_sref->parent); - info.affine_binding = IsAffineBinding(/*realize=*/realize, - /*loop_var_ranges=*/LoopDomainOfSRefTreePath(parent_sref), - /*analyzer=*/&analyzer); - } -} - /**************** FFI ****************/ TVM_REGISTER_NODE_TYPE(ScheduleStateNode); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index f429a917858b..6f679598c9d1 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -192,6 +192,28 @@ BlockRV TracedScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer /******** Schedule: Compute location ********/ +void TracedScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, + bool preserve_unit_loops) { + ConcreteScheduleNode::ComputeAt(block_rv, loop_rv, preserve_unit_loops); + + static const InstructionKind& kind = InstructionKind::Get("ComputeAt"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv, loop_rv}, + /*attrs=*/{Integer(preserve_unit_loops)}, + /*outputs=*/{})); +} + +void TracedScheduleNode::ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, + bool preserve_unit_loops) { + ConcreteScheduleNode::ReverseComputeAt(block_rv, loop_rv, preserve_unit_loops); + + static const InstructionKind& kind = InstructionKind::Get("ReverseComputeAt"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv, loop_rv}, + /*attrs=*/{Integer(preserve_unit_loops)}, + /*outputs=*/{})); +} + void TracedScheduleNode::ComputeInline(const BlockRV& block_rv) { ConcreteScheduleNode::ComputeInline(block_rv); diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index a6b5251a96a3..fb89783b6036 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -76,6 +76,9 @@ class TracedScheduleNode : public ConcreteScheduleNode { BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope) final; /******** Schedule: Compute location ********/ + void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) final; + void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, + bool preserve_unit_loops) final; void ComputeInline(const BlockRV& block_rv) final; void ReverseComputeInline(const BlockRV& block_rv) final; /******** Schedule: Reduction ********/ diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index da376fdde90f..ffb6b2d52628 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -17,14 +17,13 @@ * under the License. */ -#include "./transform.h" - #include "./utils.h" namespace tvm { namespace tir { /******** Annotation ********/ + Block WithAnnotation(const BlockNode* block, const String& attr_key, const ObjectRef& attr_value) { Map annotations = block->annotations; annotations.Set(attr_key, attr_value); @@ -71,5 +70,71 @@ Array ReplaceBuffer(Array match_buffers, c return match_buffers; } +/******** Block Removal ********/ + +void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_sref, + Stmt* src_stmt, Stmt* tgt_stmt) { + class OnlyLeafError : public ScheduleError { + public: + explicit OnlyLeafError(IRModule mod, Block leaf_block, Block scope_root) + : mod_(mod), leaf_block_(leaf_block), scope_root_(scope_root) {} + + String FastErrorString() const final { + return "ScheduleError: Cannot remove the only leaf in the scope"; + } + + String DetailRenderTemplate() const final { + return "Block {0} is the only leaf in the scope {1}, which cannot be removed; Otherwise the " + "scope will be empty."; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {leaf_block_, scope_root_}; } + + IRModule mod_; + Block leaf_block_; + Block scope_root_; + }; + + // Go upwards until find an ancestor with more than one child + const StmtNode* last_stmt = leaf_block_sref->stmt; + StmtSRefNode* sref = leaf_block_sref->parent; + for (;; last_stmt = sref->stmt, sref = sref->parent) { + if (const auto* loop = sref->StmtAs()) { + if (const auto* seq = loop->body.as()) { + if (seq->size() > 1) { + break; + } + } + } else { + // Removal is not done beyond scope-level. + // When encountering a block, i.e. the scope root, we simply stop + break; + } + } + if (const auto* block = sref->StmtAs()) { + if (const auto* seq = block->body.as()) { + ObjectPtr n = make_object(*block); + n->body = RemoveFromSeqStmt(GetRef(seq), GetRef(last_stmt)); + *src_stmt = GetRef(block); + *tgt_stmt = Stmt(std::move(n)); + return; + } + } + if (const auto* loop = sref->StmtAs()) { + if (const auto* seq = loop->body.as()) { + ObjectPtr n = make_object(*loop); + n->body = RemoveFromSeqStmt(GetRef(seq), GetRef(last_stmt)); + *src_stmt = GetRef(loop); + *tgt_stmt = Stmt(std::move(n)); + return; + } + } + ICHECK(sref != nullptr && sref->stmt != nullptr); + const auto* leaf_block = TVM_SREF_TO_BLOCK(leaf_block, leaf_block_sref); + const auto* scope_block = TVM_SREF_TO_BLOCK(scope_block, sref); + throw OnlyLeafError(self->mod, GetRef(leaf_block), GetRef(scope_block)); +} + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h index 85cce9da216e..3932c4bdbd3d 100644 --- a/src/tir/schedule/transform.h +++ b/src/tir/schedule/transform.h @@ -64,6 +64,46 @@ Array ReplaceBuffer(Array regions, const Buffer& sou */ Array ReplaceBuffer(Array match_buffers, const Buffer& source, const Buffer& target); + +/******** Block Removal ********/ + +/*! + * \brief Construct a new AST, with a specific sref tree leaf removed. + * The leaf's ancestors who have only a single child will be removed too. + * \param leaf_block_sref The block/loop sref to the sref tree leaf to be removed + * \param src_stmt The root of the subtree where the replacement begins + * \param tgt_stmt The root of the subtree after the replacement + * \return A boolean indicating if the leaf can be removed successfully + * \note Read before use: + * 1) Removal is not conducted beyond scope-level. + * 2) This method only works properly when the scope root is a stage pipeline. + * + * An example of the removal plan, say we are removing the leaf block "B" from the AST. + * + * \code + * with block([], "scope_root"): + * ... + * with block([128, 128], "B") as [vi, vj]: + * B[vi, vj] = A[vi, vj] + 1.0 + * with block([128, 128], "C") as [vi, vj]: + * C[vi, vj] = B[vi, vj] * 2.0 + * \endcode + * + * Ths method does not mutate the AST, instead it returns the a `(src_stmt, tgt_stmt)` pair as a + * plan to substitute certain pieces of the IR. + * + * In our example, it returns block "scope_root" as `src_stmt`, and the result `tgt_stmt` is: + * + * \code + * with block([], "scope_root"): + * ... + * with block([128, 128], "C") as [vi, vj]: + * C[vi, vj] = B[vi, vj] * 2.0 + * \endcode + */ +void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_sref, + Stmt* src_stmt, Stmt* tgt_stmt); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index c2f430181664..a63a9f079617 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -34,10 +34,12 @@ #include #include +#include "../../arith/pattern_match.h" #include "../../node/attr_registry.h" #include "../../printer/text_printer.h" #include "../../runtime/thread_storage_scope.h" #include "../../support/array.h" +#include "../../support/nd_int_set.h" #include "./analysis.h" #include "./error.h" #include "./instruction_traits.h" @@ -163,6 +165,21 @@ inline Stmt RemoveFromSeqStmt(const SeqStmt& seq, const Stmt& to_remove) { return SeqStmt::Flatten(new_stmts); } +/*! + * \brief Convert a Stmt to an Array. + * \param stmt The Stmt to be converted to + * \return If the Stmt is SeqStmt, then returns the sequence; + * Otherwise, returns a single-element Array with the Stmt inside. + */ +inline Array AsArray(const Stmt& stmt) { + if (const auto* seq_stmt = stmt.as()) { + return seq_stmt->seq; + } + return {stmt}; +} + +/******** IterVar ********/ + /*! * \brief Create a new IterVar for the input For loop, with specified name and type * \param loop The loop to be created from diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index 961ea1721fa1..a1f488f386b3 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -30,6 +30,7 @@ #include #include "../../support/arena.h" +#include "../../support/nd_int_set.h" #include "../../support/utils.h" #include "../schedule/utils.h" #include "ir_utils.h" @@ -37,62 +38,7 @@ namespace tvm { namespace tir { -using NDIntSet = std::vector; - -arith::IntSet IntSetFromMinExtent(const PrimExpr& min, const PrimExpr& extent) { - return arith::IntSet::FromRange(Range::FromMinExtent(min, extent)); -} - -NDIntSet NDIntSetFromRegion(const Region& region) { - NDIntSet result; - result.reserve(region.size()); - for (const Range& range : region) { - result.push_back(arith::IntSet::FromRange(range)); - } - return result; -} - -NDIntSet NDIntSetFromShape(const Array& shape) { - PrimExpr zero = Integer(0); - NDIntSet result; - result.reserve(shape.size()); - for (const PrimExpr& extent : shape) { - result.push_back(IntSetFromMinExtent(zero, extent)); - } - return result; -} - -NDIntSet NDIntSetFromPoint(const Array& indices) { - NDIntSet result; - result.reserve(indices.size()); - for (const PrimExpr& index : indices) { - result.push_back(arith::IntSet::SinglePoint(index)); - } - return result; -} - -void NDIntSetUnionWith(NDIntSet* lhs, const NDIntSet& rhs) { - ICHECK_EQ(lhs->size(), rhs.size()); - int ndim = rhs.size(); - for (int i = 0; i < ndim; ++i) { - arith::IntSet& int_set = lhs->at(i); - int_set = arith::Union({int_set, rhs.at(i)}); - } -} - -NDIntSet NDIntSetEmpty(int ndim) { - return std::vector(ndim, arith::IntSet::Nothing()); -} - -NDIntSet EvalNDIntSet(const NDIntSet& nd_int_set, - const std::unordered_map& dom_map) { - NDIntSet ret; - ret.reserve(nd_int_set.size()); - for (const arith::IntSet& s : nd_int_set) { - ret.push_back(arith::EvalSet(s, dom_map)); - } - return ret; -} +using support::NDIntSet; /*! * \brief return the region collected by NDIntSet. return the oroginal buffer shape if the @@ -164,7 +110,8 @@ class BufferAccessRegionCollector : public StmtExprVisitor { // The iter_dom_map is updated by post DFS order. // If the union point is under the for node, the loop var will not be relaxed. // If the union point is outer of the for loop, the loop var should be relaxed. - iter_dom_map_on_post_order_[op->loop_var.get()] = IntSetFromMinExtent(op->min, op->extent); + iter_dom_map_on_post_order_[op->loop_var.get()] = + arith::IntSet::FromMinExtent(op->min, op->extent); } void VisitStmt_(const BlockNode* op) final { @@ -205,10 +152,10 @@ class BufferAccessRegionCollector : public StmtExprVisitor { for (const ForNode* loop : ancestor_loops_) { const VarNode* loop_var = loop->loop_var.get(); if (NeedRelaxThread(GetRef(loop), runtime::StorageScope::Create(buffer.scope()))) { - dom_map[loop_var] = IntSetFromMinExtent(loop->min, loop->extent); + dom_map[loop_var] = arith::IntSet::FromMinExtent(loop->min, loop->extent); } } - NDIntSet int_set = EvalNDIntSet(nd_int_set, dom_map); + NDIntSet int_set = support::NDIntSetEval(nd_int_set, dom_map); buffer_access_region_[buffer] = NarrowBufferRegionFromNDIntSet(int_set, buffer->shape); } } @@ -221,7 +168,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor { if (it != buffer_var_in_scope_.end()) { const Buffer& buffer = it->second; const BufferAccessInfo* info = - arena_.make(buffer, NDIntSetFromRegion(buffer_region->region)); + arena_.make(buffer, support::NDIntSetFromRegion(buffer_region->region)); buffer_access_stack_.push(info); } } @@ -246,10 +193,11 @@ class BufferAccessRegionCollector : public StmtExprVisitor { while (buffer_access_stack_.size() > stack_top) { const BufferAccessInfo* info = buffer_access_stack_.top(); buffer_access_stack_.pop(); - NDIntSet nd_int_set = EvalNDIntSet(info->accessed_region, iter_dom_map_on_post_order_); + NDIntSet nd_int_set = + support::NDIntSetEval(info->accessed_region, iter_dom_map_on_post_order_); auto it = accesses.find(info->buffer); if (it != accesses.end()) { - NDIntSetUnionWith(&it->second, nd_int_set); + support::NDIntSetUnionWith(&it->second, nd_int_set); } else { accesses[info->buffer] = nd_int_set; } diff --git a/tests/python/unittest/test_tir_schedule_compute_at.py b/tests/python/unittest/test_tir_schedule_compute_at.py new file mode 100644 index 000000000000..a4f8b2e77078 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_compute_at.py @@ -0,0 +1,832 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import sys + +import pytest + +import tvm +from tvm import tir +from tvm.script import ty +from tvm.tir.schedule.testing import verify_trace_roundtrip + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks + +@tvm.script.tir +def two_elementwise(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128), "float32") + B = tir.alloc_buffer((128, 128), "float32") + C = tir.match_buffer(c, (128, 128), "float32") + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def two_elementwise_after_compute_at(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128), "float32") + B = tir.alloc_buffer((128, 128), "float32") + C = tir.match_buffer(c, (128, 128), "float32") + for i in range(0, 128): + for ax0, ax1 in tir.grid(1, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i + ax0) + tir.bind(vj, ax1) + B[vi, vj] = A[vi, vj] * 2.0 + for j in range(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def blockized_1(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128], "float32") + B = tir.alloc_buffer([128, 128], "float32") + C = tir.match_buffer(c, [128, 128], "float32") + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([8, 8], "C_outer") as [vi_o, vj_o]: + tir.reads([B[ + vi_o * 16 : vi_o * 16 + 16, + vj_o * 16 : vj_o * 16 + 16, + ]]) + tir.writes([C[ + vi_o * 16 : vi_o * 16 + 16, + vj_o * 16 : vj_o * 16 + 16 + ]]) + for i_i, j_i in tir.grid(16, 16): + with tir.block([128, 128], "C_inner") as [vi, vj]: + tir.bind(vi, vi_o * 16 + i_i) + tir.bind(vj, vj_o * 16 + j_i) + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def blockized_after_compute_at(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128], "float32") + B = tir.alloc_buffer([128, 128], "float32") + C = tir.match_buffer(c, [128, 128], "float32") + for i0_0, i1_0 in tir.grid(8, 8): + for ax0, ax1 in tir.grid(16, 16): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i0_0 * 16 + ax0) + tir.bind(vj, i1_0 * 16 + ax1) + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([8, 8], "C_outer") as [vi_o, vj_o]: + tir.bind(vi_o, i0_0) + tir.bind(vj_o, i1_0) + tir.reads([B[ + vi_o * 16 : vi_o * 16 + 16, + vj_o * 16 : vj_o * 16 + 16, + ]]) + tir.writes([C[ + vi_o * 16 : vi_o * 16 + 16, + vj_o * 16 : vj_o * 16 + 16 + ]]) + for i0_1, i1_1 in tir.grid(16, 16): + with tir.block([128, 128], "C_inner") as [vi, vj]: + tir.bind(vi, vi_o * 16 + i0_1) + tir.bind(vj, vj_o * 16 + i1_1) + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def blockized_2(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128], "float32") + B = tir.alloc_buffer([128, 128], "float32") + C = tir.match_buffer(c, [128, 128], "float32") + for i_o, j_o in tir.grid(8, 8): + with tir.block([8, 8], "B_outer") as [vio, vjo]: + tir.bind(vio, i_o) + tir.bind(vjo, j_o) + tir.reads([A[ + vio * 16 : vio * 16 + 16, + vjo * 16 : vjo * 16 + 16, + ]]) + tir.writes([B[ + vio * 16 : vio * 16 + 16, + vjo * 16 : vjo * 16 + 16 + ]]) + for i_i, j_i in tir.grid(16, 16): + with tir.block([128, 128], "B_inner") as [vi, vj]: + tir.bind(vi, vio * 16 + i_i) + tir.bind(vj, vjo * 16 + j_i) + B[vi, vj] = A[vi, vj] * 2.0 + for i_o, j_o, i_i, j_i in tir.grid(4, 4, 32, 32): + with tir.block([128, 128], "C") as [vi, vj]: + tir.bind(vi, i_o * 32 + i_i) + tir.bind(vj, j_o * 32 + j_i) + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def blockized_2_after_reverse_compute_at(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128], "float32") + B = tir.alloc_buffer([128, 128], "float32") + C = tir.match_buffer(c, [128, 128], "float32") + for i_o, j_o in tir.grid(8, 8): + with tir.block([8, 8], "B_outer") as [vio, vjo]: + tir.bind(vio, i_o) + tir.bind(vjo, j_o) + tir.reads([A[ + vio * 16 : vio * 16 + 16, + vjo * 16 : vjo * 16 + 16, + ]]) + tir.writes([B[ + vio * 16 : vio * 16 + 16, + vjo * 16 : vjo * 16 + 16 + ]]) + for i_i, j_i in tir.grid(16, 16): + with tir.block([128, 128], "B_inner") as [vi, vj]: + tir.bind(vi, vio * 16 + i_i) + tir.bind(vj, vjo * 16 + j_i) + B[vi, vj] = A[vi, vj] * 2.0 + for ax0, ax1 in tir.grid(16, 16): + with tir.block([128, 128], "C") as [vi, vj]: + tir.bind(vi, i_o * 16 + ax0) + tir.bind(vj, j_o * 16 + ax1) + tir.reads([B[vi, vj]]) + tir.writes([C[vi, vj]]) + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def blockized_2_after_compute_at(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128], "float32") + B = tir.alloc_buffer([128, 128], "float32") + C = tir.match_buffer(c, [128, 128], "float32") + for i_o, j_o in tir.grid(4, 4): + for ax0, ax1 in tir.grid(2, 2): + with tir.block([8, 8], "blockized_B") as [vio, vjo]: + tir.bind(vio, i_o * 2 + ax0) + tir.bind(vjo, j_o * 2 + ax1) + tir.reads([A[ + vio * 16 : vio * 16 + 16, + vjo * 16 : vjo * 16 + 16, + ]]) + tir.writes([B[ + vio * 16 : vio * 16 + 16, + vjo * 16 : vjo * 16 + 16, + ]]) + for i_i, j_i in tir.grid(16, 16): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, vio * 16 + i_i) + tir.bind(vj, vjo * 16 + j_i) + B[vi, vj] = A[vi, vj] * 2.0 + for i_i, j_i in tir.grid(32, 32): + with tir.block([128, 128], "C") as [vi, vj]: + tir.bind(vi, i_o * 32 + i_i) + tir.bind(vj, j_o * 32 + j_i) + C[vi, vj] = B[vi, vj] + 1.0 + +@tvm.script.tir +def cuda_matmul_0(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=undefined-loop-variable + A = tir.match_buffer(a, [2048, 2048], "float32") + B = tir.match_buffer(b, [2048, 2048], "float32") + C = tir.match_buffer(c, [2048, 2048], "float32") + A_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") + B_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") + A_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") + B_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") + C_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") + with tir.block([2048, 2048], "A_shared") as [v0, v1]: + A_shared[v0, v1] = A[v0, v1] + with tir.block([2048, 2048], "B_shared") as [v0, v1]: + B_shared[v0, v1] = B[v0, v1] + with tir.block([2048, 2048], "A_shared_local") as [v0, v1]: + A_shared_local[v0, v1] = A_shared[v0, v1] + with tir.block([2048, 2048], "B_shared_local") as [v0, v1]: + B_shared_local[v0, v1] = B_shared[v0, v1] + with tir.block([2048, 2048, tir.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: + with tir.init(): + C_local[vi, vj] = 0.0 + C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] + for by in tir.thread_binding(0, 32, thread = "blockIdx.y"): + for bx in tir.thread_binding(0, 32, thread = "blockIdx.x"): + for vy in tir.thread_binding(0, 2, thread = "vthread.y"): + for vx in tir.thread_binding(0, 2, thread = "vthread.x"): + for ty in tir.thread_binding(0, 8, thread = "threadIdx.y"): + for tx in tir.thread_binding(0, 8, thread = "threadIdx.x"): + for i, j in tir.grid(4, 4): + with tir.block([2048, 2048], "C_local") as [v0_4, v1_4]: + tir.bind(v0_4, by * 64 + vy * 32 + ty * 4 + i) + tir.bind(v1_4, bx * 64 + vx * 32 + tx * 4 + j) + C[v0_4, v1_4] = C_local[v0_4, v1_4] + + +@tvm.script.tir +def cuda_matmul_0_after_compute_at(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=undefined-loop-variable + A = tir.match_buffer(a, [2048, 2048], "float32") + B = tir.match_buffer(b, [2048, 2048], "float32") + C = tir.match_buffer(c, [2048, 2048], "float32") + A_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") + B_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") + A_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") + B_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") + C_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") + with tir.block([2048, 2048], "A_shared") as [v0, v1]: + A_shared[v0, v1] = A[v0, v1] + with tir.block([2048, 2048], "B_shared") as [v0, v1]: + B_shared[v0, v1] = B[v0, v1] + with tir.block([2048, 2048], "A_shared_local") as [v0, v1]: + A_shared_local[v0, v1] = A_shared[v0, v1] + with tir.block([2048, 2048], "B_shared_local") as [v0, v1]: + B_shared_local[v0, v1] = B_shared[v0, v1] + for by in tir.thread_binding(0, 32, thread = "blockIdx.y"): + for bx in tir.thread_binding(0, 32, thread = "blockIdx.x"): + for vy in tir.thread_binding(0, 2, thread = "vthread.y"): + for vx in tir.thread_binding(0, 2, thread = "vthread.x"): + for ty in tir.thread_binding(0, 8, thread = "threadIdx.y"): + for tx in tir.thread_binding(0, 8, thread = "threadIdx.x"): + for i, j, k in tir.grid(4, 4, 2048): + with tir.block([2048, 2048, tir.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: + tir.bind(vi, by * 64 + vy * 32 + ty * 4 + i) + tir.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) + tir.bind(vk, k) + with tir.init(): + C_local[vi, vj] = 0.0 + C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] + for i, j in tir.grid(4, 4): + with tir.block([2048, 2048], "C_local") as [vi, vj]: + tir.bind(vi, by * 64 + vy * 32 + ty * 4 + i) + tir.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) + C[vi, vj] = C_local[vi, vj] + + +@tvm.script.tir +def cuda_matmul_1(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=undefined-loop-variable + A = tir.match_buffer(a, [2048, 2048], "float32") + B = tir.match_buffer(b, [2048, 2048], "float32") + C = tir.match_buffer(c, [2048, 2048], "float32") + A_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") + B_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") + A_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") + B_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") + C_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") + with tir.block([2048, 2048], "A_shared") as [v0, v1]: + A_shared[v0, v1] = A[v0, v1] + with tir.block([2048, 2048], "B_shared") as [v0, v1]: + B_shared[v0, v1] = B[v0, v1] + with tir.block([2048, 2048], "A_shared_local") as [v0, v1]: + A_shared_local[v0, v1] = A_shared[v0, v1] + with tir.block([2048, 2048], "B_shared_local") as [v0, v1]: + B_shared_local[v0, v1] = B_shared[v0, v1] + for by in tir.thread_binding(0, 32, thread = "blockIdx.y"): + for bx in tir.thread_binding(0, 32, thread = "blockIdx.x"): + for vy in tir.thread_binding(0, 2, thread = "vthread.y"): + for vx in tir.thread_binding(0, 2, thread = "vthread.x"): + for ty in tir.thread_binding(0, 8, thread = "threadIdx.y"): + for tx in tir.thread_binding(0, 8, thread = "threadIdx.x"): + for k_0 in tir.serial(0, 256): + for k_1 in tir.unroll(0, 8): + for _, i, j in tir.grid(1, 4, 4): + with tir.block([2048, 2048, tir.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: + tir.bind(vi, by * 64 + vy * 32 + ty * 4 + i) + tir.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) + tir.bind(vk, k_0 * 8 + k_1) + with tir.init(): + C_local[vi, vj] = 0.0 + C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] + for i, j in tir.grid(4, 4): + with tir.block([2048, 2048], "C_local") as [vi, vj]: + tir.bind(vi, by * 64 + vy * 32 + ty * 4 + i) + tir.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) + C[vi, vj] = C_local[vi, vj] + + +@tvm.script.tir +def cuda_matmul_2(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=undefined-loop-variable + A = tir.match_buffer(a, [2048, 2048], "float32") + B = tir.match_buffer(b, [2048, 2048], "float32") + C = tir.match_buffer(c, [2048, 2048], "float32") + A_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") + B_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") + A_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") + B_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") + C_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") + with tir.block([2048, 2048], "A_shared") as [v0, v1]: + A_shared[v0, v1] = A[v0, v1] + with tir.block([2048, 2048], "B_shared") as [v0, v1]: + B_shared[v0, v1] = B[v0, v1] + with tir.block([2048, 2048], "B_shared_local") as [v0, v1]: + B_shared_local[v0, v1] = B_shared[v0, v1] + for by in tir.thread_binding(0, 32, thread = "blockIdx.y"): + for bx in tir.thread_binding(0, 32, thread = "blockIdx.x"): + for vy in tir.thread_binding(0, 2, thread = "vthread.y"): + for vx in tir.thread_binding(0, 2, thread = "vthread.x"): + for ty in tir.thread_binding(0, 8, thread = "threadIdx.y"): + for tx in tir.thread_binding(0, 8, thread = "threadIdx.x"): + for k_0 in tir.serial(0, 256): + for k_1 in tir.unroll(0, 8): + for i, j in tir.grid(1, 4): + with tir.block([2048, 2048], "A_shared_local") as [v0, v1]: + tir.bind(v0, k_0 * 8 + k_1 + i) + tir.bind(v1, by * 64 + vy * 32 + ty * 4 + j) + A_shared_local[v0, v1] = A_shared[v0, v1] + for _, i, j in tir.grid(1, 4, 4): + with tir.block([2048, 2048, tir.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: + tir.bind(vi, by * 64 + vy * 32 + ty * 4 + i) + tir.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) + tir.bind(vk, k_0 * 8 + k_1) + with tir.init(): + C_local[vi, vj] = tir.float32(0) + C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] + for i, j in tir.grid(4, 4): + with tir.block([2048, 2048], "C_local") as [v0, v1]: + tir.bind(v0, by * 64 + vy * 32 + ty * 4 + i) + tir.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) + C[v0, v1] = C_local[v0, v1] + + +@tvm.script.tir +def cuda_matmul_3(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=undefined-loop-variable + A = tir.match_buffer(a, [2048, 2048], "float32") + B = tir.match_buffer(b, [2048, 2048], "float32") + C = tir.match_buffer(c, [2048, 2048], "float32") + A_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") + B_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") + A_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") + B_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") + C_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") + with tir.block([2048, 2048], "A_shared") as [v0, v1]: + A_shared[v0, v1] = A[v0, v1] + with tir.block([2048, 2048], "B_shared") as [v0, v1]: + B_shared[v0, v1] = B[v0, v1] + for by in tir.thread_binding(0, 32, thread = "blockIdx.y"): + for bx in tir.thread_binding(0, 32, thread = "blockIdx.x"): + for vy in tir.thread_binding(0, 2, thread = "vthread.y"): + for vx in tir.thread_binding(0, 2, thread = "vthread.x"): + for ty in tir.thread_binding(0, 8, thread = "threadIdx.y"): + for tx in tir.thread_binding(0, 8, thread = "threadIdx.x"): + for k0 in tir.serial(0, 256): + for k1 in tir.unroll(0, 8): + for i, j in tir.grid(1, 4): + with tir.block([2048, 2048], "A_shared_local") as [v0, v1]: + tir.bind(v0, k0 * 8 + k1 + i) + tir.bind(v1, by * 64 + vy * 32 + ty * 4 + j) + A_shared_local[v0, v1] = A_shared[v0, v1] + for i, j in tir.grid(1, 4): + with tir.block([2048, 2048], "B_shared_local") as [v0, v1]: + tir.bind(v0, k0 * 8 + k1 + i) + tir.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) + B_shared_local[v0, v1] = B_shared[v0, v1] + for _, i, j in tir.grid(1, 4, 4): + with tir.block([2048, 2048, tir.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: + tir.bind(vi, by * 64 + vy * 32 + ty * 4 + i) + tir.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) + tir.bind(vk, k0 * 8 + k1) + with tir.init(): + C_local[vi, vj] = tir.float32(0) + C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] + for i, j in tir.grid(4, 4): + with tir.block([2048, 2048], "C_local") as [v0, v1]: + tir.bind(v0, by * 64 + vy * 32 + ty * 4 + i) + tir.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) + C[v0, v1] = C_local[v0, v1] + + +@tvm.script.tir +def cuda_matmul_4(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=undefined-loop-variable + A = tir.match_buffer(a, [2048, 2048], "float32") + B = tir.match_buffer(b, [2048, 2048], "float32") + C = tir.match_buffer(c, [2048, 2048], "float32") + A_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") + B_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") + A_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") + B_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") + C_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") + with tir.block([2048, 2048], "B_shared") as [v0, v1]: + B_shared[v0, v1] = B[v0, v1] + for by in tir.thread_binding(0, 32, thread = "blockIdx.y"): + for bx in tir.thread_binding(0, 32, thread = "blockIdx.x"): + for vy in tir.thread_binding(0, 2, thread = "vthread.y"): + for vx in tir.thread_binding(0, 2, thread = "vthread.x"): + for ty in tir.thread_binding(0, 8, thread = "threadIdx.y"): + for tx in tir.thread_binding(0, 8, thread = "threadIdx.x"): + for k0 in tir.serial(0, 256): + for i, j in tir.grid(8, 64): + with tir.block([2048, 2048], "A_shared") as [v0, v1]: + tir.bind(v0, k0 * 8 + i) + tir.bind(v1, by * 64 + j) + A_shared[v0, v1] = A[v0, v1] + for k1 in tir.unroll(0, 8): + for i, j in tir.grid(1, 4): + with tir.block([2048, 2048], "A_shared_local") as [v0, v1]: + tir.bind(v0, k0 * 8 + k1 + i) + tir.bind(v1, by * 64 + vy * 32 + ty * 4 + j) + A_shared_local[v0, v1] = A_shared[v0, v1] + for i, j in tir.grid(1, 4): + with tir.block([2048, 2048], "B_shared_local") as [v0, v1]: + tir.bind(v0, k0 * 8 + k1 + i) + tir.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) + B_shared_local[v0, v1] = B_shared[v0, v1] + for _, i, j in tir.grid(1, 4, 4): + with tir.block([2048, 2048, tir.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: + tir.bind(vi, by * 64 + vy * 32 + ty * 4 + i) + tir.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) + tir.bind(vk, k0 * 8 + k1) + with tir.init(): + C_local[vi, vj] = 0.0 + C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] + for i, j in tir.grid(4, 4): + with tir.block([2048, 2048], "C_local") as [v0, v1]: + tir.bind(v0, by * 64 + vy * 32 + ty * 4 + i) + tir.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) + C[v0, v1] = C_local[v0, v1] + + +@tvm.script.tir +def cuda_matmul_5(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=undefined-loop-variable + A = tir.match_buffer(a, [2048, 2048], "float32") + B = tir.match_buffer(b, [2048, 2048], "float32") + C = tir.match_buffer(c, [2048, 2048], "float32") + A_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") + B_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") + A_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") + B_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") + C_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") + for by in tir.thread_binding(0, 32, thread = "blockIdx.y"): + for bx in tir.thread_binding(0, 32, thread = "blockIdx.x"): + for vy in tir.thread_binding(0, 2, thread = "vthread.y"): + for vx in tir.thread_binding(0, 2, thread = "vthread.x"): + for ty in tir.thread_binding(0, 8, thread = "threadIdx.y"): + for tx in tir.thread_binding(0, 8, thread = "threadIdx.x"): + for k0 in tir.serial(0, 256): + for i, j in tir.grid(8, 64): + with tir.block([2048, 2048], "A_shared") as [v0, v1]: + tir.bind(v0, k0 * 8 + i) + tir.bind(v1, by * 64 + j) + A_shared[v0, v1] = A[v0, v1] + for i, j in tir.grid(8, 64): + with tir.block([2048, 2048], "B_shared") as [v0, v1]: + tir.bind(v0, k0 * 8 + i) + tir.bind(v1, bx * 64 + j) + B_shared[v0, v1] = B[v0, v1] + for k1 in tir.unroll(0, 8): + for i, j in tir.grid(1, 4): + with tir.block([2048, 2048], "A_shared_local") as [v0, v1]: + tir.bind(v0, k0 * 8 + k1 + i) + tir.bind(v1, by * 64 + vy * 32 + ty * 4 + j) + A_shared_local[v0, v1] = A_shared[v0, v1] + for i, j in tir.grid(1, 4): + with tir.block([2048, 2048], "B_shared_local") as [v0, v1]: + tir.bind(v0, k0 * 8 + k1 + i) + tir.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) + B_shared_local[v0, v1] = B_shared[v0, v1] + for _, i, j in tir.grid(1, 4, 4): + with tir.block([2048, 2048, tir.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: + tir.bind(vi, by * 64 + vy * 32 + ty * 4 + i) + tir.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) + tir.bind(vk, k0 * 8 + k1) + with tir.init(): + C_local[vi, vj] = 0.0 + C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] + for i, j in tir.grid(4, 4): + with tir.block([2048, 2048], "C_local") as [v0, v1]: + tir.bind(v0, by * 64 + vy * 32 + ty * 4 + i) + tir.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) + C[v0, v1] = C_local[v0, v1] + + +@tvm.script.tir +def tiled(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128], "float32") + B = tir.alloc_buffer([128, 128], "float32") + C = tir.match_buffer(c, [128, 128], "float32") + for i_0, j_0, i_1, j_1 in tir.grid(8, 8, 16, 16): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i_0 * 16 + i_1) + tir.bind(vj, j_0 * 16 + j_1) + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def tiled_after_reverse_compute_at(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128], "float32") + B = tir.alloc_buffer([128, 128], "float32") + C = tir.match_buffer(c, [128, 128], "float32") + for i_0, j_0, i_1 in tir.grid(8, 8, 16): + for j_1 in tir.serial(0, 16): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i_0 * 16 + i_1) + tir.bind(vj, j_0 * 16 + j_1) + B[vi, vj] = A[vi, vj] * 2.0 + for j_1 in tir.serial(0, 16): + with tir.block([128, 128], "C") as [vi, vj]: + tir.bind(vi, i_0 * 16 + i_1) + tir.bind(vj, j_0 * 16 + j_1) + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def factorized(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [16, 16, 16], "float32") + B = tir.match_buffer(b, [16], "float32") + B_rf_local = tir.alloc_buffer([16, 16], "float32", scope="local") + for j in tir.thread_binding(0, 16, thread = "blockIdx.x"): + for i_o in tir.thread_binding(0, 4, thread = "threadIdx.x"): + for i_i, k in tir.grid(4, 16): + with tir.block([16, 16, tir.reduce_axis(0, 16)], "B_rf") as [vi, vj, vk]: + tir.bind(vi, i_o * 4 + i_i) + tir.bind(vj, j) + tir.bind(vk, k) + with tir.init(): + B_rf_local[vi, vj] = 0.0 + B_rf_local[vi, vj] = B_rf_local[vi, vj] + A[vj, vi, vk] + for i, k in tir.grid(16, 16): + with tir.block([16, tir.reduce_axis(0, 16)], "B") as [vi, vk]: + tir.bind(vi, i) + tir.bind(vk, k) + with tir.init(): + B[vi] = 0.0 + B[vi] = B[vi] + B_rf_local[vk, vi] + + +@tvm.script.tir +def factorized_after_reverse_compute_at(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [16, 16, 16], "float32") + B = tir.match_buffer(b, [16], "float32") + B_rf_local = tir.alloc_buffer([16, 16], "float32", scope="local") + for j in tir.thread_binding(0, 16, thread = "blockIdx.x"): + for i_o in tir.thread_binding(0, 4, thread = "threadIdx.x"): + for i_i, k in tir.grid(4, 16): + with tir.block([16, 16, tir.reduce_axis(0, 16)], "B_rf") as [vi, vj, vk]: + tir.bind(vi, i_o * 4 + i_i) + tir.bind(vj, j) + tir.bind(vk, k) + with tir.init(): + B_rf_local[vi, vj] = 0.0 + B_rf_local[vi, vj] = B_rf_local[vi, vj] + A[vj, vi, vk] + for k in tir.serial(0, 4): + with tir.block([16, tir.reduce_axis(0, 16)], "B") as [vi, vk]: + tir.bind(vi, j) + tir.bind(vk, i_o * 4 + k) + with tir.init(): + B[vi] = 0.0 + B[vi] = B[vi] + B_rf_local[vk, vi] + + +@tvm.script.tir +def fail_subtree_compact_dataflow(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128), "float32") + B = tir.alloc_buffer((128, 128), "float32") + C = tir.match_buffer(c, (128, 128), "float32") + for i in range(0, 128): + for j in range(0, 64): + with tir.block([128, 128], "B_0") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + for j in range(0, 64): + with tir.block([128, 128], "B_1") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j + 64) + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def fail_all_consumers_under_loop(a: ty.handle, c: ty.handle, d: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128), "float32") + B = tir.alloc_buffer((128, 128), "float32") + C = tir.match_buffer(c, (128, 128), "float32") + D = tir.match_buffer(d, (128, 128), "float32") + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "D") as [vi, vj]: + D[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def fail_all_producers_under_loop(a: ty.handle, d: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128), "float32") + B = tir.alloc_buffer((128, 128), "float32") + C = tir.alloc_buffer((128, 128), "float32") + D = tir.match_buffer(d, (128, 128), "float32") + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = A[vi, vj] + 1.0 + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "D") as [vi, vj]: + D[vi, vj] = B[vi, vj] + C[vi, vj] + + +@tvm.script.tir +def read_out_of_bound(a: ty.handle, c:ty.handle) -> None: + A = tir.match_buffer(a, [16], "float32") + B = tir.alloc_buffer([16], "float32") + C = tir.match_buffer(c, [16], "float32") + for i in tir.serial(0, 16): + with tir.block([16], "B") as [v]: + B[v] = A[v] + for j in tir.serial(0, 16): + with tir.block([16], "C") as [v]: + tir.reads(B[v : v + 2]) + C[v] = tir.if_then_else(v < 15, tir.max(B[v], B[v + 1]), B[v], dtype="float32") + + +@tvm.script.tir +def read_out_of_bound_after_compute_at(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [16], "float32") + B = tir.alloc_buffer([16], "float32") + C = tir.match_buffer(c, [16], "float32") + for j in tir.serial(0, 16): + for i in tir.serial(0, tir.min(1, 15 - j) + 1): + with tir.block([16], "B") as [v]: + tir.bind(v, j + i) + B[v] = A[v] + with tir.block([16], "C") as [v]: + tir.bind(v, j) + tir.reads([B[v : v + 2]]) + C[v] = tir.if_then_else(v < 15, tir.max(B[v], B[v + 1]), B[v], dtype="float32") + + +# pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks +# fmt: on + + +def test_compute_at_two_elementwise(): + sch = tir.Schedule(two_elementwise, debug_mask="all") + block = sch.get_block("B") + loop, _ = sch.get_loops(sch.get_block("C")) + sch.compute_at(block, loop, preserve_unit_loops=True) + tvm.ir.assert_structural_equal(two_elementwise_after_compute_at, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=two_elementwise) + + +def test_compute_at_blockized_1(): + sch = tir.Schedule(blockized_1, debug_mask="all") + block = sch.get_block("B") + _, loop = sch.get_loops(sch.get_block("C_outer")) + sch.compute_at(block, loop, preserve_unit_loops=True) + tvm.ir.assert_structural_equal(blockized_after_compute_at, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=blockized_1) + + +def test_compute_at_blockized_2(): + sch = tir.Schedule(blockized_2, debug_mask="all") + block = sch.get_block("B_outer") + _, loop, _, _ = sch.get_loops(sch.get_block("C")) + sch.compute_at(block, loop, preserve_unit_loops=True) + tvm.ir.assert_structural_equal(blockized_2_after_compute_at, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=blockized_2) + + +def test_compute_at_cuda_matmul_0(): + sch = tir.Schedule(cuda_matmul_0, debug_mask="all") + block = sch.get_block("C") + _, _, _, _, _, loop, _, _ = sch.get_loops(sch.get_block("C_local")) + sch.compute_at(block, loop, preserve_unit_loops=True) + tvm.ir.assert_structural_equal(cuda_matmul_0_after_compute_at, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=cuda_matmul_0) + + +def test_compute_at_cuda_matmul_1(): + sch = tir.Schedule(cuda_matmul_1, debug_mask="all") + block = sch.get_block("A_shared_local") + _, _, _, _, _, _, _, loop, _, _, _ = sch.get_loops(sch.get_block("C")) + sch.compute_at(block, loop, preserve_unit_loops=True) + tvm.ir.assert_structural_equal(cuda_matmul_2, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=cuda_matmul_1) + + +def test_compute_at_cuda_matmul_2(): + sch = tir.Schedule(cuda_matmul_2, debug_mask="all") + block = sch.get_block("B_shared_local") + _, _, _, _, _, _, _, loop, _, _, _ = sch.get_loops(sch.get_block("C")) + sch.compute_at(block, loop, preserve_unit_loops=True) + tvm.ir.assert_structural_equal(cuda_matmul_3, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=cuda_matmul_2) + + +def test_compute_at_cuda_matmul_3(): + sch = tir.Schedule(cuda_matmul_3, debug_mask="all") + block = sch.get_block("A_shared") + _, _, _, _, _, _, loop, _, _, _, _ = sch.get_loops(sch.get_block("C")) + sch.compute_at(block, loop, preserve_unit_loops=True) + tvm.ir.assert_structural_equal(cuda_matmul_4, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=cuda_matmul_3) + + +def test_compute_at_cuda_matmul_4(): + sch = tir.Schedule(cuda_matmul_4, debug_mask="all") + block = sch.get_block("B_shared") + _, _, _, _, _, _, loop, _, _, _, _ = sch.get_loops(sch.get_block("C")) + sch.compute_at(block, loop, preserve_unit_loops=True) + tvm.ir.assert_structural_equal(cuda_matmul_5, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=cuda_matmul_4) + + +def test_reverse_compute_at_tiled(): + sch = tir.Schedule(tiled, debug_mask="all") + block = sch.get_block("C") + _, _, loop, _ = sch.get_loops(sch.get_block("B")) + sch.reverse_compute_at(block, loop, preserve_unit_loops=False) + tvm.ir.assert_structural_equal(tiled_after_reverse_compute_at, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=tiled) + + +def test_reverse_compute_at_blockized_2(): + sch = tir.Schedule(blockized_2, debug_mask="all") + block = sch.get_block("C") + _, loop = sch.get_loops(sch.get_block("B_outer")) + sch.reverse_compute_at(block, loop, preserve_unit_loops=True) + tvm.ir.assert_structural_equal(blockized_2_after_reverse_compute_at, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=blockized_2) + + +def test_reverse_compute_at_factorized(): + sch = tir.Schedule(factorized, debug_mask="all") + block = sch.get_block("B") + _, loop, _, _ = sch.get_loops(sch.get_block("B_rf")) + sch.reverse_compute_at(block, loop, preserve_unit_loops=False) + tvm.ir.assert_structural_equal(factorized_after_reverse_compute_at, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=factorized) + + +def test_read_out_of_bound(): + sch = tir.Schedule(read_out_of_bound, debug_mask="all") + block = sch.get_block("B") + (loop,) = sch.get_loops(sch.get_block("C")) + sch.compute_at(block, loop) + tvm.ir.assert_structural_equal(read_out_of_bound_after_compute_at, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=read_out_of_bound) + + +def test_fail_subtree_compact_dataflow(): + sch = tir.Schedule(fail_subtree_compact_dataflow, debug_mask="all") + block = sch.get_block("B_0") + loop, _ = sch.get_loops(sch.get_block("C")) + with pytest.raises(tvm.tir.ScheduleError, match="compact dataflow"): + sch.compute_at(block, loop) + + +def test_fail_not_in_same_scope(): + sch = tir.Schedule(blockized_1, debug_mask="all") + block = sch.get_block("B") + loop, _ = sch.get_loops(sch.get_block("C_inner")) + with pytest.raises(tvm.tir.ScheduleError, match="same block scope"): + sch.compute_at(block, loop) + + +def test_fail_loop_is_ancestor_of_block(): + sch = tir.Schedule(two_elementwise, debug_mask="all") + block = sch.get_block("B") + loop, _ = sch.get_loops(sch.get_block("B")) + with pytest.raises(tvm.tir.ScheduleError, match="ancestor of block"): + sch.compute_at(block, loop) + + +def test_fail_output_block(): + sch = tir.Schedule(tiled, debug_mask="all") + block = sch.get_block("C") + loop, _, _, _ = sch.get_loops(sch.get_block("B")) + with pytest.raises(tvm.tir.ScheduleError, match="output block"): + sch.compute_at(block, loop) + + +def test_fail_all_consumers_under_loop(): + sch = tir.Schedule(fail_all_consumers_under_loop, debug_mask="all") + block = sch.get_block("B") + loop, _ = sch.get_loops(sch.get_block("C")) + with pytest.raises(tvm.tir.ScheduleError, match="requires all the consumer"): + sch.compute_at(block, loop) + + +def test_fail_all_producers_under_loop(): + sch = tir.Schedule(fail_all_producers_under_loop, debug_mask="all") + block = sch.get_block("D") + loop, _ = sch.get_loops(sch.get_block("C")) + with pytest.raises(tvm.tir.ScheduleError, match="requires all the producer"): + sch.reverse_compute_at(block, loop) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:]))