diff --git a/clients/include/spmm/testing_compress.hpp b/clients/include/spmm/testing_compress.hpp index 8aab6f0e..1a639575 100644 --- a/clients/include/spmm/testing_compress.hpp +++ b/clients/include/spmm/testing_compress.hpp @@ -447,6 +447,23 @@ void testing_compress(const Arguments& arg) arg.b_type, orderB); + hipsparselt_local_mat_descr matAv2(arg.sparse_b ? hipsparselt_matrix_type_dense + : hipsparselt_matrix_type_structured, + handle, + A_row, + A_col, + lda, + arg.a_type, + orderA); + hipsparselt_local_mat_descr matBv2(arg.sparse_b ? hipsparselt_matrix_type_structured + : hipsparselt_matrix_type_dense, + handle, + B_row, + B_col, + ldb, + arg.b_type, + orderB); + hipsparselt_local_mat_descr matC( hipsparselt_matrix_type_dense, handle, M, N, ldc, arg.c_type, orderC); hipsparselt_local_mat_descr matD( @@ -484,6 +501,14 @@ void testing_compress(const Arguments& arg) hipsparseLtMatDescSetAttribute( handle, matB, HIPSPARSELT_MAT_NUM_BATCHES, &num_batches, sizeof(int)), HIPSPARSE_STATUS_SUCCESS); + EXPECT_HIPSPARSE_STATUS( + hipsparseLtMatDescSetAttribute( + handle, matAv2, HIPSPARSELT_MAT_NUM_BATCHES, &num_batches, sizeof(int)), + HIPSPARSE_STATUS_SUCCESS); + EXPECT_HIPSPARSE_STATUS( + hipsparseLtMatDescSetAttribute( + handle, matBv2, HIPSPARSELT_MAT_NUM_BATCHES, &num_batches, sizeof(int)), + HIPSPARSE_STATUS_SUCCESS); EXPECT_HIPSPARSE_STATUS( hipsparseLtMatDescSetAttribute( handle, matC, HIPSPARSELT_MAT_NUM_BATCHES, &num_batches, sizeof(int)), @@ -509,6 +534,20 @@ void testing_compress(const Arguments& arg) eStatus); if(eStatus != HIPSPARSE_STATUS_SUCCESS) return; + eStatus = expected_hipsparse_status_of_matrix_stride(stride_a, A_row, A_col, lda, orderA); + EXPECT_HIPSPARSE_STATUS( + hipsparseLtMatDescSetAttribute( + handle, matAv2, HIPSPARSELT_MAT_BATCH_STRIDE, &stride_a, sizeof(int64_t)), + eStatus); + if(eStatus != HIPSPARSE_STATUS_SUCCESS) + return; + eStatus = expected_hipsparse_status_of_matrix_stride(stride_b, B_row, B_col, ldb, orderB); + EXPECT_HIPSPARSE_STATUS( + hipsparseLtMatDescSetAttribute( + handle, matBv2, HIPSPARSELT_MAT_BATCH_STRIDE, &stride_b, sizeof(int64_t)), + eStatus); + if(eStatus != HIPSPARSE_STATUS_SUCCESS) + return; eStatus = expected_hipsparse_status_of_matrix_stride(stride_c, M, N, ldc, orderC); EXPECT_HIPSPARSE_STATUS( hipsparseLtMatDescSetAttribute( @@ -525,13 +564,13 @@ void testing_compress(const Arguments& arg) return; } - hipsparselt_local_matmul_descr matmul( - handle, transA, transB, matA, matB, matC, matD, arg.compute_type); + hipsparselt_local_matmul_descr matmul( + handle, transA, transB, matA, matB, matC, matD, arg.compute_type); - hipsparselt_local_matmul_alg_selection alg_sel(handle, matmul, HIPSPARSELT_MATMUL_ALG_DEFAULT); + hipsparselt_local_matmul_alg_selection alg_sel(handle, matmul, HIPSPARSELT_MATMUL_ALG_DEFAULT); size_t workspace_size, compressed_size, compress_buffer_size; - hipsparselt_local_matmul_plan plan(handle, matmul, alg_sel); + hipsparselt_local_matmul_plan plan(handle, matmul, alg_sel); hipsparseLtMatmulGetWorkspace(handle, plan, &workspace_size); @@ -545,7 +584,7 @@ void testing_compress(const Arguments& arg) { EXPECT_HIPSPARSE_STATUS( hipsparseLtSpMMACompressedSize2( - handle, arg.sparse_b ? matB : matA, &compressed_size, &compress_buffer_size), + handle, arg.sparse_b ? matBv2 : matAv2, &compressed_size, &compress_buffer_size), HIPSPARSE_STATUS_SUCCESS); } const size_t size_A = stride_a == 0 @@ -623,7 +662,7 @@ void testing_compress(const Arguments& arg) else if(arg.func_version == 2) { EXPECT_HIPSPARSE_STATUS(hipsparseLtSpMMAPrune2(handle, - arg.sparse_b ? matB : matA, + arg.sparse_b ? matBv2 : matAv2, !arg.sparse_b, arg.sparse_b ? transB : transA, dT, @@ -717,7 +756,7 @@ void testing_compress(const Arguments& arg) HIPSPARSE_STATUS_SUCCESS); else if(arg.func_version == 2) EXPECT_HIPSPARSE_STATUS(hipsparseLtSpMMACompress2(handle, - arg.sparse_b ? matB : matA, + arg.sparse_b ? matBv2 : matAv2, !arg.sparse_b, arg.sparse_b ? transB : transA, dT, diff --git a/library/src/hcc_detail/rocsparselt/src/spmm/rocsparselt_compress.cpp b/library/src/hcc_detail/rocsparselt/src/spmm/rocsparselt_compress.cpp index 5ffb085a..7b6664e0 100644 --- a/library/src/hcc_detail/rocsparselt/src/spmm/rocsparselt_compress.cpp +++ b/library/src/hcc_detail/rocsparselt/src/spmm/rocsparselt_compress.cpp @@ -644,6 +644,34 @@ rocsparselt_status rocsparselt_smfmac_compress2(const rocsparselt_handle* han return rocsparselt_status_not_implemented; } + if(isSparseA) + { + auto m = _sparseMatDescr->m; + auto k = _sparseMatDescr->n; + if (op == rocsparselt_operation_transpose) + std::swap(m, k); + _sparseMatDescr->c_k = k / 2; + _sparseMatDescr->c_ld = m; + _sparseMatDescr->c_n = _sparseMatDescr->c_k; + if((op == rocsparselt_operation_transpose) + != (_sparseMatDescr->order == rocsparselt_order_row)) + std::swap(_sparseMatDescr->c_ld, _sparseMatDescr->c_n); + } + else + { + auto k = _sparseMatDescr->m; + auto n = _sparseMatDescr->n; + if (op == rocsparselt_operation_transpose) + std::swap(n, k); + _sparseMatDescr->c_k = k / 2; + _sparseMatDescr->c_ld = _sparseMatDescr->c_k; + _sparseMatDescr->c_n = n; + if((op == rocsparselt_operation_transpose) + != (_sparseMatDescr->order == rocsparselt_order_row)) + std::swap(_sparseMatDescr->c_ld, _sparseMatDescr->c_n); + } + + log_api(_handle, __func__, "sparseMatDescr[in]",