Skip to content

Commit aa4e8f0

Browse files
yzh119MasterJH5574
authored andcommitted
ELL and BSR correctness test scripts (#19)
1 parent 97e36b3 commit aa4e8f0

File tree

1 file changed

+152
-13
lines changed

1 file changed

+152
-13
lines changed

tests/python/sparsetir/test_tir_sparse_correctness.py

Lines changed: 152 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
import tvm
18+
from tvm.runtime.ndarray import device
1819
import tvm.tir as tir
1920
import scipy.sparse as sp
2021
import numpy as np
@@ -36,49 +37,90 @@ def csrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.ha
3637

3738

3839
@T.prim_func
39-
def csrmm_tir(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, M: T.int32, N: T.int32, K: T.int32, nnz: T.int32) -> None:
40+
def csrmm_tir(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, M: T.int32, N: T.int32, K: T.int32, NNZ: T.int32) -> None:
4041
T.func_attr({"global_symbol": "main", "tir.noalias": True})
41-
A_data = T.match_buffer(a, (nnz,), "float32")
42+
A_data = T.match_buffer(a, (NNZ,), "float32")
4243
B = T.match_buffer(b, (N * K,), "float32")
4344
C = T.match_buffer(c, (M * K,), "float32")
4445
A_indptr = T.match_buffer(indptr, (M + 1,), "int32")
45-
A_indices = T.match_buffer(indices, (nnz,), "int32")
46+
A_indices = T.match_buffer(indices, (NNZ,), "int32")
4647
for i, k in T.grid(M, K):
4748
with T.block("spmm_outer"):
4849
vi, vk = T.axis.remap("SS", [i, k])
4950
with T.init():
5051
C[vi * K + vk] = 0.
5152
for j in T.serial(0, A_indptr[vi + 1] - A_indptr[vi]):
5253
with T.block("spmm_inner"):
53-
vj = T.axis.R(N, j + A_indptr[vi])
54-
C[vi * K + vk] = C[vi * K + vk] + A_data[vj] * B[A_indices[vj] * K + vk]
54+
vj = T.axis.R(NNZ, j + A_indptr[vi])
55+
C[vi * K + vk] = C[vi * K + vk] + \
56+
A_data[vj] * B[A_indices[vj] * K + vk]
57+
58+
59+
@T.prim_func
60+
def bsrmm_tir(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, MB: T.int32, NB: T.int32, K: T.int32, BLOCK_SIZE: T.int32, NNZB: T.int32) -> None:
61+
T.func_attr({"global_symbol": "main", "tir.noalias": True})
62+
A_data = T.match_buffer(a, (NNZB * BLOCK_SIZE * BLOCK_SIZE), "float32")
63+
B = T.match_buffer(b, (NB * BLOCK_SIZE * K,), "float32")
64+
C = T.match_buffer(c, (MB * BLOCK_SIZE * K,), "float32")
65+
A_indptr = T.match_buffer(indptr, (MB + 1,), "int32")
66+
A_indices = T.match_buffer(indices, (NNZB,), "int32")
67+
for io, ii, ji, k in T.grid(MB, BLOCK_SIZE, BLOCK_SIZE, K):
68+
with T.block("spmm_outer"):
69+
vio, vii, vji, vk = T.axis.remap("SSSS", [io, ii, ji, k])
70+
with T.init():
71+
C[(vio * BLOCK_SIZE + vii) * K + vk] = 0.
72+
for jo in T.serial(0, A_indptr[vio + 1] - A_indptr[vio]):
73+
with T.block("spmm_inner"):
74+
vjo = T.axis.R(NNZB, jo + A_indptr[vio])
75+
C[(vio * BLOCK_SIZE + vii) * K + vk] = C[(vio * BLOCK_SIZE + vii) * K + vk] + A_data[(
76+
vjo * BLOCK_SIZE + vii) * BLOCK_SIZE + vji] * B[(A_indices[vjo] * BLOCK_SIZE + vji) * K + vk]
77+
78+
79+
@T.prim_func
80+
def ellmm_tir(a: T.handle, b: T.handle, c: T.handle, indices: T.handle, M: T.int32, N: T.int32, K: T.int32, NNZ_COLS: T.int32) -> None:
81+
T.func_attr({"global_symbol": "main", "tir.noalias": True})
82+
A_data = T.match_buffer(a, (M * NNZ_COLS,), "float32")
83+
B = T.match_buffer(b, (N * K,), "float32")
84+
C = T.match_buffer(c, (M * K,), "float32")
85+
A_indices = T.match_buffer(indices, (M * NNZ_COLS,), "int32")
86+
for i, j, k in T.grid(M, NNZ_COLS, K):
87+
with T.block("spmm"):
88+
vi, vj, vk = T.axis.remap("SRS", [i, j, k])
89+
with T.init():
90+
C[vi * K + vk] = 0.
91+
C[vi * K + vk] = C[vi * K + vk] + A_data[vi * NNZ_COLS + vj] * \
92+
B[A_indices[vi * NNZ_COLS + vj] * K + vk]
5593

5694

5795
def test_csrmm():
5896
# generate random input
59-
A = sp.random(4096, 4096, dtype="float32", density=0.0125, format='csr')
60-
x = np.random.rand(4096, 256).astype("float32")
97+
m = 4096
98+
n = 4096
99+
k = 256
100+
A = sp.random(m, n, dtype="float32", density=0.0125, format='csr')
101+
nnz = A.nnz
102+
x = np.random.rand(n, k).astype("float32")
61103
y_ground_truth = A * x
62-
y = np.zeros((4096, 256)).astype("float32")
104+
y = np.zeros((m * k,)).astype("float32")
63105

64106
# specialize function
65-
_, _, _, _, _, m, n, k, nnz = csrmm_tir.params
107+
_, _, _, _, _, M, N, K, NNZ = csrmm_tir.params
66108
sch = tir.Schedule(
67109
csrmm_tir.specialize(
68-
{m: 4096, n: 4096, k: 256, nnz: A.nnz}
110+
{M: m, N: n, K: k, NNZ: nnz}
69111
)
70112
)
71113
blk_outer = sch.get_block("spmm_outer")
72114
i, k = sch.get_loops(blk_outer)
73115
sch.bind(i, "blockIdx.x")
74116
sch.bind(k, "threadIdx.x")
75-
117+
76118
# convert numpy tensor to tvm ndarray
77119
A_indptr = tvm.nd.array(A.indptr.astype("int32"), device=tvm.cuda(0))
78120
A_indices = tvm.nd.array(A.indices.astype("int32"), device=tvm.cuda(0))
79121
A_data = tvm.nd.array(A.data.astype("float32"), device=tvm.cuda(0))
80122
X_nd = tvm.nd.array(x.reshape(-1), device=tvm.cuda(0))
81-
Y_nd = tvm.nd.array(y.reshape(-1), device=tvm.cuda(0))
123+
Y_nd = tvm.nd.array(y, device=tvm.cuda(0))
82124

83125
# build function
84126
f = tvm.build(sch.mod, target='cuda')
@@ -88,5 +130,102 @@ def test_csrmm():
88130
assert np.allclose(y_ground_truth.reshape(-1), Y_nd.numpy())
89131

90132

133+
def test_bsrmm():
134+
# generate random input
135+
block_size = 1
136+
mb = 64
137+
nb = 64
138+
k = 256
139+
m = mb * block_size
140+
n = nb * block_size
141+
A_block = sp.random(mb, nb, dtype="float32", density=0.05, format='csr')
142+
indptr = A_block.indptr
143+
indices = A_block.indices
144+
nnzb = A_block.nnz
145+
data = np.random.rand(nnzb, block_size, block_size)
146+
A = sp.bsr_matrix((data, indices, indptr), shape=(m, n))
147+
x = np.random.rand(n, k).astype("float32")
148+
y_ground_truth = A * x
149+
y = np.zeros((m * k,)).astype("float32")
150+
151+
# specialize function
152+
_, _, _, _, _, MB, NB, K, BLOCK_SIZE, NNZB = bsrmm_tir.params
153+
sch = tir.Schedule(
154+
bsrmm_tir.specialize(
155+
{MB: mb, NB: nb, K: k, BLOCK_SIZE: block_size, NNZB: nnzb}
156+
)
157+
)
158+
blk_outer = sch.get_block("spmm_outer")
159+
io, ii, ji, k = sch.get_loops(blk_outer)
160+
sch.unroll(ii)
161+
sch.unroll(ji)
162+
sch.bind(io, "blockIdx.x")
163+
sch.bind(k, "threadIdx.x")
164+
165+
# convert numpy tensor to tvm ndarray
166+
A_indptr = tvm.nd.array(indptr.astype("int32"), device=tvm.cuda(0))
167+
A_indices = tvm.nd.array(indices.astype("int32"), device=tvm.cuda(0))
168+
A_data = tvm.nd.array(
169+
data.reshape(-1).astype("float32"), device=tvm.cuda(0))
170+
X_nd = tvm.nd.array(x.reshape(-1), device=tvm.cuda(0))
171+
Y_nd = tvm.nd.array(y, device=tvm.cuda(0))
172+
173+
# build function
174+
f = tvm.build(sch.mod, target="cuda")
175+
f(A_data, X_nd, Y_nd, A_indptr, A_indices)
176+
177+
# assertion
178+
assert np.allclose(y_ground_truth.reshape(-1), Y_nd.numpy())
179+
180+
181+
def test_ellmm():
182+
# generate random input
183+
nnz_cols = 64
184+
m = 4096
185+
n = 4096
186+
k = 256
187+
nnz = nnz_cols * m
188+
indptr = np.arange(0, (m + 1) * nnz_cols, nnz_cols)
189+
indices = np.random.randint(0, n, size=(nnz,))
190+
data = np.random.rand(nnz)
191+
A = sp.csr_matrix((data, indices, indptr), shape=(m, n))
192+
x = np.random.rand(n, k).astype("float32")
193+
y_ground_truth = A * x
194+
y = np.zeros((m * k,)).astype("float32")
195+
# specialize function
196+
_, _, _, _, M, N, K, NNZ_COLS = ellmm_tir.params
197+
sch = tir.Schedule(
198+
ellmm_tir.specialize(
199+
{M: m, N: n, K: k, NNZ_COLS: nnz_cols}
200+
)
201+
)
202+
blk = sch.get_block("spmm")
203+
i, j, k = sch.get_loops(blk)
204+
sch.bind(i, "blockIdx.x")
205+
sch.bind(k, "threadIdx.x")
206+
sch.unroll(j)
207+
208+
# convert numpy tensor to tvm ndarray
209+
A_indices = tvm.nd.array(indices.astype("int32"), device=tvm.cuda(0))
210+
A_data = tvm.nd.array(data.astype("float32"), device=tvm.cuda(0))
211+
X_nd = tvm.nd.array(x.reshape(-1), device=tvm.cuda(0))
212+
Y_nd = tvm.nd.array(y, device=tvm.cuda(0))
213+
214+
# build function
215+
f = tvm.build(sch.mod, target="cuda")
216+
f(A_data, X_nd, Y_nd, A_indices)
217+
218+
# assertion
219+
assert np.allclose(y_ground_truth.reshape(-1), Y_nd.numpy())
220+
221+
222+
def test_bmm():
223+
# TODO(zihao)
224+
pass
225+
226+
91227
if __name__ == "__main__":
92-
test_csrmm()
228+
test_csrmm()
229+
test_bsrmm()
230+
test_ellmm()
231+
test_bmm()

0 commit comments

Comments
 (0)