@@ -118,9 +118,89 @@ def rearrange_indices_out_ir(data, output, valid_box_count):
118
118
return ib .get ()
119
119
120
120
121
- def get_valid_counts_ir (
122
- data , valid_count , out , out_indices , score_threshold , id_index , score_index
123
- ):
121
+ def get_valid_boxes_ir (data , valid_boxes , score_threshold , id_index , score_index ):
122
+ batch_size = data .shape [0 ]
123
+ num_anchors = data .shape [1 ]
124
+ elem_length = data .shape [2 ]
125
+
126
+ ib = tvm .tir .ir_builder .create ()
127
+
128
+ data = ib .buffer_ptr (data )
129
+
130
+ valid_boxes = ib .buffer_ptr (valid_boxes )
131
+ if isinstance (score_threshold , float ):
132
+ score_threshold = tvm .tir .FloatImm ("float32" , score_threshold )
133
+ id_index = tvm .tir .IntImm ("int32" , id_index )
134
+ score_index = tvm .tir .IntImm ("int32" , score_index )
135
+
136
+ max_threads = int (tvm .target .Target .current (allow_none = False ).max_num_threads )
137
+ with ib .new_scope ():
138
+ nthread_tx = max_threads
139
+ nthread_bx = num_anchors // max_threads + 1
140
+ nthread_by = batch_size
141
+ tx = te .thread_axis ("threadIdx.x" )
142
+ bx = te .thread_axis ("blockIdx.x" )
143
+ by = te .thread_axis ("blockIdx.y" )
144
+ ib .scope_attr (tx , "thread_extent" , nthread_tx )
145
+ ib .scope_attr (bx , "thread_extent" , nthread_bx )
146
+ ib .scope_attr (by , "thread_extent" , nthread_by )
147
+ tid = bx * max_threads + tx
148
+
149
+ with ib .if_scope (tid < num_anchors ):
150
+ i = by
151
+ j = tid
152
+ score = data [(i * num_anchors + j ) * elem_length + score_index ]
153
+ with ib .if_scope (
154
+ tvm .tir .all (
155
+ score > score_threshold ,
156
+ tvm .tir .any (
157
+ id_index < 0 , data [(i * num_anchors + j ) * elem_length + id_index ] >= 0
158
+ ),
159
+ )
160
+ ):
161
+ valid_boxes [i * num_anchors + j ] = 1
162
+ with ib .else_scope ():
163
+ valid_boxes [i * num_anchors + j ] = 0
164
+ return ib .get ()
165
+
166
+
167
+ def get_valid_indices_ir (valid_boxes , valid_count , valid_indices ):
168
+ batch_size = valid_boxes .shape [0 ]
169
+ num_anchors = valid_boxes .shape [1 ]
170
+
171
+ ib = tvm .tir .ir_builder .create ()
172
+
173
+ valid_boxes = ib .buffer_ptr (valid_boxes )
174
+
175
+ valid_count = ib .buffer_ptr (valid_count )
176
+ valid_indices = ib .buffer_ptr (valid_indices )
177
+
178
+ max_threads = int (tvm .target .Target .current (allow_none = False ).max_num_threads )
179
+ with ib .new_scope ():
180
+ nthread_tx = max_threads
181
+ nthread_bx = batch_size // max_threads + 1
182
+ tx = te .thread_axis ("threadIdx.x" )
183
+ bx = te .thread_axis ("blockIdx.x" )
184
+ ib .scope_attr (tx , "thread_extent" , nthread_tx )
185
+ ib .scope_attr (bx , "thread_extent" , nthread_bx )
186
+ tid = bx * max_threads + tx
187
+ # TODO(mbrookhart): Parallelize the sum and cumsum here
188
+ current_index = ib .allocate ("int32" , (1 ,), name = "current_index" , scope = "local" )
189
+ with ib .if_scope (tid < batch_size ):
190
+ current_index [0 ] = 0
191
+ valid_count [tid ] = 0
192
+ with ib .for_range (0 , num_anchors ) as j :
193
+ idx = tid * num_anchors + j
194
+ valid_count [tid ] = valid_count [tid ] + valid_boxes [idx ]
195
+ with ib .if_scope (valid_boxes [idx ] == 1 ):
196
+ valid_indices [idx ] = current_index [0 ]
197
+ current_index [0 ] = current_index [0 ] + 1
198
+ with ib .else_scope ():
199
+ valid_indices [idx ] = - 1
200
+ return ib .get ()
201
+
202
+
203
+ def get_valid_counts_ir (data , valid_indices , out , out_indices ):
124
204
"""Low level IR to get valid count of bounding boxes
125
205
given a score threshold. Also prepares to move valid boxes to the
126
206
top of input data.
@@ -158,47 +238,51 @@ def get_valid_counts_ir(
158
238
159
239
data = ib .buffer_ptr (data )
160
240
161
- valid_count = ib .buffer_ptr (valid_count )
241
+ valid_indices = ib .buffer_ptr (valid_indices )
162
242
out = ib .buffer_ptr (out )
163
243
out_indices = ib .buffer_ptr (out_indices )
164
244
one = tvm .tir .const (1 , dtype = out .dtype )
165
- if isinstance (score_threshold , float ):
166
- score_threshold = tvm .tir .FloatImm ("float32" , score_threshold )
167
- id_index = tvm .tir .IntImm ("int32" , id_index )
168
- score_index = tvm .tir .IntImm ("int32" , score_index )
169
245
170
246
max_threads = int (tvm .target .Target .current (allow_none = False ).max_num_threads )
247
+ nthread_tx = max_threads
248
+ nthread_bx = num_anchors // max_threads + 1
249
+ nthread_by = batch_size
250
+ nthread_bz = elem_length
171
251
with ib .new_scope ():
172
- nthread_tx = max_threads
173
- nthread_bx = batch_size // max_threads + 1
174
252
tx = te .thread_axis ("threadIdx.x" )
175
253
bx = te .thread_axis ("blockIdx.x" )
254
+ by = te .thread_axis ("blockIdx.y" )
255
+ bz = te .thread_axis ("blockIdx.z" )
176
256
ib .scope_attr (tx , "thread_extent" , nthread_tx )
177
257
ib .scope_attr (bx , "thread_extent" , nthread_bx )
258
+ ib .scope_attr (by , "thread_extent" , nthread_by )
259
+ ib .scope_attr (bz , "thread_extent" , nthread_bz )
178
260
tid = bx * max_threads + tx
179
- with ib .if_scope (tid < batch_size ):
180
- valid_count [tid ] = 0
181
- i = tid
182
- with ib .for_range (0 , num_anchors ) as j :
183
- score = data [(i * num_anchors + j ) * elem_length + score_index ]
184
- with ib .if_scope (
185
- tvm .tir .all (
186
- score > score_threshold ,
187
- tvm .tir .any (
188
- id_index < 0 , data [(i * num_anchors + j ) * elem_length + id_index ] >= 0
189
- ),
190
- )
191
- ):
192
- with ib .for_range (0 , elem_length ) as k :
193
- out [(i * num_anchors + valid_count [i ]) * elem_length + k ] = data [
194
- (i * num_anchors + j ) * elem_length + k
195
- ]
196
- out_indices [i * num_anchors + valid_count [i ]] = j
197
- valid_count [i ] += 1
198
- with ib .if_scope (j >= valid_count [i ]):
199
- with ib .for_range (0 , elem_length ) as k :
200
- out [(i * num_anchors + j ) * elem_length + k ] = - one
201
- out_indices [i * num_anchors + j ] = - 1
261
+ with ib .if_scope (tid < num_anchors ):
262
+ i = by
263
+ j = tid
264
+ k = bz
265
+ out [(i * num_anchors + j ) * elem_length + k ] = - one
266
+ out_indices [i * num_anchors + j ] = - 1
267
+ with ib .new_scope ():
268
+ tx = te .thread_axis ("threadIdx.x" )
269
+ bx = te .thread_axis ("blockIdx.x" )
270
+ by = te .thread_axis ("blockIdx.y" )
271
+ bz = te .thread_axis ("blockIdx.z" )
272
+ ib .scope_attr (tx , "thread_extent" , nthread_tx )
273
+ ib .scope_attr (bx , "thread_extent" , nthread_bx )
274
+ ib .scope_attr (by , "thread_extent" , nthread_by )
275
+ ib .scope_attr (bz , "thread_extent" , nthread_bz )
276
+ tid = bx * max_threads + tx
277
+ with ib .if_scope (tid < num_anchors ):
278
+ i = by
279
+ j = tid
280
+ k = bz
281
+ with ib .if_scope (valid_indices [i , tid ] >= 0 ):
282
+ out [(i * num_anchors + valid_indices [i , tid ]) * elem_length + k ] = data [
283
+ (i * num_anchors + j ) * elem_length + k
284
+ ]
285
+ out_indices [i * num_anchors + valid_indices [i , tid ]] = j
202
286
return ib .get ()
203
287
204
288
@@ -231,23 +315,51 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1):
231
315
batch_size = data .shape [0 ]
232
316
num_anchors = data .shape [1 ]
233
317
data_buf = tvm .tir .decl_buffer (data .shape , data .dtype , "data_buf" , data_alignment = 8 )
318
+ valid_boxes_buf = tvm .tir .decl_buffer (
319
+ (batch_size , num_anchors ), "int32" , "valid_boxes_buf" , data_alignment = 8
320
+ )
321
+ valid_boxes = te .extern (
322
+ [(batch_size , num_anchors )],
323
+ [data ],
324
+ lambda ins , outs : get_valid_boxes_ir (
325
+ ins [0 ], outs [0 ], score_threshold , id_index , score_index
326
+ ),
327
+ dtype = ["int32" ],
328
+ in_buffers = [data_buf ],
329
+ out_buffers = [valid_boxes_buf ],
330
+ name = "get_valid_boxes" ,
331
+ tag = "get_valid_boxes_gpu" ,
332
+ )
333
+
334
+ valid_indices_buf = tvm .tir .decl_buffer (
335
+ (batch_size , num_anchors ), "int32" , "valid_indices_buf" , data_alignment = 8
336
+ )
234
337
valid_count_buf = tvm .tir .decl_buffer (
235
338
(batch_size ,), "int32" , "valid_count_buf" , data_alignment = 8
236
339
)
340
+ valid_count , valid_indices = te .extern (
341
+ [(batch_size ,), (batch_size , num_anchors )],
342
+ [valid_boxes ],
343
+ lambda ins , outs : get_valid_indices_ir (ins [0 ], outs [0 ], outs [1 ]),
344
+ dtype = ["int32" ],
345
+ in_buffers = [valid_boxes_buf ],
346
+ out_buffers = [valid_count_buf , valid_indices_buf ],
347
+ name = "get_valid_indices" ,
348
+ tag = "get_valid_indices_gpu" ,
349
+ )
350
+
237
351
out_buf = tvm .tir .decl_buffer (data .shape , data .dtype , "out_buf" , data_alignment = 8 )
238
352
out_indices_buf = tvm .tir .decl_buffer (
239
353
(batch_size , num_anchors ), "int32" , "out_buf" , data_alignment = 8
240
354
)
241
355
242
- valid_count , out , out_indices = te .extern (
243
- [(batch_size ,), data .shape , (batch_size , num_anchors )],
244
- [data ],
245
- lambda ins , outs : get_valid_counts_ir (
246
- ins [0 ], outs [0 ], outs [1 ], outs [2 ], score_threshold , id_index , score_index
247
- ),
356
+ out , out_indices = te .extern (
357
+ [data .shape , (batch_size , num_anchors )],
358
+ [data , valid_indices ],
359
+ lambda ins , outs : get_valid_counts_ir (ins [0 ], ins [1 ], outs [0 ], outs [1 ]),
248
360
dtype = ["int32" , data .dtype ],
249
- in_buffers = [data_buf ],
250
- out_buffers = [valid_count_buf , out_buf , out_indices_buf ],
361
+ in_buffers = [data_buf , valid_indices_buf ],
362
+ out_buffers = [out_buf , out_indices_buf ],
251
363
name = "get_valid_counts" ,
252
364
tag = "get_valid_counts_gpu" ,
253
365
)
0 commit comments