-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TECompiler] Decouple TE compute and schedule lowering in ScheduleBuilder #10561
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
6e68fd9
Decouple TE compute and schedule lowering in ScheduleBuilder
masahi eb1bc7e
fixed merge conflict
masahi 4cd3a16
removed create_schedule stuff
masahi 0c6d4a6
add public, fix include path convention
masahi be6c258
Forgot visiting arg in ScheduleBuilder CallNode vsit
masahi 6f01901
fixed anchor impl selection
masahi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,11 +28,13 @@ | |
#include <tvm/relay/expr_functor.h> | ||
#include <tvm/relay/op.h> | ||
#include <tvm/relay/op_attr_types.h> | ||
#include <tvm/relay/op_strategy.h> | ||
#include <tvm/runtime/device_api.h> | ||
#include <tvm/runtime/registry.h> | ||
#include <tvm/te/operation.h> | ||
#include <tvm/te/schedule.h> | ||
#include <tvm/te/schedule_pass.h> | ||
#include <tvm/tir/function.h> | ||
#include <tvm/topi/tags.h> | ||
|
||
#include <functional> | ||
|
@@ -114,100 +116,40 @@ Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) { | |
return res; | ||
} | ||
|
||
// Construct a schedule for a given Relay primitive function and target. | ||
class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>> { | ||
// Lowers Relay primitive Function to TE Compute | ||
class LowerToTECompute : public backend::MemoizedExprTranslator<Array<te::Tensor>> { | ||
public: | ||
explicit ScheduleBuilder(Target target, bool create_schedule = true) | ||
: target_(target), | ||
device_copy_op_(Op::Get("device_copy")), | ||
create_schedule_(create_schedule) { | ||
// Whether to use auto_scheduler schedule. | ||
use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); | ||
use_meta_schedule_ = backend::IsMetaScheduleEnabled(); | ||
} | ||
explicit LowerToTECompute(Target target) | ||
: target_(target), device_copy_op_(Op::Get("device_copy")) {} | ||
|
||
CachedFunc Create(const Function& relay_func, std::function<std::string(std::string)> renamer) { | ||
Array<tvm::te::Tensor> fn_inputs; | ||
Array<te::Tensor> Lower(const Function& relay_func, | ||
std::function<std::string(std::string)> renamer) { | ||
for (Var param : relay_func->params) { | ||
Array<tvm::te::Tensor> inputs; | ||
for (const auto& ttype : FlattenTupleType(param->checked_type())) { | ||
tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); | ||
fn_inputs.push_back(tensor); | ||
inputs.push_back(tensor); | ||
fn_inputs_.push_back(tensor); | ||
} | ||
memo_[param] = inputs; | ||
} | ||
readable_name_stream_ << "fused"; | ||
auto outputs = this->VisitExpr(relay_func->body); | ||
auto candidate_name = readable_name_stream_.str(); | ||
|
||
Array<te::Tensor> outputs = this->VisitExpr(relay_func->body); | ||
|
||
candidate_name_ = readable_name_stream_.str(); | ||
|
||
constexpr static size_t kMaxFuncNameLength = 80; | ||
// WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME | ||
// whenever the value of kMaxFuncNameLength changes | ||
if (candidate_name.size() > kMaxFuncNameLength) { | ||
if (candidate_name_.size() > kMaxFuncNameLength) { | ||
std::stringstream truncated_name; | ||
truncated_name << candidate_name.substr(0, kMaxFuncNameLength); | ||
truncated_name << "_" << std::hex << std::hash<std::string>{}(candidate_name) << "_"; | ||
candidate_name = truncated_name.str(); | ||
} | ||
|
||
// TODO(mbs): This should be the definitive global by which the PrimFunc is known and | ||
// no other GlobalVar ctors should appear inside the lowering machinery. | ||
auto prim_fn_var = GlobalVar(renamer(candidate_name)); | ||
prim_fn_var->checked_type_ = relay_func->checked_type(); | ||
|
||
// Fusion over tupled results may leave identity relationships | ||
// between inputs and outputs, and those should not be scheduled. | ||
// Hence schedule only non PlaceholderOp outputs. | ||
tvm::Array<te::Tensor> tensor_outs; | ||
for (const auto& tensor : outputs) { | ||
if (!tensor->op.as<te::PlaceholderOpNode>()) { | ||
tensor_outs.push_back(tensor); | ||
} | ||
} | ||
|
||
te::Schedule schedule{nullptr}; | ||
tir::PrimFunc prim_func{nullptr}; | ||
// No need to register schedule for device copy op. | ||
if (anchor_attrs_.as<DeviceCopyAttrs>() == nullptr && create_schedule_) { | ||
if (use_auto_scheduler_) { | ||
const auto* fauto_schedule = | ||
runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute"); | ||
ICHECK(fauto_schedule != nullptr) | ||
<< "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered"; | ||
ObjectRef obj = (*fauto_schedule)(prim_fn_var->name_hint, tensor_outs); | ||
if (obj.defined()) { | ||
schedule = Downcast<te::Schedule>(obj); | ||
} | ||
} | ||
if (use_meta_schedule_) { | ||
prim_func = tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs)); | ||
Optional<ObjectRef> opt_mod_or_base_func = | ||
meta_schedule::MetaScheduleContext::QueryInsideWithScope( | ||
prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), target_, | ||
Array<IRModule>{IRModule({{prim_fn_var, prim_func}})}); | ||
if (const auto* result = opt_mod_or_base_func.as<tir::PrimFuncNode>()) { | ||
prim_func = GetRef<tir::PrimFunc>(result); | ||
} else { | ||
prim_func = tir::PrimFunc(nullptr); | ||
} | ||
} | ||
|
||
// Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule. | ||
if (!schedule.defined() && !prim_func.defined()) { | ||
ICHECK(anchor_implementation_.defined()); | ||
schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_); | ||
} | ||
if (schedule.defined()) { | ||
for (const auto& scalar : scalars_) { | ||
if (schedule->Contain(scalar)) { | ||
schedule[scalar].compute_inline(); | ||
} | ||
} | ||
} | ||
truncated_name << candidate_name_.substr(0, kMaxFuncNameLength); | ||
truncated_name << "_" << std::hex << std::hash<std::string>{}(candidate_name_) << "_"; | ||
candidate_name_ = truncated_name.str(); | ||
} | ||
|
||
return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, prim_func, {}, | ||
IRModule(Map<GlobalVar, BaseFunc>({})), constant_tensors_); | ||
return outputs; | ||
} | ||
|
||
Array<te::Tensor> VisitExpr_(const VarNode* op) final { | ||
|
@@ -254,7 +196,6 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor> | |
} | ||
|
||
Array<te::Tensor> VisitExpr_(const CallNode* call_node) final { | ||
static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern"); | ||
static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call"); | ||
ICHECK(flower_call) << "relay.backend.lower_call is not registered."; | ||
|
||
|
@@ -278,28 +219,13 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor> | |
ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops"; | ||
Op op = Downcast<Op>(call_node->op); | ||
|
||
Array<te::Tensor> outputs; | ||
OpImplementation impl; | ||
// TODO(mbs): device_copy cleanup | ||
ICHECK_NE(op, device_copy_op_) << "device_copy cannot be lowered"; | ||
|
||
LoweredOutput lowered_out = (*flower_call)(GetRef<Call>(call_node), inputs, target_); | ||
outputs = lowered_out->outputs; | ||
impl = lowered_out->implementation; | ||
|
||
if (create_schedule_) { | ||
int op_pattern = fpattern[op]; | ||
if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { | ||
ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) | ||
<< "Cannot apply TOPI schedule to a primitive function with two complicated ops" | ||
<< " anchor=" << anchor_op_ << " current=" << op; | ||
} | ||
if (op_pattern >= anchor_op_pattern_) { | ||
anchor_op_ = op; | ||
anchor_attrs_ = call_node->attrs; | ||
anchor_op_pattern_ = op_pattern; | ||
anchor_implementation_ = impl; | ||
} | ||
} | ||
Array<te::Tensor> outputs = lowered_out->outputs; | ||
op_implementations_[op.operator->()] = lowered_out->implementation; | ||
|
||
if (outputs.size() != 1) { | ||
const auto* tuple_type = call_node->checked_type().as<TupleTypeNode>(); | ||
ICHECK(tuple_type) << "Expected output to be a tuple type " | ||
|
@@ -308,8 +234,6 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor> | |
ICHECK_EQ(tuple_type->fields.size(), outputs.size()); | ||
} | ||
|
||
// TODO(mbs): device_copy cleanup | ||
ICHECK_NE(op, device_copy_op_) << "device_copy cannot be lowered"; | ||
readable_name_stream_ << '_' << op->name; | ||
return outputs; | ||
} | ||
|
@@ -347,26 +271,131 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor> | |
return {tuple[op->index]}; | ||
} | ||
|
||
public: | ||
// Additional outputs | ||
Array<tvm::te::Tensor> fn_inputs_; | ||
Array<te::Operation> scalars_; | ||
std::unordered_map<const ConstantNode*, te::Tensor> constant_tensors_; | ||
std::unordered_map<const OpNode*, OpImplementation> op_implementations_; | ||
std::string candidate_name_; | ||
|
||
private: | ||
tvm::Target target_; | ||
Op anchor_op_; | ||
Attrs anchor_attrs_; | ||
int anchor_op_pattern_{0}; | ||
OpImplementation anchor_implementation_; | ||
std::ostringstream readable_name_stream_; | ||
Array<te::Operation> scalars_; | ||
std::unordered_map<const ConstantNode*, te::Tensor> constant_tensors_; | ||
bool use_auto_scheduler_; | ||
bool use_meta_schedule_; | ||
// Index of the global constants | ||
static int const_index; | ||
// Cache device copy op for equivalence checking to reduce registry lookup | ||
// overhead for each invocation of call node when retrieving schedules. | ||
const Op& device_copy_op_; | ||
bool create_schedule_; | ||
// Index of the global constants | ||
static int const_index; | ||
}; | ||
|
||
int ScheduleBuilder::const_index = 0; | ||
int LowerToTECompute::const_index = 0; | ||
|
||
// Construct a schedule for a given Relay primitive function and target. | ||
class ScheduleBuilder : public ExprVisitor { | ||
public: | ||
explicit ScheduleBuilder(Target target) : target_(target) { | ||
// Whether to use auto_scheduler schedule. | ||
use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); | ||
} | ||
|
||
CachedFunc Create(const Function& relay_func, std::function<std::string(std::string)> renamer) { | ||
LowerToTECompute lower_te_compute(target_); | ||
Array<te::Tensor> outputs = lower_te_compute.Lower(relay_func, renamer); | ||
Array<te::Tensor> fn_inputs = lower_te_compute.fn_inputs_; | ||
VisitExpr(relay_func->body); | ||
|
||
// TODO(mbs): This should be the definitive global by which the PrimFunc is known and | ||
// no other GlobalVar ctors should appear inside the lowering machinery. | ||
auto prim_fn_var = GlobalVar(renamer(lower_te_compute.candidate_name_)); | ||
prim_fn_var->checked_type_ = relay_func->checked_type(); | ||
|
||
// Fusion over tupled results may leave identity relationships | ||
// between inputs and outputs, and those should not be scheduled. | ||
// Hence schedule only non PlaceholderOp outputs. | ||
tvm::Array<te::Tensor> tensor_outs; | ||
for (const auto& tensor : outputs) { | ||
if (!tensor->op.as<te::PlaceholderOpNode>()) { | ||
tensor_outs.push_back(tensor); | ||
} | ||
} | ||
|
||
te::Schedule schedule{nullptr}; | ||
tir::PrimFunc prim_func{nullptr}; | ||
// No need to register schedule for device copy op. | ||
if (anchor_attrs_.as<DeviceCopyAttrs>() == nullptr) { | ||
if (use_auto_scheduler_) { | ||
const auto* fauto_schedule = | ||
runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute"); | ||
ICHECK(fauto_schedule != nullptr) | ||
<< "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered"; | ||
ObjectRef obj = (*fauto_schedule)(prim_fn_var->name_hint, tensor_outs); | ||
if (obj.defined()) { | ||
schedule = Downcast<te::Schedule>(obj); | ||
} | ||
} | ||
if (backend::IsMetaScheduleEnabled()) { | ||
prim_func = tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs)); | ||
Optional<ObjectRef> opt_mod_or_base_func = | ||
meta_schedule::MetaScheduleContext::QueryInsideWithScope( | ||
prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), target_, | ||
Array<IRModule>{IRModule({{prim_fn_var, prim_func}})}); | ||
if (const auto* result = opt_mod_or_base_func.as<tir::PrimFuncNode>()) { | ||
prim_func = GetRef<tir::PrimFunc>(result); | ||
} else { | ||
prim_func = tir::PrimFunc(nullptr); | ||
} | ||
} | ||
|
||
// Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule. | ||
if (!schedule.defined() && !prim_func.defined()) { | ||
auto anchor_impl = lower_te_compute.op_implementations_.find(anchor_op_.operator->()); | ||
ICHECK(anchor_impl != lower_te_compute.op_implementations_.end()); | ||
schedule = anchor_impl->second.Schedule(anchor_attrs_, tensor_outs, target_); | ||
} | ||
if (schedule.defined()) { | ||
for (const auto& scalar : lower_te_compute.scalars_) { | ||
if (schedule->Contain(scalar)) { | ||
schedule[scalar].compute_inline(); | ||
} | ||
} | ||
} | ||
} | ||
|
||
return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, prim_func, {}, | ||
IRModule(Map<GlobalVar, BaseFunc>({})), lower_te_compute.constant_tensors_); | ||
} | ||
|
||
void VisitExpr_(const CallNode* call_node) final { | ||
static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern"); | ||
|
||
ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops"; | ||
Op op = Downcast<Op>(call_node->op); | ||
|
||
for (Expr arg : call_node->args) { | ||
VisitExpr(arg); | ||
} | ||
|
||
int op_pattern = fpattern[op]; | ||
if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { | ||
ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) | ||
<< "Cannot apply TOPI schedule to a primitive function with two complicated ops" | ||
<< " anchor=" << anchor_op_ << " current=" << op; | ||
} | ||
if (op_pattern >= anchor_op_pattern_) { | ||
anchor_op_ = op; | ||
anchor_attrs_ = call_node->attrs; | ||
anchor_op_pattern_ = op_pattern; | ||
} | ||
} | ||
|
||
private: | ||
tvm::Target target_; | ||
Op anchor_op_; | ||
Attrs anchor_attrs_; | ||
int anchor_op_pattern_{0}; | ||
bool use_auto_scheduler_; | ||
}; | ||
|
||
/*! | ||
* \brief Create schedule for target. | ||
|
@@ -750,9 +779,12 @@ std::string GetUniqueName(std::string name, std::unordered_map<std::string, int> | |
} | ||
|
||
TVM_REGISTER_GLOBAL("relay.backend.LowerToTE").set_body_typed([](Function prim_func) { | ||
return ScheduleBuilder(tvm::Target("ext_dev"), false).Create(prim_func, [&](std::string name) { | ||
return name; | ||
}); | ||
auto tgt = tvm::Target("ext_dev"); | ||
LowerToTECompute lower_te_compute(tgt); | ||
auto outputs = lower_te_compute.Lower(prim_func, [&](std::string name) { return name; }); | ||
return CachedFunc(tgt, GlobalVar(lower_te_compute.candidate_name_), lower_te_compute.fn_inputs_, | ||
outputs, te::Schedule(), tir::PrimFunc(), {}, | ||
IRModule(Map<GlobalVar, BaseFunc>({})), lower_te_compute.constant_tensors_); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @mbaret for this change |
||
}); | ||
|
||
} // namespace tec | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While you are refactoring this code it might be good to better annotate it with comments describing the pieces here as now this code (due to me and Mark refactoring) has quite a few branches.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this function is not only used for ordinary lowering purposes, but also abused for ansor / meta schedule task extraction or "apply history best" depending on where it is called...
Optional<ObjectRef> opt_mod_or_base_func = meta_schedule::MetaScheduleContext::QueryInsideWithScope
is very opaque since (1) we don't know if it returns anything and (2) we don't what concrete value it returns.After my task extraction refactor,
QueryInsideWithScope
can always return a non-null, concreteSchedule
object, since theNone
case is used only for task extraction currently. cc @junrushao1994There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per discussion with Masa, we decided to just clean up
meta_schedule/integration.cc
and make it clearer