Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Add preserve-unit-iters #11585

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
[TIR] Add preserve-unit-iters
Follow-up of #11578, which enforces structural stability of TIR by
avoiding over-simplification in affine analysis. On the other hand, it
is possible that over-simplification could be desirable behavior.
Therefore, following the precedent of `preserve-unit-loops` in
`Compute-At`, this PR introduces `preserve-unit-iters` in block binding
for cases where users don't need structural stability (which is
admittedly rare).

This PR does not affect any existing functionalities.

Example:

```python
for i in T.serial(2):
    with T.block("C"):
        k = T.axis.reduce(2, i)

Split(i, [1, 2], preserve-unit-iters=True/False)

for i_0, i_1 in T.grid(1, 2):
    with T.block("C"):
        k = T.axis.reduce(2, i_0 * 2 + i_1)

for i_0, i_1 in T.grid(1, 2):
    with T.block("C"):
        k = T.axis.reduce(2, i_1)
```
  • Loading branch information
junrushao committed Jun 16, 2022
commit 9b858365de235513f3b6e0c8cc1e2b9182d9acba
7 changes: 5 additions & 2 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,19 +277,22 @@ class ScheduleNode : public runtime::Object {
* 3) All loops must start with 0.
* 4) The domain of a loop to be fused cannot depend on another loop to be fused.
* \param loop_rvs The loops to be fused
* \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
* \return The new loop after fusion
*/
virtual LoopRV Fuse(const Array<LoopRV>& loop_rvs) = 0;
virtual LoopRV Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters = true) = 0;
Hzfengsy marked this conversation as resolved.
Show resolved Hide resolved
/*!
* \brief Split a loop into a list of consecutive loops. It requires:
* 1) The loop can't have annotation or thread binding.
* 2) The loop must start with 0.
* \param loop_rv The loop to be split
* \param factors The positive tiling factors, and at most one of which is `NullOpt`, which means
* that factor is inferred.
* \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
* \return The new loops after split
*/
virtual Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors) = 0;
virtual Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors,
bool preserve_unit_iters = true) = 0;
/*!
* \brief Reorder a list of loops. It doesn't require the loops to be consecutive.
* It requires:
Expand Down
21 changes: 18 additions & 3 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,11 @@ def get_consumers(self, block: Union[BlockRV, str]) -> List[BlockRV]:

########## Schedule: Transform loops ##########
@type_checked
def fuse(self, *loops: List[LoopRV]) -> LoopRV:
def fuse(
self,
*loops: List[LoopRV],
preserve_unit_iters: bool = True,
) -> LoopRV:
"""Fuse a list of consecutive loops into one. It requires:
1) The loops can't have annotations or thread bindings.
2) The (i+1)-th loop must be the only child of the i-th loop.
Expand Down Expand Up @@ -553,13 +557,14 @@ def after_fuse(a: T.handle, b: T.handle) -> None:
B[vi, vj] = A[vi, vj] * 2.0

"""
return _ffi_api.ScheduleFuse(self, loops) # type: ignore # pylint: disable=no-member
return _ffi_api.ScheduleFuse(self, loops, preserve_unit_iters) # type: ignore # pylint: disable=no-member

@type_checked
def split(
self,
loop: LoopRV,
factors: List[Union[int, ExprRV, None]],
preserve_unit_iters: bool = True,
) -> List[LoopRV]:
"""Split a loop into a list of consecutive loops. It requires:
1) The loop can't have annotation or thread binding.
Expand All @@ -580,6 +585,9 @@ def split(
- ExprRV
- Positive constant integers

preserve_unit_iters : bool
Whether or not to preserve unit iterators in block bindings

Returns
-------
split_loops : List[LoopRV]
Expand Down Expand Up @@ -628,7 +636,14 @@ def after_split(a: T.handle, b: T.handle) -> None:
"""
# it will be checked later in C++ implementation
# that there is at most one None in `factors`
return list(_ffi_api.ScheduleSplit(self, loop, factors)) # type: ignore # pylint: disable=no-member
return list(
_ffi_api.ScheduleSplit( # type: ignore # pylint: disable=no-member
self,
loop,
factors,
preserve_unit_iters,
)
)

@type_checked
def reorder(self, *ordered_loops: List[LoopRV]) -> None:
Expand Down
9 changes: 5 additions & 4 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -333,19 +333,20 @@ Array<BlockRV> ConcreteScheduleNode::GetConsumers(const BlockRV& block_rv) {

/******** Schedule: Transform loops ********/

LoopRV ConcreteScheduleNode::Fuse(const Array<LoopRV>& loop_rvs) {
LoopRV ConcreteScheduleNode::Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters) {
CHECK(!loop_rvs.empty()) << "ValueError: 'fuse' requires at least 1 loop(s)";
Array<StmtSRef> loop_srefs = this->GetSRefs(loop_rvs);
StmtSRef result{nullptr};
TVM_TIR_SCHEDULE_BEGIN();
result = tir::Fuse(state_, loop_srefs);
result = tir::Fuse(state_, loop_srefs, preserve_unit_iters);
TVM_TIR_SCHEDULE_END("fuse", this->error_render_level_);
this->state_->DebugVerify();
return CreateRV<LoopRV>(result);
}

Array<LoopRV> ConcreteScheduleNode::Split(const LoopRV& loop_rv,
const Array<Optional<ExprRV>>& factor_rvs) {
const Array<Optional<ExprRV>>& factor_rvs,
bool preserve_unit_iters) {
class NotSingleInferFactorError : public ScheduleError {
public:
explicit NotSingleInferFactorError(IRModule mod) : mod_(mod) {}
Expand Down Expand Up @@ -440,7 +441,7 @@ Array<LoopRV> ConcreteScheduleNode::Split(const LoopRV& loop_rv,
} else if (!this->analyzer_->CanProve(tot_length >= loop->extent)) {
throw WrongFactorProductError(state_->mod, GetRef<For>(loop));
}
results = tir::Split(state_, loop_sref, factors);
results = tir::Split(state_, loop_sref, factors, preserve_unit_iters);
TVM_TIR_SCHEDULE_END("split", this->error_render_level_);
this->state_->DebugVerify();
return CreateRV<LoopRV>(results);
Expand Down
5 changes: 3 additions & 2 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,9 @@ class ConcreteScheduleNode : public ScheduleNode {
Array<BlockRV> GetProducers(const BlockRV& block_rv) override;
Array<BlockRV> GetConsumers(const BlockRV& block_rv) override;
/******** Schedule: Transform loops ********/
LoopRV Fuse(const Array<LoopRV>& loop_rvs) override;
Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors) override;
LoopRV Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters) override;
Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors,
bool preserve_unit_iters) override;
void Reorder(const Array<LoopRV>& ordered_loop_rvs) override;
LoopRV AddUnitLoop(const BlockRV& block_rv) override;
LoopRV AddUnitLoop(const LoopRV& loop_rv) override;
Expand Down
7 changes: 5 additions & 2 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,11 @@ Array<StmtSRef> GetConsumers(const ScheduleState& self, const StmtSRef& block_sr
* \param self The state of the schedule
* \param loop_sref The sref to the loop being split
* \param factors The splitting factors
* \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
* \return An array of srefs to the loops after splitting
*/
TVM_DLL Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
const Array<PrimExpr>& factors);
const Array<PrimExpr>& factors, bool preserve_unit_iters);
/*!
* \brief Fuse a list of consecutive loops into one. It requires:
* 1) The loops can't have annotations or thread bindings.
Expand All @@ -168,9 +169,11 @@ TVM_DLL Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
* 4) The domain of a loop to be fused cannot depend on another loop to be fused.
* \param self The state of the schedule
* \param loop_srefs An array of srefs to the loops to be fused
* \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
* \return The sref to the fused loop
*/
TVM_DLL StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs);
TVM_DLL StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs,
bool preserve_unit_loops);
/*!
* \brief Reorder a list of loops. It doesn't require the loops to be consecutive.
* It requires:
Expand Down
64 changes: 39 additions & 25 deletions src/tir/schedule/primitive/loop_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,21 @@ class SubstituteVarAndCollectOpaqueBlock : public StmtExprMutator {
/*! \brief Simplify the binding of block realize and update the opaque block reuse mapping */
class IterMapSimplifyBlockBinding : public StmtExprMutator {
public:
explicit IterMapSimplifyBlockBinding(MapNode* opaque_blocks, Map<Var, Range> loop_var2extent)
: opaque_blocks_(opaque_blocks), loop_var2extent_(loop_var2extent) {}

static For SimplifyBindings(Stmt stmt, const Array<StmtSRef>& loop_srefs,
MapNode* opaque_blocks) {
explicit IterMapSimplifyBlockBinding(MapNode* opaque_blocks, Map<Var, Range> loop_var2extent,
bool preserve_unit_iters)
: opaque_blocks_(opaque_blocks),
loop_var2extent_(loop_var2extent),
preserve_unit_iters_(preserve_unit_iters) {}

static For SimplifyBindings(Stmt stmt, const Array<StmtSRef>& loop_srefs, MapNode* opaque_blocks,
bool preserve_unit_iters) {
Map<Var, Range> loop_var2extent;
for (const StmtSRef& sref : loop_srefs) {
const ForNode* loop = TVM_SREF_TO_FOR(loop, sref);
loop_var2extent.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
}
return Downcast<For>(
IterMapSimplifyBlockBinding(opaque_blocks, std::move(loop_var2extent))(std::move(stmt)));
return Downcast<For>(IterMapSimplifyBlockBinding(opaque_blocks, std::move(loop_var2extent),
preserve_unit_iters)(std::move(stmt)));
}

private:
Expand All @@ -112,11 +115,12 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator {
}
return std::move(realize);
}
Array<PrimExpr> v = arith::IterMapSimplify(/*indices=*/op->iter_values,
/*input_iters=*/loop_var2extent_,
/*input_pred=*/op->predicate,
/*check_level=*/arith::IterMapLevel::Surjective,
/*simplify_trivial_iterators=*/false);
Array<PrimExpr> v =
arith::IterMapSimplify(/*indices=*/op->iter_values,
/*input_iters=*/loop_var2extent_,
/*input_pred=*/op->predicate,
/*check_level=*/arith::IterMapLevel::Surjective,
/*simplify_trivial_iterators=*/!preserve_unit_iters_);
if (v.same_as(op->iter_values)) {
return GetRef<Stmt>(op);
} else {
Expand All @@ -130,6 +134,8 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator {
MapNode* opaque_blocks_;
/*! \brief The range of loops */
Map<Var, Range> loop_var2extent_;
/*! \brief Whether or not to simplify unit iterators */
bool preserve_unit_iters_;
};

class BlockPropertyError : public ScheduleError {
Expand Down Expand Up @@ -376,8 +382,8 @@ class DependentLoopError : public ScheduleError {
PrimitiveKind kind_;
};

Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
const Array<PrimExpr>& factors) {
Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref, const Array<PrimExpr>& factors,
bool preserve_unit_iters) {
// Invariance
// - The total repeat number has not changed for each direct child block with updating predicate.
// - The execution order has not changed. (The block executes with the same args and the same
Expand Down Expand Up @@ -432,7 +438,8 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
new_stmt = For(new_loop_vars[i], 0, factors[i], ForKind::kSerial, new_stmt);
}
new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings(std::move(new_stmt), GetLoops(loop_sref),
opaque_block_reuse.CopyOnWrite());
opaque_block_reuse.CopyOnWrite(),
preserve_unit_iters);
self->Replace(loop_sref, new_stmt, opaque_block_reuse);
Array<StmtSRef> result_srefs;
result_srefs.reserve(n);
Expand All @@ -444,7 +451,7 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
return result_srefs;
}

StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs) {
StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs, bool preserve_unit_iters) {
// Invariance
// - The total repeat number has not changed for each direct child block.
// - The execution order has not changed. (The block executes with the same
Expand Down Expand Up @@ -527,7 +534,8 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs) {
fused_extent = analyzer.Simplify(fused_extent);
new_stmt = For(fused_var, 0, fused_extent, ForKind::kSerial, new_stmt);
new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings(
std::move(new_stmt), GetLoops(loop_srefs[0]), opaque_block_reuse.CopyOnWrite());
std::move(new_stmt), GetLoops(loop_srefs[0]), opaque_block_reuse.CopyOnWrite(),
preserve_unit_iters);
self->Replace(loop_srefs[0], new_stmt, opaque_block_reuse);
return self->stmt2ref.at(new_stmt.get());
}
Expand Down Expand Up @@ -755,7 +763,7 @@ struct SplitTraits : public UnpackedInstTraits<SplitTraits> {

private:
static constexpr size_t kNumInputs = 2;
static constexpr size_t kNumAttrs = 0;
static constexpr size_t kNumAttrs = 1;
static constexpr size_t kNumDecisions = 0;

template <size_t delta>
Expand All @@ -770,14 +778,17 @@ struct SplitTraits : public UnpackedInstTraits<SplitTraits> {
}

static Array<LoopRV> UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv,
Array<Optional<ExprRV>> factors) {
return sch->Split(loop_rv, factors);
Array<Optional<ExprRV>> factors,
Bool preserve_unit_iters) {
return sch->Split(loop_rv, factors, preserve_unit_iters.operator bool());
}

static String UnpackedAsPython(Array<String> outputs, String loop_rv, Array<ObjectRef> factors) {
static String UnpackedAsPython(Array<String> outputs, String loop_rv, Array<ObjectRef> factors,
Bool preserve_unit_iters) {
PythonAPICall py("split");
py.Input("loop", loop_rv);
py.Input("factors", factors);
py.Input("preserve_unit_iters", preserve_unit_iters.operator bool());
py.OutputList(outputs);
return py.Str();
}
Expand All @@ -792,7 +803,7 @@ struct FuseTraits : public UnpackedInstTraits<FuseTraits> {

private:
static constexpr size_t kNumInputs = 1;
static constexpr size_t kNumAttrs = 0;
static constexpr size_t kNumAttrs = 1;
static constexpr size_t kNumDecisions = 0;

template <size_t delta>
Expand All @@ -801,15 +812,18 @@ struct FuseTraits : public UnpackedInstTraits<FuseTraits> {
setter(delta, inputs);
}

static LoopRV UnpackedApplyToSchedule(Schedule sch, Array<LoopRV> loop_rvs) {
return sch->Fuse(loop_rvs);
static LoopRV UnpackedApplyToSchedule(Schedule sch, Array<LoopRV> loop_rvs,
Bool preserve_unit_iters) {
return sch->Fuse(loop_rvs, preserve_unit_iters.operator bool());
}

static String UnpackedAsPython(Array<String> outputs, Array<String> loop_rvs) {
static String UnpackedAsPython(Array<String> outputs, Array<String> loop_rvs,
Bool preserve_unit_iters) {
PythonAPICall py("fuse");
for (const String& loop_rv : loop_rvs) {
py.Input("", loop_rv);
}
py.Input("preserve_unit_iters", preserve_unit_iters.operator bool());
py.SingleOutput(outputs);
return py.Str();
}
Expand Down
13 changes: 7 additions & 6 deletions src/tir/schedule/traced_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,20 +158,21 @@ Array<BlockRV> TracedScheduleNode::GetConsumers(const BlockRV& block_rv) {

/******** Schedule: Transform loops ********/

LoopRV TracedScheduleNode::Fuse(const Array<LoopRV>& loop_rvs) {
LoopRV result = ConcreteScheduleNode::Fuse(loop_rvs);
LoopRV TracedScheduleNode::Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_loops) {
LoopRV result = ConcreteScheduleNode::Fuse(loop_rvs, preserve_unit_loops);

static const InstructionKind& kind = InstructionKind::Get("Fuse");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
/*inputs=*/{loop_rvs.begin(), loop_rvs.end()},
/*attrs=*/{},
/*attrs=*/{Integer(preserve_unit_loops)},
/*outputs=*/{result}));
return result;
}

Array<LoopRV> TracedScheduleNode::Split(const LoopRV& loop_rv,
const Array<Optional<ExprRV>>& factor_rvs) {
Array<LoopRV> results = ConcreteScheduleNode::Split(loop_rv, factor_rvs);
const Array<Optional<ExprRV>>& factor_rvs,
bool preserve_unit_iters) {
Array<LoopRV> results = ConcreteScheduleNode::Split(loop_rv, factor_rvs, preserve_unit_iters);

std::vector<ObjectRef> inputs;
inputs.reserve(1 + factor_rvs.size());
Expand All @@ -183,7 +184,7 @@ Array<LoopRV> TracedScheduleNode::Split(const LoopRV& loop_rv,
static const InstructionKind& kind = InstructionKind::Get("Split");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
/*inputs=*/inputs,
/*attrs=*/{},
/*attrs=*/{Integer(preserve_unit_iters)},
/*outputs=*/{results.begin(), results.end()}));
return results;
}
Expand Down
5 changes: 3 additions & 2 deletions src/tir/schedule/traced_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ class TracedScheduleNode : public ConcreteScheduleNode {
Array<BlockRV> GetProducers(const BlockRV& block_rv) final;
Array<BlockRV> GetConsumers(const BlockRV& block_rv) final;
/******** Schedule: Transform loops ********/
LoopRV Fuse(const Array<LoopRV>& loop_rvs) final;
Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factor_rvs) final;
LoopRV Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters) final;
Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factor_rvs,
bool preserve_unit_iters) final;
void Reorder(const Array<LoopRV>& ordered_loop_rvs) final;
LoopRV AddUnitLoop(const BlockRV& block_rv) final;
LoopRV AddUnitLoop(const LoopRV& loop_rv) final;
Expand Down
9 changes: 5 additions & 4 deletions tests/python/unittest/test_meta_schedule_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def test_meta_schedule_integration_extract_from_bert_base():
@requires_torch
def test_meta_schedule_integration_extract_from_resnet_with_filter_func():
def filter_func(args) -> bool:
from tvm.te import create_prim_func # pylint: disable=import-outside-toplevel

has_complex_op = False
visited = set()
Expand All @@ -205,16 +206,16 @@ def traverse(t):
if isinstance(t.op, te.PlaceholderOp):
pass
elif isinstance(t.op, te.ComputeOp):
has_complex_op = has_complex_op or any(
[isinstance(e, tir.Reduce) for e in t.op.body]
)
has_complex_op = has_complex_op or any(isinstance(e, tir.Reduce) for e in t.op.body)
for x in t.op.input_tensors:
traverse(x)
visited.add(t.handle.value)

for t in args:
traverse(t)
return has_complex_op
if not has_complex_op:
return None
return create_prim_func(args)

mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
extracted_tasks = ms.extract_task_from_relay(
Expand Down
Loading