Skip to content

Commit 4ed21d3

Browse files
committed
compute and schedule updated for yolo reorg
1 parent 5a81ae0 commit 4ed21d3

File tree

5 files changed

+62
-14
lines changed

5 files changed

+62
-14
lines changed

include/tvm/relay/attrs/vision.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,12 @@ struct NMSAttrs : public tvm::AttrsNode<NMSAttrs>{
5353
.describe("Suppress all detections regardless of class_id.");
5454
TVM_ATTR_FIELD(topk).set_default(-1)
5555
.describe("Keep maximum top k detections before nms, -1 for no limit.");
56+
}
57+
};
58+
5659
/*! \brief Attributes used in yolo reorg operators */
5760
struct YoloReorgAttrs : public tvm::AttrsNode<YoloReorgAttrs> {
58-
IndexExpr stride;
61+
Integer stride;
5962

6063
TVM_DECLARE_ATTRS(YoloReorgAttrs, "relay.attrs.YoloReorgAttrs") {
6164
TVM_ATTR_FIELD(stride)

python/tvm/relay/op/vision/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55
from .multibox import *
66
from .nms import *
77
from .yolo import *
8+
from ._yolo import *
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#pylint: disable=invalid-name, unused-argument
2+
"""Backend compiler related feature registration"""
3+
from __future__ import absolute_import
4+
from ..op import register_schedule, register_pattern
5+
from ..op import schedule_injective, OpPattern
6+
7+
# reorg
8+
register_pattern("vision.yolo_reorg", OpPattern.INJECTIVE)
9+
register_schedule("vision.yolo_reorg", schedule_injective)

src/relay/op/vision/yolo.cc

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
*/
66
#include <tvm/relay/op.h>
77
#include <tvm/relay/attrs/vision.h>
8+
#include <topi/vision/reorg.h>
89
#include <vector>
910
#include "../op_common.h"
1011
#include "../type_relations.h"
@@ -42,7 +43,7 @@ bool YoloReorgRel(const Array<Type>& types,
4243
}
4344

4445
Expr MakeYoloReorg(Expr data,
45-
IndexExpr stride) {
46+
Integer stride) {
4647
auto attrs = make_node<YoloReorgAttrs>();
4748
attrs->stride = stride;
4849
static const Op& op = Op::Get("vision.yolo_reorg");
@@ -63,7 +64,15 @@ Its function is mostly shape transform.")doc" TVM_ADD_FILELINE)
6364
.set_num_inputs(1)
6465
.set_support_level(5)
6566
.set_attrs_type_key("relay.attrs.YoloReorgAttrs")
66-
.add_type_rel("YoloReorg", YoloReorgRel);
67+
.add_type_rel("YoloReorg", YoloReorgRel)
68+
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
69+
const Array<Tensor>& inputs,
70+
const Type& out_type,
71+
const Target& target) {
72+
const auto* params = attrs.as<YoloReorgAttrs>();
73+
CHECK(params != nullptr);
74+
return Array<Tensor>{ topi::vision::reorg(inputs[0], params->stride) };
75+
});
6776

6877
} // namespace relay
6978
} // namespace tvm

tests/python/relay/test_op_level5.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
""" Support level5 operator test cases.
22
"""
3+
import numpy as np
34
import tvm
45
from tvm import relay
6+
from tvm.relay.testing import ctx_list
7+
import topi.testing
58

69
def test_resize_infer_type():
710
n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
@@ -70,22 +73,45 @@ def test_nms():
7073
zz = relay.ir_pass.infer_type(z)
7174
assert zz.checked_type == relay.ty.TensorType(
7275
(n, num_anchors, 6), "float32")
73-
def test_yolo_reorg():
76+
77+
78+
def test_yolo_reorg_infer_shape():
79+
def verify_yolo_reorg(shape, stride, out_shape):
80+
x = relay.var("x", relay.TensorType(shape, "float32"))
81+
z = relay.vision.yolo_reorg(x, stride=stride)
82+
zz = relay.ir_pass.infer_type(z)
83+
assert "stride=" in z.astext()
84+
assert zz.checked_type == relay.ty.TensorType(out_shape, "float32")
85+
7486
n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
75-
x = relay.var("x", relay.TensorType((n, c, 20, 20), "float32"))
76-
z = relay.vision.yolo_reorg(x, stride=10)
77-
zz = relay.ir_pass.infer_type(z)
78-
assert "stride=10" in z.astext()
79-
assert zz.checked_type == relay.ty.TensorType((n, c*10*10, 2, 2), "float32")
87+
verify_yolo_reorg((n, c, 20, 20), 10, (n, c*10*10, 2, 2))
88+
verify_yolo_reorg((n, c, h, w), 2, (n, c*2*2, h/2, w/2))
8089

81-
x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
82-
z = relay.vision.yolo_reorg(x, stride=2)
83-
assert "stride=2" in z.astext()
84-
zz = relay.ir_pass.infer_type(z)
85-
assert zz.checked_type == relay.ty.TensorType((n, c*2*2, h/2, w/2), "float32")
90+
def test_yolo_reorg():
91+
def verify_yolo_reorg(shape, stride):
92+
x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
93+
ref_res = topi.testing.reorg_python(x_data, stride)
94+
95+
x = relay.var("x", relay.TensorType(shape, "float32"))
96+
z = relay.vision.yolo_reorg(x, stride=stride)
97+
zz = relay.ir_pass.infer_type(z)
98+
assert "stride=" in z.astext()
99+
assert zz.checked_type == relay.ty.TensorType(ref_res.shape, "float32")
100+
101+
func = relay.Function([x], z)
102+
103+
for target, ctx in ctx_list():
104+
for kind in ["graph", "debug"]:
105+
intrp = relay.create_executor(kind, ctx=ctx, target=target)
106+
op_res = intrp.evaluate(func)(x_data)
107+
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
108+
109+
verify_yolo_reorg((1, 100, 20, 20), 10)
110+
verify_yolo_reorg((1, 4, 6, 6), 2)
86111

87112
if __name__ == "__main__":
88113
test_resize_infer_type()
89114
test_multibox_prior()
90115
test_nms()
116+
test_yolo_reorg_infer_shape()
91117
test_yolo_reorg()

0 commit comments

Comments
 (0)