@@ -69,6 +69,46 @@ def test_bound3():
69
69
assert (bounds [A1 .op .axis [0 ]].extent .value == 32 )
70
70
assert (bounds [A1 .op .axis [1 ]].extent .value == 16 )
71
71
72
+ def test_bound_fusesplit1 ():
73
+ m = tvm .var ('m' )
74
+ l = tvm .var ('l' )
75
+ split = tvm .var ('s' )
76
+ A = tvm .placeholder ((m , l ), name = 'A' )
77
+ A1 = tvm .compute ((m , l ), lambda i , j : A [i , j ], name = 'A1' )
78
+ A2 = tvm .compute ((m , l ), lambda i , j : A1 [i , j ] + 3 , name = 'A2' )
79
+
80
+ s = tvm .create_schedule (A2 .op )
81
+ fused_axes = s [A2 ].fuse (A2 .op .axis [0 ], A2 .op .axis [1 ])
82
+ xo , xi = s [A2 ].split (fused_axes , split )
83
+ s [A1 ].compute_at (s [A2 ], xo )
84
+
85
+ bounds = tvm .schedule .InferBound (s )
86
+ assert isinstance (bounds , tvm .container .Map )
87
+ assert (tvm .ir_pass .Simplify (bounds [A1 .op .axis [0 ]].min - (xo * split ) / l ).value == 0 )
88
+ assert (tvm .ir_pass .Simplify (bounds [A1 .op .axis [0 ]].extent - (((xo + 1 ) * split - 1 ) / l - (xo * split ) / l + 1 )).value == 0 )
89
+ assert (tvm .ir_pass .Simplify (bounds [A1 .op .axis [1 ]].extent - l ).value == 0 )
90
+
91
+ def test_bound_fusesplit2 ():
92
+ m = tvm .var ("m" )
93
+ l = tvm .convert (6 )
94
+ split = tvm .convert (3 )
95
+ A = tvm .placeholder ((m , l ), name = 'A' )
96
+ A1 = tvm .compute ((m , l ), lambda i , j : A [i , j ], name = 'A1' )
97
+ A2 = tvm .compute ((m , l ), lambda i , j : A1 [i , j ] + 3 , name = 'A2' )
98
+
99
+ s = tvm .create_schedule (A2 .op )
100
+ fused_axes = s [A2 ].fuse (A2 .op .axis [0 ], A2 .op .axis [1 ])
101
+ xo , xi = s [A2 ].split (fused_axes , split )
102
+ s [A1 ].compute_at (s [A2 ], xo )
103
+
104
+ bounds = tvm .schedule .InferBound (s )
105
+ assert isinstance (bounds , tvm .container .Map )
106
+ vars = tvm .convert ({xo .var : tvm .const (5 , "int32" )})
107
+ assert (tvm .ir_pass .Simplify (tvm .ir_pass .Substitute (bounds [A1 .op .axis [0 ]].min , vars )).value == 2 )
108
+ assert (tvm .ir_pass .Simplify (tvm .ir_pass .Substitute (bounds [A1 .op .axis [1 ]].min , vars )).value == 3 )
109
+ assert (tvm .ir_pass .Simplify (tvm .ir_pass .Substitute (bounds [A1 .op .axis [0 ]].extent , vars )).value == 1 )
110
+ assert (tvm .ir_pass .Simplify (tvm .ir_pass .Substitute (bounds [A1 .op .axis [1 ]].extent , vars )).value == 3 )
111
+
72
112
73
113
def test_bound_warp ():
74
114
m = tvm .var ('m' )
@@ -320,3 +360,5 @@ def _body():
320
360
test_gemm_bound ()
321
361
test_bound_warp ()
322
362
test_bound_tensor_compute_op ()
363
+ test_bound_fusesplit1 ()
364
+ test_bound_fusesplit2 ()
0 commit comments