Skip to content

Commit

Permalink
[M3c][MetaScheduler] Add ScheduleRule class & PostOrderApply space ge…
Browse files Browse the repository at this point in the history
…nerator. (apache#9761)

* Add ScheduleRule class & PostOrderApply space generator.

Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>

* Fix comments & docs.

* Fix for mypy.

* Retrigger CI.

* remove get_hex_address

Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
  • Loading branch information
7 people authored and baoxinqi committed Dec 27, 2021
1 parent c8d6aae commit 46c64b6
Show file tree
Hide file tree
Showing 20 changed files with 961 additions and 42 deletions.
195 changes: 195 additions & 0 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#ifndef TVM_META_SCHEDULE_SCHEDULE_RULE_H_
#define TVM_META_SCHEDULE_SCHEDULE_RULE_H_

#include <tvm/tir/schedule/schedule.h>

namespace tvm {
namespace meta_schedule {

class TuneContext;

/*! \brief Rules to modify a block in a schedule. */
class ScheduleRuleNode : public runtime::Object {
public:
/*! \brief Virtual destructor. */
virtual ~ScheduleRuleNode() = default;

void VisitAttrs(tvm::AttrVisitor* v) {}

/*!
* \brief Initialize the design space generator with tuning context.
* \param context The tuning context for initialization.
* \note This method is supposed to be called only once before every other method.
*/
virtual void InitializeWithTuneContext(const TuneContext& context) = 0;

/*!
* \brief Apply a schedule rule to the specific block in the given schedule.
* \param sch The schedule to be modified.
* \param block The specific block to apply the schedule rule.
* \return The list of schedules generated by applying the schedule rule.
*/
virtual runtime::Array<tir::Schedule> Apply(const tir::Schedule& sch,
const tir::BlockRV& block) = 0;

static constexpr const char* _type_key = "meta_schedule.ScheduleRule";
TVM_DECLARE_BASE_OBJECT_INFO(ScheduleRuleNode, Object);
};

/*! \brief The schedule rule with customized methods on the python-side. */
class PyScheduleRuleNode : public ScheduleRuleNode {
public:
/*!
* \brief The function type of `InitializeWithTuneContext` method.
* \param context The tuning context for initialization.
*/
using FInitializeWithTuneContext = runtime::TypedPackedFunc<void(const TuneContext&)>;
/*!
* \brief The function type of `Apply` method.
* \param sch The schedule to be modified.
* \param block The specific block to apply the schedule rule.
* \return The list of schedules generated by applying the schedule rule.
*/
using FApply =
runtime::TypedPackedFunc<Array<tir::Schedule>(const tir::Schedule&, const tir::BlockRV&)>;
/*!
* \brief Get the schedule rule as string with name.
* \return The string of the schedule rule.
*/
using FAsString = runtime::TypedPackedFunc<String()>;

/*! \brief The packed function to the `InitializeWithTuneContext` function. */
FInitializeWithTuneContext f_initialize_with_tune_context;
/*! \brief The packed function to the `Apply` function. */
FApply f_apply;
/*! \brief The packed function to the `AsString` function. */
FAsString f_as_string;

void VisitAttrs(tvm::AttrVisitor* v) {
// `f_initialize_with_tune_context` is not visited
// `f_apply` is not visited
// `f_as_string` is not visited
}

void InitializeWithTuneContext(const TuneContext& context) final {
ICHECK(f_initialize_with_tune_context != nullptr)
<< "PyScheduleRule's InitializeWithTuneContext method not implemented!";
this->f_initialize_with_tune_context(context);
}

Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block) final {
ICHECK(f_apply != nullptr) << "PyScheduleRule's Apply method not implemented!";
return this->f_apply(sch, block);
}

static constexpr const char* _type_key = "meta_schedule.PyScheduleRule";
TVM_DECLARE_FINAL_OBJECT_INFO(PyScheduleRuleNode, ScheduleRuleNode);
};

/*!
* \brief Managed reference to ScheduleRuleNode
* \sa ScheduleRuleNode
*/
class ScheduleRule : public runtime::ObjectRef {
public:
/*!
* \brief Create an auto-inline rule that inlines spatial blocks if it satisfies some conditions
* \param into_producer If allows to inline a block into its producer
* \param into_consumer If allows to inline a block into its consumer
* \param into_cache_only If it only allows to inline into a block generated by cache_read/write
* \param inline_const_tensor Always inline constant tensors
* \param disallow_if_then_else Always disallow if-then-else-like constructs
* \param require_ordered Always require the read-to-write mapping to be ordered
* \param require_injective Always require the read-to-write mapping to be injective
* \param disallow_op The operators that are disallowed in auto inline
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule AutoInline(bool into_producer, //
bool into_consumer, //
bool into_cache_only, //
bool inline_const_tensor, //
bool disallow_if_then_else, //
bool require_injective, //
bool require_ordered, //
Optional<Array<String>> disallow_op);
/*!
* \brief Create a mega rule: multi-level tiling with data reuse
* \param structure The tiling structure. Recommended:
* - 'SSRSRS' on CPU
* - 'SSSRRSRS' on GPU
* \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended:
* - NullOpt on CPU
* - [blockIdx.x, vthread.x, threadIdx.x] on GPU
* \param use_tensor_core Whether to apply tensor core wmma intrinsic for the computation
* \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit
* \param vector_load_max_len The length of vector lane in vectorized cooperative fetching.
* NullOpt means disable vectorization
* \param reuse_read Data reuse configuration for reading. NullOpt means no reuse.
* \param reuse_write Data reuse configuration for writing. NullOpt means no reuse.
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule MultiLevelTiling(String structure, //
Optional<Array<String>> tile_binds, //
bool use_tensor_core, //
Optional<Integer> max_innermost_factor, //
Optional<Integer> vector_load_max_len, //
Optional<Map<String, ObjectRef>> reuse_read, //
Optional<Map<String, ObjectRef>> reuse_write);
/*!
* \brief A rule that randomly select a compute-at location for a free block
* \return The rule created
*/
TVM_DLL static ScheduleRule RandomComputeLocation();
/*!
* \brief Mark parallelize, vectorize and unroll to each block correspondingly
* \param max_jobs_per_core The maximum number of jobs to be launched per CPU core. It sets the
* uplimit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`. Use -1 to disable
* parallelism.
* \param max_vectorize_extent The maximum extent to be vectorized.
* It sets the uplimit of the CPU vectorization. Use -1 to disable vectorization.
* \param unroll_max_steps The maximum number of unroll steps to be done.
* Use an empty array to disable unroll.
* \param unroll_explicit Whether to explicitly unroll the loop, or just add a unroll pragma.
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, //
int max_vectorize_extent, //
Array<Integer> unroll_max_steps, //
bool unroll_explicit);
/*!
* \brief Create a schedule rule with customized methods on the python-side.
* \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
* \param f_apply The packed function of `Apply`.
* \param f_as_string The packed function of `AsString`.
* \return The schedule rule created.
*/
TVM_DLL static ScheduleRule PyScheduleRule(
PyScheduleRuleNode::FInitializeWithTuneContext f_initialize_with_tune_context, //
PyScheduleRuleNode::FApply f_apply, //
PyScheduleRuleNode::FAsString f_as_string);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleRule, ObjectRef, ScheduleRuleNode);
};

} // namespace meta_schedule
} // namespace tvm

#endif // TVM_META_SCHEDULE_SCHEDULE_RULE_H_
6 changes: 3 additions & 3 deletions include/tvm/meta_schedule/search_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ class SearchStrategyNode : public runtime::Object {

/*!
* \brief Initialize the search strategy with tuning context.
* \param tune_context The tuning context for initialization.
* \param context The tuning context for initialization.
* \note This method is supposed to be called only once before every other method.
*/
virtual void InitializeWithTuneContext(const TuneContext& tune_context) = 0;
virtual void InitializeWithTuneContext(const TuneContext& context) = 0;

/*!
* \brief Pre-tuning for the search strategy.
Expand Down Expand Up @@ -146,7 +146,7 @@ class PySearchStrategyNode : public SearchStrategyNode {
public:
/*!
* \brief The function type of `InitializeWithTuneContext` method.
* \param tune_context The tuning context for initialization.
* \param context The tuning context for initialization.
*/
using FInitializeWithTuneContext = runtime::TypedPackedFunc<void(const TuneContext&)>;
/*!
Expand Down
16 changes: 11 additions & 5 deletions include/tvm/meta_schedule/space_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ class SpaceGeneratorNode : public Object {

/*!
* \brief Initialize the design space generator with tuning context.
* \param tune_context The tuning context for initialization.
* \param context The tuning context for initialization.
* \note This method is supposed to be called only once before every other method.
*/
virtual void InitializeWithTuneContext(const TuneContext& tune_context) = 0;
virtual void InitializeWithTuneContext(const TuneContext& context) = 0;

/*!
* \brief Generate design spaces given a module.
Expand All @@ -92,7 +92,7 @@ class PySpaceGeneratorNode : public SpaceGeneratorNode {
public:
/*!
* \brief The function type of `InitializeWithTuneContext` method.
* \param tune_context The tuning context for initialization.
* \param context The tuning context for initialization.
*/
using FInitializeWithTuneContext = runtime::TypedPackedFunc<void(const TuneContext&)>;
/*!
Expand All @@ -112,10 +112,10 @@ class PySpaceGeneratorNode : public SpaceGeneratorNode {
// `f_generate_design_space` is not visited
}

void InitializeWithTuneContext(const TuneContext& tune_context) final {
void InitializeWithTuneContext(const TuneContext& context) final {
ICHECK(f_initialize_with_tune_context != nullptr)
<< "PySpaceGenerator's InitializeWithTuneContext !";
f_initialize_with_tune_context(tune_context);
f_initialize_with_tune_context(context);
}

Array<tir::Schedule> GenerateDesignSpace(const IRModule& mod) final {
Expand Down Expand Up @@ -153,6 +153,12 @@ class SpaceGenerator : public ObjectRef {
* \return The design space generator created.
*/
TVM_DLL static SpaceGenerator SpaceGeneratorUnion(Array<SpaceGenerator, void> space_generators);
/*!
* \brief Create a design space generator that generates design spaces by applying schedule rules
* to blocks in post-DFS order.
* \return The design space generator created.
*/
TVM_DLL static SpaceGenerator PostOrderApply();
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(SpaceGenerator, ObjectRef, SpaceGeneratorNode);
};

Expand Down
5 changes: 5 additions & 0 deletions include/tvm/meta_schedule/tune_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class TuneContextNode : public runtime::Object {
Optional<SpaceGenerator> space_generator;
/*! \brief The search strategy. */
Optional<SearchStrategy> search_strategy;
/*! \brief The schedule rules. */
Array<ScheduleRule> sch_rules;
/*! \brief The name of the tuning task. */
Optional<String> task_name;
/*! \brief The random state. */
Expand All @@ -57,6 +59,7 @@ class TuneContextNode : public runtime::Object {
v->Visit("target", &target);
v->Visit("space_generator", &space_generator);
v->Visit("search_strategy", &search_strategy);
v->Visit("sch_rules", &sch_rules);
v->Visit("task_name", &task_name);
v->Visit("rand_state", &rand_state);
v->Visit("num_threads", &num_threads);
Expand All @@ -81,6 +84,7 @@ class TuneContext : public runtime::ObjectRef {
* \param target The target to be tuned for.
* \param space_generator The design space generator.
* \param search_strategy The search strategy.
* \param sch_rules The schedule rules.
* \param task_name The name of the tuning task.
* \param rand_state The random state.
* \param num_threads The number of threads to be used.
Expand All @@ -89,6 +93,7 @@ class TuneContext : public runtime::ObjectRef {
Optional<Target> target, //
Optional<SpaceGenerator> space_generator, //
Optional<SearchStrategy> search_strategy, //
Optional<Array<ScheduleRule>> sch_rules, //
Optional<String> task_name, //
support::LinearCongruentialEngine::TRandState rand_state, //
int num_threads);
Expand Down
6 changes: 6 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@ class ScheduleNode : public runtime::Object {
* \return The corresponding loop sref
*/
virtual StmtSRef GetSRef(const LoopRV& loop_rv) const = 0;
/*!
* \brief Check the existance of a specific BlockRV
* \param block_rv The BlockRV to be looked up
* \return Whether the corresponding block exists
*/
virtual bool HasBlock(const BlockRV& block_rv) const = 0;
/*!
* \brief Get the block/loop sref corresponding to the specific statement
* \param stmt The statement to be looked up
Expand Down
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@
from . import runner
from . import space_generator
from . import search_strategy
from . import schedule_rule
from . import integration
from .tune_context import TuneContext
19 changes: 19 additions & 0 deletions python/tvm/meta_schedule/schedule_rule/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
The tvm.meta_schedule.schedule_rule package.
Meta Schedule schedule rules are used for modification of
blocks in a schedule. See also PostOrderApply.
"""
from .schedule_rule import PyScheduleRule, ScheduleRule
Loading

0 comments on commit 46c64b6

Please sign in to comment.