Skip to content

Commit

Permalink
[TIR] Add preserve-unit-iters
Browse files Browse the repository at this point in the history
Follow-up of #11578, which enforces structural stability of TIR by
avoiding over-simplification in affine analysis. On the other hand, it
is possible that over-simplification could be desirable behavior.
Therefore, following the precedent of `preserve-unit-loops` in
`Compute-At`, this PR introduces `preserve-unit-iters` in block binding
for cases where users don't need structural stability (which is
admittedly rare).

This PR does not affect any existing functionalities.
  • Loading branch information
junrushao committed Jun 16, 2022
1 parent ec91864 commit 0529a35
Show file tree
Hide file tree
Showing 15 changed files with 202 additions and 162 deletions.
7 changes: 5 additions & 2 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,19 +277,22 @@ class ScheduleNode : public runtime::Object {
* 3) All loops must start with 0.
* 4) The domain of a loop to be fused cannot depend on another loop to be fused.
* \param loop_rvs The loops to be fused
* \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
* \return The new loop after fusion
*/
virtual LoopRV Fuse(const Array<LoopRV>& loop_rvs) = 0;
virtual LoopRV Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters = true) = 0;
/*!
* \brief Split a loop into a list of consecutive loops. It requires:
* 1) The loop can't have annotation or thread binding.
* 2) The loop must start with 0.
* \param loop_rv The loop to be split
* \param factors The positive tiling factors, and at most one of which is `NullOpt`, which means
* that factor is inferred.
* \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
* \return The new loops after split
*/
virtual Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors) = 0;
virtual Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors,
bool preserve_unit_iters = true) = 0;
/*!
* \brief Reorder a list of loops. It doesn't require the loops to be consecutive.
* It requires:
Expand Down
21 changes: 18 additions & 3 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,11 @@ def get_consumers(self, block: Union[BlockRV, str]) -> List[BlockRV]:

########## Schedule: Transform loops ##########
@type_checked
def fuse(self, *loops: List[LoopRV]) -> LoopRV:
def fuse(
self,
*loops: List[LoopRV],
preserve_unit_iters: bool = True,
) -> LoopRV:
"""Fuse a list of consecutive loops into one. It requires:
1) The loops can't have annotations or thread bindings.
2) The (i+1)-th loop must be the only child of the i-th loop.
Expand Down Expand Up @@ -553,13 +557,14 @@ def after_fuse(a: T.handle, b: T.handle) -> None:
B[vi, vj] = A[vi, vj] * 2.0
"""
return _ffi_api.ScheduleFuse(self, loops) # type: ignore # pylint: disable=no-member
return _ffi_api.ScheduleFuse(self, loops, preserve_unit_iters) # type: ignore # pylint: disable=no-member

@type_checked
def split(
self,
loop: LoopRV,
factors: List[Union[int, ExprRV, None]],
preserve_unit_iters: bool = True,
) -> List[LoopRV]:
"""Split a loop into a list of consecutive loops. It requires:
1) The loop can't have annotation or thread binding.
Expand All @@ -580,6 +585,9 @@ def split(
- ExprRV
- Positive constant integers
preserve_unit_iters : bool
Whether or not to preserve unit iterators in block bindings
Returns
-------
split_loops : List[LoopRV]
Expand Down Expand Up @@ -628,7 +636,14 @@ def after_split(a: T.handle, b: T.handle) -> None:
"""
# it will be checked later in C++ implementation
# that there is at most one None in `factors`
return list(_ffi_api.ScheduleSplit(self, loop, factors)) # type: ignore # pylint: disable=no-member
return list(
_ffi_api.ScheduleSplit( # type: ignore # pylint: disable=no-member
self,
loop,
factors,
preserve_unit_iters,
)
)

@type_checked
def reorder(self, *ordered_loops: List[LoopRV]) -> None:
Expand Down
9 changes: 5 additions & 4 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -333,19 +333,20 @@ Array<BlockRV> ConcreteScheduleNode::GetConsumers(const BlockRV& block_rv) {

/******** Schedule: Transform loops ********/

LoopRV ConcreteScheduleNode::Fuse(const Array<LoopRV>& loop_rvs) {
LoopRV ConcreteScheduleNode::Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters) {
CHECK(!loop_rvs.empty()) << "ValueError: 'fuse' requires at least 1 loop(s)";
Array<StmtSRef> loop_srefs = this->GetSRefs(loop_rvs);
StmtSRef result{nullptr};
TVM_TIR_SCHEDULE_BEGIN();
result = tir::Fuse(state_, loop_srefs);
result = tir::Fuse(state_, loop_srefs, preserve_unit_iters);
TVM_TIR_SCHEDULE_END("fuse", this->error_render_level_);
this->state_->DebugVerify();
return CreateRV<LoopRV>(result);
}

Array<LoopRV> ConcreteScheduleNode::Split(const LoopRV& loop_rv,
const Array<Optional<ExprRV>>& factor_rvs) {
const Array<Optional<ExprRV>>& factor_rvs,
bool preserve_unit_iters) {
class NotSingleInferFactorError : public ScheduleError {
public:
explicit NotSingleInferFactorError(IRModule mod) : mod_(mod) {}
Expand Down Expand Up @@ -440,7 +441,7 @@ Array<LoopRV> ConcreteScheduleNode::Split(const LoopRV& loop_rv,
} else if (!this->analyzer_->CanProve(tot_length >= loop->extent)) {
throw WrongFactorProductError(state_->mod, GetRef<For>(loop));
}
results = tir::Split(state_, loop_sref, factors);
results = tir::Split(state_, loop_sref, factors, preserve_unit_iters);
TVM_TIR_SCHEDULE_END("split", this->error_render_level_);
this->state_->DebugVerify();
return CreateRV<LoopRV>(results);
Expand Down
5 changes: 3 additions & 2 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,9 @@ class ConcreteScheduleNode : public ScheduleNode {
Array<BlockRV> GetProducers(const BlockRV& block_rv) override;
Array<BlockRV> GetConsumers(const BlockRV& block_rv) override;
/******** Schedule: Transform loops ********/
LoopRV Fuse(const Array<LoopRV>& loop_rvs) override;
Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors) override;
LoopRV Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters) override;
Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors,
bool preserve_unit_iters) override;
void Reorder(const Array<LoopRV>& ordered_loop_rvs) override;
LoopRV AddUnitLoop(const BlockRV& block_rv) override;
LoopRV AddUnitLoop(const LoopRV& loop_rv) override;
Expand Down
7 changes: 5 additions & 2 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,11 @@ Array<StmtSRef> GetConsumers(const ScheduleState& self, const StmtSRef& block_sr
* \param self The state of the schedule
* \param loop_sref The sref to the loop being split
* \param factors The splitting factors
* \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
* \return An array of srefs to the loops after splitting
*/
TVM_DLL Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
const Array<PrimExpr>& factors);
const Array<PrimExpr>& factors, bool preserve_unit_iters);
/*!
* \brief Fuse a list of consecutive loops into one. It requires:
* 1) The loops can't have annotations or thread bindings.
Expand All @@ -168,9 +169,11 @@ TVM_DLL Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
* 4) The domain of a loop to be fused cannot depend on another loop to be fused.
* \param self The state of the schedule
* \param loop_srefs An array of srefs to the loops to be fused
* \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
* \return The sref to the fused loop
*/
TVM_DLL StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs);
TVM_DLL StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs,
bool preserve_unit_loops);
/*!
* \brief Reorder a list of loops. It doesn't require the loops to be consecutive.
* It requires:
Expand Down
64 changes: 39 additions & 25 deletions src/tir/schedule/primitive/loop_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,21 @@ class SubstituteVarAndCollectOpaqueBlock : public StmtExprMutator {
/*! \brief Simplify the binding of block realize and update the opaque block reuse mapping */
class IterMapSimplifyBlockBinding : public StmtExprMutator {
public:
explicit IterMapSimplifyBlockBinding(MapNode* opaque_blocks, Map<Var, Range> loop_var2extent)
: opaque_blocks_(opaque_blocks), loop_var2extent_(loop_var2extent) {}

static For SimplifyBindings(Stmt stmt, const Array<StmtSRef>& loop_srefs,
MapNode* opaque_blocks) {
explicit IterMapSimplifyBlockBinding(MapNode* opaque_blocks, Map<Var, Range> loop_var2extent,
bool preserve_unit_iters)
: opaque_blocks_(opaque_blocks),
loop_var2extent_(loop_var2extent),
preserve_unit_iters_(preserve_unit_iters) {}

static For SimplifyBindings(Stmt stmt, const Array<StmtSRef>& loop_srefs, MapNode* opaque_blocks,
bool preserve_unit_iters) {
Map<Var, Range> loop_var2extent;
for (const StmtSRef& sref : loop_srefs) {
const ForNode* loop = TVM_SREF_TO_FOR(loop, sref);
loop_var2extent.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
}
return Downcast<For>(
IterMapSimplifyBlockBinding(opaque_blocks, std::move(loop_var2extent))(std::move(stmt)));
return Downcast<For>(IterMapSimplifyBlockBinding(opaque_blocks, std::move(loop_var2extent),
preserve_unit_iters)(std::move(stmt)));
}

private:
Expand All @@ -112,11 +115,12 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator {
}
return std::move(realize);
}
Array<PrimExpr> v = arith::IterMapSimplify(/*indices=*/op->iter_values,
/*input_iters=*/loop_var2extent_,
/*input_pred=*/op->predicate,
/*check_level=*/arith::IterMapLevel::Surjective,
/*simplify_trivial_iterators=*/false);
Array<PrimExpr> v =
arith::IterMapSimplify(/*indices=*/op->iter_values,
/*input_iters=*/loop_var2extent_,
/*input_pred=*/op->predicate,
/*check_level=*/arith::IterMapLevel::Surjective,
/*simplify_trivial_iterators=*/!preserve_unit_iters_);
if (v.same_as(op->iter_values)) {
return GetRef<Stmt>(op);
} else {
Expand All @@ -130,6 +134,8 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator {
MapNode* opaque_blocks_;
/*! \brief The range of loops */
Map<Var, Range> loop_var2extent_;
/*! \brief Whether or not to simplify unit iterators */
bool preserve_unit_iters_;
};

class BlockPropertyError : public ScheduleError {
Expand Down Expand Up @@ -376,8 +382,8 @@ class DependentLoopError : public ScheduleError {
PrimitiveKind kind_;
};

Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
const Array<PrimExpr>& factors) {
Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref, const Array<PrimExpr>& factors,
bool preserve_unit_iters) {
// Invariance
// - The total repeat number has not changed for each direct child block with updating predicate.
// - The execution order has not changed. (The block executes with the same args and the same
Expand Down Expand Up @@ -432,7 +438,8 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
new_stmt = For(new_loop_vars[i], 0, factors[i], ForKind::kSerial, new_stmt);
}
new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings(std::move(new_stmt), GetLoops(loop_sref),
opaque_block_reuse.CopyOnWrite());
opaque_block_reuse.CopyOnWrite(),
preserve_unit_iters);
self->Replace(loop_sref, new_stmt, opaque_block_reuse);
Array<StmtSRef> result_srefs;
result_srefs.reserve(n);
Expand All @@ -444,7 +451,7 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
return result_srefs;
}

StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs) {
StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs, bool preserve_unit_iters) {
// Invariance
// - The total repeat number has not changed for each direct child block.
// - The execution order has not changed. (The block executes with the same
Expand Down Expand Up @@ -527,7 +534,8 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs) {
fused_extent = analyzer.Simplify(fused_extent);
new_stmt = For(fused_var, 0, fused_extent, ForKind::kSerial, new_stmt);
new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings(
std::move(new_stmt), GetLoops(loop_srefs[0]), opaque_block_reuse.CopyOnWrite());
std::move(new_stmt), GetLoops(loop_srefs[0]), opaque_block_reuse.CopyOnWrite(),
preserve_unit_iters);
self->Replace(loop_srefs[0], new_stmt, opaque_block_reuse);
return self->stmt2ref.at(new_stmt.get());
}
Expand Down Expand Up @@ -755,7 +763,7 @@ struct SplitTraits : public UnpackedInstTraits<SplitTraits> {

private:
static constexpr size_t kNumInputs = 2;
static constexpr size_t kNumAttrs = 0;
static constexpr size_t kNumAttrs = 1;
static constexpr size_t kNumDecisions = 0;

template <size_t delta>
Expand All @@ -770,14 +778,17 @@ struct SplitTraits : public UnpackedInstTraits<SplitTraits> {
}

static Array<LoopRV> UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv,
Array<Optional<ExprRV>> factors) {
return sch->Split(loop_rv, factors);
Array<Optional<ExprRV>> factors,
Bool preserve_unit_iters) {
return sch->Split(loop_rv, factors, preserve_unit_iters.operator bool());
}

static String UnpackedAsPython(Array<String> outputs, String loop_rv, Array<ObjectRef> factors) {
static String UnpackedAsPython(Array<String> outputs, String loop_rv, Array<ObjectRef> factors,
Bool preserve_unit_iters) {
PythonAPICall py("split");
py.Input("loop", loop_rv);
py.Input("factors", factors);
py.Input("preserve_unit_iters", preserve_unit_iters.operator bool());
py.OutputList(outputs);
return py.Str();
}
Expand All @@ -792,7 +803,7 @@ struct FuseTraits : public UnpackedInstTraits<FuseTraits> {

private:
static constexpr size_t kNumInputs = 1;
static constexpr size_t kNumAttrs = 0;
static constexpr size_t kNumAttrs = 1;
static constexpr size_t kNumDecisions = 0;

template <size_t delta>
Expand All @@ -801,15 +812,18 @@ struct FuseTraits : public UnpackedInstTraits<FuseTraits> {
setter(delta, inputs);
}

static LoopRV UnpackedApplyToSchedule(Schedule sch, Array<LoopRV> loop_rvs) {
return sch->Fuse(loop_rvs);
static LoopRV UnpackedApplyToSchedule(Schedule sch, Array<LoopRV> loop_rvs,
Bool preserve_unit_iters) {
return sch->Fuse(loop_rvs, preserve_unit_iters.operator bool());
}

static String UnpackedAsPython(Array<String> outputs, Array<String> loop_rvs) {
static String UnpackedAsPython(Array<String> outputs, Array<String> loop_rvs,
Bool preserve_unit_iters) {
PythonAPICall py("fuse");
for (const String& loop_rv : loop_rvs) {
py.Input("", loop_rv);
}
py.Input("preserve_unit_iters", preserve_unit_iters.operator bool());
py.SingleOutput(outputs);
return py.Str();
}
Expand Down
13 changes: 7 additions & 6 deletions src/tir/schedule/traced_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,20 +158,21 @@ Array<BlockRV> TracedScheduleNode::GetConsumers(const BlockRV& block_rv) {

/******** Schedule: Transform loops ********/

LoopRV TracedScheduleNode::Fuse(const Array<LoopRV>& loop_rvs) {
LoopRV result = ConcreteScheduleNode::Fuse(loop_rvs);
LoopRV TracedScheduleNode::Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_loops) {
LoopRV result = ConcreteScheduleNode::Fuse(loop_rvs, preserve_unit_loops);

static const InstructionKind& kind = InstructionKind::Get("Fuse");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
/*inputs=*/{loop_rvs.begin(), loop_rvs.end()},
/*attrs=*/{},
/*attrs=*/{Integer(preserve_unit_loops)},
/*outputs=*/{result}));
return result;
}

Array<LoopRV> TracedScheduleNode::Split(const LoopRV& loop_rv,
const Array<Optional<ExprRV>>& factor_rvs) {
Array<LoopRV> results = ConcreteScheduleNode::Split(loop_rv, factor_rvs);
const Array<Optional<ExprRV>>& factor_rvs,
bool preserve_unit_iters) {
Array<LoopRV> results = ConcreteScheduleNode::Split(loop_rv, factor_rvs, preserve_unit_iters);

std::vector<ObjectRef> inputs;
inputs.reserve(1 + factor_rvs.size());
Expand All @@ -183,7 +184,7 @@ Array<LoopRV> TracedScheduleNode::Split(const LoopRV& loop_rv,
static const InstructionKind& kind = InstructionKind::Get("Split");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
/*inputs=*/inputs,
/*attrs=*/{},
/*attrs=*/{Integer(preserve_unit_iters)},
/*outputs=*/{results.begin(), results.end()}));
return results;
}
Expand Down
5 changes: 3 additions & 2 deletions src/tir/schedule/traced_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ class TracedScheduleNode : public ConcreteScheduleNode {
Array<BlockRV> GetProducers(const BlockRV& block_rv) final;
Array<BlockRV> GetConsumers(const BlockRV& block_rv) final;
/******** Schedule: Transform loops ********/
LoopRV Fuse(const Array<LoopRV>& loop_rvs) final;
Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factor_rvs) final;
LoopRV Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters) final;
Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factor_rvs,
bool preserve_unit_iters) final;
void Reorder(const Array<LoopRV>& ordered_loop_rvs) final;
LoopRV AddUnitLoop(const BlockRV& block_rv) final;
LoopRV AddUnitLoop(const LoopRV& loop_rv) final;
Expand Down
9 changes: 5 additions & 4 deletions tests/python/unittest/test_meta_schedule_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def test_meta_schedule_integration_extract_from_bert_base():
@requires_torch
def test_meta_schedule_integration_extract_from_resnet_with_filter_func():
def filter_func(args) -> bool:
from tvm.te import create_prim_func # pylint: disable=import-outside-toplevel

has_complex_op = False
visited = set()
Expand All @@ -205,16 +206,16 @@ def traverse(t):
if isinstance(t.op, te.PlaceholderOp):
pass
elif isinstance(t.op, te.ComputeOp):
has_complex_op = has_complex_op or any(
[isinstance(e, tir.Reduce) for e in t.op.body]
)
has_complex_op = has_complex_op or any(isinstance(e, tir.Reduce) for e in t.op.body)
for x in t.op.input_tensors:
traverse(x)
visited.add(t.handle.value)

for t in args:
traverse(t)
return has_complex_op
if not has_complex_op:
return None
return create_prim_func(args)

mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
extracted_tasks = ms.extract_task_from_relay(
Expand Down
Loading

0 comments on commit 0529a35

Please sign in to comment.