@@ -136,8 +136,47 @@ def verify_multibox_prior(x, dshape, ref_res, sizes=(1.0,),
136
136
verify_multibox_prior (x , dshape , ref_res , clip = False , check_type_only = True )
137
137
138
138
139
+ def test_get_valid_counts ():
140
+ def verify_get_valid_counts (dshape , score_threshold ):
141
+ dtype = "float32"
142
+ batch_size , num_anchor , elem_length = dshape
143
+ np_data = np .random .uniform (size = dshape ).astype (dtype )
144
+ np_out1 = np .zeros (shape = (batch_size ,))
145
+ np_out2 = np .zeros (shape = dshape ).astype (dtype )
146
+ for i in range (batch_size ):
147
+ np_out1 [i ] = 0
148
+ inter_idx = 0
149
+ for j in range (num_anchor ):
150
+ score = np_data [i , j , 1 ]
151
+ if score >= score_threshold :
152
+ for k in range (elem_length ):
153
+ np_out2 [i , inter_idx , k ] = np_data [i , j , k ]
154
+ np_out1 [i ] += 1
155
+ inter_idx += 1
156
+ if j >= np_out1 [i ]:
157
+ for k in range (elem_length ):
158
+ np_out2 [i , j , k ] = - 1
159
+
160
+ x = relay .var ("x" , relay .ty .TensorType (dshape , dtype ))
161
+ z = relay .vision .get_valid_counts (x , score_threshold )
162
+ assert "score_threshold" in z .astext ()
163
+ func = relay .Function ([x ], z .astuple ())
164
+ func = relay .ir_pass .infer_type (func )
165
+ ctx_list = [("llvm" , tvm .cpu (0 ))]
166
+ for target , ctx in ctx_list :
167
+ intrp = relay .create_executor ("debug" , ctx = ctx , target = target )
168
+ out = intrp .evaluate (func )(np_data )
169
+ tvm .testing .assert_allclose (out [0 ].asnumpy (), np_out1 , rtol = 1e-3 )
170
+ tvm .testing .assert_allclose (out [1 ].asnumpy (), np_out2 , rtol = 1e-3 )
171
+
172
+ verify_get_valid_counts ((1 , 2500 , 6 ), 0 )
173
+ verify_get_valid_counts ((1 , 2500 , 6 ), - 1 )
174
+ verify_get_valid_counts ((3 , 1000 , 6 ), 0.55 )
175
+ verify_get_valid_counts ((16 , 500 , 6 ), 0.95 )
176
+
177
+
139
178
def test_nms ():
140
- def verify_nms (x0_data , x1_data , dshape , ref_res , valid_count ,
179
+ def verify_nms (x0_data , x1_data , dshape , ref_res ,
141
180
overlap_threshold = 0.5 , force_suppress = False , topk = - 1 ,
142
181
check_type_only = False ):
143
182
x0 = relay .var ("x0" , relay .ty .TensorType (dshape , "float32" ))
@@ -166,26 +205,24 @@ def verify_nms(x0_data, x1_data, dshape, ref_res, valid_count,
166
205
[1 , 0.5 , 100 , 60 , 70 , 110 ]]]).astype ("float32" )
167
206
np_valid_count = np .array ([4 ]).astype ("int32" )
168
207
np_result = np .array ([[[2 , 0.9 , 35 , 61 , 52 , 79 ], [0 , 0.8 , 1 , 20 , 25 , 45 ],
169
- [0 , 0.4 , 4 , 21 , 19 , 40 ], [- 1 , 0.9 , 35 , 61 , 52 , 79 ],
208
+ [- 1 , - 1 , - 1 , - 1 , - 1 , - 1 ], [- 1 , - 1 , - 1 , - 1 , - 1 , - 1 ],
170
209
[- 1 , - 1 , - 1 , - 1 , - 1 , - 1 ]]])
171
210
num_anchors = 5
172
211
173
212
dshape = (tvm .var ("n" ), num_anchors , 6 )
174
- verify_nms (np_data , np_valid_count , dshape , np_result , dshape [ 0 ],
213
+ verify_nms (np_data , np_valid_count , dshape , np_result ,
175
214
force_suppress = True , topk = 2 , check_type_only = True )
176
215
dshape = (1 , num_anchors , 6 )
177
- verify_nms (np_data , np_valid_count , dshape , np_result , dshape [ 0 ],
216
+ verify_nms (np_data , np_valid_count , dshape , np_result ,
178
217
force_suppress = True , topk = 2 , check_type_only = False )
179
218
180
219
np_result = np .array ([[[2 , 0.9 , 35 , 61 , 52 , 79 ], [0 , 0.8 , 1 , 20 , 25 , 45 ],
181
- [1 , 0.7 , 30 , 60 , 50 , 80 ], [- 1 , 0.9 , 35 , 61 , 52 , 79 ],
220
+ [1 , 0.7 , 30 , 60 , 50 , 80 ], [- 1 , - 1 , - 1 , - 1 , - 1 , - 1 ],
182
221
[- 1 , - 1 , - 1 , - 1 , - 1 , - 1 ]]])
183
222
dshape = (tvm .var ("n" ), num_anchors , 6 )
184
- verify_nms (np_data , np_valid_count , dshape , np_result , dshape [0 ],
185
- check_type_only = True )
223
+ verify_nms (np_data , np_valid_count , dshape , np_result , check_type_only = True )
186
224
dshape = (1 , num_anchors , 6 )
187
- verify_nms (np_data , np_valid_count , dshape , np_result , dshape [0 ],
188
- topk = 3 )
225
+ verify_nms (np_data , np_valid_count , dshape , np_result , topk = 3 )
189
226
190
227
191
228
def test_multibox_transform_loc ():
@@ -278,4 +315,5 @@ def test_threshold():
278
315
test_resize ()
279
316
test_multibox_prior ()
280
317
test_multibox_transform_loc ()
318
+ test_get_valid_counts ()
281
319
test_nms ()
0 commit comments