Skip to content

Commit 89d965d

Browse files
committed
[RELAY]Vision ops for yolo
1 parent 4300bbc commit 89d965d

File tree

6 files changed

+231
-0
lines changed

6 files changed

+231
-0
lines changed

docs/langref/relay_op.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ This level enables additional math and transform operators.
116116
:nosignatures:
117117

118118
tvm.relay.image.resize
119+
tvm.relay.vision.yolo_regorg
120+
tvm.relay.vision.yolo_region
121+
tvm.relay.vision.yolov3_yolo
119122

120123

121124
Level 1 Definitions
@@ -192,3 +195,6 @@ Level 4 Definitions
192195
Level 5 Definitions
193196
-------------------
194197
.. autofunction:: tvm.relay.image.resize
198+
autofunction:: tvm.relay.vision.yolo_regorg
199+
autofunction:: tvm.relay.vision.yolo_region
200+
autofunction:: tvm.relay.vision.yolov3_yolo

include/tvm/relay/attrs/vision.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,17 @@ struct MultiBoxPriorAttrs : public tvm::AttrsNode<MultiBoxPriorAttrs> {
4040
}
4141
};
4242

43+
/*! \brief Attributes used in yolo reorg operators */
44+
struct YoloReorgAttrs : public tvm::AttrsNode<YoloReorgAttrs> {
45+
IndexExpr stride;
46+
47+
TVM_DECLARE_ATTRS(YoloReorgAttrs, "relay.attrs.YoloReorgAttrs") {
48+
TVM_ATTR_FIELD(stride)
49+
.set_default(1)
50+
.describe("Stride value for yolo reorg");
51+
}
52+
};
53+
4354
} // namespace relay
4455
} // namespace tvm
4556
#endif // TVM_RELAY_ATTRS_VISION_H_

python/tvm/relay/op/vision/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from __future__ import absolute_import as _abs
44

55
from .multibox import *
6+
from .yolo import *

python/tvm/relay/op/vision/yolo.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""Yolo operations."""
2+
from . import _make
3+
4+
def yolo_reorg(data, stride=1):
5+
"""Yolo reorg operation. This layer reorganize the output based on the stride value.
6+
Its function is mostly shape transform.
7+
8+
Parameters
9+
----------
10+
data : relay.Expr
11+
The input data tensor.
12+
13+
stride : int
14+
The stride value for reorganisation.
15+
16+
Returns
17+
-------
18+
ret : relay.Expr
19+
The computed result.
20+
"""
21+
return _make.yolo_reorg(data, stride)
22+
23+
24+
def yolo_region(data):
25+
"""Yolo region operation used for detection.
26+
27+
Parameters
28+
----------
29+
data : relay.Expr
30+
The input data tensor.
31+
32+
Returns
33+
-------
34+
ret : relay.Expr
35+
The computed result.
36+
"""
37+
return _make.yolo_region(data)
38+
39+
40+
def yolov3_yolo(data):
41+
"""Yolo operation used for detection
42+
43+
Parameters
44+
----------
45+
data : relay.Expr
46+
The input data tensor.
47+
48+
Returns
49+
-------
50+
ret : relay.Expr
51+
The computed result.
52+
"""
53+
return _make.yolov3_yolo(data)

src/relay/op/vision/yolo.cc

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
/*!
2+
* Copyright (c) 2018 by Contributors
3+
* \file yolo.cc
4+
* \brief Yolo related operators
5+
*/
6+
#include <tvm/relay/op.h>
7+
#include <tvm/relay/attrs/vision.h>
8+
#include <vector>
9+
#include "../op_common.h"
10+
#include "../type_relations.h"
11+
12+
namespace tvm {
13+
namespace relay {
14+
15+
TVM_REGISTER_NODE_TYPE(YoloReorgAttrs);
16+
17+
/*!
18+
* \brief YoloReorgRel Output type and shape relation evaluation function.
19+
* \param num_inputs Number of input types in the args.
20+
* \param attrs The additional attributes of the operator.
21+
* \param reporter The reporter to report solution to.
22+
* \return false if This relation cannot be resolved. true if this relation has been resolved.
23+
*/
24+
bool YoloReorgRel(const Array<Type>& types,
25+
int num_inputs,
26+
const Attrs& attrs,
27+
const TypeReporter& reporter) {
28+
CHECK_EQ(types.size(), 2);
29+
const auto* data = types[0].as<TensorTypeNode>();
30+
if (data == nullptr) return false;
31+
32+
const YoloReorgAttrs* param = attrs.as<YoloReorgAttrs>();
33+
CHECK(param != nullptr);
34+
35+
CHECK(data->shape.size() == 4) << "Yolo reorg supports only 4 dimension.";
36+
std::vector<IndexExpr>&& oshape = AsVector(data->shape);
37+
oshape[1] = oshape[1] * param->stride * param->stride;
38+
oshape[2] = oshape[2] / param->stride;
39+
oshape[3] = oshape[3] / param->stride;
40+
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
41+
return true;
42+
}
43+
44+
Expr MakeYoloReorg(Expr data,
45+
IndexExpr stride) {
46+
auto attrs = make_node<YoloReorgAttrs>();
47+
attrs->stride = stride;
48+
static const Op& op = Op::Get("vision.yolo_reorg");
49+
return CallNode::make(op, {data}, Attrs(attrs), {});
50+
}
51+
52+
53+
TVM_REGISTER_API("relay.op.vision._make.yolo_reorg")
54+
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
55+
runtime::detail::unpack_call<Expr, 2>(MakeYoloReorg, args, rv);
56+
});
57+
58+
59+
RELAY_REGISTER_OP("vision.yolo_reorg")
60+
.describe(R"doc("Yolo reorg operation. This layer reorganize the output.
61+
Its function is mostly shape transform.")doc" TVM_ADD_FILELINE)
62+
.add_argument("data", "Tensor", "The input tensor.")
63+
.set_num_inputs(1)
64+
.set_support_level(5)
65+
.set_attrs_type_key("relay.attrs.YoloReorgAttrs")
66+
.add_type_rel("YoloReorg", YoloReorgRel);
67+
68+
69+
Expr MakeYoloRegion(Expr data) {
70+
static const Op& op = Op::Get("vision.yolo_region");
71+
return CallNode::make(op, {data}, Attrs(), {});
72+
}
73+
74+
75+
TVM_REGISTER_API("relay.op.vision._make.yolo_region")
76+
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
77+
runtime::detail::unpack_call<Expr, 1>(MakeYoloRegion, args, rv);
78+
});
79+
80+
81+
RELAY_REGISTER_OP("vision.yolo_region")
82+
.describe(R"doc("Yolo region operation used for detection."
83+
)doc" TVM_ADD_FILELINE)
84+
.add_argument("data", "Tensor", "The input tensor.")
85+
.set_num_inputs(1)
86+
.set_support_level(5)
87+
.add_type_rel("Identity", IdentityRel);
88+
89+
90+
Expr MakeYolov3Yolo(Expr data) {
91+
static const Op& op = Op::Get("vision.yolov3_yolo");
92+
return CallNode::make(op, {data}, Attrs(), {});
93+
}
94+
95+
96+
TVM_REGISTER_API("relay.op.vision._make.yolov3_yolo")
97+
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
98+
runtime::detail::unpack_call<Expr, 1>(MakeYolov3Yolo, args, rv);
99+
});
100+
101+
102+
RELAY_REGISTER_OP("vision.yolov3_yolo")
103+
.describe(R"doc("Yolov3 operation used for detection."
104+
)doc" TVM_ADD_FILELINE)
105+
.add_argument("data", "Tensor", "The input tensor.")
106+
.set_num_inputs(1)
107+
.set_support_level(5)
108+
.add_type_rel("Identity", IdentityRel);
109+
110+
} // namespace relay
111+
} // namespace tvm

tests/python/relay/test_op_level5.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,55 @@ def test_multibox_prior():
6060
(1, h * w, 4), "float32")
6161

6262

63+
def test_yolo_reorg():
64+
ib = relay.ir_builder.IRBuilder()
65+
n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
66+
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
67+
with ib.function(x) as func:
68+
ib.ret(relay.vision.yolo_reorg(x))
69+
ib.ret(func)
70+
func = relay.ir_pass.infer_type(ib.env, func.to_func())
71+
ftype = func.checked_type
72+
assert ftype.ret_type == relay.ty.TensorType((n, c, h, w), "float32")
73+
74+
ib = relay.ir_builder.IRBuilder()
75+
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
76+
77+
with ib.function(x) as func:
78+
ib.ret(relay.vision.yolo_reorg(x, stride=2))
79+
ib.ret(func)
80+
func = relay.ir_pass.infer_type(ib.env, func.to_func())
81+
ftype = func.checked_type
82+
assert ftype.ret_type == relay.ty.TensorType((n, c*2*2, h/2, w/2), "float32")
83+
84+
85+
def test_yolo_region():
86+
ib = relay.ir_builder.IRBuilder()
87+
n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
88+
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
89+
with ib.function(x) as func:
90+
ib.ret(relay.vision.yolo_region(x))
91+
ib.ret(func)
92+
func = relay.ir_pass.infer_type(ib.env, func.to_func())
93+
ftype = func.checked_type
94+
assert ftype.ret_type == relay.ty.TensorType((n, c, h, w), "float32")
95+
96+
97+
def test_yolov3_yolo():
98+
ib = relay.ir_builder.IRBuilder()
99+
n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
100+
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
101+
with ib.function(x) as func:
102+
ib.ret(relay.vision.yolov3_yolo(x))
103+
ib.ret(func)
104+
func = relay.ir_pass.infer_type(ib.env, func.to_func())
105+
ftype = func.checked_type
106+
assert ftype.ret_type == relay.ty.TensorType((n, c, h, w), "float32")
107+
108+
63109
if __name__ == "__main__":
64110
test_resize_infer_type()
65111
test_multibox_prior()
112+
test_yolo_reorg()
113+
test_yolo_region()
114+
test_yolov3_yolo()

0 commit comments

Comments
 (0)