Skip to content

Commit

Permalink
Initialized information of compressed matrix's row, col and ld when u…
Browse files Browse the repository at this point in the history
…sing hipsparseLtSpMMACompress2() directly.
  • Loading branch information
vin-huang committed Sep 5, 2024
1 parent b344687 commit e3ff0a8
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 7 deletions.
53 changes: 46 additions & 7 deletions clients/include/spmm/testing_compress.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,23 @@ void testing_compress(const Arguments& arg)
arg.b_type,
orderB);

hipsparselt_local_mat_descr matAv2(arg.sparse_b ? hipsparselt_matrix_type_dense
: hipsparselt_matrix_type_structured,
handle,
A_row,
A_col,
lda,
arg.a_type,
orderA);
hipsparselt_local_mat_descr matBv2(arg.sparse_b ? hipsparselt_matrix_type_structured
: hipsparselt_matrix_type_dense,
handle,
B_row,
B_col,
ldb,
arg.b_type,
orderB);

hipsparselt_local_mat_descr matC(
hipsparselt_matrix_type_dense, handle, M, N, ldc, arg.c_type, orderC);
hipsparselt_local_mat_descr matD(
Expand Down Expand Up @@ -484,6 +501,14 @@ void testing_compress(const Arguments& arg)
hipsparseLtMatDescSetAttribute(
handle, matB, HIPSPARSELT_MAT_NUM_BATCHES, &num_batches, sizeof(int)),
HIPSPARSE_STATUS_SUCCESS);
EXPECT_HIPSPARSE_STATUS(
hipsparseLtMatDescSetAttribute(
handle, matAv2, HIPSPARSELT_MAT_NUM_BATCHES, &num_batches, sizeof(int)),
HIPSPARSE_STATUS_SUCCESS);
EXPECT_HIPSPARSE_STATUS(
hipsparseLtMatDescSetAttribute(
handle, matBv2, HIPSPARSELT_MAT_NUM_BATCHES, &num_batches, sizeof(int)),
HIPSPARSE_STATUS_SUCCESS);
EXPECT_HIPSPARSE_STATUS(
hipsparseLtMatDescSetAttribute(
handle, matC, HIPSPARSELT_MAT_NUM_BATCHES, &num_batches, sizeof(int)),
Expand All @@ -509,6 +534,20 @@ void testing_compress(const Arguments& arg)
eStatus);
if(eStatus != HIPSPARSE_STATUS_SUCCESS)
return;
eStatus = expected_hipsparse_status_of_matrix_stride(stride_a, A_row, A_col, lda, orderA);
EXPECT_HIPSPARSE_STATUS(
hipsparseLtMatDescSetAttribute(
handle, matAv2, HIPSPARSELT_MAT_BATCH_STRIDE, &stride_a, sizeof(int64_t)),
eStatus);
if(eStatus != HIPSPARSE_STATUS_SUCCESS)
return;
eStatus = expected_hipsparse_status_of_matrix_stride(stride_b, B_row, B_col, ldb, orderB);
EXPECT_HIPSPARSE_STATUS(
hipsparseLtMatDescSetAttribute(
handle, matBv2, HIPSPARSELT_MAT_BATCH_STRIDE, &stride_b, sizeof(int64_t)),
eStatus);
if(eStatus != HIPSPARSE_STATUS_SUCCESS)
return;
eStatus = expected_hipsparse_status_of_matrix_stride(stride_c, M, N, ldc, orderC);
EXPECT_HIPSPARSE_STATUS(
hipsparseLtMatDescSetAttribute(
Expand All @@ -525,13 +564,13 @@ void testing_compress(const Arguments& arg)
return;
}

hipsparselt_local_matmul_descr matmul(
handle, transA, transB, matA, matB, matC, matD, arg.compute_type);
hipsparselt_local_matmul_descr matmul(
handle, transA, transB, matA, matB, matC, matD, arg.compute_type);

hipsparselt_local_matmul_alg_selection alg_sel(handle, matmul, HIPSPARSELT_MATMUL_ALG_DEFAULT);
hipsparselt_local_matmul_alg_selection alg_sel(handle, matmul, HIPSPARSELT_MATMUL_ALG_DEFAULT);

size_t workspace_size, compressed_size, compress_buffer_size;
hipsparselt_local_matmul_plan plan(handle, matmul, alg_sel);
hipsparselt_local_matmul_plan plan(handle, matmul, alg_sel);

hipsparseLtMatmulGetWorkspace(handle, plan, &workspace_size);

Expand All @@ -545,7 +584,7 @@ void testing_compress(const Arguments& arg)
{
EXPECT_HIPSPARSE_STATUS(
hipsparseLtSpMMACompressedSize2(
handle, arg.sparse_b ? matB : matA, &compressed_size, &compress_buffer_size),
handle, arg.sparse_b ? matBv2 : matAv2, &compressed_size, &compress_buffer_size),
HIPSPARSE_STATUS_SUCCESS);
}
const size_t size_A = stride_a == 0
Expand Down Expand Up @@ -623,7 +662,7 @@ void testing_compress(const Arguments& arg)
else if(arg.func_version == 2)
{
EXPECT_HIPSPARSE_STATUS(hipsparseLtSpMMAPrune2(handle,
arg.sparse_b ? matB : matA,
arg.sparse_b ? matBv2 : matAv2,
!arg.sparse_b,
arg.sparse_b ? transB : transA,
dT,
Expand Down Expand Up @@ -717,7 +756,7 @@ void testing_compress(const Arguments& arg)
HIPSPARSE_STATUS_SUCCESS);
else if(arg.func_version == 2)
EXPECT_HIPSPARSE_STATUS(hipsparseLtSpMMACompress2(handle,
arg.sparse_b ? matB : matA,
arg.sparse_b ? matBv2 : matAv2,
!arg.sparse_b,
arg.sparse_b ? transB : transA,
dT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,34 @@ rocsparselt_status rocsparselt_smfmac_compress2(const rocsparselt_handle* han
return rocsparselt_status_not_implemented;
}

if(isSparseA)
{
auto m = _sparseMatDescr->m;
auto k = _sparseMatDescr->n;
if (op == rocsparselt_operation_transpose)
std::swap(m, k);
_sparseMatDescr->c_k = k / 2;
_sparseMatDescr->c_ld = m;
_sparseMatDescr->c_n = _sparseMatDescr->c_k;
if((op == rocsparselt_operation_transpose)
!= (_sparseMatDescr->order == rocsparselt_order_row))
std::swap(_sparseMatDescr->c_ld, _sparseMatDescr->c_n);
}
else
{
auto k = _sparseMatDescr->m;
auto n = _sparseMatDescr->n;
if (op == rocsparselt_operation_transpose)
std::swap(n, k);
_sparseMatDescr->c_k = k / 2;
_sparseMatDescr->c_ld = _sparseMatDescr->c_k;
_sparseMatDescr->c_n = n;
if((op == rocsparselt_operation_transpose)
!= (_sparseMatDescr->order == rocsparselt_order_row))
std::swap(_sparseMatDescr->c_ld, _sparseMatDescr->c_n);
}


log_api(_handle,
__func__,
"sparseMatDescr[in]",
Expand Down

0 comments on commit e3ff0a8

Please sign in to comment.