Skip to content

Commit 1f741bd

Browse files
committed
[checkpoint] clarrify lowering does not support flttened result tuple types
[checkpoint] Don't allow FromExprInContext to be ffi-able. Update tests/python/relay/test_backend_interpreter.py Co-authored-by: Altan Haan <altanh@cs.washington.edu> Update tests/python/relay/test_backend_interpreter.py Co-authored-by: Altan Haan <altanh@cs.washington.edu> Update tests/python/relay/test_backend_interpreter.py Co-authored-by: Altan Haan <altanh@cs.washington.edu> Update tests/python/relay/test_backend_interpreter.py Co-authored-by: Altan Haan <altanh@cs.washington.edu> [checkpoint] Fix target bug. [checkpoint] Explicit hash struct [checkpoint] flubbed add [checkpoint] Shape functions always run on cpu - Support target->lowered irmodule map in interpreter so can keep device vs cpu ir functions separated. - Cleanup flattening / construction of ADT/NDArray to mirror what we do for tuple types. - Make sure unit tests compile clean for no signed comparison. [checkpoint] logging to debug cuda failure [checkpoint] No signed/unsigned compare. Why is -Werror=sign-compare not part of the cmake setup??? [checkpoint] Include lint. [checkpoint] Cleanup tuple flattening. [checkpoint] First batch of Jared's comments, mostly fixup FromExpr. [checkpoint] Fixup target resolution. Somewhere early on I got confused about how the device context map works. Fixed. [checkpoint] I will run black before every commit. I will run black before every commit. [checkpoint] Bug fixes. [checkpoint] bug fixes [checkpoint] can't pass my own tests! [checkpoint] format [checkpoint] typos [checkpoint] bit the bullet and bring create_executor & evaluate together as much as possible [checkpoint] doc lints [checkpoint] sigh [checkpoint] more python lints, hoist executor for adt tests to exploit cache [checkpoint] Python lints [checkpoint] lint fixes [checkpoint] cleanup, nuke vlog stuff, collape create_executor() & evaluate() [checkpoint] Bug fixes. AOT is broken. [checkpoint] Separate eval-to-clousure and apply-closure phases at last [checkpoint] Fix GetType recursion, get debug going. [checkpoint] Audit python to collapse create_executor and evaluate phases Just a few places where this doesn't work, seem harmless. [checkpoint] Get interpreter working using tec::LowerTE, but no dynamic shapes. - Hide TECompiler impl inside te_compiler.cc. However I think it is already exposed into Python land so this probably won't be possible now. - Move 'optimize' pre-transforms from interpreter.py to interpreter.cc so can be applied uniformly to both mod and expr. - Don't push the expr into the mod in interpreter.py since it's done again in interpreter.cc. Instead just build the Call node with the reflected args. - Both the mod and the expr are prepared identically (same transforms, of which LowerTensorExpr should be one). - LowerTensorExpr can look through let-bound and global vars, eg let f = fn (..., Primitive=1) { ... } ... f(...) ==> @lowered_f = ... @lowered_f(...) - Lots of DLOGs that need to be removed or reorganized. [checkpoint] Support shape functions. TODO: - Unit tests. - Cleanup logging (VLOG?) - Don't build all prims on each apply. [checkpoint] typo [checkpoint] Don't allow evaling expr independently of preparing module. TODO: - Make eval(mod, expr) the interface. - GlobalVar's don't line up. - Rework use of interpreter in fold_constant.cc to make clear it is evaling prim calls which have already been prepared. - Find a dynamic shape example that works at HEAD. - Unit tests. [checkpoint] Interpreting expression with refs to module defs working Commit to interpreter evaling expr w.r.t. mod in single phase. Thankfully turns out no existing uses broke that assumption so we dodged a bullet. Binding of expr into mod to be evaled is a mess, needs to be fixed. Still can't confirm dynamic shapes working since don't have an example working at HEAD. Change to partial_eval needs to be tested, but smells ok. [checkpoint] Dynamic shapes working The use of TIRCallAttrs is pretty hacky but the shape function calls are at least working. Next is to tackle the 'build everything just to project one prim fun' problem. [checkpoint] Cache built prims, make sure build with minimal deps. [checkpoint] Cleanup expr-to-module hackery.
1 parent 49756a5 commit 1f741bd

File tree

89 files changed

+1994
-1325
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

89 files changed

+1994
-1325
lines changed

docs/langref/relay_pattern.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ Either match the first pattern or the second pattern.
406406
Domination
407407
**********
408408

409-
Match child pattern, find a match for the parent pattern, insuring that the child ultimately dominates the parrent (i.e., no nodes outside the pattern use outputs of the parent), and that ever node betwen the child and the pattern matches the path pattern.
409+
Match child pattern, find a match for the parent pattern, insuring that the child ultimately dominates the parent (i.e., no nodes outside the pattern use outputs of the parent), and that ever node between the child and the pattern matches the path pattern.
410410

411411
Function Pattern
412412
****************

include/tvm/ir/module.h

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include <string>
3737
#include <unordered_map>
3838
#include <unordered_set>
39+
#include <utility>
3940
#include <vector>
4041

4142
namespace tvm {
@@ -253,6 +254,14 @@ class IRModuleNode : public Object {
253254
/*! \brief Helper function for registering a typedef's constructors */
254255
void RegisterConstructors(const GlobalTypeVar& var, const TypeData& type);
255256

257+
/*!
258+
* \brief Returns a version of \p name which is unique amongst all function definitions in module.
259+
*
260+
* \param name The original name.
261+
* \return Updated name which is unique.
262+
*/
263+
String GetUniqueName(const String& name);
264+
256265
/*! \brief A map from string names to global variables that
257266
* ensures global uniqueness.
258267
*/
@@ -307,16 +316,38 @@ class IRModule : public ObjectRef {
307316
}
308317

309318
/*!
310-
* \brief Construct a module from a standalone expression.
319+
* \brief Constructs a module from a standalone expression \p expr.
320+
*
321+
* If \p expr is a function it will be bound directly. Otherwise a function over the free
322+
* variables of \p expr (possibly none) with \p expr as body is created and bound.
323+
*
324+
* The function is bound to, in preference order:
325+
* - The "global_symbol" attribute of \p expr, if it is a function with that attribute.
326+
* - 'main'
327+
* - A unique name derived from 'main' if 'main' is already bound in \p global_funcs.
311328
*
312-
* Allows one to optionally pass a global function map and
313-
* map of type definitions as well.
329+
* Additional global functions and type definitions may be included in the result module.
330+
*
331+
* See also \p FromExpr.
314332
*
315333
* \param expr The expression to set as the main function to the module.
316-
* \param global_funcs The global function map.
317-
* \param type_definitions Map of global type definitions
334+
* \param global_funcs The global function map. Default empty.
335+
* \param type_definitions The global type definition map. Default empty.
336+
* \param import_set Set of external modules already imported. Default empty.
337+
*
338+
* \returns A module with \p expr set as the main function, and the global var to which
339+
* \p expr was bound (typcially 'main').
318340
*
319-
* \returns A module with expr set as the main function.
341+
* TODO(mbs): Does import_set and the bound global var need to be exposed via ffi?
342+
*/
343+
static std::pair<IRModule, GlobalVar> FromExprInContext(
344+
const RelayExpr& expr, const Map<GlobalVar, BaseFunc>& global_funcs = {},
345+
const Map<GlobalTypeVar, TypeData>& type_definitions = {},
346+
std::unordered_set<String> import_set = {});
347+
348+
/*!
349+
* \brief As for \p FromExprInContext, but assuming \p expr is bound to 'main' and no
350+
* imports.
320351
*/
321352
TVM_DLL static IRModule FromExpr(const RelayExpr& expr,
322353
const Map<GlobalVar, BaseFunc>& global_funcs = {},

include/tvm/relay/interpreter.h

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -40,31 +40,11 @@
4040
#include <tvm/runtime/object.h>
4141
#include <tvm/target/target.h>
4242

43+
#include <unordered_set>
44+
4345
namespace tvm {
4446
namespace relay {
4547

46-
/*!
47-
*\brief Create a Interpreter function that can
48-
* evaluate an expression and produce a value.
49-
*
50-
* The resulting value can be passed to Python, making it easy to use
51-
* for testing and debugging.
52-
*
53-
* The interpreter interprets the program fragments not supported by the
54-
* TVM runtime, although the interpreter is naively implemented it uses
55-
* TVM operators for evaluating all operators.
56-
*
57-
* Our intent is that this will never be the most efficient implementation of
58-
* Relay's semantics, but a readable and clear one.
59-
*
60-
* \param mod The function module.
61-
* \param device The primary device that the interepreter runs on.
62-
* \param target Compiler target flag to compile the functions on the context.
63-
* \return A function that takes in an expression and returns a value.
64-
*/
65-
runtime::TypedPackedFunc<ObjectRef(Expr)> CreateInterpreter(IRModule mod, Device device,
66-
Target target);
67-
6848
/*! \brief The container type of Closures used by the interpreter. */
6949
class InterpreterClosureObj : public runtime::ClosureObj {
7050
public:
@@ -164,6 +144,52 @@ class ConstructorValue : public ObjectRef {
164144
TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueObj);
165145
};
166146

147+
/*!
148+
* \brief Returns a packed function over Relay expressions which will evaluate \p expr
149+
* applied to those arguments, where \p expr is w.r.t. the definitions in \p mod.
150+
*
151+
* This function is intended to support the Python 'debug' executor.
152+
*
153+
* The given \p expr should have function type. The given \p mod may be empty or
154+
* undefined if \p expr is self-contained. Relay arguments passed to the result
155+
* packed function must be constants, references, or constructors/tuples over such.
156+
* As much work as possible is done while constructing the result packed function, and
157+
* that function may be reasonably efficiently applied multiple times without redoing
158+
* unnecessary work.
159+
*
160+
* Primitives are lowered and compiled to packed functions for execution on \p device
161+
* with properties given by \p target. All other Relay constructs are interpreted.
162+
*
163+
* The interpreter is intended to be a 'reference' implementation of the Relay semantics
164+
* for testing and interactive use. It is not intended to be particularly efficient.
165+
*
166+
* \param mod A module containing definitions which can be referenced from
167+
* \p expr. May be empty or undefined.
168+
* \param expr An expression of function type to evaluate. May reference definitions from \p mod.
169+
* \param device The device on which all primitives will be executed.
170+
* \param target The compiler target flag for compiling primitives.
171+
* \return A packed function that takes an array of Relay expressions and returns the
172+
* result of applying \p expr to those arguments.
173+
*/
174+
TypedPackedFunc<ObjectRef(Array<Expr>)> EvalFunction(IRModule mod, Expr expr, Device device,
175+
Target target);
176+
177+
/*!
178+
* \brief Evaluates \p expr and returns its result.
179+
*
180+
* This function is intended to support TVM constant evaluation.
181+
*
182+
* \param expr An expression to evaluate.
183+
* \param type_definitions Global type definitions which \p expr may references.
184+
* \param import_set Already imported external modules.
185+
* \param device The device on which all primitives will be executed.
186+
* \param target The compiler target flag for compiling primitives.
187+
* @return The object representing the result.
188+
*/
189+
ObjectRef Eval(Expr expr, Map<GlobalTypeVar, TypeData> type_definitions,
190+
std::unordered_set<String> import_set, Device device, Target target);
191+
167192
} // namespace relay
168193
} // namespace tvm
194+
169195
#endif // TVM_RELAY_INTERPRETER_H_

python/tvm/relay/analysis/analysis.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -433,8 +433,7 @@ def get_calibration_data(mod, data):
433433
mod = _ffi_api.get_calibrate_module(mod)
434434
mod = transform.Inline()(mod)
435435

436-
ref_ex = build_module.create_executor("graph", mod=mod, device=cpu(0))
437-
ref_res = ref_ex.evaluate()(**data)
436+
ref_res = build_module.create_executor("graph", mod=mod, device=cpu(0)).evaluate()(**data)
438437

439438
calib_data = {}
440439
for gvar, indices in output_map.items():

python/tvm/relay/backend/interpreter.py

Lines changed: 35 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,9 @@
2222

2323
import tvm._ffi
2424
from tvm.runtime import container, Object
25-
from tvm.ir import IRModule
2625

2726
from . import _backend
28-
from .. import _make, analysis, transform
27+
from .. import _make, analysis
2928
from ... import nd
3029
from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, const
3130
from ..function import Function
@@ -178,6 +177,7 @@ def evaluate(self, expr=None, binds=None):
178177
return self._make_executor(expr)
179178

180179
# normal expression evaluated by running a function.
180+
# TODO(mbs): This should really be type rather than syntax driven.
181181
func = Function([], expr)
182182
return self._make_executor(func)()
183183

@@ -196,65 +196,54 @@ class Interpreter(Executor):
196196
197197
target : tvm.Target
198198
The target option to build the function.
199+
200+
CAUTION: Despite the API the module is prepared upon each call to evaluate
201+
rather than once in create_executor.
202+
That is:
203+
.. code-block:: python
204+
205+
executor = relay.create_executor(kind="debug", mod=module)
206+
a = executor.evaluate(expr)(args1)
207+
b = executor.evaluate(expr)(args2)
208+
209+
will prepare all the bindings in module twice. For efficiency, try to hoist
210+
calls to evaluate as high as possible, preferably immediately after create_executor:
211+
.. code-block:: python
212+
213+
func = relay.create_executor(kind="debug", mod=module).evaluate(expr)
214+
a = func(args1)
215+
b = func(args2)
199216
"""
200217

201218
def __init__(self, mod, device, target):
202219
self.mod = mod
203220
self.device = device
204221
self.target = target
205222

206-
def optimize(self):
207-
"""Optimize functions in a module.
208-
209-
Returns
210-
-------
211-
opt_mod : tvm.IRModule
212-
The optimized module.
213-
"""
214-
seq = tvm.transform.Sequential(
215-
[
216-
# tvm.parser.AnnotateSpans(),
217-
transform.SimplifyInference(),
218-
transform.FuseOps(0),
219-
transform.ToANormalForm(),
220-
transform.InferType(),
221-
]
222-
)
223-
mod = seq(self.mod)
224-
return mod
225-
226223
def _make_executor(self, expr=None):
227224
if expr is None or isinstance(expr, GlobalVar):
228225
assert self.mod is not None
229226

230-
_intrp = _backend.CreateInterpreter(self.optimize(), self.device, self.target)
227+
if expr is None:
228+
# A missing expr denotes 'main' in the given module.
229+
expr = self.mod.get_global_var("main")
231230

232-
def _interp_wrapper(*args, **kwargs):
233-
if expr is None:
234-
args = self._convert_args(self.mod["main"], args, kwargs)
231+
# Evaluate expr to a packed function we can efficiently re-apply
232+
# to Relay arguments.
233+
func = _backend.EvalFunction(self.mod, expr, self.device, self.target)
234+
235+
def _apply_args(*args, **kwargs):
236+
if isinstance(expr, GlobalVar):
237+
# When expanding args, look inside the actual global definition so kwargs
238+
# can be matched.
239+
args = self._convert_args(self.mod[expr.name_hint], args, kwargs)
235240
else:
236241
args = self._convert_args(expr, args, kwargs)
237-
242+
# Reflect python arguments up into Relay.
238243
relay_args = []
239244
for arg in args:
240245
relay_args.append(_arg_to_ast(self.mod, arg))
246+
# Apply func to Relay args
247+
return func(relay_args)
241248

242-
# Set the entry function for the module.
243-
if expr is None:
244-
pass
245-
elif isinstance(expr, GlobalVar):
246-
self.mod["main"] = self.mod[expr]
247-
else:
248-
assert isinstance(expr, Function)
249-
func = Function([], Call(expr, relay_args))
250-
relay_args = []
251-
if self.mod:
252-
self.mod["main"] = func
253-
else:
254-
self.mod = IRModule.from_expr(func)
255-
256-
mod = self.optimize()
257-
opt_expr = Call(mod["main"], relay_args)
258-
return _intrp(opt_expr)
259-
260-
return _interp_wrapper
249+
return _apply_args

python/tvm/relay/build_module.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,9 @@ def _graph_wrapper(*args, **kwargs):
511511
return _graph_wrapper
512512

513513

514+
# TODO(mbs): Collapse the create_executor/evaluate phases together since a) most callers don't
515+
# reuse the executor for multiple expressions and b) any preparation necessary for the expression
516+
# evaluation needs to (currently) be done along with preparation for the module.
514517
def create_executor(kind="debug", mod=None, device=None, target="llvm", params=None):
515518
"""Factory function to create an executor.
516519

python/tvm/relay/frontend/common.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -545,11 +545,12 @@ def infer_value(input_val, params, mod=None):
545545
mod["main"] = _function.Function(analysis.free_vars(input_val), input_val)
546546
else:
547547
mod = IRModule.from_expr(input_val)
548-
exc = tvm.relay.create_executor("debug", mod=mod, device=tvm.cpu(), target="llvm")
549548
inputs = []
550549
for param in mod["main"].params:
551550
inputs.append(params[param.name_hint])
552-
result = exc.evaluate()(*inputs)
551+
result = tvm.relay.create_executor(
552+
"debug", mod=mod, device=tvm.cpu(), target="llvm"
553+
).evaluate()(*inputs)
553554
return result
554555

555556

python/tvm/relay/testing/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,13 @@ def check_grad(
134134
test_inputs = inputs
135135

136136
for target, dev in enabled_targets():
137-
intrp = relay.create_executor(device=dev, target=target)
137+
# Eval the backward and forward functions
138+
# TODO(mbs): Evaluate a pair of functions so can share preparation between them.
139+
bwd_func_compiled = relay.create_executor(device=dev, target=target).evaluate(bwd_func)
140+
fwd_func_compiled = relay.create_executor(device=dev, target=target).evaluate(fwd_func)
138141

139142
# Get analytic gradients.
140-
_, grads = intrp.evaluate(bwd_func)(*inputs)
143+
_, grads = bwd_func_compiled(*inputs)
141144
grads = [grad.numpy().astype("float64") for grad in grads]
142145

143146
# Throw out gradients we aren't testing
@@ -154,7 +157,6 @@ def check_grad(
154157
assert len(grads) > 0, "You must test at least one gradient."
155158

156159
# Get numeric gradients for each dimension of each param, using two-sided approximation.
157-
fwd_func_compiled = intrp.evaluate(fwd_func)
158160
approx_grads = []
159161
for x in test_inputs:
160162
approx_grad = np.zeros(x.shape)

0 commit comments

Comments
 (0)