Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TECompiler] Decouple TE compute and schedule lowering in ScheduleBuilder #10561

Merged
merged 6 commits into from
Mar 11, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 146 additions & 114 deletions src/relay/backend/te_compiler_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/op_strategy.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <tvm/te/schedule.h>
#include <tvm/te/schedule_pass.h>
#include <tvm/tir/function.h>
#include <tvm/topi/tags.h>

#include <functional>
Expand Down Expand Up @@ -114,100 +116,40 @@ Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
return res;
}

// Construct a schedule for a given Relay primitive function and target.
class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
// Lowers Relay primitive Function to TE Compute
class LowerToTECompute : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
public:
explicit ScheduleBuilder(Target target, bool create_schedule = true)
: target_(target),
device_copy_op_(Op::Get("device_copy")),
create_schedule_(create_schedule) {
// Whether to use auto_scheduler schedule.
use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
use_meta_schedule_ = backend::IsMetaScheduleEnabled();
}
explicit LowerToTECompute(Target target)
: target_(target), device_copy_op_(Op::Get("device_copy")) {}

CachedFunc Create(const Function& relay_func, std::function<std::string(std::string)> renamer) {
Array<tvm::te::Tensor> fn_inputs;
Array<te::Tensor> Lower(const Function& relay_func,
std::function<std::string(std::string)> renamer) {
for (Var param : relay_func->params) {
Array<tvm::te::Tensor> inputs;
for (const auto& ttype : FlattenTupleType(param->checked_type())) {
tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype);
fn_inputs.push_back(tensor);
inputs.push_back(tensor);
fn_inputs_.push_back(tensor);
}
memo_[param] = inputs;
}
readable_name_stream_ << "fused";
auto outputs = this->VisitExpr(relay_func->body);
auto candidate_name = readable_name_stream_.str();

Array<te::Tensor> outputs = this->VisitExpr(relay_func->body);

candidate_name_ = readable_name_stream_.str();

constexpr static size_t kMaxFuncNameLength = 80;
// WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME
// whenever the value of kMaxFuncNameLength changes
if (candidate_name.size() > kMaxFuncNameLength) {
if (candidate_name_.size() > kMaxFuncNameLength) {
std::stringstream truncated_name;
truncated_name << candidate_name.substr(0, kMaxFuncNameLength);
truncated_name << "_" << std::hex << std::hash<std::string>{}(candidate_name) << "_";
candidate_name = truncated_name.str();
}

// TODO(mbs): This should be the definitive global by which the PrimFunc is known and
// no other GlobalVar ctors should appear inside the lowering machinery.
auto prim_fn_var = GlobalVar(renamer(candidate_name));
prim_fn_var->checked_type_ = relay_func->checked_type();

// Fusion over tupled results may leave identity relationships
// between inputs and outputs, and those should not be scheduled.
// Hence schedule only non PlaceholderOp outputs.
tvm::Array<te::Tensor> tensor_outs;
for (const auto& tensor : outputs) {
if (!tensor->op.as<te::PlaceholderOpNode>()) {
tensor_outs.push_back(tensor);
}
}

te::Schedule schedule{nullptr};
tir::PrimFunc prim_func{nullptr};
// No need to register schedule for device copy op.
if (anchor_attrs_.as<DeviceCopyAttrs>() == nullptr && create_schedule_) {
if (use_auto_scheduler_) {
const auto* fauto_schedule =
runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute");
ICHECK(fauto_schedule != nullptr)
<< "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered";
ObjectRef obj = (*fauto_schedule)(prim_fn_var->name_hint, tensor_outs);
if (obj.defined()) {
schedule = Downcast<te::Schedule>(obj);
}
}
if (use_meta_schedule_) {
prim_func = tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs));
Optional<ObjectRef> opt_mod_or_base_func =
meta_schedule::MetaScheduleContext::QueryInsideWithScope(
prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), target_,
Array<IRModule>{IRModule({{prim_fn_var, prim_func}})});
if (const auto* result = opt_mod_or_base_func.as<tir::PrimFuncNode>()) {
prim_func = GetRef<tir::PrimFunc>(result);
} else {
prim_func = tir::PrimFunc(nullptr);
}
}

// Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule.
if (!schedule.defined() && !prim_func.defined()) {
ICHECK(anchor_implementation_.defined());
schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_);
}
if (schedule.defined()) {
for (const auto& scalar : scalars_) {
if (schedule->Contain(scalar)) {
schedule[scalar].compute_inline();
}
}
}
truncated_name << candidate_name_.substr(0, kMaxFuncNameLength);
truncated_name << "_" << std::hex << std::hash<std::string>{}(candidate_name_) << "_";
candidate_name_ = truncated_name.str();
}

return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, prim_func, {},
IRModule(Map<GlobalVar, BaseFunc>({})), constant_tensors_);
return outputs;
}

Array<te::Tensor> VisitExpr_(const VarNode* op) final {
Expand Down Expand Up @@ -254,7 +196,6 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
}

Array<te::Tensor> VisitExpr_(const CallNode* call_node) final {
static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call");
ICHECK(flower_call) << "relay.backend.lower_call is not registered.";

Expand All @@ -278,28 +219,13 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops";
Op op = Downcast<Op>(call_node->op);

Array<te::Tensor> outputs;
OpImplementation impl;
// TODO(mbs): device_copy cleanup
ICHECK_NE(op, device_copy_op_) << "device_copy cannot be lowered";

LoweredOutput lowered_out = (*flower_call)(GetRef<Call>(call_node), inputs, target_);
outputs = lowered_out->outputs;
impl = lowered_out->implementation;

if (create_schedule_) {
int op_pattern = fpattern[op];
if (!use_auto_scheduler_ && op_pattern >= kCommReduce) {
ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce)
<< "Cannot apply TOPI schedule to a primitive function with two complicated ops"
<< " anchor=" << anchor_op_ << " current=" << op;
}
if (op_pattern >= anchor_op_pattern_) {
anchor_op_ = op;
anchor_attrs_ = call_node->attrs;
anchor_op_pattern_ = op_pattern;
anchor_implementation_ = impl;
}
}
Array<te::Tensor> outputs = lowered_out->outputs;
op_implementations_[op.operator->()] = lowered_out->implementation;

if (outputs.size() != 1) {
const auto* tuple_type = call_node->checked_type().as<TupleTypeNode>();
ICHECK(tuple_type) << "Expected output to be a tuple type "
Expand All @@ -308,8 +234,6 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
ICHECK_EQ(tuple_type->fields.size(), outputs.size());
}

// TODO(mbs): device_copy cleanup
ICHECK_NE(op, device_copy_op_) << "device_copy cannot be lowered";
readable_name_stream_ << '_' << op->name;
return outputs;
}
Expand Down Expand Up @@ -347,26 +271,131 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
return {tuple[op->index]};
}

public:
// Additional outputs
Array<tvm::te::Tensor> fn_inputs_;
Array<te::Operation> scalars_;
std::unordered_map<const ConstantNode*, te::Tensor> constant_tensors_;
std::unordered_map<const OpNode*, OpImplementation> op_implementations_;
std::string candidate_name_;

private:
tvm::Target target_;
Op anchor_op_;
Attrs anchor_attrs_;
int anchor_op_pattern_{0};
OpImplementation anchor_implementation_;
std::ostringstream readable_name_stream_;
Array<te::Operation> scalars_;
std::unordered_map<const ConstantNode*, te::Tensor> constant_tensors_;
bool use_auto_scheduler_;
bool use_meta_schedule_;
// Index of the global constants
static int const_index;
// Cache device copy op for equivalence checking to reduce registry lookup
// overhead for each invocation of call node when retrieving schedules.
const Op& device_copy_op_;
bool create_schedule_;
// Index of the global constants
static int const_index;
};

int ScheduleBuilder::const_index = 0;
int LowerToTECompute::const_index = 0;

// Construct a schedule for a given Relay primitive function and target.
class ScheduleBuilder : public ExprVisitor {
public:
explicit ScheduleBuilder(Target target) : target_(target) {
// Whether to use auto_scheduler schedule.
use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
}

CachedFunc Create(const Function& relay_func, std::function<std::string(std::string)> renamer) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While you are refactoring this code it might be good to better annotate it with comments describing the pieces here as now this code (due to me and Mark refactoring) has quite a few branches.

Copy link
Member Author

@masahi masahi Mar 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this function is not only used for ordinary lowering purposes, but also abused for ansor / meta schedule task extraction or "apply history best" depending on where it is called...

Optional<ObjectRef> opt_mod_or_base_func = meta_schedule::MetaScheduleContext::QueryInsideWithScope is very opaque since (1) we don't know if it returns anything and (2) we don't what concrete value it returns.

After my task extraction refactor, QueryInsideWithScope can always return a non-null, concrete Schedule object, since the None case is used only for task extraction currently. cc @junrushao1994

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per discussion with Masa, we decided to just clean up meta_schedule/integration.cc and make it clearer

LowerToTECompute lower_te_compute(target_);
Array<te::Tensor> outputs = lower_te_compute.Lower(relay_func, renamer);
Array<te::Tensor> fn_inputs = lower_te_compute.fn_inputs_;
VisitExpr(relay_func->body);

// TODO(mbs): This should be the definitive global by which the PrimFunc is known and
// no other GlobalVar ctors should appear inside the lowering machinery.
auto prim_fn_var = GlobalVar(renamer(lower_te_compute.candidate_name_));
prim_fn_var->checked_type_ = relay_func->checked_type();

// Fusion over tupled results may leave identity relationships
// between inputs and outputs, and those should not be scheduled.
// Hence schedule only non PlaceholderOp outputs.
tvm::Array<te::Tensor> tensor_outs;
for (const auto& tensor : outputs) {
if (!tensor->op.as<te::PlaceholderOpNode>()) {
tensor_outs.push_back(tensor);
}
}

te::Schedule schedule{nullptr};
tir::PrimFunc prim_func{nullptr};
// No need to register schedule for device copy op.
if (anchor_attrs_.as<DeviceCopyAttrs>() == nullptr) {
if (use_auto_scheduler_) {
const auto* fauto_schedule =
runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute");
ICHECK(fauto_schedule != nullptr)
<< "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered";
ObjectRef obj = (*fauto_schedule)(prim_fn_var->name_hint, tensor_outs);
if (obj.defined()) {
schedule = Downcast<te::Schedule>(obj);
}
}
if (backend::IsMetaScheduleEnabled()) {
prim_func = tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs));
Optional<ObjectRef> opt_mod_or_base_func =
meta_schedule::MetaScheduleContext::QueryInsideWithScope(
prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), target_,
Array<IRModule>{IRModule({{prim_fn_var, prim_func}})});
if (const auto* result = opt_mod_or_base_func.as<tir::PrimFuncNode>()) {
prim_func = GetRef<tir::PrimFunc>(result);
} else {
prim_func = tir::PrimFunc(nullptr);
}
}

// Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule.
if (!schedule.defined() && !prim_func.defined()) {
auto anchor_impl = lower_te_compute.op_implementations_.find(anchor_op_.operator->());
ICHECK(anchor_impl != lower_te_compute.op_implementations_.end());
schedule = anchor_impl->second.Schedule(anchor_attrs_, tensor_outs, target_);
}
if (schedule.defined()) {
for (const auto& scalar : lower_te_compute.scalars_) {
if (schedule->Contain(scalar)) {
schedule[scalar].compute_inline();
}
}
}
}

return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, prim_func, {},
IRModule(Map<GlobalVar, BaseFunc>({})), lower_te_compute.constant_tensors_);
}

void VisitExpr_(const CallNode* call_node) final {
static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");

ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops";
Op op = Downcast<Op>(call_node->op);

for (Expr arg : call_node->args) {
VisitExpr(arg);
}

int op_pattern = fpattern[op];
if (!use_auto_scheduler_ && op_pattern >= kCommReduce) {
ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce)
<< "Cannot apply TOPI schedule to a primitive function with two complicated ops"
<< " anchor=" << anchor_op_ << " current=" << op;
}
if (op_pattern >= anchor_op_pattern_) {
anchor_op_ = op;
anchor_attrs_ = call_node->attrs;
anchor_op_pattern_ = op_pattern;
}
}

private:
tvm::Target target_;
Op anchor_op_;
Attrs anchor_attrs_;
int anchor_op_pattern_{0};
bool use_auto_scheduler_;
};

/*!
* \brief Create schedule for target.
Expand Down Expand Up @@ -750,9 +779,12 @@ std::string GetUniqueName(std::string name, std::unordered_map<std::string, int>
}

TVM_REGISTER_GLOBAL("relay.backend.LowerToTE").set_body_typed([](Function prim_func) {
return ScheduleBuilder(tvm::Target("ext_dev"), false).Create(prim_func, [&](std::string name) {
return name;
});
auto tgt = tvm::Target("ext_dev");
LowerToTECompute lower_te_compute(tgt);
auto outputs = lower_te_compute.Lower(prim_func, [&](std::string name) { return name; });
return CachedFunc(tgt, GlobalVar(lower_te_compute.candidate_name_), lower_te_compute.fn_inputs_,
outputs, te::Schedule(), tir::PrimFunc(), {},
IRModule(Map<GlobalVar, BaseFunc>({})), lower_te_compute.constant_tensors_);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @mbaret for this change

});

} // namespace tec
Expand Down