Skip to content

Commit 6cb0dca

Browse files
committed
add while loop storage rewrite test
1 parent f442ecc commit 6cb0dca

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

tests/python/unittest/test_tir_transform_storage_rewrite.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,45 @@ def test_parallel_alloc():
297297

298298
assert isinstance(body.body.body.body.body, tvm.tir.Allocate)
299299

300+
ib = tvm.tir.ir_builder.create()
301+
n = te.var("n")
302+
with ib.for_range(0, n, name="i", kind="parallel") as i:
303+
j = ib.allocate("int32", 1, name="j", scope="global")
304+
j[0] = 0
305+
with ib.while_loop(j[0] < 10):
306+
A = ib.allocate("float32", n, name="A", scope="global")
307+
A[j[0]] = A[j[0]] + 2
308+
j[0] += j[0] + 1
309+
310+
body = ib.get()
311+
# parallel (i, 0, n) {
312+
# // attr [j] storage_scope = "global"
313+
# allocate j[int32 * 1]
314+
# j[0] = 0
315+
# while((j[0] < 10)){
316+
# // attr [A] storage_scope = "global"
317+
# allocate A[float32 * n]
318+
# A[j[0]] = (A[j[0]] + 2f)
319+
# j[0] = (j[0] + (j[0] + 1))
320+
# }
321+
# }
322+
323+
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
324+
body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
325+
326+
# parallel (i, 0, n) {
327+
# // attr [j] storage_scope = "global"
328+
# allocate j[int32 * 1]
329+
# // attr [A] storage_scope = "global"
330+
# allocate A[float32 * n]
331+
# j[0] = 0
332+
# while((j[0] < 10)){
333+
# A[j[0]] = (A[j[0]] + 2f)
334+
# j[0] = (j[0] + (j[0] + 1))
335+
# }
336+
# }
337+
assert isinstance(body.body.body, tvm.tir.Allocate)
338+
300339

301340
def test_inplace_rule2(scope_tb="local_TB2", max_bits=1024 * 1024 * 1024):
302341
# Test Buffer

0 commit comments

Comments
 (0)