Skip to content

Commit b5f46c4

Browse files
siju-samueltqchen
authored andcommitted
yolo reorg op for relay (#1941)
1 parent 0edb332 commit b5f46c4

File tree

7 files changed

+172
-1
lines changed

7 files changed

+172
-1
lines changed

docs/langref/relay_op.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ This level enables additional math and transform operators.
135135
tvm.relay.vision.multibox_prior
136136
tvm.relay.vision.multibox_transform_loc
137137
tvm.relay.vision.nms
138+
tvm.relay.vision.yolo_reorg
138139

139140

140141
**Level 10: Temporary Operators**
@@ -251,6 +252,7 @@ Level 5 Definitions
251252
.. autofunction:: tvm.relay.vision.multibox_prior
252253
.. autofunction:: tvm.relay.vision.multibox_transform_loc
253254
.. autofunction:: tvm.relay.vision.nms
255+
.. autofunction:: tvm.relay.vision.yolo_reorg
254256

255257

256258
Level 10 Definitions

include/tvm/relay/attrs/vision.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,17 @@ struct ROIAlignAttrs : public tvm::AttrsNode<ROIAlignAttrs> {
9898
}
9999
};
100100

101+
/*! \brief Attributes used in yolo reorg operators */
102+
struct YoloReorgAttrs : public tvm::AttrsNode<YoloReorgAttrs> {
103+
Integer stride;
104+
105+
TVM_DECLARE_ATTRS(YoloReorgAttrs, "relay.attrs.YoloReorgAttrs") {
106+
TVM_ATTR_FIELD(stride)
107+
.set_default(1)
108+
.describe("Stride value for yolo reorg");
109+
}
110+
};
111+
101112
} // namespace relay
102113
} // namespace tvm
103114
#endif // TVM_RELAY_ATTRS_VISION_H_

python/tvm/relay/op/vision/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,7 @@
55
from .multibox import *
66
from .nms import *
77
from .rcnn import *
8+
from .yolo import *
89
from . import _multibox
910
from . import _rcnn
11+
from . import _yolo
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#pylint: disable=invalid-name, unused-argument
2+
"""Backend compiler related feature registration"""
3+
from __future__ import absolute_import
4+
from ..op import register_schedule, register_pattern
5+
from ..op import schedule_injective, OpPattern
6+
7+
# reorg
8+
register_pattern("vision.yolo_reorg", OpPattern.INJECTIVE)
9+
register_schedule("vision.yolo_reorg", schedule_injective)

python/tvm/relay/op/vision/yolo.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""Yolo operations."""
2+
from . import _make
3+
4+
def yolo_reorg(data, stride):
5+
"""Yolo reorg operation used in darknet models.
6+
This layer shuffles the input tensor values based on the stride value.
7+
Along with the shuffling, it does the shape transform.
8+
If '(n, c, h, w)' is the data shape and 's' is stride, output shape is '(n, c*s*s, h/s, w/s)'
9+
Example: data(1, 4, 2, 2) = [[[[ 0 1] [ 2 3]]
10+
[[ 4 5] [ 6 7]]
11+
[[ 8 9] [10 11]]
12+
[[12 13] [14 15]]]]
13+
stride = 2
14+
ret(1, 16, 1, 1) = [[[[ 0]] [[ 2]] [[ 8]] [[10]]
15+
[[ 1]] [[ 3]] [[ 9]] [[11]]
16+
[[ 4]] [[ 6]] [[12]] [[14]]
17+
[[ 5]] [[ 7]] [[13]] [[15]]]]
18+
19+
Note: stride=1 has no significance for reorg operation.
20+
21+
Parameters
22+
----------
23+
data : relay.Expr
24+
The input data tensor.
25+
26+
stride : int
27+
The stride value for reorganisation.
28+
29+
Returns
30+
-------
31+
ret : relay.Expr
32+
The computed result.
33+
"""
34+
return _make.yolo_reorg(data, stride)

src/relay/op/vision/yolo.cc

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/*!
2+
* Copyright (c) 2018 by Contributors
3+
* \file yolo.cc
4+
* \brief Yolo related operators
5+
*/
6+
#include <tvm/relay/op.h>
7+
#include <tvm/relay/attrs/vision.h>
8+
#include <topi/vision/reorg.h>
9+
#include <vector>
10+
#include "../op_common.h"
11+
#include "../type_relations.h"
12+
13+
namespace tvm {
14+
namespace relay {
15+
16+
TVM_REGISTER_NODE_TYPE(YoloReorgAttrs);
17+
18+
/*!
19+
* \brief YoloReorgRel Output type and shape relation evaluation function.
20+
* \param num_inputs Number of input types in the args.
21+
* \param attrs The additional attributes of the operator.
22+
* \param reporter The reporter to report solution to.
23+
* \return false if This relation cannot be resolved. true if this relation has been resolved.
24+
*/
25+
bool YoloReorgRel(const Array<Type>& types,
26+
int num_inputs,
27+
const Attrs& attrs,
28+
const TypeReporter& reporter) {
29+
CHECK_EQ(types.size(), 2);
30+
const auto* data = types[0].as<TensorTypeNode>();
31+
if (data == nullptr) return false;
32+
33+
const YoloReorgAttrs* param = attrs.as<YoloReorgAttrs>();
34+
CHECK(param != nullptr);
35+
36+
CHECK(data->shape.size() == 4) << "Yolo reorg supports only 4 dimension.";
37+
std::vector<IndexExpr>&& oshape = AsVector(data->shape);
38+
oshape[1] = oshape[1] * param->stride * param->stride;
39+
oshape[2] = oshape[2] / param->stride;
40+
oshape[3] = oshape[3] / param->stride;
41+
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
42+
return true;
43+
}
44+
45+
Expr MakeYoloReorg(Expr data,
46+
Integer stride) {
47+
auto attrs = make_node<YoloReorgAttrs>();
48+
attrs->stride = stride;
49+
static const Op& op = Op::Get("vision.yolo_reorg");
50+
return CallNode::make(op, {data}, Attrs(attrs), {});
51+
}
52+
53+
54+
TVM_REGISTER_API("relay.op.vision._make.yolo_reorg")
55+
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
56+
runtime::detail::unpack_call<Expr, 2>(MakeYoloReorg, args, rv);
57+
});
58+
59+
60+
RELAY_REGISTER_OP("vision.yolo_reorg")
61+
.describe(R"doc("Yolo reorg operation. This layer reorganize the output.
62+
Its function is mostly shape transform.")doc" TVM_ADD_FILELINE)
63+
.add_argument("data", "Tensor", "The input tensor.")
64+
.set_num_inputs(1)
65+
.set_support_level(5)
66+
.set_attrs_type_key("relay.attrs.YoloReorgAttrs")
67+
.add_type_rel("YoloReorg", YoloReorgRel)
68+
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
69+
const Array<Tensor>& inputs,
70+
const Type& out_type,
71+
const Target& target) {
72+
const auto* params = attrs.as<YoloReorgAttrs>();
73+
CHECK(params != nullptr);
74+
return Array<Tensor>{ topi::vision::reorg(inputs[0], params->stride) };
75+
});
76+
77+
} // namespace relay
78+
} // namespace tvm

tests/python/relay/test_op_level5.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from tvm.relay.testing import ctx_list
88
import topi.testing
99

10-
1110
def test_resize_infer_type():
1211
n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
1312
x = relay.var("x", relay.TensorType((n, c, h, w), "int8"))
@@ -307,10 +306,46 @@ def verify_roi_align(data_shape, rois_shape, pooled_size, spatial_scale, sample_
307306
verify_roi_align((4, 4, 16, 16), (32, 5), pooled_size=7, spatial_scale=0.5, sample_ratio=2)
308307

309308

309+
def test_yolo_reorg_infer_shape():
310+
def verify_yolo_reorg(shape, stride, out_shape):
311+
x = relay.var("x", relay.TensorType(shape, "float32"))
312+
z = relay.vision.yolo_reorg(x, stride=stride)
313+
zz = relay.ir_pass.infer_type(z)
314+
assert "stride=" in z.astext()
315+
assert zz.checked_type == relay.ty.TensorType(out_shape, "float32")
316+
317+
n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
318+
verify_yolo_reorg((n, c, 20, 20), 10, (n, c*10*10, 2, 2))
319+
verify_yolo_reorg((n, c, h, w), 2, (n, c*2*2, h/2, w/2))
320+
321+
def test_yolo_reorg():
322+
def verify_yolo_reorg(shape, stride):
323+
x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
324+
ref_res = topi.testing.reorg_python(x_data, stride)
325+
326+
x = relay.var("x", relay.TensorType(shape, "float32"))
327+
z = relay.vision.yolo_reorg(x, stride=stride)
328+
zz = relay.ir_pass.infer_type(z)
329+
assert "stride=" in z.astext()
330+
assert zz.checked_type == relay.ty.TensorType(ref_res.shape, "float32")
331+
332+
func = relay.Function([x], z)
333+
334+
for target, ctx in ctx_list():
335+
for kind in ["graph", "debug"]:
336+
intrp = relay.create_executor(kind, ctx=ctx, target=target)
337+
op_res = intrp.evaluate(func)(x_data)
338+
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
339+
340+
verify_yolo_reorg((1, 100, 20, 20), 10)
341+
verify_yolo_reorg((1, 4, 6, 6), 2)
342+
310343
if __name__ == "__main__":
311344
test_resize_infer_type()
312345
test_resize()
313346
test_multibox_prior()
314347
test_multibox_transform_loc()
315348
test_nms()
316349
test_roi_align()
350+
test_yolo_reorg_infer_shape()
351+
test_yolo_reorg()

0 commit comments

Comments
 (0)