Skip to content

Commit 8a9b978

Browse files
altanhyongwww
authored andcommitted
fix IRModule parsing by resolving GlobalVars later (apache#41)
* fix IRModule parsing by resolving GlobalVars later * disable fast path that causes type inference problem for now * print checked type on vars if present * document ResolveGlobals
1 parent 3fc8133 commit 8a9b978

File tree

12 files changed

+297
-197
lines changed

12 files changed

+297
-197
lines changed

python/tvm/relax/transform/transform.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,68 +19,82 @@
1919
import tvm.ir
2020
from . import _ffi_api
2121

22+
2223
@tvm._ffi.register_object("relax.FunctionPass")
2324
class FunctionPass(tvm.ir.transform.Pass):
2425
"""A pass that works on each tvm.relax.Function in a module. A function
2526
pass class should be created through `function_pass`.
2627
"""
2728

28-
def FMARewrite() -> tvm.transform.Pass:
29+
30+
def FMARewrite() -> tvm.ir.transform.Pass:
2931
"""Perform fused multiply add rewriting in dataflow blocks.
3032
3133
Returns
3234
-------
33-
ret: tvm.transform.Pass
35+
ret: tvm.ir.transform.Pass
3436
"""
3537
return _ffi_api.FMARewrite()
3638

3739

38-
def ToNonDataflow() -> tvm.transform.Pass:
40+
def ToNonDataflow() -> tvm.ir.transform.Pass:
3941
"""Transform all dataflow structure to non-dataflow version.
4042
4143
Returns
4244
-------
43-
ret: tvm.transform.Pass
45+
ret: tvm.ir.transform.Pass
4446
"""
4547
return _ffi_api.ToNonDataflow()
4648

4749

48-
def CallDPSRewrite() -> tvm.transform.Pass:
50+
def CallDPSRewrite() -> tvm.ir.transform.Pass:
4951
"""Perform explicit tensor allocation for call_dps.
5052
5153
Returns
5254
-------
53-
ret: tvm.transform.Pass
55+
ret: tvm.ir.transform.Pass
5456
"""
5557
return _ffi_api.CallDPSRewrite()
5658

5759

58-
def VMMemoryLower() -> tvm.transform.Pass:
60+
def VMMemoryLower() -> tvm.ir.transform.Pass:
5961
"""Perform memory lowering. Lowers the relax.builtin.alloc_tensor intrinsic to VM intrinsics.
6062
6163
Returns
6264
-------
63-
ret: tvm.transform.Pass
65+
ret: tvm.ir.transform.Pass
6466
"""
6567
return _ffi_api.VMMemoryLower()
6668

6769

68-
def VMShapeLower() -> tvm.transform.Pass:
69-
"""Lower the shape expressions in relax to VM shape heap manipulations and generate related
70+
def VMShapeLower() -> tvm.ir.transform.Pass:
71+
"""Lower the shape expressions in relax to VM shape heap manipulations and generate related
7072
TIR functions to do shape calculations.
7173
7274
Returns
7375
-------
74-
ret: tvm.transform.Pass
76+
ret: tvm.ir.transform.Pass
7577
"""
7678
return _ffi_api.VMShapeLower()
7779

7880

79-
def ToANF() -> tvm.transform.Pass:
81+
def ToANF() -> tvm.ir.transform.Pass:
8082
"""Transforming Relax IR to A-normal form.
8183
8284
Returns
8385
-------
84-
ret: tvm.transform.Pass
86+
ret: tvm.ir.transform.Pass
8587
"""
8688
return _ffi_api.ToANF()
89+
90+
91+
def ResolveGlobals() -> tvm.ir.transform.Pass:
92+
"""Resolve global variables using string equality. This ensures all GlobalVars in the IR refer
93+
to the correct GlobalVar of the input IRModule. An error is reported if any GlobalVar cannot be
94+
resolved.
95+
96+
Returns
97+
-------
98+
ret: tvm.ir.transform.Pass
99+
"""
100+
return _ffi_api.ResolveGlobals()

python/tvm/script/parser.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from synr import ast, Transformer, to_ast
3030

3131
import tvm
32-
from tvm import IRModule
32+
from tvm import IRModule, relax
3333
from tvm._ffi.base import TVMError
3434
from tvm.ir import GlobalVar
3535
from tvm.ir.function import BaseFunc
@@ -1381,5 +1381,9 @@ def ir_module(input_module: type) -> IRModule:
13811381
func_dict = {
13821382
name: f for name, f in input_module.__dict__.items() if isinstance(f, BaseFunc)
13831383
}
1384-
return IRModule(func_dict)
1384+
mod = IRModule(func_dict)
1385+
mod = relax.transform.ResolveGlobals()(mod)
1386+
# FIXME(@altanh): where is the source map?
1387+
return mod
1388+
13851389
raise TypeError("Only class definitions are supported.")

python/tvm/script/relax/parser.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -970,9 +970,12 @@ def transform_expr(self, expr: ast.Expr) -> relax.Expr:
970970
var_name = expr.id.name
971971
if _is_registered(var_name, op_set=self._registered_ops):
972972
return relay.op.get(var_name)
973-
if var_name not in self.scope:
974-
self.report_error("undefined variable", expr.span)
975-
return self.scope[var_name]
973+
if var_name in self.scope:
974+
return self.scope[var_name]
975+
# NOTE: this is a "hack" to get around Python eagerly parsing class method decorators
976+
# first (meaning we need to resolve them after the functions are parsed). These
977+
# GlobalVars need to be resolved using string equality only.
978+
return relay.GlobalVar(var_name)
976979

977980
elif isinstance(expr, ast.Constant):
978981
# FIXME(@altanh): use internal representation that doesn't have precision limits here

src/printer/relax_script_printer.cc

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -497,13 +497,18 @@ Doc RelaxScriptPrinter::PrintFunctionDef(const Doc& name, const relax::Function&
497497
}
498498

499499
Doc RelaxScriptPrinter::PrintVarAnnotation(const relax::Var& var) {
500+
// TODO(@altanh): we should consider moving annotation into binding
500501
Doc doc;
501-
if (var->type_annotation.defined()) {
502+
Type annotation = var->checked_type_;
503+
if (!annotation.defined()) {
504+
annotation = var->type_annotation.value_or(Type());
505+
}
506+
if (annotation.defined()) {
502507
doc << ": ";
503-
if (const relax::DynTensorTypeNode* tty = var->type_annotation.as<relax::DynTensorTypeNode>()) {
508+
if (const relax::DynTensorTypeNode* tty = annotation.as<relax::DynTensorTypeNode>()) {
504509
doc << PrintTensorAnnotation(GetRef<DynTensorType>(tty), var->shape_);
505510
} else {
506-
doc << Print(var->type_annotation);
511+
doc << Print(annotation);
507512
}
508513
}
509514
return doc;

src/relax/ir/block_builder.cc

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -203,10 +203,10 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor<Expr(const Expr&)> {
203203

204204
private:
205205
/*!
206-
* \brief Memoization map for expressions using Id for equality of variables.
207-
*/
206+
* \brief Memoization map for expressions using Id for equality of variables.
207+
*/
208208
class ExprMemo {
209-
public:
209+
public:
210210
Optional<Expr> Get(const Expr& expr) {
211211
if (const VarNode* var = expr.as<VarNode>()) {
212212
auto it = var_memo_.find(var->vid);
@@ -230,7 +230,7 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor<Expr(const Expr&)> {
230230
}
231231
}
232232

233-
private:
233+
private:
234234
std::unordered_map<Id, Expr, ObjectPtrHash, ObjectPtrEqual> var_memo_;
235235
std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual> expr_memo_;
236236
};
@@ -370,7 +370,9 @@ Var BlockBuilderNode::Emit(const Expr& expr, bool is_dataflow, std::string name_
370370
Var BlockBuilderNode::Emit(const VarBinding& binding) {
371371
BlockFrame* cur_frame = CurrentFrame();
372372
if (cur_frame->is_dataflow) {
373-
ICHECK(binding->var.as<DataflowVarNode>());
373+
ICHECK(binding->var.as<DataflowVarNode>())
374+
<< "Emit can only be used for local bindings in a dataflow block, use EmitOutput for "
375+
"output bindings instead";
374376
}
375377
cur_frame->bindings.push_back(binding);
376378
binding_table_[binding->var->vid] = binding->value;
@@ -408,9 +410,11 @@ Var BlockBuilderNode::EmitMatchShape(const Expr& value, const Array<PrimExpr>& p
408410

409411
Var BlockBuilderNode::EmitMatchShape(const MatchShape& binding) {
410412
BlockFrame* cur_frame = CurrentFrame();
411-
if (cur_frame->is_dataflow && binding->var.defined()) {
412-
ICHECK(!binding->var.as<DataflowVarNode>())
413-
<< "cannot bind DataflowVar outside dataflow block.";
413+
if (binding->var.defined()) {
414+
ICHECK(!cur_frame->is_dataflow || binding->var.as<DataflowVarNode>())
415+
<< "EmitMatchShape can only be used for local bindings in a dataflow block.";
416+
ICHECK(cur_frame->is_dataflow || !binding->var.as<DataflowVarNode>())
417+
<< "cannot emit dataflow vars outside a dataflow block: " << binding->var->name_hint();
414418
}
415419
cur_frame->bindings.push_back(binding);
416420
// TODO(@altanh, @yuchen): what value should we bind? Consider
@@ -511,13 +515,9 @@ BlockBuilderNode::BlockFrame* BlockBuilderNode::CurrentFrame() {
511515
return &block_stack_.top();
512516
}
513517

514-
NameTable* BlockBuilderNode::name_table() {
515-
return name_table_.get();
516-
}
518+
NameTable* BlockBuilderNode::name_table() { return name_table_.get(); }
517519

518-
BlockBuilder BlockBuilder::Create() {
519-
return BlockBuilder(make_object<BlockBuilderNode>());
520-
}
520+
BlockBuilder BlockBuilder::Create() { return BlockBuilder(make_object<BlockBuilderNode>()); }
521521

522522
TVM_REGISTER_GLOBAL("relax.BlockBuilderCreate").set_body_typed(BlockBuilder::Create);
523523

src/relax/ir/expr_functor.cc

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -354,11 +354,20 @@ void ExprMutator::VisitBinding_(const VarBindingNode* binding) {
354354
Expr new_value = this->VisitExpr(binding->value);
355355
Var new_var = this->VisitVarDef(binding->var);
356356

357-
if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) {
358-
// no-op if there is no change
359-
builder_->Emit(GetRef<VarBinding>(binding));
360-
return;
361-
}
357+
auto emit = [this](VarBinding b) {
358+
if (this->builder_->CurrentBlockIsDataFlow() && !b->var.as<DataflowVarNode>()) {
359+
this->builder_->EmitOutput(b);
360+
} else {
361+
this->builder_->Emit(b);
362+
}
363+
};
364+
365+
// FIXME(@altanh): try to clean up all the fast paths and ty/shape infer, it's getting unwieldy
366+
// if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) {
367+
// // no-op if there is no change
368+
// emit(GetRef<VarBinding>(binding));
369+
// return;
370+
// }
362371

363372
{
364373
Var temp = WithShapeAndType(new_var, new_value->shape_, new_value->checked_type_);
@@ -368,11 +377,7 @@ void ExprMutator::VisitBinding_(const VarBindingNode* binding) {
368377
}
369378
}
370379

371-
if (builder_->CurrentBlockIsDataFlow() && !new_var.as<DataflowVarNode>()) {
372-
builder_->EmitOutput(VarBinding(new_var, new_value));
373-
} else {
374-
builder_->Emit(VarBinding(new_var, new_value));
375-
}
380+
emit(VarBinding(new_var, new_value));
376381
}
377382

378383
void ExprMutator::VisitBinding_(const MatchShapeNode* binding) {
@@ -387,8 +392,8 @@ void ExprMutator::VisitBinding_(const MatchShapeNode* binding) {
387392
if (new_value->checked_type_.defined() && new_value->checked_type_.as<DynTensorTypeNode>()) {
388393
new_shape = new_pattern;
389394
}
390-
Var temp =
391-
WithShapeAndType(this->VisitVarDef(binding->var), new_shape, new_value->checked_type_);
395+
new_var = this->VisitVarDef(binding->var);
396+
Var temp = WithShapeAndType(new_var, new_shape, new_value->checked_type_);
392397
if (!temp.same_as(new_var)) {
393398
new_var = temp;
394399
this->var_remap_[binding->var->vid] = new_var;

src/relax/op/tensor/binary.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ Type InferTypeBinaryBroadcast(const Call& call, DiagnosticContext diag_ctx) {
8181
auto* t1 = rhs_type.as<DynTensorTypeNode>();
8282
if (!t0 || !t1) {
8383
diag_ctx.EmitFatal(Diagnostic::Error(call->span)
84-
<< "Both lhs and rhs should be DynTensor for broadcasting");
84+
<< "Both lhs and rhs should be DynTensor for broadcasting, but got "
85+
<< lhs_type->GetTypeKey() << " and " << rhs_type->GetTypeKey());
8586
}
8687

8788
DataType output_dtype;
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
/*!
20+
* \file src/relax/transform/resolve_globals.cc
21+
* \brief Resolve GlobalVars using string equality.
22+
*/
23+
#include <tvm/relax/expr_functor.h>
24+
#include <tvm/relax/transform.h>
25+
26+
namespace tvm {
27+
namespace relax {
28+
29+
class GlobalVarResolver : public ExprMutator {
30+
public:
31+
GlobalVarResolver(IRModule mod, DiagnosticContext diag_ctx) : mod_(mod), diag_ctx_(diag_ctx) {}
32+
33+
Expr VisitExpr_(const GlobalVarNode* gvar) {
34+
if (!mod_->ContainGlobalVar(gvar->name_hint)) {
35+
diag_ctx_.Emit(Diagnostic::Error(gvar->span)
36+
<< "undefined variable/global \"" << gvar->name_hint << "\"");
37+
return GetRef<GlobalVar>(gvar);
38+
}
39+
return mod_->GetGlobalVar(gvar->name_hint);
40+
}
41+
42+
private:
43+
/*! \brief the IRModule used for GlobalVar lookup. */
44+
IRModule mod_;
45+
DiagnosticContext diag_ctx_;
46+
};
47+
48+
namespace transform {
49+
50+
Pass ResolveGlobals() {
51+
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
52+
[](Function f, IRModule m, PassContext pc) {
53+
// TODO(@altanh): make sure pc always has diag_ctx?
54+
GlobalVarResolver resolver(m, pc->diag_ctx.value());
55+
return Downcast<Function>(resolver.VisitExpr(f));
56+
};
57+
return CreateFunctionPass(pass_func, 0, "ResolveGlobals", {});
58+
}
59+
60+
TVM_REGISTER_GLOBAL("relax.transform.ResolveGlobals").set_body_typed(ResolveGlobals);
61+
62+
} // namespace transform
63+
64+
} // namespace relax
65+
} // namespace tvm

0 commit comments

Comments
 (0)