Skip to content

Commit 562e93a

Browse files
jroeschelectriclilies
authored andcommitted
Refactor the compile engine into a cleaner interface.
Duplicate the CompileEngine interface. Refactor the graph_runtime_codegen to invoke the new LowerTE pass More changes Things appear to be working Some tracing to get Relay code to flow through too. Disable some assertions as exp. Tweak printing for now Fix a few bugs: (#13) 1. Don't add relay main function to list of lowered TIR functions 2. Don't skip visiting call to relay function in graph runtime codegen Remove debug prints. Start refactoring Split out shared data structures Fix implicit duplicate decl of IsDynamic Clean up handling of name + global prim fn Clean up the code and debug issue introduced by previous hack Clean up the debugging Do C++ lint clean up Update src/relay/backend/graph_executor_codegen.cc Co-authored-by: Chris Sullivan <csullivan@octoml.ai> Clean up handling of external functions Add more error messages More clean up Update src/runtime/graph_executor/graph_executor.cc Co-authored-by: Chris Sullivan <csullivan@octoml.ai> Update src/runtime/graph_executor/graph_executor.cc Co-authored-by: Chris Sullivan <csullivan@octoml.ai> Update src/relay/backend/te_compiler.h Co-authored-by: Haichen Shen <shenhaichen@gmail.com> Update src/relay/backend/te_compiler.h Co-authored-by: Haichen Shen <shenhaichen@gmail.com> Fix CR More CR Format Fix lowering path for C++ Fix tests Remove uncessary change Clean up a few more things CI fix Fix the default context Fix Fix broken test cases Update Fix WIP Clean up storage data structures WIP WIP Fix build errors Remove TVMLower Fix lint Lint again fix black Move UpdateMainWorkspaceSize into te_compiler.cc Fix link errors Formatting Change UpdateMainWorkspaceSize to return Map<String, FunctionInfo> Workaround for GCC 5 error caused by enums in maps (GCC 5 is on i386 CI) Testing how functions should be named Lint Change how function metadata is updated Attempt to update aot_executor_codegen to use new StaticMemoryPlan instead of storage_device_map Pass memory plan through LowerTE into UpdateMainWorkspaceSize so that we don't need to run GraphPlanMemory an extra time Fix return in UpdateMainWorkspaceSize Lint Try to fix UpdateMainWorkspaceSize Fix construction of static memory plan Clean up code while debugging Adding UpdateWorkspaceSize back Add closure + call to UpdateFunctionMetadata (WIP) UpdateFunctionMetadata builds; weird error with device ctx map though. Not sure if it came from this change or something else Add some debugging of UpdateMainWorkspaceSize Starting to move UpdateFunctionMetadata call to use process_fn infra UWhat target should be passed to UpdateFunctionMetadata? UpdateFunctionMetadata is not workinggg Added some comments about UpdateFunctionMetadata for Jared Fix the creation of function metadata Try another stab at cleaning up the information Fix Port StorageInfo and StaticMemoryPlan data structure (apache#8297) Restoring reshape opt Fix tests Caught a nasty typo from Lily, Map::Set does not mutate Format Disable stupid Google style warning Rebase cleanup Formatting Add docstring for storage info Black Post rebase fix Remove prints Disable assert that doesn't make sense for now Fix lint Add copying attrs from relay node to graph node; still need to figure out how to do this in the case of global vars Work with Lily to fix graph attrs Try to figure out where extra arguments are coming from; fix merge passes the profiling test Clean up Fix profile test Remove debugging Add attributes for BYOC uTVM case Format Dumb typo Another fix for byoc Format Fix last 3 failing tests Format Fix final two test cases Format Fix lint Fix again Fix Fix auto scheduler code Fix issue Address CR comment Format
1 parent 8fb4cdf commit 562e93a

29 files changed

+2340
-1158
lines changed

include/tvm/relay/attrs/annotation.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,18 @@ struct CompilerAttrs : public tvm::AttrsNode<CompilerAttrs> {
6767
}
6868
};
6969

70+
/*!
71+
* \brief Metadata for calls to TIR functions, useful for program analysis crossing Relay and TIR.
72+
*/
73+
struct TIRCallAttrs : public tvm::AttrsNode<TIRCallAttrs> {
74+
/*! \brief The metadata attached to the call node. */
75+
Map<String, ObjectRef> metadata;
76+
77+
TVM_DECLARE_ATTRS(TIRCallAttrs, "relay.attrs.TIRCallAttrs") {
78+
TVM_ATTR_FIELD(metadata).describe("Metadata attached to the TIR function call.");
79+
}
80+
};
81+
7082
} // namespace relay
7183
} // namespace tvm
7284
#endif // TVM_RELAY_ATTRS_ANNOTATION_H_

python/tvm/auto_scheduler/relay_integration.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ def auto_schedule_topi(func_name, outs):
318318
A tuned schedule or none (if not tuned) in the final build mode;
319319
None in the tracing mode so that the fallback topi schedule will be used.
320320
"""
321+
321322
# pylint: disable=import-outside-toplevel
322323
from tvm.auto_scheduler.measure import (
323324
prepare_input_map,
@@ -376,6 +377,15 @@ def auto_schedule_topi(func_name, outs):
376377
return schedule
377378

378379

380+
@tvm._ffi.register_func("auto_scheduler.relay_integration.te_compiler_update_weights")
381+
def te_compiler_update_weights(function_weights):
382+
"""A callback for updating the weights of extracted tasks."""
383+
env = TracingEnvironment.current
384+
if env is not None:
385+
for key in env.wkl_key_to_weight:
386+
env.wkl_key_to_weight[key] = function_weights[key[0]]
387+
388+
379389
def tensor_no_check_call(self, *indices):
380390
"""An indexing function without any check.
381391
This is the same as `tvm.te.Tensor::__call__` except that the safety

python/tvm/auto_scheduler/task_scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ def pre_tune(self, task_scheduler, task_id):
598598

599599
# overall info
600600
if all(cost < 1e9 for cost in task_scheduler.best_costs):
601-
total_latency_str = "%.3f" % (task_scheduler.cur_score * 1e3)
601+
total_latency_str = "%.3f" % (task_scheduler.cur_score.value * 1e3)
602602
else:
603603
total_latency_str = "-"
604604
print(

python/tvm/relay/backend/compile_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ def dump(self):
429429
res += "------------------------------------\n"
430430
res += "target={}\n".format(k.target)
431431
res += "use_count={}\n".format(v.use_count)
432-
res += "func_name={}\n".format(v.cached_func.func_name)
432+
res += "func_name={}\n".format(v.cached_func.prim_fn_var.name_hint)
433433
res += "----relay function----\n"
434434
res += k.source_func.astext() + "\n"
435435
res += "----tir function----- \n"
@@ -444,7 +444,7 @@ def dump(self):
444444
res += "------------------------------------\n"
445445
res += "target={}\n".format(k.target)
446446
res += "use_count={}\n".format(v.use_count)
447-
res += "func_name={}\n".format(v.cached_func.func_name)
447+
res += "func_name={}\n".format(v.cached_func.prim_fn_var.name_hint)
448448
res += "----relay function----\n"
449449
res += k.source_func.astext() + "\n"
450450
res += "----tir function----- \n"

python/tvm/relay/expr.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import tvm._ffi
2424
from tvm._ffi import base as _base
2525
from tvm.runtime import NDArray, ndarray as _nd
26-
from tvm.ir import RelayExpr, GlobalVar
26+
from tvm.ir import RelayExpr, GlobalVar, Node
2727

2828
from .base import RelayNode
2929
from . import _ffi_api
@@ -538,3 +538,25 @@ def bind(expr, binds):
538538
The expression or function after binding.
539539
"""
540540
return _ffi_api.Bind(expr, binds)
541+
542+
543+
@tvm._ffi.register_object("relay.StorageInfo")
544+
class StorageInfo(Node):
545+
"""StorageInfo
546+
547+
The static storage information produced by memory planning.
548+
Contains the storage ids where expressions are stored, the
549+
type of the "virtual devices" the expressions are stored on,
550+
and the sizes of each storage element."""
551+
552+
@property
553+
def storage_ids(self):
554+
return _ffi_api.StorageInfoStorageIds(self)
555+
556+
@property
557+
def device_types(self):
558+
return _ffi_api.StorageInfoDeviceTypes(self)
559+
560+
@property
561+
def storage_sizes(self):
562+
return _ffi_api.StorageInfoStorageSizes(self)

src/driver/driver_api.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -437,14 +437,18 @@ std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule mod_mixed, const Target
437437
}
438438

439439
if (target->kind->device_type == kDLCPU && target_host == target) {
440-
ICHECK(mdevice->functions.empty()) << "No device code should be generated when target "
441-
<< "and host_target are both llvm target."
442-
<< "\n";
440+
// TODO(@jroesch): This check is no longer true we need to figure out if we care about this.
441+
// We need to relax this check for just TIR functions.
442+
// ICHECK(mdevice->functions.empty()) << "No device code should be generated when target "
443+
// << "and host_target are both llvm target."
444+
// << "\n";
443445
}
444446

445447
return {mhost, mdevice};
446448
}
447449

450+
// Can we make this take one annotated IRModule?
451+
//
448452
// Build for heterogeneous execution.
449453
runtime::Module build(const Map<Target, IRModule>& inputs_arg, const Target& target_host_arg) {
450454
auto pass_ctx = transform::PassContext::Current();

src/relay/backend/aot_executor_codegen.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ class AOTExecutorCodegen : public ExprVisitor {
439439
fi_node->tir_primfuncs.Set(primfunc_target, primfunc);
440440
fi_node->relay_primfuncs.Set(primfunc_target, relay_func);
441441
}
442-
function_metadata_.Set(cfunc->func_name, FunctionInfo(fi_node));
442+
function_metadata_.Set(cfunc->prim_fn_var->name_hint, FunctionInfo(fi_node));
443443
}
444444

445445
void VisitExpr_(const CallNode* op) override {
@@ -465,20 +465,18 @@ class AOTExecutorCodegen : public ExprVisitor {
465465
<< "(i.e functions composed of fusable operator invocations)";
466466
}
467467

468-
auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey");
469-
auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
470468
Target target;
471469

472470
// Handle external function
473471
if (func->GetAttr<String>(attr::kCompiler).defined()) {
474472
target = Target("ext_dev");
475-
CCacheKey key = (*pf0)(func, target);
476-
CachedFunc ext_func = (*pf1)(compile_engine_, key, mod_name_);
473+
CCacheKey key = CCacheKey(func, target);
474+
CachedFunc ext_func = compile_engine_->Lower(key, mod_name_);
477475
ICHECK(ext_func.defined()) << "External function is not defined.";
478476
UpdateConstants(func, &params_);
479477

480478
// Generate the TIR function call
481-
CreateFuncCall(GetRef<Call>(op), ext_func->func_name);
479+
CreateFuncCall(GetRef<Call>(op), ext_func->prim_fn_var->name_hint);
482480
return;
483481
}
484482

@@ -503,8 +501,10 @@ class AOTExecutorCodegen : public ExprVisitor {
503501
}
504502
target = targets_[call_dev_type];
505503
}
506-
CCacheKey key = (*pf0)(func, target);
507-
CachedFunc lowered_func = (*pf1)(compile_engine_, key, mod_name_);
504+
505+
CCacheKey key = CCacheKey(func, target);
506+
CachedFunc lowered_func = compile_engine_->Lower(key, mod_name_);
507+
508508
if (!lowered_funcs_.count(target->str())) {
509509
lowered_funcs_[target->str()] = IRModule(Map<GlobalVar, BaseFunc>({}));
510510
}
@@ -513,7 +513,7 @@ class AOTExecutorCodegen : public ExprVisitor {
513513
UpdateFunctionMetadata(lowered_func, func, target);
514514

515515
// Generate the TIR function call
516-
CreateFuncCall(GetRef<Call>(op), lowered_func->func_name);
516+
CreateFuncCall(GetRef<Call>(op), lowered_func->prim_fn_var->name_hint);
517517
}
518518

519519
void VisitExpr_(const VarNode* op) override {

0 commit comments

Comments
 (0)