Skip to content

Commit 4029e38

Browse files
YuchenJinyongwww
authored andcommitted
Shape and type deduction (apache#7)
* Shape and type deduction. * Fix header. * Add call attrs to the deduce signature. * Address comments. * Add DiagnosticContext to IRBuilder and inference signature. * Fix nits.
1 parent 28b676e commit 4029e38

File tree

11 files changed

+369
-62
lines changed

11 files changed

+369
-62
lines changed

include/tvm/relax/ir_builder.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class IRBuilderNode : public Object {
8282
/*!
8383
* \brief Generate an output for the current dataflow block or function.
8484
* \param output The output variable of the block/function.
85-
* \return The variable being binded to \p ouput.
85+
* \return The variable being binded to \p output.
8686
*/
8787
Var EmitOutput(const Expr& output);
8888
/*!
@@ -107,13 +107,15 @@ class IRBuilderNode : public Object {
107107

108108
private:
109109
/*! \brief The state of the function currently being built. */
110-
RelaxFunction func;
110+
RelaxFunction func_;
111111
/*! \brief A flag tracking if currently inside a dataflow block or not. */
112-
bool is_dataflow = false;
112+
bool is_dataflow_ = false;
113113
/*! \brief A global variable counter for naming global variables. */
114-
int global_var_counter = 0;
114+
int global_var_counter_ = 0;
115115
/*! \brief A dataflow variable counter for naming dataflow variables. */
116-
int dataflow_var_counter = 0;
116+
int dataflow_var_counter_ = 0;
117+
/*! \brief A diagnostic context for reporting errors. */
118+
DiagnosticContext diag_ctx_ = DiagnosticContext::Default(IRModule({}, {}));
117119
};
118120

119121
class IRBuilder : public ObjectRef {

include/tvm/relax/op_attr_types.h

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file tvm/relax/op_attr_types.h
22+
* \brief Data structures that can appear in operator attributes.
23+
*/
24+
#ifndef TVM_RELAX_OP_ATTR_TYPES_H_
25+
#define TVM_RELAX_OP_ATTR_TYPES_H_
26+
27+
#include <tvm/relay/expr.h>
28+
#include <tvm/relay/type.h>
29+
#include <tvm/te/schedule.h>
30+
#include <tvm/te/tensor.h>
31+
32+
#include <string>
33+
34+
namespace tvm {
35+
namespace relax {
36+
37+
using relay::Call;
38+
39+
/*!
40+
* \brief Infer the output shape for operators. This function will
41+
* be invoked to fill the \p shape_ field of expressions.
42+
* \param call The call node.
43+
* \param diag_ctx The diagnostic context for reporting errors.
44+
* \return The inferred output shape expression.
45+
*/
46+
using FInferShape = runtime::TypedPackedFunc<Optional<RelayExpr>(const Call& call, DiagnosticContext diag_ctx)>;
47+
48+
/*!
49+
* \brief Infer the output type for operators. This function will
50+
* be invoked to fill the \p checked_type_ field of expressions.
51+
* \param call The call node.
52+
* \param diag_ctx The diagnostic context for reporting errors.
53+
* \return The inferred output type.
54+
*/
55+
using FInferType = runtime::TypedPackedFunc<Type(const Call& call, DiagnosticContext diag_ctx)>;
56+
57+
} // namespace relax
58+
} // namespace tvm
59+
#endif // TVM_RELAX_OP_ATTR_TYPES_H_

python/tvm/ir/expr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def checked_type(self):
4545
checked_type : tvm.relay.Type
4646
The checked type.
4747
"""
48-
ret = self._checked_type_
48+
ret = self.checked_type_
4949
if ret is None:
5050
raise ValueError("The type checker has not populated" " the checked_type for this node")
5151
return ret

python/tvm/relax/expr.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ def __init__(
6666
type_annotation: Optional[Type] = None,
6767
span: Span = None,
6868
) -> None:
69+
if shape_annotation is not None:
70+
shape_annotation = make_shape(shape_annotation)
6971
self.__init_handle_by_constructor__(
7072
_ffi_api.Var, name_hint, shape_annotation, type_annotation, span
7173
)
@@ -86,6 +88,8 @@ def __init__(
8688
type_annotation: Optional[Type] = None,
8789
span: Span = None,
8890
) -> None:
91+
if shape_annotation is not None:
92+
shape_annotation = make_shape(shape_annotation)
8993
self.__init_handle_by_constructor__(
9094
_ffi_api.DataflowVar, name_hint, shape_annotation, type_annotation, span
9195
)

python/tvm/relax/op/tensor.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
"""Basic tensor operations."""
117
from . import _ffi_api
218
from ..expr import Expr
319

src/relax/ir_builder.cc

Lines changed: 57 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
*/
2323

2424
#include <tvm/relax/ir_builder.h>
25+
#include <tvm/relax/op_attr_types.h>
2526
#include <tvm/relay/op.h>
2627

2728
namespace tvm {
@@ -38,59 +39,84 @@ IRBuilder IRBuilderNode::Create() {
3839

3940
void IRBuilderNode::FillFuncNameParam(const Array<Var>& params, const std::string& func_name) {
4041
if (!func_name.empty()) {
41-
this->func.func_name = GlobalVar(func_name);
42+
this->func_.func_name = GlobalVar(func_name);
4243
}
43-
44-
this->func.params = params;
44+
45+
this->func_.params = params;
4546
}
4647

4748
void IRBuilderNode::BuildFunction() {
48-
SeqExpr seq = SeqExpr(this->func.binding_blocks, this->func.ret);
49-
this->func.func = Function(this->func.func_name, this->func.params, seq, {});
50-
this->global_var_counter = 0;
49+
SeqExpr seq = SeqExpr(this->func_.binding_blocks, this->func_.ret);
50+
this->func_.func = Function(this->func_.func_name, this->func_.params, seq, {});
51+
this->global_var_counter_ = 0;
5152
}
5253

5354
void IRBuilderNode::BuildBlock() {
54-
if (!this->func.bindings.empty()) {
55-
if (is_dataflow) {
56-
this->func.binding_blocks.emplace_back(DataflowBlock(this->func.bindings));
55+
if (!this->func_.bindings.empty()) {
56+
if (is_dataflow_) {
57+
this->func_.binding_blocks.emplace_back(DataflowBlock(this->func_.bindings));
5758
} else {
58-
this->func.binding_blocks.emplace_back(BindingBlock(this->func.bindings));
59+
this->func_.binding_blocks.emplace_back(BindingBlock(this->func_.bindings));
5960
}
60-
this->func.bindings.clear();
61+
this->func_.bindings.clear();
6162
}
62-
this->dataflow_var_counter = 0;
63-
this->is_dataflow = !this->is_dataflow;
63+
this->dataflow_var_counter_ = 0;
64+
this->is_dataflow_ = !this->is_dataflow_;
65+
}
66+
67+
Optional<RelayExpr> InferShape(const Call& call, DiagnosticContext diag_ctx) {
68+
auto op_map = Op::GetAttrMap<FInferShape>("FInferShape");
69+
Op op = Downcast<Op>(call->op);
70+
return op_map[op](call, diag_ctx);
71+
}
72+
73+
Type InferType(const Call& call, DiagnosticContext diag_ctx) {
74+
auto op_map = Op::GetAttrMap<FInferType>("FInferType");
75+
Op op = Downcast<Op>(call->op);
76+
return op_map[op](call, diag_ctx);
6477
}
6578

6679
Var IRBuilderNode::Emit(const Call& call) {
6780
Var var;
68-
if (is_dataflow) {
69-
var = DataflowVar(Id("lv" + std::to_string(dataflow_var_counter++)), NullOpt, NullOpt);
81+
if (is_dataflow_) {
82+
var = DataflowVar(Id("lv" + std::to_string(dataflow_var_counter_++)), NullOpt, NullOpt);
7083
} else {
71-
var = Var(Id("gv" + std::to_string(global_var_counter++)), NullOpt, NullOpt);
84+
var = Var(Id("gv" + std::to_string(global_var_counter_++)), NullOpt, NullOpt);
85+
}
86+
87+
// Shape inference
88+
auto inferred_shape = InferShape(call, this->diag_ctx_);
89+
if (inferred_shape.defined()) {
90+
if (auto* shape_expr = inferred_shape.value().as<ShapeExprNode>()) {
91+
call->shape_ = GetRef<Expr>(shape_expr);
92+
var->shape_ = call->shape_;
93+
}
7294
}
95+
// Type inference
96+
auto inferred_type = InferType(call, this->diag_ctx_);
97+
call->checked_type_ = inferred_type;
98+
var->checked_type_ = inferred_type;
7399

74-
this->func.bindings.emplace_back(VarBinding(var, call));
100+
this->func_.bindings.emplace_back(VarBinding(var, call));
75101
return var;
76102
}
77103

78104
Var IRBuilderNode::EmitOutput(const Expr& output) {
79105
Var ret;
80-
if (is_dataflow) {
81-
ret = Var(Id("gv" + std::to_string(global_var_counter++)), NullOpt, NullOpt);
106+
if (is_dataflow_) {
107+
ret = Var(Id("gv" + std::to_string(global_var_counter_++)), NullOpt, NullOpt);
82108
ret->shape_ = output->shape_;
83109
ret->checked_type_ = output->checked_type_;
84-
this->func.bindings.emplace_back(VarBinding(ret, output));
110+
this->func_.bindings.emplace_back(VarBinding(ret, output));
85111
} else {
86-
this->func.ret = output;
112+
this->func_.ret = output;
87113
}
88114
return ret;
89115
}
90116

91-
Function IRBuilderNode::Get() { return this->func.func; }
117+
Function IRBuilderNode::Get() { return this->func_.func; }
92118

93-
std::vector<BindingBlock> IRBuilderNode::GetBlocks() { return this->func.binding_blocks; }
119+
std::vector<BindingBlock> IRBuilderNode::GetBlocks() { return this->func_.binding_blocks; }
94120

95121
class FunctionScope::Internal {
96122
public:
@@ -121,20 +147,16 @@ DataflowScope::DataflowScope(IRBuilder ib) {
121147
data_ = std::move(n);
122148
}
123149

124-
void DataflowScope::EnterWithScope() {
125-
this->get()->ir_builder->BuildBlock();
126-
}
150+
void DataflowScope::EnterWithScope() { this->get()->ir_builder->BuildBlock(); }
127151

128-
void DataflowScope::ExitWithScope() {
129-
this->get()->ir_builder->BuildBlock();
130-
}
152+
void DataflowScope::ExitWithScope() { this->get()->ir_builder->BuildBlock(); }
131153

132154
TVM_REGISTER_GLOBAL("relax.IRBuilderCreate").set_body_typed(IRBuilderNode::Create);
133155

134156
TVM_REGISTER_GLOBAL("relax.IRBuilderFillFuncNameParam")
135-
.set_body_typed([](IRBuilder builder, const Array<Var>& params, const std::string& func_name) {
136-
return builder->FillFuncNameParam(params, func_name);
137-
});
157+
.set_body_typed([](IRBuilder builder, const Array<Var>& params, const std::string& func_name) {
158+
return builder->FillFuncNameParam(params, func_name);
159+
});
138160

139161
TVM_REGISTER_GLOBAL("relax.IRBuilderBuildFunction").set_body_typed([](IRBuilder builder) {
140162
return builder->BuildFunction();
@@ -145,9 +167,9 @@ TVM_REGISTER_GLOBAL("relax.IRBuilderEmit").set_body_typed([](IRBuilder builder,
145167
});
146168

147169
TVM_REGISTER_GLOBAL("relax.IRBuilderEmitOutput")
148-
.set_body_typed([](IRBuilder builder, const Expr& output) {
149-
return builder->EmitOutput(output);
150-
});
170+
.set_body_typed([](IRBuilder builder, const Expr& output) {
171+
return builder->EmitOutput(output);
172+
});
151173

152174
TVM_REGISTER_GLOBAL("relax.IRBuilderGet").set_body_typed([](IRBuilder builder) {
153175
return builder->Get();

src/relax/op/op_common.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
#ifndef TVM_RELAX_OP_OP_COMMON_H_
2626
#define TVM_RELAX_OP_OP_COMMON_H_
2727

28+
#include <tvm/relax/op_attr_types.h>
2829
#include <tvm/relay/expr.h>
2930
#include <tvm/relay/op.h>
30-
#include <tvm/relay/op_attr_types.h>
3131

3232
namespace tvm {
3333
namespace relax {
@@ -42,15 +42,17 @@ namespace relax {
4242
*
4343
* \param OpName the name of registry.
4444
*/
45-
#define RELAX_REGISTER_BINARY_OP(OpName) \
45+
#define RELAX_REGISTER_BINARY_BROADCAST_OP(OpName) \
4646
TVM_REGISTER_GLOBAL("relax.op." OpName).set_body_typed([](Expr lhs, Expr rhs) { \
47-
static const Op& op = Op::Get(OpName); \
47+
static const Op& op = Op::Get("relax." OpName); \
4848
return Call(op, {lhs, rhs}, Attrs(), {}); \
4949
}); \
5050
RELAY_REGISTER_OP("relax." OpName) \
5151
.set_num_inputs(2) \
5252
.add_argument("lhs", "Tensor", "The left hand side tensor.") \
53-
.add_argument("rhs", "Tensor", "The right hand side tensor.")
53+
.add_argument("rhs", "Tensor", "The right hand side tensor.") \
54+
.set_attr<FInferShape>("FInferShape", InferShapeBinaryBroadcast) \
55+
.set_attr<FInferType>("FInferType", InferTypeBinaryBroadcast)
5456

5557
} // namespace relax
5658
} // namespace tvm

src/relax/op/tensor/binary.cc

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,32 +22,18 @@
2222
* \brief binary broadcast operators.
2323
*/
2424

25-
#include <tvm/arith/analyzer.h>
26-
#include <tvm/relax/expr.h>
27-
#include <tvm/relax/type.h>
28-
#include <tvm/tir/op.h>
29-
#include <tvm/topi/broadcast.h>
25+
#include "binary.h"
3026

3127
#include "../op_common.h"
3228

3329
namespace tvm {
3430
namespace relax {
3531

36-
using Expr = tvm::RelayExpr;
37-
using relay::Call;
38-
39-
#define RELAX_BINARY_COMPUTE(FTOPI) \
40-
[](const Attrs& attrs, const Array<te::Tensor>& inputs, \
41-
const Type& out_type) -> Array<te::Tensor> { \
42-
ICHECK_EQ(inputs.size(), 2U); \
43-
return {FTOPI(inputs[0], inputs[1])}; \
44-
}
45-
46-
RELAX_REGISTER_BINARY_OP("add")
32+
RELAX_REGISTER_BINARY_BROADCAST_OP("add")
4733
.describe("Elementwise add with broadcasting")
4834
.set_support_level(1);
4935

50-
RELAX_REGISTER_BINARY_OP("multiply")
36+
RELAX_REGISTER_BINARY_BROADCAST_OP("multiply")
5137
.describe("Elementwise multiply with broadcasting")
5238
.set_support_level(1);
5339

0 commit comments

Comments
 (0)