Skip to content

Commit d3d75a7

Browse files
authored
Add more examples part 1 (sddmm) (#22)
* upd * upd * upd
1 parent 9f6a2cd commit d3d75a7

File tree

3 files changed

+89
-23
lines changed

3 files changed

+89
-23
lines changed

src/target/source/literal/cuda_binary_search.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@
2626

2727
static constexpr const char* _cuda_binary_search_def = R"(
2828
template <typename DType>
29-
__forceinline__ __device__ int32_t __lower_bound(
29+
__forceinline__ __device__ int __lower_bound(
3030
const DType* __restrict__ arr,
3131
DType val,
32-
int32_t l,
33-
int32_t r) {
34-
int32_t low = l - 1, high = r;
32+
int l,
33+
int r) {
34+
int low = l - 1, high = r;
3535
/* loop invariant: low < mid < high, arr[low] < val, arr[high] >= val */
3636
while (low + 1 < high) {
37-
int32_t mid = (low + high) >> 1;
37+
int mid = (low + high) >> 1;
3838
if (arr[mid] < val) {
3939
low = mid;
4040
} else {
@@ -46,15 +46,15 @@ __forceinline__ __device__ int32_t __lower_bound(
4646
}
4747
4848
template <typename DType>
49-
__forceinline__ __device__ int32_t __upper_bound(
49+
__forceinline__ __device__ int __upper_bound(
5050
const DType* __restrict__ arr,
5151
DType val,
52-
int32_t l,
53-
int32_t r) {
54-
int32_t low = l - 1, high = r;
52+
int l,
53+
int r) {
54+
int low = l - 1, high = r;
5555
/* loop invariant: low < mid < high, arr[low] < val, arr[high] > val */
5656
while (low + 1 < high) {
57-
int32_t mid = (low + high) >> 1;
57+
int mid = (low + high) >> 1;
5858
if (arr[mid] > val) {
5959
high = mid;
6060
} else {

tests/python/sparsetir/test_tir_sparse_correctness.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,25 @@ def ellmm_tir(a: T.handle, b: T.handle, c: T.handle, indices: T.handle, M: T.int
9292
B[A_indices[vi * NNZ_COLS + vj] * K + vk]
9393

9494

95+
@T.prim_func
96+
def sddmm_tir(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, M: T.int32, N: T.int32, K: T.int32, NNZ: T.int32) -> None:
97+
T.func_attr({"global_symbol": "main", "tir.noalis": True})
98+
A = T.match_buffer(a, (M * K,), "float32")
99+
B = T.match_buffer(b, (N * K,), "float32")
100+
C_data = T.match_buffer(c, (NNZ,), "float32")
101+
C_indptr = T.match_buffer(indptr, (M + 1,), "int32")
102+
C_indices = T.match_buffer(indices, (NNZ,), "int32")
103+
for ij, k in T.grid(NNZ, K):
104+
with T.block("sddmm"):
105+
vij, vk = T.axis.remap("SR", [ij, k])
106+
T.reads([A[0: M * K], B[0: N * K], C_data[vij], C_indices[vij], C_indptr[0: M + 1]])
107+
T.writes([C_data[vij]])
108+
with T.init():
109+
C_data[vij] = 0.
110+
C_data[vij] = C_data[vij] + \
111+
A[T.lower_bound(C_indptr.data, vij, 0, M + 1) * K + vk] * B[C_indices[vij] * K + vk]
112+
113+
95114
def test_csrmm():
96115
# generate random input
97116
m = 4096
@@ -219,6 +238,50 @@ def test_ellmm():
219238
assert np.allclose(y_ground_truth.reshape(-1), Y_nd.numpy())
220239

221240

241+
def test_sddmm():
242+
# generate random input
243+
m = 4096
244+
n = 4096
245+
k = 256
246+
C = sp.random(m, n, dtype="float32", density=0.0125, format='csr')
247+
indptr = C.indptr
248+
indices = C.indices
249+
C_coo = C.tocoo()
250+
nnz = C.nnz
251+
x = np.random.rand(m, k).astype("float32")
252+
y = np.random.rand(n, k).astype("float32")
253+
z_ground_truth = np.matmul(x, y.transpose())[C_coo.row, C_coo.col]
254+
z = np.zeros((nnz,)).astype("float32")
255+
256+
# specialize function
257+
_, _, _, _, _, M, N, K, NNZ = sddmm_tir.params
258+
sch = tir.Schedule(
259+
sddmm_tir.specialize(
260+
{M: m, N: n, K: k, NNZ: nnz}
261+
)
262+
)
263+
blk = sch.get_block("sddmm")
264+
ij, k = sch.get_loops(blk)
265+
#sch.decompose_reduction(blk, ij)
266+
sch.bind(ij, "blockIdx.x")
267+
ko, ki = sch.split(k, [None, 1])
268+
sch.bind(ki, "threadIdx.x")
269+
270+
# convert numpy tensor to tvm ndarray
271+
C_indices = tvm.nd.array(indices.astype("int32"), device=tvm.cuda(0))
272+
C_indptr = tvm.nd.array(indptr.astype("int32"), device=tvm.cuda(0))
273+
X_nd = tvm.nd.array(x.reshape(-1), device=tvm.cuda(0))
274+
Y_nd = tvm.nd.array(y.reshape(-1), device=tvm.cuda(0))
275+
C_data = tvm.nd.array(z, device=tvm.cuda(0))
276+
277+
# build function
278+
f = tvm.build(sch.mod['main'], target="cuda")
279+
f(X_nd, Y_nd, C_data, C_indptr, C_indices)
280+
281+
# assertion
282+
np.allclose(z_ground_truth, C_data.numpy())
283+
284+
222285
def test_bmm():
223286
# TODO(zihao)
224287
pass
@@ -228,4 +291,5 @@ def test_bmm():
228291
test_csrmm()
229292
test_bsrmm()
230293
test_ellmm()
294+
test_sddmm()
231295
test_bmm()

tests/python/unittest/test_tir_intrin.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -253,19 +253,21 @@ def test_fma():
253253
assert mod["test_tir_fma"].body.body.value.op.name == "tir.call_llvm_pure_intrin"
254254

255255

256-
@tvm.script.tir
257-
def binary_search(a: ty.handle, b: ty.handle, c: ty.handle, d: ty.handle) -> None:
258-
n = tir.var('int32')
259-
m = tir.var('int32')
260-
A = tir.match_buffer(a, (n,), dtype='int32')
261-
B = tir.match_buffer(b, (m,), dtype='int32')
262-
C = tir.match_buffer(c, (m,), dtype='int32')
263-
D = tir.match_buffer(d, (m,), dtype='int32')
264-
with tir.block([m], 'search') as [vi]:
265-
tir.reads([A[0:n], B[vi]])
266-
tir.writes([C[vi], D[vi]])
267-
C[vi] = tir.lower_bound(A.data, B[vi], 0, n)
268-
D[vi] = tir.upper_bound(A.data, B[vi], 0, n)
256+
@T.prim_func
257+
def binary_search(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None:
258+
n = T.var('int32')
259+
m = T.var('int32')
260+
A = T.match_buffer(a, (n,), dtype='int32')
261+
B = T.match_buffer(b, (m,), dtype='int32')
262+
C = T.match_buffer(c, (m,), dtype='int32')
263+
D = T.match_buffer(d, (m,), dtype='int32')
264+
for i in T.serial(0, m):
265+
with T.block('search'):
266+
vi = T.axis.S(m, i)
267+
T.reads([A[0:n], B[vi]])
268+
T.writes([C[vi], D[vi]])
269+
C[vi] = T.lower_bound(A.data, B[vi], 0, n)
270+
D[vi] = T.upper_bound(A.data, B[vi], 0, n)
269271

270272

271273
def test_binary_search():

0 commit comments

Comments
 (0)