Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PerfAlign] NRM & SFM on Raspi Aligned #6

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
db90652
Rewrite CollectComputeLocations
MasterJH5574 Dec 29, 2021
b7bbedd
rewrite SampleComputeLocation
MasterJH5574 Dec 29, 2021
cb51763
Fix CollectComputeLocation/SampleComputeLocation
MasterJH5574 Dec 29, 2021
a445316
Rewrite rule RandomComputeLocation; aligned
MasterJH5574 Dec 29, 2021
557bf4b
Fix AddRFactor to avoid unit-loop template
MasterJH5574 Dec 29, 2021
1327f7d
Change the API of SampleComputeLocation
MasterJH5574 Dec 29, 2021
5933e52
Enable inlining
MasterJH5574 Dec 30, 2021
f409d7a
Skip unit loops
MasterJH5574 Dec 30, 2021
50d0f3a
Take reduction block iterators into consideration
MasterJH5574 Dec 30, 2021
14c279f
[TIR] For-kind inheritance in decompose-reduction
MasterJH5574 Dec 30, 2021
70714c8
Complete MutatorComputeLocation with test
MasterJH5574 Dec 30, 2021
7c9aae5
Complete RandomComputeLocation with test
MasterJH5574 Dec 30, 2021
cfb2363
Complete SampleComputeLocation in sampling.cc
MasterJH5574 Dec 30, 2021
1a2883a
Do random-compute-location in AddRFactor
MasterJH5574 Jan 2, 2022
8f19717
Comment out the warning, and disable n_leading_iter
MasterJH5574 Jan 2, 2022
e4f4abf
Use the annotation trick
MasterJH5574 Jan 2, 2022
5d5aa53
Refactor and add docstring for CollectComputeLocation
MasterJH5574 Jan 2, 2022
945dc08
Minor
MasterJH5574 Jan 2, 2022
d6367eb
Test for SampleComputeLocation
MasterJH5574 Jan 2, 2022
93cac6d
Annotate the tiling structure
MasterJH5574 Jan 7, 2022
ee28c30
Skip tiled blocks in RandomComputeLocation
MasterJH5574 Jan 3, 2022
b5fda31
Minor updates
MasterJH5574 Jan 3, 2022
38d7712
Use bool instead of Bool object
MasterJH5574 Jan 3, 2022
9ae015c
Overload another HasAnn
MasterJH5574 Jan 7, 2022
db10385
Add prefix
MasterJH5574 Jan 7, 2022
c03d894
Update shell scripts
MasterJH5574 Jan 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,10 @@ class ScheduleNode : public runtime::Object {
virtual Array<ExprRV> SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor,
Optional<Array<Integer>> decision = NullOpt) = 0;
/*!
* \brief Sample a compute-at location on a BlockRV so that its producer can compute at that loop
* \param block_rv The consumer block to be computed at
* \brief Sample a compute-at location of the given block
* \param block_rv The block whose compute-at location is to be sampled
* \param decision The sampling decision
* \return The sampled loop to be computed at
* \return The sampled loop where the input block is to be computed at
*/
virtual LoopRV SampleComputeLocation(const BlockRV& block_rv,
Optional<Integer> decision = NullOpt) = 0;
Expand Down
7 changes: 7 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1390,6 +1390,13 @@ constexpr const int meta_schedule_cache_type_read = 0;
/*! \sa meta_schedule_cache_type */
constexpr const int meta_schedule_cache_type_write = 1;

/*! \brief Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling */
constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_structure";

/*! \brief Mark the block whose producer needs to be applied by rule Random-Compute-Location */
constexpr const char* meta_schedule_random_compute_producer =
"meta_schedule.random_compute_producer";

/*! \brief Mark auto-parallel setting on the block. */
constexpr const char* meta_schedule_parallel = "meta_schedule.parallel";

Expand Down
5 changes: 2 additions & 3 deletions python/tvm/meta_schedule/mutator/mutate_compute_location.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Mutator that mutates the outcome of SampleComputeLocation"""
"""A mutator that mutates the compute-at location decision of SampleComputeLocation"""
from tvm._ffi.registry import register_object

from .. import _ffi_api
Expand All @@ -23,10 +23,9 @@

@register_object("meta_schedule.MutateComputeLocation")
class MutateComputeLocation(Mutator):
"""Mutator thatmutates the outcome of SampleComputeLocation"""
"""A mutator that mutates the compute-at location decision of SampleComputeLocation"""

def __init__(self) -> None:
"""Mutator that mutates the outcome of SampleComputeLocation"""
self.__init_handle_by_constructor__(
_ffi_api.MutatorMutateComputeLocation, # type: ignore # pylint: disable=no-member
)
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/testing/schedule_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def get(target: Target) -> List[ScheduleRule]:
add_rfactor(target),
multi_level_tiling(target),
parallel_vectorize_unroll(target),
random_compute_location(target),
]
if target.kind.name == "cuda":
return [
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/meta_schedule/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _sch_rules() -> List[ScheduleRule]:
),
M.ParallelizeVectorizeUnroll(
max_jobs_per_core=16,
max_vectorize_extent=32,
max_vectorize_extent=64,
unroll_max_steps=[0, 16, 64, 512],
unroll_explicit=True,
),
Expand All @@ -133,6 +133,7 @@ def _mutator_probs() -> Dict[Mutator, float]:

return {
M.MutateTileSize(): 0.9,
M.MutateComputeLocation(): 0.05,
M.MutateUnroll(): 0.03,
M.MutateParallel(max_jobs_per_core=16): 0.02,
}
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,19 +375,19 @@ def sample_compute_location(
block: BlockRV,
decision: Optional[int] = None,
) -> LoopRV:
"""Sample a compute-at location on a BlockRV so that its producer can compute at that loop
"""Sample a compute-at location of the given block

Parameters
----------
block : BlockRV
The consumer block to be computed at
The block whose compute-at location is to be sampled
decision : Optional[int]
The sampling decision

Returns
-------
result : LoopRV
The sampled loop to be computed at
The sampled loop where the input block is to be computed at
"""
return _ffi_api.ScheduleSampleComputeLocation( # pylint: disable=no-member
self,
Expand Down
93 changes: 43 additions & 50 deletions src/meta_schedule/mutator/mutate_compute_location.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ using tir::Instruction;
using tir::InstructionKind;
using tir::Trace;

/*! \brief Create a Mutator that mutates auto unroll step */
/*! \brief A mutator that mutates the compute-at location decision of SampleComputeLocation */
class MutateComputeLocationNode : public MutatorNode {
public:
/*! \brief JSON representation of the workload */
Expand All @@ -36,7 +36,18 @@ class MutateComputeLocationNode : public MutatorNode {
TVM_DECLARE_FINAL_OBJECT_INFO(MutateComputeLocationNode, MutatorNode);

public:
struct Candidate;
struct Candidate {
/*! \brief The SampleComputeLocation instruction */
Instruction inst;
/*! \brief The candidate compute-at locations */
std::vector<int> locs;

explicit Candidate(Instruction inst, std::vector<int> locs)
: inst(std::move(inst)), locs(std::move(locs)) {}
};

std::vector<Candidate> FindCandidates(const Trace& trace, TRandState* rand_state);

// Inherit from `MutatorNode`
void InitializeWithTuneContext(const TuneContext& context) final {
this->json_mod_ = SaveJSON(context->mod.value());
Expand All @@ -45,60 +56,47 @@ class MutateComputeLocationNode : public MutatorNode {
Optional<Trace> Apply(const Trace& trace, TRandState* rand_state) final;
};

/*! \brief The candidate to be mutated */
struct MutateComputeLocationNode::Candidate {
/*! \brief The SampleComputeLocation instruction */
Instruction inst;
/*! \brief The candidate compute locations */
std::vector<int> locs;

explicit Candidate(Instruction inst, std::vector<int> locs)
: inst(std::move(inst)), locs(std::move(locs)) {}
};

/*!
* \brief Find instruction `SampleComputeLocation`
* \brief Find all appearances of instruction `SampleComputeLocation` whose decision can be mutated
* to at lease one other value
* \param trace The trace from which to find the instructions
* \param workload The workload
* \return All the candidate instructions together with the candidate compute locations
* \return All the candidate instructions together with the candidate compute-at locations
*/
std::vector<MutateComputeLocationNode::Candidate> FindCandidates(const Trace& trace,
const tir::Schedule& sch) {
std::vector<MutateComputeLocationNode::Candidate> MutateComputeLocationNode::FindCandidates(
const Trace& trace, TRandState* rand_state) {
tir::Schedule sch = tir::Schedule::Traced( //
/*mod=*/Downcast<IRModule>(LoadJSON(this->json_mod_)), //
/*rand_state=*/ForkSeed(rand_state), //
/*debug_mode=*/0, //
/*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone);
static InstructionKind inst_sample_compute_location =
InstructionKind::Get("SampleComputeLocation");
std::vector<MutateComputeLocationNode::Candidate> candidates;
auto f_provide_decision = [&](const tir::Instruction& inst,

auto f_provide_decision = [&](const tir::Instruction& inst, //
const Array<ObjectRef>& inputs, //
const Array<ObjectRef>& attrs,
const Array<ObjectRef>& attrs, //
const ObjectRef& decision) -> ObjectRef {
if (inst->kind.same_as(inst_sample_compute_location)) {
// The decision made
int decided = Downcast<Integer>(decision)->value;
// Extract the inputs
// Step 1. Extract the instruction input and the old decision.
ICHECK_EQ(inputs.size(), 1);
tir::BlockRV block_rv = Downcast<tir::BlockRV>(inputs[0]);
tir::StmtSRef block_sref = sch->GetSRef(block_rv);
// Extract locations that can be computed at
Array<tir::StmtSRef> loop_srefs = CollectComputeLocation(sch->state(), block_sref);
std::vector<int> locs{-2, -1};
{
int i = 0;
for (const tir::StmtSRef& loop_sref : loop_srefs) {
int64_t extent = *tir::GetLoopIntExtent(loop_sref);
if (extent != 1 && extent != -1) {
locs.push_back(i);
}
++i;
}
tir::StmtSRef block_sref = sch->GetSRef(Downcast<tir::BlockRV>(inputs[0]));
int old_decision = Downcast<Integer>(decision)->value;
// Step 2. Collect all the compute-at locations.
Array<tir::StmtSRef> location_srefs;
std::vector<int> location_indices;
std::tie(location_srefs, location_indices) = CollectComputeLocation(sch->state(), block_sref);
// Step 3. Remove the old decision.
auto it = std::find(location_indices.begin(), location_indices.end(), old_decision);
if (it != location_indices.end()) {
location_srefs.erase(location_srefs.begin() + (it - location_indices.begin()));
location_indices.erase(it);
}
// Remove `decided`
std::vector<int>::iterator rm = std::find(locs.begin(), locs.end(), decided);
if (rm != locs.end()) {
locs.erase(rm);
ICHECK_EQ(location_srefs.size(), location_indices.size());
// Step 4. Add a new candidate if there are at least one remaining compute-at position.
if (!location_srefs.empty()) {
candidates.emplace_back(inst, std::move(location_indices));
}
// Add the candidate
ICHECK(!locs.empty());
candidates.emplace_back(inst, std::move(locs));
}
return decision;
};
Expand All @@ -107,12 +105,7 @@ std::vector<MutateComputeLocationNode::Candidate> FindCandidates(const Trace& tr
}

Optional<Trace> MutateComputeLocationNode::Apply(const Trace& trace, TRandState* rand_state) {
tir::Schedule sch = tir::Schedule::Traced( //
/*mod=*/Downcast<IRModule>(LoadJSON(this->json_mod_)), //
/*rand_state=*/ForkSeed(rand_state), //
/*debug_mode=*/0,
/*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone);
std::vector<Candidate> candidates = FindCandidates(trace, sch);
std::vector<Candidate> candidates = FindCandidates(trace, rand_state);
if (candidates.empty()) {
return NullOpt;
}
Expand Down
4 changes: 4 additions & 0 deletions src/meta_schedule/schedule_rule/add_rfactor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ Array<tir::Schedule> AddRFactorNode::Apply(const tir::Schedule& sch, const tir::
const tir::BlockRV& block_rf = sch_tmp->RFactor(split_loop, num_spatial_loops);
Array<tir::LoopRV> axes = sch_tmp->GetLoops(block_rf);
ICHECK_GT(axes.size(), num_spatial_loops);

// Annotate that the rfactor block, which is now the producer of the original block, needs to be
// considered by the rule Random-Compute-Location.
sch_tmp->Annotate(block_rv, tir::attr::meta_schedule_random_compute_producer, Bool(true));
res.push_back(sch_tmp);
}

Expand Down
2 changes: 2 additions & 0 deletions src/meta_schedule/schedule_rule/multi_level_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,8 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
if (!NeedsMultiLevelTiling(sch->state(), sch->GetSRef(block_rv))) {
return {sch};
}
sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure, structure);

std::vector<State> states{State(sch, block_rv)};
states = SubRule(std::move(states), [&](State state) { return DetectTensorCore(state); });
states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(state); });
Expand Down
83 changes: 59 additions & 24 deletions src/meta_schedule/schedule_rule/random_compute_location.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,30 +23,39 @@ namespace meta_schedule {

class RandomComputeLocationNode : public ScheduleRuleNode {
public:
bool IsFreeBlock(const tir::Schedule sch, const tir::StmtSRef& block_sref) const {
bool CheckConditions(const tir::Schedule sch, const tir::BlockRV& block_rv) const {
const tir::StmtSRef& block_sref = sch->GetSRef(block_rv);
const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);

// Cond 1. The block is not the root block.
if (block_sref->parent == nullptr) {
return false;
}
if (!tir::IsSubrootBlock(sch->state(), block_sref)) {
// Cond 2. The block should be the direct child block of the root block.
if (GetScopeRoot(sch->state(), block_sref, //
/*require_stage_pipeline=*/false, //
/*require_subtree_compact_dataflow=*/false)
->parent != nullptr) {
return false;
}
tir::ScheduleState state = sch->state();
if (!tir::IsCompleteBlock(state, block_sref,
tir::GetScopeRoot(state, block_sref, false, false))) {
// Cond 3 & 4. The block has at least one outer loop, and the outermost loop has only one child
// block.
Array<tir::StmtSRef> loop_srefs = tir::GetLoops(block_sref);
if (loop_srefs.empty()) {
return false;
}
Array<tir::StmtSRef> loop_srefs = tir::GetLoops(block_sref);
for (const tir::StmtSRef& loop_sref : loop_srefs) {
if (!tir::HasSingleChild(loop_sref)) {
return false;
}
if (tir::GetChildBlockSRefOnSRefTree(sch->state(), loop_srefs[0]).size() > 1) {
return false;
}
Array<PrimExpr> binds = tir::GetBlockRealize(state, block_sref)->iter_values;
for (const PrimExpr& bind : binds) {
if (!bind->IsInstance<IntImmNode>() && !bind->IsInstance<tir::VarNode>()) {
return false;
}
// Cond 5. The block is not tiled. We check this condition by examine the block's annotation.
if (tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_tiling_structure).defined()) {
return false;
}
// Cond 6. The block has at lease one consumer.
if (tir::GetConsumers(sch->state(), sch->GetSRef(block_rv)).empty()) {
return false;
}

return true;
}

Expand All @@ -55,18 +64,44 @@ class RandomComputeLocationNode : public ScheduleRuleNode {

// Inherited from ScheduleRuleNode
Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final {
tir::StmtSRef block_sref = sch->GetSRef(block_rv);
if (!IsFreeBlock(sch, block_sref)) {
if (!CheckConditions(sch, block_rv)) {
return {sch};
}
Array<tir::BlockRV> consumers = sch->GetConsumers(block_rv);
if (consumers.size() != 1) {
return {sch};

// Step 1. If the producer of the input block needs a random compute-at location (specified by
// the annotation), we colect the producer first, and transform the producer block later.
// - The reason we collect the producer before transforming the input block is that, if the
// decision of Sample-Compute-Location is "compute-inline" for the input block, we can no longer
// access the input block. Hence we collect its producer ahead of time.
// - Note that only single producer is allowed in this case.
Array<tir::BlockRV> producers{nullptr};
if (tir::HasAnn(sch->GetSRef(block_rv), tir::attr::meta_schedule_random_compute_producer,
true)) {
producers = sch->GetProducers(block_rv);
sch->Unannotate(block_rv, tir::attr::meta_schedule_random_compute_producer);
ICHECK_EQ(producers.size(), 1);
}
tir::BlockRV consumer = consumers[0];
// Try to compute `block_rv` at `consumer`

// Step 2. Transform the input block.
tir::Schedule res = RandomlyComputeAt(sch, block_rv);

// Step 3. Transform the producer block if compute-location sampling is needed.
if (producers.defined()) {
res = RandomlyComputeAt(res, producers[0]);
}

return {res};
}

/*!
* \brief Keep sampling a compute-at location for the input block until success.
* \param sch The TIR schedule
* \param block_rv The block whose compute-at location is to be sampled
* \return The TIR schedule after transformation
*/
tir::Schedule RandomlyComputeAt(const tir::Schedule& sch, const tir::BlockRV& block_rv) {
for (;;) {
tir::LoopRV compute_at_loc = sch->SampleComputeLocation(consumer);
tir::LoopRV compute_at_loc = sch->SampleComputeLocation(block_rv);
try {
sch->ComputeAt(block_rv, compute_at_loc, true);
} catch (const dmlc::Error& e) {
Expand All @@ -79,7 +114,7 @@ class RandomComputeLocationNode : public ScheduleRuleNode {
}
break;
}
return {sch};
return sch;
}

public:
Expand Down
Loading