Skip to content

Commit fe06049

Browse files
vinx13Laurawly
authored andcommitted
[RELAY][OP] Faster-RCNN Proposal OP (#2725)
* [RELAY][OP] Proposal * Fix * Fix test
1 parent c8a3a59 commit fe06049

File tree

7 files changed

+277
-3
lines changed

7 files changed

+277
-3
lines changed

include/tvm/relay/attrs/vision.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,44 @@ struct YoloReorgAttrs : public tvm::AttrsNode<YoloReorgAttrs> {
109109
}
110110
};
111111

112+
/*! \brief Attributes used in proposal operators */
113+
struct ProposalAttrs : public tvm::AttrsNode<ProposalAttrs> {
114+
Array<IndexExpr> scales;
115+
Array<IndexExpr> ratios;
116+
int feature_stride;
117+
double threshold;
118+
int rpn_pre_nms_top_n;
119+
int rpn_post_nms_top_n;
120+
int rpn_min_size;
121+
bool iou_loss;
122+
123+
TVM_DECLARE_ATTRS(ProposalAttrs, "relay.attrs.ProposalAttrs") {
124+
TVM_ATTR_FIELD(scales)
125+
.set_default(Array<IndexExpr>({4.0f, 8.0f, 16.0f, 32.0f}))
126+
.describe("Used to generate anchor windows by enumerating scales");
127+
TVM_ATTR_FIELD(ratios)
128+
.set_default(Array<IndexExpr>({0.5f, 1.0f, 2.0f}))
129+
.describe("Used to generate anchor windows by enumerating ratios");
130+
TVM_ATTR_FIELD(feature_stride)
131+
.set_default(16)
132+
.describe(
133+
"The size of the receptive field each unit in the convolution layer of the rpn,"
134+
"for example the product of all stride's prior to this layer.");
135+
TVM_ATTR_FIELD(threshold)
136+
.set_default(0.7)
137+
.describe(
138+
"IoU threshold of non-maximum suppresion (suppress boxes with IoU >= this threshold)");
139+
TVM_ATTR_FIELD(rpn_pre_nms_top_n)
140+
.set_default(6000)
141+
.describe("Number of top scoring boxes to apply NMS. -1 to use all boxes");
142+
TVM_ATTR_FIELD(rpn_post_nms_top_n)
143+
.set_default(300)
144+
.describe("Number of top scoring boxes to keep after applying NMS to RPN proposals");
145+
TVM_ATTR_FIELD(rpn_min_size).set_default(16).describe("Minimum height or width in proposal");
146+
TVM_ATTR_FIELD(iou_loss).set_default(false).describe("Usage of IoU Loss");
147+
}
148+
};
149+
112150
} // namespace relay
113151
} // namespace tvm
114152
#endif // TVM_RELAY_ATTRS_VISION_H_

python/tvm/relay/frontend/mxnet.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,20 @@ def _mx_roi_align(inputs, attrs):
351351
return _op.vision.roi_align(inputs[0], inputs[1], **new_attrs)
352352

353353

354+
def _mx_proposal(inputs, attrs):
355+
new_attrs = {}
356+
new_attrs["scales"] = attrs.get_float_tuple("scales", (4.0, 8.0, 16.0, 32.0))
357+
new_attrs["ratios"] = attrs.get_float_tuple("ratios", (0.5, 1.0, 2.0))
358+
new_attrs["feature_stride"] = attrs.get_int("feature_stride", 16)
359+
new_attrs["threshold"] = attrs.get_float("threshold", 0.7)
360+
new_attrs["rpn_pre_nms_top_n"] = attrs.get_int("rpn_pre_nms_top_n", 6000)
361+
new_attrs["rpn_post_nms_top_n"] = attrs.get_int("rpn_post_nms_top_n", 300)
362+
new_attrs["rpn_min_size"] = attrs.get_int("rpn_min_size", 16)
363+
new_attrs["iou_loss"] = attrs.get_bool("iou_loss", False)
364+
assert not attrs.get_bool("output_score", False), "proposal doesn't support output score"
365+
return _op.vision.proposal(inputs[0], inputs[1], inputs[2], **new_attrs)
366+
367+
354368
# Note: due to attribute conversion constraint
355369
# ops in the identity set must be attribute free
356370
_identity_list = [
@@ -466,6 +480,8 @@ def _mx_roi_align(inputs, attrs):
466480
"_contrib_MultiBoxPrior" : _mx_multibox_prior,
467481
"_contrib_MultiBoxDetection" : _mx_multibox_detection,
468482
"_contrib_ROIAlign" : _mx_roi_align,
483+
"_contrib_Proposal" : _mx_proposal,
484+
"_contrib_MultiProposal" : _mx_proposal,
469485
# List of missing operators that are present in NNVMv1
470486
# TODO(tvm-tvm): support all operators.
471487
#

python/tvm/relay/op/vision/_rcnn.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# pylint: disable=invalid-name, unused-argument
22
"""Faster R-CNN and Mask R-CNN operations."""
33
import topi
4-
from topi.util import get_const_tuple
4+
from topi.util import get_const_tuple, get_float_tuple, get_const_int
55
from .. import op as reg
66
from ..op import OpPattern
77

@@ -21,3 +21,29 @@ def schedule_roi_align(_, outs, target):
2121
return topi.generic.vision.schedule_roi_align(outs)
2222

2323
reg.register_pattern("vision.roi_align", OpPattern.OUT_ELEMWISE_FUSABLE)
24+
25+
@reg.register_compute("vision.proposal")
26+
def compute_proposal(attrs, inputs, _, target):
27+
"""Compute definition of proposal"""
28+
scales = get_float_tuple(attrs.scales)
29+
ratios = get_float_tuple(attrs.ratios)
30+
feature_stride = attrs.feature_stride
31+
threshold = attrs.threshold
32+
rpn_pre_nms_top_n = attrs.rpn_pre_nms_top_n
33+
rpn_post_nms_top_n = attrs.rpn_post_nms_top_n
34+
rpn_min_size = attrs.rpn_min_size
35+
iou_loss = bool(get_const_int(attrs.iou_loss))
36+
with target:
37+
return [
38+
topi.vision.rcnn.proposal(inputs[0], inputs[1], inputs[2], scales, ratios,
39+
feature_stride, threshold, rpn_pre_nms_top_n,
40+
rpn_post_nms_top_n, rpn_min_size, iou_loss)
41+
]
42+
43+
@reg.register_schedule("vision.proposal")
44+
def schedule_proposal(_, outs, target):
45+
"""Schedule definition of proposal"""
46+
with target:
47+
return topi.generic.schedule_proposal(outs)
48+
49+
reg.register_pattern("vision.proposal", OpPattern.OPAQUE)

python/tvm/relay/op/vision/rcnn.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,63 @@ def roi_align(data, rois, pooled_size, spatial_scale, sample_ratio=-1, layout='N
3030
4-D tensor with shape [num_roi, channel, pooled_size, pooled_size]
3131
"""
3232
return _make.roi_align(data, rois, pooled_size, spatial_scale, sample_ratio, layout)
33+
34+
35+
def proposal(cls_prob,
36+
bbox_pred,
37+
im_info,
38+
scales,
39+
ratios,
40+
feature_stride,
41+
threshold,
42+
rpn_pre_nms_top_n,
43+
rpn_post_nms_top_n,
44+
rpn_min_size,
45+
iou_loss):
46+
"""Proposal operator.
47+
48+
Parameters
49+
----------
50+
cls_prob : relay.Expr
51+
4-D tensor with shape [batch, 2 * num_anchors, height, width].
52+
53+
bbox_pred : relay.Expr
54+
4-D tensor with shape [batch, 4 * num_anchors, height, width].
55+
56+
im_info : relay.Expr
57+
2-D tensor with shape [batch, 3]. The last dimension should be in format of
58+
[im_height, im_width, im_scale]
59+
60+
scales : list/tuple of float
61+
Scales of anchor windoes.
62+
63+
ratios : list/tuple of float
64+
Ratios of anchor windoes.
65+
66+
feature_stride : int
67+
The size of the receptive field each unit in the convolution layer of the rpn, for example
68+
the product of all stride's prior to this layer.
69+
70+
threshold : float
71+
Non-maximum suppression threshold.
72+
73+
rpn_pre_nms_top_n : int
74+
Number of top scoring boxes to apply NMS. -1 to use all boxes.
75+
76+
rpn_post_nms_top_n : int
77+
Number of top scoring boxes to keep after applying NMS to RPN proposals.
78+
79+
rpn_min_size : int
80+
Minimum height or width in proposal.
81+
82+
iou_loss : bool
83+
Usage of IoU loss.
84+
85+
Returns
86+
-------
87+
output : relay.Expr
88+
2-D tensor with shape [batch * rpn_post_nms_top_n, 5]. The last dimension is in format of
89+
[batch_index, w_start, h_start, w_end, h_end].
90+
"""
91+
return _make.proposal(cls_prob, bbox_pred, im_info, scales, ratios, feature_stride, threshold,
92+
rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_min_size, iou_loss)

src/relay/op/vision/rcnn_op.cc

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,5 +63,72 @@ RELAY_REGISTER_OP("vision.roi_align")
6363
.set_support_level(5)
6464
.add_type_rel("ROIAlign", ROIAlignRel);
6565

66+
TVM_REGISTER_NODE_TYPE(ProposalAttrs);
67+
68+
bool ProposalRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
69+
const TypeReporter& reporter) {
70+
auto proposal_attrs = attrs.as<ProposalAttrs>();
71+
CHECK_EQ(types.size(), 4);
72+
const auto* cls_prob = types[0].as<TensorTypeNode>();
73+
const auto* bbox_pred = types[1].as<TensorTypeNode>();
74+
const auto* im_info = types[2].as<TensorTypeNode>();
75+
76+
if (!cls_prob || !bbox_pred || !im_info) {
77+
return false;
78+
}
79+
80+
CHECK_EQ(cls_prob->shape.size(), 4U)
81+
<< "The dimension of class probability should be 4, but received " << cls_prob->shape.size();
82+
CHECK_EQ(bbox_pred->shape.size(), 4U)
83+
<< "The dimension of box prediction should be 4, but received " << bbox_pred->shape.size();
84+
CHECK_EQ(im_info->shape.size(), 2U)
85+
<< "The dimension of image info should be 2, but received " << im_info->shape.size();
86+
CHECK(reporter->AssertEQ(im_info->shape[1], 3));
87+
88+
auto batch = cls_prob->shape[0];
89+
90+
std::vector<IndexExpr> oshape(
91+
{batch * proposal_attrs->rpn_post_nms_top_n, 5});
92+
reporter->Assign(types[3], TensorTypeNode::make(oshape, cls_prob->dtype));
93+
return true;
94+
}
95+
96+
Expr MakeProposal(Expr cls_prob, Expr bbox_pred, Expr im_info, Array<IndexExpr> scales,
97+
Array<IndexExpr> ratios, int feature_stride, double threshold,
98+
int rpn_pre_nms_top_n, int rpn_post_nms_top_n, int rpn_min_size,
99+
bool iou_loss) {
100+
auto attrs = make_node<ProposalAttrs>();
101+
attrs->scales = scales;
102+
attrs->ratios = ratios;
103+
attrs->feature_stride = feature_stride;
104+
attrs->threshold = threshold;
105+
attrs->rpn_pre_nms_top_n = rpn_pre_nms_top_n;
106+
attrs->rpn_post_nms_top_n = rpn_post_nms_top_n;
107+
attrs->rpn_min_size = rpn_min_size;
108+
attrs->iou_loss = iou_loss;
109+
static const Op& op = Op::Get("vision.proposal");
110+
return CallNode::make(op, {cls_prob, bbox_pred, im_info}, Attrs(attrs), {});
111+
}
112+
113+
TVM_REGISTER_API("relay.op.vision._make.proposal")
114+
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
115+
runtime::detail::unpack_call<Expr, 11>(MakeProposal, args, rv);
116+
});
117+
118+
RELAY_REGISTER_OP("vision.proposal")
119+
.describe(R"code(Generate region proposals via RPN.
120+
121+
- **cls_prob**: 4-D with shape [batch, 2 * num_anchors, height, width].
122+
- **bbox_pred**: 4-D with shape [batch, 4 * num_anchors, height, width].
123+
- **im_info**: 2-D with shape [batch, 3].
124+
- **out**: 2-D with shape [batch * rpn_post_nms_top_n, 5].
125+
)code" TVM_ADD_FILELINE)
126+
.set_num_inputs(3)
127+
.add_argument("cls_prob", "Tensor", "Score of how likely proposal is object")
128+
.add_argument("bbox_pred", "Tensor", "BBox predicted deltas from anchors for proposals")
129+
.add_argument("im_info", "Tensor", "Image size and scale")
130+
.set_support_level(5)
131+
.add_type_rel("Proposal", ProposalRel);
132+
66133
} // namespace relay
67134
} // namespace tvm

tests/python/relay/test_op_level5.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,72 @@ def verify_roi_align(data_shape, rois_shape, pooled_size, spatial_scale, sample_
306306
verify_roi_align((4, 4, 16, 16), (32, 5), pooled_size=7, spatial_scale=0.5, sample_ratio=2)
307307

308308

309+
def test_proposal():
310+
def verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs):
311+
cls_prob = relay.var("cls_prob", relay.ty.TensorType(np_cls_prob.shape, "float32"))
312+
bbox_pred = relay.var("bbox_pred", relay.ty.TensorType(np_bbox_pred.shape, "float32"))
313+
im_info = relay.var("im_info", relay.ty.TensorType(np_im_info.shape, "float32"))
314+
z = relay.vision.proposal(cls_prob, bbox_pred, im_info, **attrs)
315+
zz = relay.ir_pass.infer_type(z)
316+
317+
assert zz.checked_type == relay.ty.TensorType(np_out.shape, "float32")
318+
319+
func = relay.Function([cls_prob, bbox_pred, im_info], z)
320+
func = relay.ir_pass.infer_type(func)
321+
for target in ['cuda']:
322+
if not tvm.module.enabled(target):
323+
print("Skip test because %s is not enabled." % target)
324+
continue
325+
ctx = tvm.context(target, 0)
326+
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
327+
op_res1 = intrp1.evaluate(func)(np_cls_prob, np_bbox_pred, np_im_info)
328+
tvm.testing.assert_allclose(op_res1.asnumpy(), np_out, rtol=1e-4)
329+
intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
330+
op_res2 = intrp2.evaluate(func)(np_cls_prob, np_bbox_pred, np_im_info)
331+
tvm.testing.assert_allclose(op_res2.asnumpy(), np_out, rtol=1e-4)
332+
333+
attrs = {
334+
'scales': (0.5,),
335+
'ratios': (0.5,),
336+
'feature_stride': 16,
337+
'iou_loss': False,
338+
'rpn_min_size': 16,
339+
'threshold': 0.7,
340+
'rpn_pre_nms_top_n': 200,
341+
'rpn_post_nms_top_n': 4,
342+
}
343+
344+
np_cls_prob = np.array([[
345+
[[0.3, 0.6, 0.2], [0.4, 0.7, 0.5], [0.1, 0.4, 0.3]],
346+
[[0.7, 0.5, 0.3], [0.6, 0.4, 0.8], [0.9, 0.2, 0.5]]
347+
]], dtype='float32')
348+
np_bbox_pred = np.array([[
349+
[[0.5, 1.0, 0.6], [0.8, 1.2, 2.0], [0.9, 1.0, 0.8]],
350+
[[0.5, 1.0, 0.7], [0.8, 1.2, 1.6], [2.1, 1.5, 0.7]],
351+
[[1.0, 0.5, 0.7], [1.5, 0.9, 1.6], [1.4, 1.5, 0.8]],
352+
[[1.0, 0.5, 0.6], [1.5, 0.9, 2.0], [1.8, 1.0, 0.9]],
353+
]], dtype='float32')
354+
np_im_info = np.array([[48., 48., 1.]], dtype='float32')
355+
np_out = np.array([
356+
[0., 0., 2.8451548,28.38012, 18.154846],
357+
[0., 0., 15.354933, 41.96971, 41.245064],
358+
[0., 18.019852, 1.0538368, 51.98015, 25.946163],
359+
[0., 27.320923, -1.266357, 55., 24.666357]
360+
], dtype='float32')
361+
362+
363+
verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs)
364+
365+
np_out = np.array([
366+
[ 0., -5.25, -2.5, 21.75, 19.],
367+
[ 0., 11.25, -2., 37.25, 18.5],
368+
[ 0., 26.849998, -2.3000002, 53.45, 18.6],
369+
[ 0., -4.95, 13.799999, 22.25, 35.5]
370+
], dtype='float32')
371+
attrs['iou_loss'] = True
372+
verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs)
373+
374+
309375
def test_yolo_reorg_infer_shape():
310376
def verify_yolo_reorg(shape, stride, out_shape):
311377
x = relay.var("x", relay.TensorType(shape, "float32"))
@@ -347,5 +413,6 @@ def verify_yolo_reorg(shape, stride):
347413
test_multibox_transform_loc()
348414
test_nms()
349415
test_roi_align()
416+
test_proposal()
350417
test_yolo_reorg_infer_shape()
351418
test_yolo_reorg()

topi/tests/python/test_topi_vision.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def test_roi_align():
210210
def verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs):
211211
cls_prob = tvm.placeholder(np_cls_prob.shape)
212212
bbox_pred = tvm.placeholder(np_bbox_pred.shape)
213-
im_info = tvm.placeholder(np_im_info.shape, dtype='int32')
213+
im_info = tvm.placeholder(np_im_info.shape)
214214

215215
def check_device(device):
216216
ctx = tvm.context(device, 0)
@@ -252,7 +252,7 @@ def test_proposal():
252252
[[1.0, 0.5, 0.7], [1.5, 0.9, 1.6], [1.4, 1.5, 0.8]],
253253
[[1.0, 0.5, 0.6], [1.5, 0.9, 2.0], [1.8, 1.0, 0.9]],
254254
]], dtype='float32')
255-
np_im_info = np.array([[48, 48, 1]], dtype='int32')
255+
np_im_info = np.array([[48., 48., 1.]], dtype='float32')
256256
np_out = np.array([
257257
[0., 0., 2.8451548,28.38012, 18.154846],
258258
[0., 0., 15.354933, 41.96971, 41.245064],

0 commit comments

Comments
 (0)