@@ -297,6 +297,45 @@ def test_parallel_alloc():
297
297
298
298
assert isinstance (body .body .body .body .body , tvm .tir .Allocate )
299
299
300
+ ib = tvm .tir .ir_builder .create ()
301
+ n = te .var ("n" )
302
+ with ib .for_range (0 , n , name = "i" , kind = "parallel" ) as i :
303
+ j = ib .allocate ("int32" , 1 , name = "j" , scope = "global" )
304
+ j [0 ] = 0
305
+ with ib .while_loop (j [0 ] < 10 ):
306
+ A = ib .allocate ("float32" , n , name = "A" , scope = "global" )
307
+ A [j [0 ]] = A [j [0 ]] + 2
308
+ j [0 ] += j [0 ] + 1
309
+
310
+ body = ib .get ()
311
+ # parallel (i, 0, n) {
312
+ # // attr [j] storage_scope = "global"
313
+ # allocate j[int32 * 1]
314
+ # j[0] = 0
315
+ # while((j[0] < 10)){
316
+ # // attr [A] storage_scope = "global"
317
+ # allocate A[float32 * n]
318
+ # A[j[0]] = (A[j[0]] + 2f)
319
+ # j[0] = (j[0] + (j[0] + 1))
320
+ # }
321
+ # }
322
+
323
+ mod = tvm .IRModule .from_expr (tvm .tir .PrimFunc ([n ], body ))
324
+ body = tvm .tir .transform .StorageRewrite ()(mod )["main" ].body
325
+
326
+ # parallel (i, 0, n) {
327
+ # // attr [j] storage_scope = "global"
328
+ # allocate j[int32 * 1]
329
+ # // attr [A] storage_scope = "global"
330
+ # allocate A[float32 * n]
331
+ # j[0] = 0
332
+ # while((j[0] < 10)){
333
+ # A[j[0]] = (A[j[0]] + 2f)
334
+ # j[0] = (j[0] + (j[0] + 1))
335
+ # }
336
+ # }
337
+ assert isinstance (body .body .body , tvm .tir .Allocate )
338
+
300
339
301
340
def test_inplace_rule2 (scope_tb = "local_TB2" , max_bits = 1024 * 1024 * 1024 ):
302
341
# Test Buffer
0 commit comments