Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 23 additions & 38 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
*/

/*!
* Copyright (c) 2019 by Contributors
* \file relay/backend/build_module.cc
* \brief Code generation for TVM's graph runtime.
*/

#include <tvm/build_module.h>
#include <tvm/runtime/device_api.h>
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/attrs/nn.h>
Expand All @@ -40,31 +39,6 @@ namespace backend {

using TargetsMap = Map<tvm::Integer, tvm::Target>;

/*!
* \brief Context index to Target
*/
struct ContextTargetMap {
static const std::unordered_map<int, tvm::Target> mask2str;
static tvm::Target Mask2Str(int mask) {
CHECK_GT(mask2str.count(mask), 0) << "Unknown mask.";
return mask2str.at(mask);
}
};

const std::unordered_map<int, tvm::Target> ContextTargetMap::mask2str = {
{1, tvm::Target::create("llvm")},
{2, tvm::Target::create("cuda")},
{4, tvm::Target::create("opencl")},
{5, tvm::Target::create("aocl")},
{6, tvm::Target::create("sdaccel")},
{7, tvm::Target::create("vulkan")},
{8, tvm::Target::create("metal")},
{9, tvm::Target::create("vpi")},
{10, tvm::Target::create("rocm")},
{11, tvm::Target::create("opengl")},
{12, tvm::Target::create("ext_dev")}
};

/*!
* \brief A data structure to map the names of specific optimizations to
* numeric optimization levels
Expand Down Expand Up @@ -310,8 +284,8 @@ class RelayBuildModule : public runtime::ModuleNode {
*
* \return Array<StringImm> names of params
*/
Array<HalideIR::Expr> ListParamNames() {
Array<HalideIR::Expr> ret;
Array<tvm::Expr> ListParamNames() {
Array<tvm::Expr> ret;
for (const auto& kv : params_) {
ret.push_back(ir::StringImm::make(kv.first));
}
Expand Down Expand Up @@ -470,12 +444,9 @@ class RelayBuildModule : public runtime::ModuleNode {
if (cfg.pass_enabled("AlterOpLayout")) {
if (targets.size() == 1) {
func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
auto enter_pf = GetPackedFunc("_EnterTargetScope");
auto exit_pf = GetPackedFunc("_ExitTargetScope");
for (const auto& kv : targets) {
(*enter_pf)(kv.second);
TargetContext tctx(kv.second);
func = CallPackedFunc("relay._ir_pass.AlterOpLayout", func);
(*exit_pf)();
}
} else {
LOG(WARNING) << "AlterOpLayout pass is not enabled for heterogeneous"
Expand All @@ -487,6 +458,18 @@ class RelayBuildModule : public runtime::ModuleNode {
}
return func;
}

/*!
* \brief Create a default type.
* \param device_type The device type index.
* \return the default target for the device.
*/
Target CreateDefaultTarget(int device_type) {
std::string name = runtime::DeviceName(device_type);
if (name == "cpu") return Target::create("llvm");
if (name == "gpu") return Target::create("cuda");
return Target::create(name);
}
/*!
* \brief Update the target and fallback device required for heterogeneous
* compilation. CPU is used as the fallback device if it wasn't provided.
Expand All @@ -507,7 +490,7 @@ class RelayBuildModule : public runtime::ModuleNode {
if (tmp_map.count(cfg.fallback_device) == 0) {
device_target.Set(
cfg.fallback_device,
ContextTargetMap::Mask2Str(cfg.fallback_device));
CreateDefaultTarget(cfg.fallback_device));
}
return device_target;
}
Expand All @@ -520,7 +503,8 @@ class RelayBuildModule : public runtime::ModuleNode {
* \param targets_map_ptr
* \return Function
*/
Function RunDeviceAnnotationPass(Function func, const RelayBuildConfig& cfg,
Function RunDeviceAnnotationPass(Function func,
const RelayBuildConfig& cfg,
TargetsMap* targets_map_ptr) {
func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
func = CallPackedFunc("relay._ir_pass.RewriteDeviceAnnotation", func,
Expand All @@ -532,7 +516,7 @@ class RelayBuildModule : public runtime::ModuleNode {
"relay._ir_pass.CollectDeviceAnnotationOps", func, nullptr);
if (annotation_map.size() == 0) {
targets_map_ptr->Set(
0, ContextTargetMap::Mask2Str(cfg.fallback_device));
0, CreateDefaultTarget(cfg.fallback_device));
} else {
int64_t dev_type = -1;
for (auto kv : annotation_map) {
Expand All @@ -547,7 +531,7 @@ class RelayBuildModule : public runtime::ModuleNode {
<< "found. Please check the "
<< "RewriteAnnotation pass.";
}
targets_map_ptr->Set(0, ContextTargetMap::Mask2Str(dev_type));
targets_map_ptr->Set(0, CreateDefaultTarget(dev_type));
}
}
return func;
Expand Down Expand Up @@ -611,7 +595,8 @@ runtime::Module RelayBuildCreate() {
return runtime::Module(exec);
}

TVM_REGISTER_GLOBAL("relay.build_module._BuildModule").set_body([](TVMArgs args, TVMRetValue* rv) {
TVM_REGISTER_GLOBAL("relay.build_module._BuildModule")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = RelayBuildCreate();
});

Expand Down