Skip to content

Commit 1aa0b5e

Browse files
authored
Merge pull request #7 from MarisaKirisame/type_of
Type of
2 parents 5a1c2f4 + ded0621 commit 1aa0b5e

File tree

12 files changed

+103
-7
lines changed

12 files changed

+103
-7
lines changed

include/tvm/relay/expr.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,37 @@ std::string RelayPrint(
486486
const NodeRef& node,
487487
bool show_meta_data = true,
488488
runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr);
489+
490+
/*!
491+
* \brief User defined type relation, is an input-output relation on types.
492+
*/
493+
class TypeOf;
494+
/*!
495+
* \brief TypeRelation container.
496+
* \note This node is not directly serializable.
497+
* The type function need to be lookedup in the module.
498+
*/
499+
class TypeOfNode : public TypeNode {
500+
public:
501+
/*!
502+
* \brief The function on input and output variables which
503+
* this is not directly serializable,
504+
* need to be looked-up in the module.
505+
*/
506+
relay::Expr expr;
507+
508+
void VisitAttrs(tvm::AttrVisitor* v) final {
509+
v->Visit("expr", &expr);
510+
}
511+
512+
TVM_DLL static TypeOf make(relay::Expr expr);
513+
514+
static constexpr const char* _type_key = "relay.TypeOf";
515+
TVM_DECLARE_NODE_TYPE_INFO(TypeOfNode, TypeNode);
516+
};
517+
518+
RELAY_DEFINE_NODE_REF(TypeOf, TypeOfNode, Type);
519+
489520
} // namespace relay
490521
} // namespace tvm
491522
#endif // TVM_RELAY_EXPR_H_

python/tvm/relay/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
scalar_type = ty.scalar_type
4949
GlobalTypeVar = ty.GlobalTypeVar
5050
TypeCall = ty.TypeCall
51+
TypeOf = ty.TypeOf
5152

5253
# Expr
5354
Expr = expr.Expr

python/tvm/relay/op/_tensor.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from .op import register_gradient
77
from .op import schedule_injective, OpPattern
88
from .transform import collapse_sum_like
9-
from .tensor import negative
9+
from .tensor import negative, zeros_like, ones_like
1010

1111

1212
def add_grad(orig, grad):
@@ -29,6 +29,12 @@ def multiply_grad(orig, grad):
2929

3030
register_gradient("multiply", multiply_grad)
3131

32+
def take_grad(orig, grad):
33+
return [zeros_like(orig.args[0]), zeros_like(orig.args[1])]
34+
35+
36+
register_gradient("take", take_grad)
37+
3238
schedule_broadcast = schedule_injective
3339
schedule_elemwise = schedule_injective
3440

python/tvm/relay/ty.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,10 @@ def __init__(self, func, args, num_inputs, attrs):
270270
self.__init_handle_by_constructor__(_make.TypeRelation,
271271
func, args, num_inputs, attrs)
272272

273+
@register_relay_node
274+
class TypeOf(Type):
275+
def __init__(self, expr):
276+
self.__init_handle_by_constructor__(_make.TypeOf, expr)
273277

274278
def scalar_type(dtype):
275279
"""Creates a scalar type.

src/relay/ir/error.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,9 @@ void ErrorReporter::RenderErrors(const Module& module, bool use_color) {
8787
std::endl <<
8888
rang::style::reset;
8989

90-
// for (auto pair : err_map) {
91-
// std::cout << "Key: " << pair.first << " Value: " << pair.second << std::endl;
92-
// }
90+
for (auto pair : err_map) {
91+
std::cout << "Key: " << pair.first << std::endl << " Value: " << pair.second << std::endl;
92+
}
9393

9494
// We then call into the Relay printer to generate the program.
9595
//

src/relay/ir/expr.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,24 @@ TVM_REGISTER_API("relay._expr.TempExprRealize")
283283
*ret = temp->Realize();
284284
});
285285

286+
TypeOf TypeOfNode::make(relay::Expr expr) {
287+
NodePtr<TypeOfNode> n = make_node<TypeOfNode>();
288+
n->expr = std::move(expr);
289+
return TypeOf(n);
290+
}
291+
292+
TVM_REGISTER_NODE_TYPE(TypeOfNode);
293+
294+
TVM_REGISTER_API("relay._make.TypeOf")
295+
.set_body([](TVMArgs args, TVMRetValue* ret) {
296+
*ret = TypeOfNode::make(args[0]);
297+
});
298+
299+
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
300+
.set_dispatch<TypeOfNode>([](const TypeOfNode* node,
301+
tvm::IRPrinter* p) {
302+
p->stream << "TypeOf(" << node->expr << ")";
303+
});
286304

287305
} // namespace relay
288306
} // namespace tvm

src/relay/ir/type.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,5 +207,6 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
207207
p->stream << "TupleTypeNode(" << node->fields << ")";
208208
});
209209

210+
210211
} // namespace relay
211212
} // namespace tvm

src/relay/ir/type_functor.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ void TypeVisitor::VisitType_(const TypeDataNode* op) {
6868
}
6969
}
7070

71+
void TypeVisitor::VisitType_(const TypeOfNode* op) {}
72+
7173
// Type Mutator.
7274
Array<Type> TypeMutator::MutateArray(Array<Type> arr) {
7375
// The array will do copy on write
@@ -172,6 +174,10 @@ Type TypeMutator::VisitType_(const TypeDataNode* op) {
172174
return GetRef<Type>(op);
173175
}
174176

177+
Type TypeMutator::VisitType_(const TypeOfNode* op) {
178+
return GetRef<Type>(op);
179+
}
180+
175181
// Implements bind.
176182
class TypeBinder : public TypeMutator {
177183
public:

src/relay/ir/type_functor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
7272
virtual R VisitType_(const GlobalTypeVarNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
7373
virtual R VisitType_(const TypeCallNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
7474
virtual R VisitType_(const TypeDataNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
75+
virtual R VisitType_(const TypeOfNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
7576

7677
virtual R VisitTypeDefault_(const Node* op, Args...) {
7778
LOG(FATAL) << "Do not have a default for " << op->type_key();
@@ -93,6 +94,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
9394
RELAY_TYPE_FUNCTOR_DISPATCH(GlobalTypeVarNode);
9495
RELAY_TYPE_FUNCTOR_DISPATCH(TypeCallNode);
9596
RELAY_TYPE_FUNCTOR_DISPATCH(TypeDataNode);
97+
RELAY_TYPE_FUNCTOR_DISPATCH(TypeOfNode);
9698
return vtable;
9799
}
98100
};
@@ -111,6 +113,7 @@ class TypeVisitor : public TypeFunctor<void(const Type& n)> {
111113
void VisitType_(const GlobalTypeVarNode* op) override;
112114
void VisitType_(const TypeCallNode* op) override;
113115
void VisitType_(const TypeDataNode* op) override;
116+
void VisitType_(const TypeOfNode* op) override;
114117
};
115118

116119
// Mutator that transform a type to another one.
@@ -125,6 +128,7 @@ class TypeMutator : public TypeFunctor<Type(const Type& n)> {
125128
Type VisitType_(const GlobalTypeVarNode* op) override;
126129
Type VisitType_(const TypeCallNode* op) override;
127130
Type VisitType_(const TypeDataNode* op) override;
131+
Type VisitType_(const TypeOfNode* op) override;
128132

129133
private:
130134
Array<Type> MutateArray(Array<Type> arr);

src/relay/pass/kind_check.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,10 @@ struct KindChecker : TypeFunctor<Kind(const Type&)> {
156156
return Kind::kTypeData;
157157
}
158158

159+
Kind VisitType_(const TypeOfNode* op) override {
160+
return kType;
161+
}
162+
159163
Kind Check(const Type& t) {
160164
return this->VisitType(t);
161165
}

0 commit comments

Comments
 (0)