Skip to content

Commit 0def9c9

Browse files
slyubomirskyjroesch
authored andcommitted
Add missing tests and modify attributes (#5)
Restores the tests which were lost in the repo port, and makes it possible for conv2d to typecheck, as well as some integration tests. * Add back python tests that were missing from the repo move (history lost, sorry) * Ensure shape evaluator doesn't trip up on type vars or type ids * Add tests for shape evaluator when faced with type var or type id * Use Strings as key for Attributes, repair tests and uses * Add preliminary information for conv2d operator * Add test of attr propagation * Ensure attributes hash by string value rather than string pointer identity (tests still failing though, idk why) * Adjust shape equality checks to use ordinary visitor, as nested type id information was not transferring * Repair integration test by using std::unordered_map for attrs, add cases * Add regression test for alpha-eq comparison of type IDs across nested shapes (was losing the type id equality map in alpha_eq * Add more integration test variants * Correct import names in tests * Add clarifying comment to typechecker * Missing paren in operators.py * Correct python source directory for mypy
1 parent 3d1b67b commit 0def9c9

32 files changed

+3151
-67
lines changed

relay/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ cyclean:
7575
lint: pylint cpplint
7676

7777
mypy:
78-
python3.6 -m mypy --ignore-missing-imports python/tvm/relay tests/python/relay/
78+
python3.6 -m mypy --ignore-missing-imports python/relay tests/python/relay/
7979

8080
cpplint:
8181
python3.6 dmlc-core/scripts/lint.py relay cpp include src

relay/include/tvm/relay/ir/base.h

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,41 @@ struct Expr : public NodeRef {
150150
using ContainerType = ExprNode;
151151
};
152152

153+
struct StringNode;
154+
155+
/*! \brief an entry that represents output data from a node */
156+
class String : public NodeRef {
157+
public:
158+
/*! \brief default constructor, used internally */
159+
String() {}
160+
explicit String(std::shared_ptr<tvm::Node> n) : NodeRef(n) {}
161+
inline const StringNode* operator->() const;
162+
163+
/*! \brief specify container node */
164+
using ContainerType = StringNode;
165+
};
166+
167+
struct StringNode : public ExprNode {
168+
public:
169+
std::string name;
170+
171+
StringNode() {}
172+
173+
void VisitAttrs(tvm::AttrVisitor* v) final {
174+
v->Visit("span", &span);
175+
v->Visit("name", &name);
176+
}
177+
178+
TVM_DLL static String make(std::string name);
179+
180+
static constexpr const char* _type_key = "nnvm.String";
181+
TVM_DECLARE_NODE_TYPE_INFO(StringNode, ExprNode);
182+
};
183+
184+
inline const StringNode* String::operator->() const {
185+
return static_cast<const StringNode*>(node_.get());
186+
}
187+
153188
class LocalId;
154189

155190
/*! \brief A LocalId from the node's current type to target type. */
@@ -194,21 +229,33 @@ class GlobalIdNode : public ExprNode {
194229

195230
RELAY_DEFINE_NODE_REF(GlobalId, GlobalIdNode, Expr);
196231

232+
struct StringHash {
233+
size_t operator()(const String &key) const {
234+
return std::hash<std::string>() (key->name);
235+
}
236+
};
237+
238+
struct StringEqual {
239+
bool operator()(const String &lhs, const String &rhs) const {
240+
return lhs->name == rhs->name;
241+
}
242+
};
243+
197244
class Attributes;
198245

199246
/*! \brief A floating point value. */
200247
class AttributesNode : public Node {
201248
public:
202-
tvm::Map<LocalId, Expr> attributes;
249+
std::unordered_map<String, Expr, StringHash, StringEqual> attributes;
203250

204251
AttributesNode() {}
205252

206253
void VisitAttrs(tvm::AttrVisitor* v) final {
207254
v->Visit("span", &span);
208-
v->Visit("attributes", &attributes);
209255
}
210256

211-
TVM_DLL static Attributes make(tvm::Map<LocalId, Expr> attributes);
257+
TVM_DLL static Attributes make(std::unordered_map<String, Expr,
258+
StringHash, StringEqual> attributes);
212259

213260
static constexpr const char* _type_key = "nnvm.Attributes";
214261
TVM_DECLARE_NODE_TYPE_INFO(AttributesNode, Node);

relay/include/tvm/relay/ir/expr.h

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,41 +17,6 @@ namespace relay {
1717

1818
typedef HalideIR::Type HType;
1919

20-
struct StringNode;
21-
22-
/*! \brief an entry that represents output data from a node */
23-
class String : public NodeRef {
24-
public:
25-
/*! \brief default constructor, used internally */
26-
String() {}
27-
explicit String(std::shared_ptr<tvm::Node> n) : NodeRef(n) {}
28-
inline const StringNode* operator->() const;
29-
30-
/*! \brief specify container node */
31-
using ContainerType = StringNode;
32-
};
33-
34-
struct StringNode : public ExprNode {
35-
public:
36-
std::string name;
37-
38-
StringNode() {}
39-
40-
void VisitAttrs(tvm::AttrVisitor* v) final {
41-
v->Visit("span", &span);
42-
v->Visit("name", &name);
43-
}
44-
45-
TVM_DLL static String make(std::string name);
46-
47-
static constexpr const char* _type_key = "nnvm.String";
48-
TVM_DECLARE_NODE_TYPE_INFO(StringNode, ExprNode);
49-
};
50-
51-
inline const StringNode* String::operator->() const {
52-
return static_cast<const StringNode*>(node_.get());
53-
}
54-
5520
class FloatLit;
5621

5722
/*! \brief Floating point literal `0.0`, `5e10`. */

relay/include/tvm/relay/ir/type.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ class ShapeAttr;
271271
/*! \brief Shape singleton that captures the value of an attribute */
272272
class ShapeAttrNode : public TypeNode {
273273
public:
274-
LocalId id;
274+
String id;
275275

276276
ShapeAttrNode() {}
277277

@@ -280,7 +280,7 @@ class ShapeAttrNode : public TypeNode {
280280
v->Visit("span", &span);
281281
}
282282

283-
TVM_DLL static ShapeAttr make(LocalId id);
283+
TVM_DLL static ShapeAttr make(String id);
284284

285285
static constexpr const char* _type_key = "nnvm.ShapeAttr";
286286
TVM_DECLARE_NODE_TYPE_INFO(ShapeAttrNode, TypeNode);

relay/python/relay/ir/base.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,7 @@ def set_span(self, span: Span):
8888

8989
@register_nnvm_node
9090
class Attributes(NodeBase):
91-
def __getitem__(self, index):
92-
return self.attributes[index]
93-
91+
pass
9492

9593
class Value(NodeBase):
9694
"""Base class of all values.

relay/python/relay/ir/expr.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77

88
@register_nnvm_node
99
class String(Expr):
10-
value: str
10+
name: str
11+
12+
# need to define hash to use in maps (e.g., for attrs)
13+
def __hash__(self):
14+
return self.name.__hash__()
1115

1216

1317
@register_nnvm_node

relay/python/relay/make.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def TypeVar() -> ir.Type: ...
2929
def PlaceholderType() -> ir.Type: ...
3030
def ShapeSeq(shapes: List[ir.Type]) -> ir.ShapeSeq: ...
3131
def ShapeSingleton(value: int) -> ir.ShapeSingleton: ...
32-
def ShapeAttr(id: ir.LocalId) -> ir.ShapeAttr: ...
32+
def ShapeAttr(id: ir.String) -> ir.ShapeAttr: ...
3333
def ShapeProjection(shape: ir.Type, value: int) -> ir.ShapeProjection: ...
3434
def ShapeBinaryOp(op: ir.ShapeOp, left: ir.Type, right: ir.Type) -> ir.ShapeBinaryOp: ...
3535

@@ -46,7 +46,7 @@ def TensorLit(value: List[ir.Expr]) -> ir.TensorLit: ...
4646
def ProductLit(fields: List[ir.Expr]) -> ir.Expr: ...
4747
def BoolLit(value: bool) -> ir.BoolLit: ...
4848
def String(value: str) -> ir.String: ...
49-
def Attributes(attrs: Dict[ir.LocalId, ir.Expr]) -> ir.Attributes: ...
49+
def Attributes(attrs: Dict[ir.String, ir.Expr]) -> ir.Attributes: ...
5050
def Call(func: ir.Expr, args: List[ir.Expr], attrs: ir.Attributes) -> ir.Call: ...
5151
def UnaryOp(op: ir.UOp, arg: ir.Expr) -> ir.Expr: ...
5252
def BinaryOp(op: ir.BOp, left: ir.Expr, right: ir.Expr) -> ir.Expr: ...

relay/python/relay/operators.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
import topi
66
from relay.env import Environment
77
import relay.ir as ir
8-
from relay.make import Operator, IntrinsicId, TypeId, TensorType, FloatType
8+
from relay.make import String, Operator, IntrinsicId, TypeId, TensorType, FloatType
99
from relay.make import TypeQuantifier, TypeArrow, ProductType
10+
from relay.make import ShapeAttr, ShapeBinaryOp, ShapeProjection, ShapeSingleton, ShapeSeq
1011

1112
# TODO(@jroesch): Fix up my type
1213
__operator_registry__: Dict[str, Any] = {}
@@ -128,6 +129,15 @@ def broadcast_mul_compiler(func_ty: ir.Type) -> Any:
128129
module = tvm.build(schedule, Inputs + [Output], __tgt__, target_host=__tgt_host__, name="broadcast_mul_compiler")
129130
return module.get_function("broadcast_mul_compiler")
130131

132+
# TODO(@jroesch): ensure this interfaces correctly
133+
# note that the type provided doesn't handle padding
134+
# feel free to assume some default behavior
135+
def conv2d_compiler(func_ty: ir.Type) -> Any:
136+
Inputs, ret_ty = func_ty_to_placeholders(func_ty)
137+
Output = topi.nn.conv2d(*Inputs)
138+
schedule = tvm.create_schedule(Output.op)
139+
module = tvm.build(schedule, Inputs + [Output], __tgt__, target_host=__tgt_host__, name="conv2d_compiler")
140+
return module.get_function("conv2d_compiler")
131141

132142
def initialize_operators(env) -> None:
133143
"""Initialize the default set of operators for the system, this will populate
@@ -175,3 +185,44 @@ def initialize_operators(env) -> None:
175185
bmul_type = TypeQuantifier(shape, TypeArrow(ProductType([in_out_type, in_out_type]), in_out_type))
176186
# TODO: reverse mode
177187
register_op(env, 'broadcast_mul', bmul_type, broadcast_mul_compiler)
188+
189+
# Conv2d
190+
# input: [batch, in_channel, in_height, in_width]
191+
# filter: [filter_height, filter_width, in_channel, num_filter]
192+
# output shape: [out_height, out_width, num_filter, batch]
193+
# out_height = (in_height - filter_h)/stride_h + 1
194+
# out_width = (in_width - filter_w)/stride_w + 1
195+
stride_h = ShapeAttr(String("stride_h"))
196+
stride_w = ShapeAttr(String("stride_w"))
197+
btvar = TypeId("bt", Kind.BaseType)
198+
input_shape = TypeId("input_shape", Kind.Shape)
199+
filter_shape = TypeId("filter_shape", Kind.Shape)
200+
output_shape = ShapeSeq([
201+
ShapeBinaryOp(ShapeOp.SHPLUS,
202+
ShapeBinaryOp(ShapeOp.SHDIV,
203+
ShapeBinaryOp(ShapeOp.SHSUB,
204+
ShapeProjection(input_shape, 2),
205+
ShapeProjection(filter_shape, 0)),
206+
stride_h),
207+
ShapeSingleton(1)),
208+
ShapeBinaryOp(ShapeOp.SHPLUS,
209+
ShapeBinaryOp(ShapeOp.SHDIV,
210+
ShapeBinaryOp(ShapeOp.SHSUB,
211+
ShapeProjection(input_shape, 3),
212+
ShapeProjection(filter_shape, 1)),
213+
stride_w),
214+
ShapeSingleton(1)),
215+
ShapeProjection(filter_shape, 3),
216+
ShapeProjection(input_shape, 0)
217+
])
218+
conv2d_type = TypeQuantifier(
219+
btvar,
220+
TypeQuantifier(
221+
input_shape,
222+
TypeQuantifier(
223+
filter_shape,
224+
TypeArrow(ProductType([TensorType(btvar, input_shape), TensorType(btvar, filter_shape)],
225+
TensorType(btvar, output_shape)
226+
)))))
227+
# TODO: reverse mode
228+
register_op(env, 'conv2d', conv2d_type, conv2d_compiler)

relay/src/tvm/relay/alpha_eq.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -399,8 +399,8 @@ struct TypeAlphaEq : TypeVisitor<const Type &> {
399399

400400
void VisitType_(const ShapeAttrNode *sn1, const Type &t2) override {
401401
if (const ShapeAttrNode *sn2 = t2.as<ShapeAttrNode>()) {
402-
// require exact quality of identifiers
403-
equal = equal && (sn1->id == sn2->id);
402+
// check equality of names
403+
equal = equal && (sn1->id->name == sn2->id->name);
404404
} else {
405405
equal = false;
406406
}
@@ -418,7 +418,7 @@ struct TypeAlphaEq : TypeVisitor<const Type &> {
418418
auto size = shape1->shapes.size();
419419
for (size_t i = 0U; i < size; i++) {
420420
if (!equal) { return; }
421-
equal = equal && alpha_eq(shape1->shapes[i], shape2->shapes[i]);
421+
this->VisitType(shape1->shapes[i], shape2->shapes[i]);
422422
}
423423
} else {
424424
equal = false;
@@ -433,7 +433,7 @@ struct TypeAlphaEq : TypeVisitor<const Type &> {
433433
return;
434434
}
435435
ShapeProjection proj2 = GetRef<ShapeProjection>(spn2);
436-
equal = equal && alpha_eq(proj1->shape, proj2->shape);
436+
this->VisitType(proj1->shape, proj2->shape);
437437
} else {
438438
equal = false;
439439
}
@@ -447,7 +447,8 @@ struct TypeAlphaEq : TypeVisitor<const Type &> {
447447
return;
448448
}
449449
ShapeBinaryOp op2 = GetRef<ShapeBinaryOp>(sbn2);
450-
equal = equal && alpha_eq(op1->left, op2->left) && alpha_eq(op1->right, op2->right);
450+
this->VisitType(op1->left, op2->left);
451+
this->VisitType(op1->right, op2->right);
451452
} else {
452453
equal = false;
453454
}

relay/src/tvm/relay/ir/expr.cc

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ TVM_REGISTER_API("nnvm.make.String")
2222
.set_body([](TVMArgs args,
2323
TVMRetValue *ret) { *ret = StringNode::make(args[0]); });
2424

25+
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
26+
.set_dispatch<StringNode>([](const StringNode *node, tvm::IRPrinter *p) {
27+
p->stream << "String(" << node->name << ")";
28+
});
29+
2530
FloatLit FloatLitNode::make(double value) {
2631
std::shared_ptr<FloatLitNode> n = std::make_shared<FloatLitNode>();
2732
n->value = std::move(value);
@@ -262,7 +267,8 @@ Call CallNode::make(Expr fn, tvm::Array<Expr> args, Attributes attrs) {
262267

263268
TVM_REGISTER_API("nnvm.make.Call").set_body([](TVMArgs args, TVMRetValue *ret) {
264269
if (args.size() < 3) {
265-
Attributes attrs = AttributesNode::make(tvm::Map<LocalId, Expr>());
270+
Attributes attrs = AttributesNode::make(
271+
std::unordered_map<String, Expr, StringHash, StringEqual>());
266272
*ret = CallNode::make(args[0], args[1], attrs);
267273
} else {
268274
*ret = CallNode::make(args[0], args[1], args[2]);
@@ -441,18 +447,31 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
441447
<< node->value << ", " << node->body;
442448
});
443449

444-
static void validate_attributes(tvm::Map<LocalId, Expr> attrs) { return; }
450+
static void
451+
validate_attributes(tvm::Map<String, Expr, StringHash, StringEqual> attrs) {
452+
return;
453+
}
445454

446-
Attributes AttributesNode::make(tvm::Map<LocalId, Expr> attrs) {
455+
Attributes AttributesNode::make(
456+
std::unordered_map<String, Expr, StringHash, StringEqual> attrs) {
447457
std::shared_ptr<AttributesNode> n = std::make_shared<AttributesNode>();
448458
validate_attributes(attrs);
449-
n->attributes = std::move(attrs);
459+
n->attributes = attrs;
450460
return Attributes(n);
451461
}
452462

453463
TVM_REGISTER_API("nnvm.make.Attributes")
454-
.set_body([](TVMArgs args,
455-
TVMRetValue *ret) { *ret = AttributesNode::make(args[0]); });
464+
.set_body([](TVMArgs args, TVMRetValue *ret) {
465+
// ensure attrs are moved to appropriate map
466+
tvm::Map<String, Expr> map = args[0];
467+
std::unordered_map<String, Expr, StringHash, StringEqual> attrs;
468+
469+
for (auto p : map) {
470+
attrs[p.first] = p.second;
471+
}
472+
473+
*ret = AttributesNode::make(attrs);
474+
});
456475

457476
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
458477
.set_dispatch<AttributesNode>([](const AttributesNode *node,

0 commit comments

Comments
 (0)