Skip to content

Commit 6213c94

Browse files
committed
upd (#17)
1 parent dbe8061 commit 6213c94

File tree

1 file changed

+94
-0
lines changed

1 file changed

+94
-0
lines changed
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
import tvm
18+
import tvm.tir as tir
19+
import scipy.sparse as sp
20+
import numpy as np
21+
from tvm.script import tir as T
22+
23+
24+
@T.prim_func
25+
def csrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle) -> None:
26+
n = T.var("int32")
27+
m = T.var("int32")
28+
k = T.var("int32")
29+
nnz = T.var("int32")
30+
I = T.dense_fixed(m)
31+
J = T.sparse_variable((n, m + 1, nnz), (indptr, indices), "int32")
32+
K = T.dense_fixed(k)
33+
A = T.match_sparse_buffer(a, (I, J), nnz, "float32")
34+
B = T.match_sparse_buffer(b, (T.to_dense(J), K), n * k, "float32")
35+
C = T.match_sparse_buffer(c, (I, K), m * k, "float32")
36+
with T.iter([T.cord(I), T.cord(J), T.cord(K)], "SRS", "csrmm") as [vi, vj, vk]:
37+
with T.init():
38+
C[vi, vk] = 0.0
39+
C[vi, vk] = C[vi, vk] + A[vi, vj] * B[vj, vk]
40+
41+
42+
@T.prim_func
43+
def csrmm_tir(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle) -> None:
44+
T.func_attr({"global_symbol": "main", "tir.noalias": True})
45+
n = T.var("int32")
46+
m = T.var("int32")
47+
k = T.var("int32")
48+
nnz = T.var("int32")
49+
A_data = T.match_buffer(a, (nnz,), "float32")
50+
B = T.match_buffer(b, (n, k), "float32")
51+
C = T.match_buffer(c, (m, k), "float32")
52+
A_indptr = T.match_buffer(indptr, (m + 1,), "int32")
53+
A_indices = T.match_buffer(indices, (nnz,), "int32")
54+
for i, k in T.grid(m, k):
55+
with T.block("spmm_outer"):
56+
vi, vk = T.axis.remap("SS", [i, k])
57+
with T.init():
58+
C[vi, vk] = 0.
59+
for j in T.serial(0, A_indptr[vi + 1] - A_indptr[vi]):
60+
with T.block("spmm_inner"):
61+
vj = T.axis.R(n, j + A_indptr[vi])
62+
C[vi, vk] = C[vi, vk] + A_data[vj] * B[A_indices[vj], vk]
63+
64+
65+
def test_csrmm():
66+
# generate random input
67+
A = sp.random(4096, 4096, dtype="float32", density=0.0125, format='csr')
68+
x = np.random.rand(4096, 256).astype("float32")
69+
y_ground_truth = A * x
70+
y = np.zeros((4096, 256)).astype("float32")
71+
72+
# specialize function
73+
sch = tir.Schedule(csrmm_tir)
74+
blk_outer = sch.get_block("spmm_outer")
75+
i, k = sch.get_loops(blk_outer)
76+
sch.bind(i, "blockIdx.x")
77+
sch.bind(k, "threadIdx.x")
78+
79+
# convert numpy tensor to tvm ndarray
80+
A_indptr = tvm.nd.array(A.indptr.astype("int32"), device=tvm.cuda(0))
81+
A_indices = tvm.nd.array(A.indices.astype("int32"), device=tvm.cuda(0))
82+
A_data = tvm.nd.array(A.data.astype("float32"), device=tvm.cuda(0))
83+
X_nd = tvm.nd.array(x, device=tvm.cuda(0))
84+
Y_nd = tvm.nd.array(y, device=tvm.cuda(0))
85+
86+
# build function
87+
f = tvm.build(sch.mod, target='cuda')
88+
f(A_data, X_nd, Y_nd, A_indptr, A_indices)
89+
90+
assert np.allclose(y_ground_truth, Y_nd.numpy())
91+
92+
93+
if __name__ == "__main__":
94+
test_csrmm()

0 commit comments

Comments
 (0)