Skip to content

Commit

Permalink
removed create_schedule stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Mar 10, 2022
1 parent fe1ec9d commit 8e4f69e
Showing 1 changed file with 39 additions and 37 deletions.
76 changes: 39 additions & 37 deletions src/relay/backend/te_compiler_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
#include "../op/memory/memory.h"
#include "../transforms/pass_utils.h"
#include "tvm/relay/op_strategy.h"
#include "tvm/tir/function.h"
#include "utils.h"

namespace tvm {
Expand Down Expand Up @@ -115,7 +116,7 @@ Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
return res;
}

// Construct a schedule for a given Relay primitive function and target.
// Lowers Relay primitive Function to TE Compute
class LowerToTECompute : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
public:
explicit LowerToTECompute(Target target)
Expand All @@ -133,7 +134,21 @@ class LowerToTECompute : public backend::MemoizedExprTranslator<Array<te::Tensor
memo_[param] = inputs;
}
readable_name_stream_ << "fused";
return this->VisitExpr(relay_func->body);

Array<te::Tensor> outputs = this->VisitExpr(relay_func->body);

candidate_name_ = readable_name_stream_.str();
constexpr static size_t kMaxFuncNameLength = 80;
// WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME
// whenever the value of kMaxFuncNameLength changes
if (candidate_name_.size() > kMaxFuncNameLength) {
std::stringstream truncated_name;
truncated_name << candidate_name_.substr(0, kMaxFuncNameLength);
truncated_name << "_" << std::hex << std::hash<std::string>{}(candidate_name_) << "_";
candidate_name_ = truncated_name.str();
}

return outputs;
}

Array<te::Tensor> VisitExpr_(const VarNode* op) final {
Expand Down Expand Up @@ -260,11 +275,12 @@ class LowerToTECompute : public backend::MemoizedExprTranslator<Array<te::Tensor
Array<tvm::te::Tensor> fn_inputs_;
Array<te::Operation> scalars_;
std::unordered_map<const ConstantNode*, te::Tensor> constant_tensors_;
std::ostringstream readable_name_stream_;
std::string candidate_name_;
OpImplementation anchor_implementation_;

private:
tvm::Target target_;
std::ostringstream readable_name_stream_;
// Index of the global constants
static int const_index;
// Cache device copy op for equivalence checking to reduce registry lookup
Expand All @@ -277,33 +293,20 @@ int LowerToTECompute::const_index = 0;
// Construct a schedule for a given Relay primitive function and target.
class ScheduleBuilder : ExprVisitor {
public:
explicit ScheduleBuilder(Target target, bool create_schedule = true)
: target_(target), create_schedule_(create_schedule) {
explicit ScheduleBuilder(Target target) : target_(target) {
// Whether to use auto_scheduler schedule.
use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
use_meta_schedule_ = backend::IsMetaScheduleEnabled();
}

CachedFunc Create(const Function& relay_func, std::function<std::string(std::string)> renamer) {
LowerToTECompute lower_te_compute(target_);
Array<te::Tensor> outputs = lower_te_compute.Lower(relay_func, renamer);
Array<te::Tensor> fn_inputs = lower_te_compute.fn_inputs_;
std::string candidate_name = lower_te_compute.readable_name_stream_.str();
VisitExpr(relay_func->body);

constexpr static size_t kMaxFuncNameLength = 80;
// WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME
// whenever the value of kMaxFuncNameLength changes
if (candidate_name.size() > kMaxFuncNameLength) {
std::stringstream truncated_name;
truncated_name << candidate_name.substr(0, kMaxFuncNameLength);
truncated_name << "_" << std::hex << std::hash<std::string>{}(candidate_name) << "_";
candidate_name = truncated_name.str();
}

// TODO(mbs): This should be the definitive global by which the PrimFunc is known and
// no other GlobalVar ctors should appear inside the lowering machinery.
auto prim_fn_var = GlobalVar(renamer(candidate_name));
auto prim_fn_var = GlobalVar(renamer(lower_te_compute.candidate_name_));
prim_fn_var->checked_type_ = relay_func->checked_type();

// Fusion over tupled results may leave identity relationships
Expand All @@ -319,7 +322,7 @@ class ScheduleBuilder : ExprVisitor {
te::Schedule schedule{nullptr};
tir::PrimFunc prim_func{nullptr};
// No need to register schedule for device copy op.
if (anchor_attrs_.as<DeviceCopyAttrs>() == nullptr && create_schedule_) {
if (anchor_attrs_.as<DeviceCopyAttrs>() == nullptr) {
if (use_auto_scheduler_) {
const auto* fauto_schedule =
runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute");
Expand All @@ -330,7 +333,7 @@ class ScheduleBuilder : ExprVisitor {
schedule = Downcast<te::Schedule>(obj);
}
}
if (use_meta_schedule_) {
if (backend::IsMetaScheduleEnabled()) {
prim_func = tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs));
Optional<ObjectRef> opt_mod_or_base_func =
meta_schedule::MetaScheduleContext::QueryInsideWithScope(
Expand Down Expand Up @@ -368,18 +371,16 @@ class ScheduleBuilder : ExprVisitor {
ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops";
Op op = Downcast<Op>(call_node->op);

if (create_schedule_) {
int op_pattern = fpattern[op];
if (!use_auto_scheduler_ && op_pattern >= kCommReduce) {
ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce)
<< "Cannot apply TOPI schedule to a primitive function with two complicated ops"
<< " anchor=" << anchor_op_ << " current=" << op;
}
if (op_pattern >= anchor_op_pattern_) {
anchor_op_ = op;
anchor_attrs_ = call_node->attrs;
anchor_op_pattern_ = op_pattern;
}
int op_pattern = fpattern[op];
if (!use_auto_scheduler_ && op_pattern >= kCommReduce) {
ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce)
<< "Cannot apply TOPI schedule to a primitive function with two complicated ops"
<< " anchor=" << anchor_op_ << " current=" << op;
}
if (op_pattern >= anchor_op_pattern_) {
anchor_op_ = op;
anchor_attrs_ = call_node->attrs;
anchor_op_pattern_ = op_pattern;
}
}

Expand All @@ -389,8 +390,6 @@ class ScheduleBuilder : ExprVisitor {
Attrs anchor_attrs_;
int anchor_op_pattern_{0};
bool use_auto_scheduler_;
bool use_meta_schedule_;
bool create_schedule_;
};

/*!
Expand Down Expand Up @@ -775,9 +774,12 @@ std::string GetUniqueName(std::string name, std::unordered_map<std::string, int>
}

TVM_REGISTER_GLOBAL("relay.backend.LowerToTE").set_body_typed([](Function prim_func) {
return ScheduleBuilder(tvm::Target("ext_dev"), false).Create(prim_func, [&](std::string name) {
return name;
});
auto tgt = tvm::Target("ext_dev");
LowerToTECompute lower_te_compute(tgt);
auto outputs = lower_te_compute.Lower(prim_func, [&](std::string name) { return name; });
return CachedFunc(tgt, GlobalVar(lower_te_compute.candidate_name_), lower_te_compute.fn_inputs_,
outputs, te::Schedule(), tir::PrimFunc(), {},
IRModule(Map<GlobalVar, BaseFunc>({})), lower_te_compute.constant_tensors_);
});

} // namespace tec
Expand Down

0 comments on commit 8e4f69e

Please sign in to comment.