Skip to content

Commit 01c73c5

Browse files
committed
[checkpoint] Bug fixes. AOT is broken.
[checkpoint] Separate eval-to-clousure and apply-closure phases at last [checkpoint] Fix GetType recursion, get debug going. [checkpoint] Audit python to collapse create_executor and evaluate phases Just a few places where this doesn't work, seem harmless. [checkpoint] Get interpreter working using tec::LowerTE, but no dynamic shapes. - Hide TECompiler impl inside te_compiler.cc. However I think it is already exposed into Python land so this probably won't be possible now. - Move 'optimize' pre-transforms from interpreter.py to interpreter.cc so can be applied uniformly to both mod and expr. - Don't push the expr into the mod in interpreter.py since it's done again in interpreter.cc. Instead just build the Call node with the reflected args. - Both the mod and the expr are prepared identically (same transforms, of which LowerTensorExpr should be one). - LowerTensorExpr can look through let-bound and global vars, eg let f = fn (..., Primitive=1) { ... } ... f(...) ==> @lowered_f = ... @lowered_f(...) - Lots of DLOGs that need to be removed or reorganized. [checkpoint] Support shape functions. TODO: - Unit tests. - Cleanup logging (VLOG?) - Don't build all prims on each apply. [checkpoint] typo [checkpoint] Don't allow evaling expr independently of preparing module. TODO: - Make eval(mod, expr) the interface. - GlobalVar's don't line up. - Rework use of interpreter in fold_constant.cc to make clear it is evaling prim calls which have already been prepared. - Find a dynamic shape example that works at HEAD. - Unit tests. [checkpoint] Interpreting expression with refs to module defs working Commit to interpreter evaling expr w.r.t. mod in single phase. Thankfully turns out no existing uses broke that assumption so we dodged a bullet. Binding of expr into mod to be evaled is a mess, needs to be fixed. Still can't confirm dynamic shapes working since don't have an example working at HEAD. Change to partial_eval needs to be tested, but smells ok. [checkpoint] Dynamic shapes working The use of TIRCallAttrs is pretty hacky but the shape function calls are at least working. Next is to tackle the 'build everything just to project one prim fun' problem. [checkpoint] Cache built prims, make sure build with minimal deps. [checkpoint] Cleanup expr-to-module hackery.
1 parent b3e832a commit 01c73c5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+1120
-533
lines changed

docs/langref/relay_pattern.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ Either match the first pattern or the second pattern.
406406
Domination
407407
**********
408408

409-
Match child pattern, find a match for the parent pattern, insuring that the child ultimately dominates the parrent (i.e., no nodes outside the pattern use outputs of the parent), and that ever node betwen the child and the pattern matches the path pattern.
409+
Match child pattern, find a match for the parent pattern, insuring that the child ultimately dominates the parent (i.e., no nodes outside the pattern use outputs of the parent), and that ever node between the child and the pattern matches the path pattern.
410410

411411
Function Pattern
412412
****************

include/tvm/ir/module.h

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -307,20 +307,30 @@ class IRModule : public ObjectRef {
307307
}
308308

309309
/*!
310-
* \brief Construct a module from a standalone expression.
310+
* \brief Constructs a module from a standalone expression \p expr.
311311
*
312-
* Allows one to optionally pass a global function map and
313-
* map of type definitions as well.
312+
* If \p expr is a function it will be bound directly. Otherwise a function over the free
313+
* variables of \p expr (possibly none) with \p expr as body is created and bound.
314+
*
315+
* The function is bound to, in preference order:
316+
* - The "global_symbol" attribute of \p expr, if it is a function with that attribute.
317+
* - \p name_hint, if non-empty.
318+
* - "main"
319+
*
320+
* Additional global functions and type definitions may be included in the result module.
314321
*
315322
* \param expr The expression to set as the main function to the module.
316-
* \param global_funcs The global function map.
317-
* \param type_definitions Map of global type definitions
323+
* \param global_funcs The global function map. Default empty.
324+
* \param type_definitions Map of global type definitions. Default empty.
325+
* \param name_hint Name hint for global var. Default empty.
318326
*
319-
* \returns A module with expr set as the main function.
327+
* \returns A module with \p expr set as the main function.
320328
*/
321329
TVM_DLL static IRModule FromExpr(const RelayExpr& expr,
322330
const Map<GlobalVar, BaseFunc>& global_funcs = {},
323-
const Map<GlobalTypeVar, TypeData>& type_definitions = {});
331+
const Map<GlobalTypeVar, TypeData>& type_definitions = {},
332+
std::unordered_set<String> import_set = {},
333+
const std::string& name_hint = std::string());
324334

325335
/*!
326336
* \brief Parse text format source file into an IRModule.

include/tvm/relay/interpreter.h

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -43,28 +43,6 @@
4343
namespace tvm {
4444
namespace relay {
4545

46-
/*!
47-
*\brief Create a Interpreter function that can
48-
* evaluate an expression and produce a value.
49-
*
50-
* The resulting value can be passed to Python, making it easy to use
51-
* for testing and debugging.
52-
*
53-
* The interpreter interprets the program fragments not supported by the
54-
* TVM runtime, although the interpreter is naively implemented it uses
55-
* TVM operators for evaluating all operators.
56-
*
57-
* Our intent is that this will never be the most efficient implementation of
58-
* Relay's semantics, but a readable and clear one.
59-
*
60-
* \param mod The function module.
61-
* \param device The primary device that the interepreter runs on.
62-
* \param target Compiler target flag to compile the functions on the context.
63-
* \return A function that takes in an expression and returns a value.
64-
*/
65-
runtime::TypedPackedFunc<ObjectRef(Expr)> CreateInterpreter(IRModule mod, Device device,
66-
Target target);
67-
6846
/*! \brief The container type of Closures used by the interpreter. */
6947
class InterpreterClosureObj : public runtime::ClosureObj {
7048
public:
@@ -164,6 +142,50 @@ class ConstructorValue : public ObjectRef {
164142
TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueObj);
165143
};
166144

145+
/*!
146+
* \brief Returns a packed function over Relay expressions which will evaluate \p expr
147+
* applied to those arguments, where \p expr is w.r.t. the definitions in \p mod.
148+
*
149+
* This function is intended to support the Python 'debug' executor.
150+
*
151+
* The given \p expr should have function type. The given \p mod may be empty or
152+
* undefined if \p expr is self-contained. Relay arguments passed to the result
153+
* packed function must be constants, references, or constructors/tuples over such.
154+
* As much work as possible is done while constructing the result packed function, and
155+
* that function may be reasonably efficiently applied multiple times without redoing
156+
* unnecessary work.
157+
*
158+
* Primitives are lowered and compiled to packed functions for execution on \p device
159+
* with properties given by \p target. All other Relay constructs are interpreted.
160+
*
161+
* The interpreter is intended to be a 'reference' implementation of the Relay semantics
162+
* for testing and interactive use. It is not intended to be particularly efficient.
163+
*
164+
* \param mod A module containing definitions which can be referenced from
165+
* \p expr. May be empty or undefined.
166+
* \param expr An expression of function type to evaluate. May reference definitions from \p mod.
167+
* \param device The device on which all primitives will be executed.
168+
* \param target The compiler target flag for compiling primitives.
169+
* \return A packed function that takes an array of Relay expressions and returns the
170+
* result of applying \p expr to those arguments.
171+
*/
172+
TypedPackedFunc<ObjectRef(Array<Expr>)> EvalFunction(IRModule mod, Expr expr, Device device,
173+
Target target);
174+
175+
/*!
176+
* \brief Evaluates \p expr and returns its result.
177+
*
178+
* This function is intended to support TVM constant evaluation.
179+
*
180+
* @param expr An expression to evaluate.
181+
* \param device The device on which all primitives will be executed.
182+
* \param target The compiler target flag for compiling primitives.
183+
* @return The object representing the result.
184+
*/
185+
ObjectRef Eval(Expr expr, Map<GlobalTypeVar, TypeData> type_definitions,
186+
std::unordered_set<String> import_set, Device device, Target target);
187+
167188
} // namespace relay
168189
} // namespace tvm
190+
169191
#endif // TVM_RELAY_INTERPRETER_H_

include/tvm/runtime/logging.h

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -487,41 +487,53 @@ TVM_CHECK_FUNC(_NE, !=)
487487
#define DLOG_IF(severity, condition) \
488488
LOG_IF(severity, ::tvm::runtime::detail::DebugLoggingEnabled() && (condition))
489489

490+
#ifdef VLOG_LEVEL
491+
#define VLOG(level) DLOG_IF(INFO, (level) <= VLOG_LEVEL)
492+
#else
493+
#define VLOG(level) true ? (void)0 : ::tvm::runtime::detail::LogMessageVoidify() & LOG(INFO)
494+
#endif
495+
490496
#else
491497

492498
#define LOG_DFATAL LOG_ERROR
493499
#define DFATAL ERROR
494500
#define DLOG(severity) true ? (void)0 : ::tvm::runtime::detail::LogMessageVoidify() & LOG(severity)
495501
#define DLOG_IF(severity, condition) \
496502
(true || !(condition)) ? (void)0 : ::tvm::runtime::detail::LogMessageVoidify() & LOG(severity)
503+
#define VLOG(level) true ? (void)0 : ::tvm::runtime::detail::LogMessageVoidify() & LOG(INFO)
497504

498505
#endif
499506

500507
#if TVM_LOG_DEBUG
508+
#define DCHECK(x) CHECK(x)
509+
#define DCHECK_LT(x, y) CHECK((x) < (y))
510+
#define DCHECK_GT(x, y) CHECK((x) > (y))
511+
#define DCHECK_LE(x, y) CHECK((x) <= (y))
512+
#define DCHECK_GE(x, y) CHECK((x) >= (y))
513+
#define DCHECK_EQ(x, y) CHECK((x) == (y))
514+
#define DCHECK_NE(x, y) CHECK((x) != (y))
515+
#else
501516
#define DCHECK(x) \
502-
while (false) CHECK(x)
517+
while (false) CHECK(x)
503518
#define DCHECK_LT(x, y) \
504-
while (false) CHECK((x) < (y))
519+
while (false) CHECK((x) < (y))
505520
#define DCHECK_GT(x, y) \
506-
while (false) CHECK((x) > (y))
521+
while (false) CHECK((x) > (y))
507522
#define DCHECK_LE(x, y) \
508-
while (false) CHECK((x) <= (y))
523+
while (false) CHECK((x) <= (y))
509524
#define DCHECK_GE(x, y) \
510-
while (false) CHECK((x) >= (y))
525+
while (false) CHECK((x) >= (y))
511526
#define DCHECK_EQ(x, y) \
512-
while (false) CHECK((x) == (y))
527+
while (false) CHECK((x) == (y))
513528
#define DCHECK_NE(x, y) \
514-
while (false) CHECK((x) != (y))
529+
while (false) CHECK((x) != (y))
530+
#endif
531+
532+
#if TVM_LOG_DEBUG
515533
#else
516-
#define DCHECK(x) CHECK(x)
517-
#define DCHECK_LT(x, y) CHECK((x) < (y))
518-
#define DCHECK_GT(x, y) CHECK((x) > (y))
519-
#define DCHECK_LE(x, y) CHECK((x) <= (y))
520-
#define DCHECK_GE(x, y) CHECK((x) >= (y))
521-
#define DCHECK_EQ(x, y) CHECK((x) == (y))
522-
#define DCHECK_NE(x, y) CHECK((x) != (y))
523534
#endif
524535

536+
525537
#define TVM_ICHECK_INDENT " "
526538

527539
#define ICHECK_BINARY_OP(name, op, x, y) \
@@ -552,5 +564,8 @@ TVM_CHECK_FUNC(_NE, !=)
552564
// Re-export error types
553565
using runtime::Error;
554566
using runtime::InternalError;
567+
568+
void InitLogging();
569+
555570
} // namespace tvm
556571
#endif // TVM_RUNTIME_LOGGING_H_

python/tvm/relay/backend/interpreter.py

Lines changed: 35 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -196,65 +196,55 @@ class Interpreter(Executor):
196196
197197
target : tvm.Target
198198
The target option to build the function.
199+
200+
CAUTION: Despite the API the module is prepared upon each call to evaluate
201+
rather than once in create_executor.
202+
That is:
203+
.. code-block:: python
204+
205+
executor = relay.create_executor(kind="debug", mod=module)
206+
a = executor.evaluate(expr)(args1)
207+
b = executor.evaluate(expr)(args2)
208+
209+
will prepare all the bindings in module twice. For efficiency, try to hoist
210+
calls to evaluate as high as possible, preferably immediately after create_executor:
211+
.. code-block:: python
212+
213+
func = relay.create_executor(kind="debug", mod=module).evaluate(expr)
214+
a = func(args1)
215+
b = func(args2)
199216
"""
200217

201218
def __init__(self, mod, device, target):
202219
self.mod = mod
203220
self.device = device
204221
self.target = target
205222

206-
def optimize(self):
207-
"""Optimize functions in a module.
208-
209-
Returns
210-
-------
211-
opt_mod : tvm.IRModule
212-
The optimized module.
213-
"""
214-
seq = tvm.transform.Sequential(
215-
[
216-
# tvm.parser.AnnotateSpans(),
217-
transform.SimplifyInference(),
218-
transform.FuseOps(0),
219-
transform.ToANormalForm(),
220-
transform.InferType(),
221-
]
222-
)
223-
mod = seq(self.mod)
224-
return mod
225-
226223
def _make_executor(self, expr=None):
227224
if expr is None or isinstance(expr, GlobalVar):
228225
assert self.mod is not None
229226

230-
_intrp = _backend.CreateInterpreter(self.optimize(), self.device, self.target)
227+
if expr is None:
228+
# A missing expr denotes 'main' in the given module.
229+
expr = self.mod.get_global_var("main")
231230

232-
def _interp_wrapper(*args, **kwargs):
233-
if expr is None:
234-
args = self._convert_args(self.mod["main"], args, kwargs)
235-
else:
236-
args = self._convert_args(expr, args, kwargs)
231+
# Evaluate expr to a packed function we can efficiently re-apply
232+
# to Relay arguments.
233+
print("before EvalFunction\n")
234+
func = _backend.EvalFunction(self.mod, expr, self.device, self.target)
235+
print("after EvalFunction\n")
237236

237+
def _apply_args(*args, **kwargs):
238+
if expr is GlobalVar:
239+
# When expanding args, look inside the actual global definition so kwargs can be matched.
240+
args = self._convert_args(self.mod[expr.name_hint], args, kwargs)
241+
else:
242+
args = self._convert_args(expr, args, kwargs)
243+
# Reflect python arguments up into Relay.
238244
relay_args = []
239245
for arg in args:
240246
relay_args.append(_arg_to_ast(self.mod, arg))
247+
# Apply func to Relay args
248+
return func(relay_args)
241249

242-
# Set the entry function for the module.
243-
if expr is None:
244-
pass
245-
elif isinstance(expr, GlobalVar):
246-
self.mod["main"] = self.mod[expr]
247-
else:
248-
assert isinstance(expr, Function)
249-
func = Function([], Call(expr, relay_args))
250-
relay_args = []
251-
if self.mod:
252-
self.mod["main"] = func
253-
else:
254-
self.mod = IRModule.from_expr(func)
255-
256-
mod = self.optimize()
257-
opt_expr = Call(mod["main"], relay_args)
258-
return _intrp(opt_expr)
259-
260-
return _interp_wrapper
250+
return _apply_args

python/tvm/relay/frontend/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -545,10 +545,10 @@ def infer_value(input_val, params, mod=None):
545545
mod["main"] = _function.Function(analysis.free_vars(input_val), input_val)
546546
else:
547547
mod = IRModule.from_expr(input_val)
548-
exc = tvm.relay.create_executor("debug", mod=mod, device=tvm.cpu(), target="llvm")
549548
inputs = []
550549
for param in mod["main"].params:
551550
inputs.append(params[param.name_hint])
551+
exc = tvm.relay.create_executor("debug", mod=mod, device=tvm.cpu(), target="llvm")
552552
result = exc.evaluate()(*inputs)
553553
return result
554554

python/tvm/relay/testing/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,15 @@ def check_grad(
134134
test_inputs = inputs
135135

136136
for target, dev in enabled_targets():
137+
# Eval the backward and forward functions
137138
intrp = relay.create_executor(device=dev, target=target)
139+
bwd_func_compiled, fwd_func_compiled = intrp.evaluate(relay.Tuple([bwd_func, fwd_func]))
138140

139141
# Get analytic gradients.
140-
_, grads = intrp.evaluate(bwd_func)(*inputs)
142+
_, grads = bwd_func_compiled(*inputs)
141143
grads = [grad.numpy().astype("float64") for grad in grads]
142144

145+
143146
# Throw out gradients we aren't testing
144147
if inputs != test_inputs:
145148
tmp = []
@@ -154,7 +157,6 @@ def check_grad(
154157
assert len(grads) > 0, "You must test at least one gradient."
155158

156159
# Get numeric gradients for each dimension of each param, using two-sided approximation.
157-
fwd_func_compiled = intrp.evaluate(fwd_func)
158160
approx_grads = []
159161
for x in test_inputs:
160162
approx_grad = np.zeros(x.shape)

src/ir/module.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -349,20 +349,24 @@ void IRModuleNode::Update(const IRModule& mod) {
349349

350350
IRModule IRModule::FromExpr(const RelayExpr& expr,
351351
const tvm::Map<GlobalVar, BaseFunc>& global_funcs,
352-
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions) {
353-
auto mod = IRModule(global_funcs, type_definitions);
352+
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions,
353+
std::unordered_set<String> import_set,
354+
const std::string& name_hint) {
355+
auto mod = IRModule(global_funcs, type_definitions, std::move(import_set));
354356
BaseFunc func;
355-
std::string gv_name = "main";
357+
std::string gv_name = name_hint;
356358

357359
if (auto* func_node = expr.as<BaseFuncNode>()) {
358360
func = GetRef<BaseFunc>(func_node);
359361
if (auto opt = func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
360362
gv_name = opt.value();
361363
}
362-
363364
} else {
364365
func = relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod), {});
365366
}
367+
if (gv_name.empty()) {
368+
gv_name = "main";
369+
}
366370
auto main_gv = GlobalVar(gv_name);
367371
mod->Add(main_gv, func);
368372
return mod;

src/ir/transform.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ Pass GetPass(const String& pass_name) {
466466
return (*f)();
467467
}
468468

469-
// TODO(zhiics): we currenlty only sequentially execute each pass in
469+
// TODO(zhiics): we currently only sequentially execute each pass in
470470
// a Sequential without the consideration of their orders. The phase
471471
// ordering problem needs to be handled in the future.
472472
IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) const {

0 commit comments

Comments
 (0)