diff --git a/clients/samples/example_compress.cpp b/clients/samples/example_compress.cpp index c4845982..501ec87b 100644 --- a/clients/samples/example_compress.cpp +++ b/clients/samples/example_compress.cpp @@ -45,7 +45,7 @@ error, \ __FILE__, \ __LINE__); \ - exit(EXIT_FAILURE); \ + throw EXIT_FAILURE; \ } #endif @@ -65,10 +65,52 @@ if(error == HIPSPARSE_STATUS_ARCH_MISMATCH) \ fprintf(stderr, "HIPSPARSE_STATUS_ARCH_MISMATCH"); \ fprintf(stderr, "\n"); \ - exit(EXIT_FAILURE); \ + throw EXIT_FAILURE; \ } #endif +template +class SMART_DESTROYER +{ + typedef To (*destroy_func)(T*); + +public: + SMART_DESTROYER(T* ptr, destroy_func func) + { + _ptr = ptr; + _func = func; + } + ~SMART_DESTROYER() + { + if(_ptr != nullptr) + _func(_ptr); + } + + T* _ptr = nullptr; + destroy_func _func; +}; + +template +class SMART_DESTROYER_NON_PTR +{ + typedef To (*destroy_func)(T); + +public: + SMART_DESTROYER_NON_PTR(T* ptr, destroy_func func) + { + _ptr = ptr; + _func = func; + } + ~SMART_DESTROYER_NON_PTR() + { + if(_ptr != nullptr) + _func(*_ptr); + } + + T* _ptr = nullptr; + destroy_func _func; +}; + inline unsigned char generate_metadata(int a, int b, int c, int d) { unsigned char metadata = (a)&0x03; @@ -709,9 +751,13 @@ void run(int64_t m, int num_streams = 0; hipStream_t stream = nullptr; hipStreamCreate(&stream); + SMART_DESTROYER_NON_PTR sS(&stream, hipStreamDestroy); CHECK_HIP_ERROR(hipMalloc(&d, size * sizeof(T))); + SMART_DESTROYER sD(d, hipFree); + CHECK_HIP_ERROR(hipMalloc(&d_test, size * sizeof(T))); + SMART_DESTROYER sDt(d_test, hipFree); // copy matrices from host to device CHECK_HIP_ERROR(hipMemcpy(d, hp.data(), sizeof(T) * size, hipMemcpyHostToDevice)); @@ -722,6 +768,7 @@ void run(int64_t m, hipsparseLtMatmulPlan_t plan; CHECK_HIPSPARSELT_ERROR(hipsparseLtInit(&handle)); + SMART_DESTROYER sH(&handle, hipsparseLtDestroy); if(!sparse_b) { @@ -735,18 +782,22 @@ void run(int64_t m, type, HIPSPARSE_ORDER_COL, HIPSPARSELT_SPARSITY_50_PERCENT)); - CHECK_HIPSPARSELT_ERROR( - hipsparseLtDenseDescriptorInit(&handle, &matB, n, m, n, 16, type, HIPSPARSE_ORDER_COL)); - CHECK_HIPSPARSELT_ERROR( - hipsparseLtDenseDescriptorInit(&handle, &matC, m, m, m, 16, type, HIPSPARSE_ORDER_COL)); - CHECK_HIPSPARSELT_ERROR( - hipsparseLtDenseDescriptorInit(&handle, &matD, m, m, m, 16, type, HIPSPARSE_ORDER_COL)); } else { CHECK_HIPSPARSELT_ERROR( hipsparseLtDenseDescriptorInit(&handle, &matA, n, m, n, 16, type, HIPSPARSE_ORDER_COL)); + } + SMART_DESTROYER smA( + &matA, hipsparseLtMatDescriptorDestroy); + if(!sparse_b) + { + CHECK_HIPSPARSELT_ERROR( + hipsparseLtDenseDescriptorInit(&handle, &matB, n, m, n, 16, type, HIPSPARSE_ORDER_COL)); + } + else + { CHECK_HIPSPARSELT_ERROR( hipsparseLtStructuredDescriptorInit(&handle, &matB, @@ -757,11 +808,21 @@ void run(int64_t m, type, HIPSPARSE_ORDER_COL, HIPSPARSELT_SPARSITY_50_PERCENT)); - CHECK_HIPSPARSELT_ERROR( - hipsparseLtDenseDescriptorInit(&handle, &matC, n, n, n, 16, type, HIPSPARSE_ORDER_COL)); - CHECK_HIPSPARSELT_ERROR( - hipsparseLtDenseDescriptorInit(&handle, &matD, n, n, n, 16, type, HIPSPARSE_ORDER_COL)); } + SMART_DESTROYER smB( + &matB, hipsparseLtMatDescriptorDestroy); + + auto tmp_m = (!sparse_b) ? m : n; + + CHECK_HIPSPARSELT_ERROR(hipsparseLtDenseDescriptorInit( + &handle, &matC, tmp_m, tmp_m, tmp_m, 16, type, HIPSPARSE_ORDER_COL)); + SMART_DESTROYER smC( + &matC, hipsparseLtMatDescriptorDestroy); + + CHECK_HIPSPARSELT_ERROR(hipsparseLtDenseDescriptorInit( + &handle, &matD, tmp_m, tmp_m, tmp_m, 16, type, HIPSPARSE_ORDER_COL)); + SMART_DESTROYER smD( + &matD, hipsparseLtMatDescriptorDestroy); CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescSetAttribute( &handle, &matA, HIPSPARSELT_MAT_NUM_BATCHES, &batch_count, sizeof(batch_count))); @@ -811,6 +872,8 @@ void run(int64_t m, size_t workspace_size, compressed_size, compress_buffer_size; CHECK_HIPSPARSELT_ERROR(hipsparseLtMatmulPlanInit(&handle, &plan, &matmul, &alg_sel)); + SMART_DESTROYER sP( + &plan, hipsparseLtMatmulPlanDestroy); CHECK_HIPSPARSELT_ERROR(hipsparseLtMatmulGetWorkspace(&handle, &plan, &workspace_size)); @@ -824,7 +887,11 @@ void run(int64_t m, std::vector hp_compressed(compressed_size / sizeof(T)); std::vector hp_compressBuffer(compress_buffer_size / sizeof(T)); CHECK_HIP_ERROR(hipMalloc(&d_compressed, compressed_size)); + SMART_DESTROYER sDc(d_compressed, hipFree); + CHECK_HIP_ERROR(hipMalloc(&d_compressBuffer, compress_buffer_size)); + SMART_DESTROYER sDcB(d_compressBuffer, hipFree); + CHECK_HIPSPARSELT_ERROR( hipsparseLtSpMMACompress(&handle, &plan, d_test, d_compressed, d_compressBuffer, stream)); hipStreamSynchronize(stream); @@ -1028,89 +1095,84 @@ void run(int64_t m, //validate(&hp_test[0], &hp_compressed[0], reinterpret_cast(&hp_compressed[c_stride_b * batch_count]), m, n, batch_count, stride_1, stride_2, stride, m, n/2, batch_count, c_stride_1, c_stride_2, c_stride_b, m, n/8, batch_count, m_stride_1, m_stride_2, m_stride_b); //CHECK_HIP_ERROR(hipMalloc(&d_compressed, compressed_size)); - - CHECK_HIP_ERROR(hipFree(d)); - CHECK_HIP_ERROR(hipFree(d_test)); - CHECK_HIP_ERROR(hipFree(d_compressed)); - CHECK_HIP_ERROR(hipFree(d_compressBuffer)); - - CHECK_HIPSPARSELT_ERROR(hipsparseLtMatmulPlanDestroy(&plan)); - CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescriptorDestroy(&matA)); - CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescriptorDestroy(&matB)); - CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescriptorDestroy(&matC)); - CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescriptorDestroy(&matD)); - CHECK_HIPSPARSELT_ERROR(hipsparseLtDestroy(&handle)); } int main(int argc, char* argv[]) { - // initialize parameters with default values - hipsparseOperation_t trans = HIPSPARSE_OPERATION_NON_TRANSPOSE; + try + { + // initialize parameters with default values + hipsparseOperation_t trans = HIPSPARSE_OPERATION_NON_TRANSPOSE; - // invalid int and float for hipsparselt spmm int and float arguments - int64_t invalid_int64 = std::numeric_limits::min() + 1; - int invalid_int = std::numeric_limits::min() + 1; - float invalid_float = std::numeric_limits::quiet_NaN(); + // invalid int and float for hipsparselt spmm int and float arguments + int64_t invalid_int64 = std::numeric_limits::min() + 1; + int invalid_int = std::numeric_limits::min() + 1; + float invalid_float = std::numeric_limits::quiet_NaN(); - // initialize to invalid value to detect if values not specified on command line - int64_t m = invalid_int64, n = invalid_int64, ld = invalid_int64, stride = invalid_int64; + // initialize to invalid value to detect if values not specified on command line + int64_t m = invalid_int64, n = invalid_int64, ld = invalid_int64, stride = invalid_int64; - int batch_count = invalid_int; - hipDataType type = HIP_R_16F; + int batch_count = invalid_int; + hipDataType type = HIP_R_16F; - bool verbose = false; - bool header = false; - bool sparse_b = false; + bool verbose = false; + bool header = false; + bool sparse_b = false; - if(parse_arguments( - argc, argv, m, n, ld, stride, batch_count, trans, type, header, sparse_b, verbose)) - { - show_usage(argv); - return EXIT_FAILURE; - } + if(parse_arguments( + argc, argv, m, n, ld, stride, batch_count, trans, type, header, sparse_b, verbose)) + { + show_usage(argv); + return EXIT_FAILURE; + } - // when arguments not specified, set to default values - if(m == invalid_int64) - m = DIM1; - if(n == invalid_int64) - n = DIM2; - if(ld == invalid_int64) - ld = trans == HIPSPARSE_OPERATION_NON_TRANSPOSE ? m : n; - if(stride == invalid_int64) - stride = trans == HIPSPARSE_OPERATION_NON_TRANSPOSE ? ld * n : ld * m; - if(batch_count == invalid_int) - batch_count = BATCH_COUNT; - - if(bad_argument(trans, m, n, ld, stride, batch_count)) - { - show_usage(argv); - return EXIT_FAILURE; - } + // when arguments not specified, set to default values + if(m == invalid_int64) + m = DIM1; + if(n == invalid_int64) + n = DIM2; + if(ld == invalid_int64) + ld = trans == HIPSPARSE_OPERATION_NON_TRANSPOSE ? m : n; + if(stride == invalid_int64) + stride = trans == HIPSPARSE_OPERATION_NON_TRANSPOSE ? ld * n : ld * m; + if(batch_count == invalid_int) + batch_count = BATCH_COUNT; + + if(bad_argument(trans, m, n, ld, stride, batch_count)) + { + show_usage(argv); + return EXIT_FAILURE; + } - if(header) - { - std::cout << "type,trans,M,N,K,ld,stride,batch_count," - "result,error"; - std::cout << std::endl; - } + if(header) + { + std::cout << "type,trans,M,N,K,ld,stride,batch_count," + "result,error"; + std::cout << std::endl; + } - switch(type) + switch(type) + { + case HIP_R_16F: + std::cout << "H_"; + run<__half>(m, n, ld, stride, batch_count, trans, type, sparse_b, verbose); + break; + case HIP_R_16BF: + std::cout << "BF16_"; + run(m, n, ld, stride, batch_count, trans, type, sparse_b, verbose); + break; + case HIP_R_8I: + std::cout << "I8_"; + run(m, n, ld, stride, batch_count, trans, type, sparse_b, verbose); + break; + default: + break; + } + + return EXIT_SUCCESS; + } + catch(int i) { - case HIP_R_16F: - std::cout << "H_"; - run<__half>(m, n, ld, stride, batch_count, trans, type, sparse_b, verbose); - break; - case HIP_R_16BF: - std::cout << "BF16_"; - run(m, n, ld, stride, batch_count, trans, type, sparse_b, verbose); - break; - case HIP_R_8I: - std::cout << "I8_"; - run(m, n, ld, stride, batch_count, trans, type, sparse_b, verbose); - break; - default: - break; + return EXIT_FAILURE; } - - return EXIT_SUCCESS; } diff --git a/clients/samples/example_prune_strip.cpp b/clients/samples/example_prune_strip.cpp index 8bb06c72..8e6ea856 100644 --- a/clients/samples/example_prune_strip.cpp +++ b/clients/samples/example_prune_strip.cpp @@ -44,7 +44,7 @@ error, \ __FILE__, \ __LINE__); \ - exit(EXIT_FAILURE); \ + throw EXIT_FAILURE; \ } #endif @@ -64,10 +64,52 @@ if(error == HIPSPARSE_STATUS_ARCH_MISMATCH) \ fprintf(stderr, "HIPSPARSE_STATUS_ARCH_MISMATCH"); \ fprintf(stderr, "\n"); \ - exit(EXIT_FAILURE); \ + throw EXIT_FAILURE; \ } #endif +template +class SMART_DESTROYER +{ + typedef To (*destroy_func)(T*); + +public: + SMART_DESTROYER(T* ptr, destroy_func func) + { + _ptr = ptr; + _func = func; + } + ~SMART_DESTROYER() + { + if(_ptr != nullptr) + _func(_ptr); + } + + T* _ptr = nullptr; + destroy_func _func; +}; + +template +class SMART_DESTROYER_NON_PTR +{ + typedef To (*destroy_func)(T); + +public: + SMART_DESTROYER_NON_PTR(T* ptr, destroy_func func) + { + _ptr = ptr; + _func = func; + } + ~SMART_DESTROYER_NON_PTR() + { + if(_ptr != nullptr) + _func(*_ptr); + } + + T* _ptr = nullptr; + destroy_func _func; +}; + // default sizes #define DIM1 127 #define DIM2 128 @@ -466,9 +508,13 @@ void run(int64_t m, int num_streams = 0; hipStream_t stream = nullptr; hipStreamCreate(&stream); + SMART_DESTROYER_NON_PTR sS(&stream, hipStreamDestroy); CHECK_HIP_ERROR(hipMalloc(&d, size * sizeof(T))); + SMART_DESTROYER sD(d, hipFree); + CHECK_HIP_ERROR(hipMalloc(&d_test, size * sizeof(T))); + SMART_DESTROYER sDt(d_test, hipFree); // copy matrices from host to device CHECK_HIP_ERROR(hipMemcpy(d, hp.data(), sizeof(T) * size, hipMemcpyHostToDevice)); @@ -477,6 +523,7 @@ void run(int64_t m, hipsparseLtMatmulDescriptor_t matmul; CHECK_HIPSPARSELT_ERROR(hipsparseLtInit(&handle)); + SMART_DESTROYER sH(&handle, hipsparseLtDestroy); if(!sparse_b) { @@ -490,18 +537,22 @@ void run(int64_t m, type, HIPSPARSE_ORDER_COL, HIPSPARSELT_SPARSITY_50_PERCENT)); - CHECK_HIPSPARSELT_ERROR( - hipsparseLtDenseDescriptorInit(&handle, &matB, n, m, n, 16, type, HIPSPARSE_ORDER_COL)); - CHECK_HIPSPARSELT_ERROR( - hipsparseLtDenseDescriptorInit(&handle, &matC, m, m, m, 16, type, HIPSPARSE_ORDER_COL)); - CHECK_HIPSPARSELT_ERROR( - hipsparseLtDenseDescriptorInit(&handle, &matD, m, m, m, 16, type, HIPSPARSE_ORDER_COL)); } else { CHECK_HIPSPARSELT_ERROR( hipsparseLtDenseDescriptorInit(&handle, &matA, n, m, n, 16, type, HIPSPARSE_ORDER_COL)); + } + SMART_DESTROYER smA( + &matC, hipsparseLtMatDescriptorDestroy); + if(!sparse_b) + { + CHECK_HIPSPARSELT_ERROR( + hipsparseLtDenseDescriptorInit(&handle, &matB, n, m, n, 16, type, HIPSPARSE_ORDER_COL)); + } + else + { CHECK_HIPSPARSELT_ERROR( hipsparseLtStructuredDescriptorInit(&handle, &matB, @@ -512,11 +563,21 @@ void run(int64_t m, type, HIPSPARSE_ORDER_COL, HIPSPARSELT_SPARSITY_50_PERCENT)); - CHECK_HIPSPARSELT_ERROR( - hipsparseLtDenseDescriptorInit(&handle, &matC, n, n, n, 16, type, HIPSPARSE_ORDER_COL)); - CHECK_HIPSPARSELT_ERROR( - hipsparseLtDenseDescriptorInit(&handle, &matD, n, n, n, 16, type, HIPSPARSE_ORDER_COL)); } + SMART_DESTROYER smB( + &matC, hipsparseLtMatDescriptorDestroy); + + auto tmp_m = (!sparse_b) ? m : n; + + CHECK_HIPSPARSELT_ERROR(hipsparseLtDenseDescriptorInit( + &handle, &matC, tmp_m, tmp_m, tmp_m, 16, type, HIPSPARSE_ORDER_COL)); + SMART_DESTROYER smC( + &matC, hipsparseLtMatDescriptorDestroy); + + CHECK_HIPSPARSELT_ERROR(hipsparseLtDenseDescriptorInit( + &handle, &matD, tmp_m, tmp_m, tmp_m, 16, type, HIPSPARSE_ORDER_COL)); + SMART_DESTROYER smD( + &matC, hipsparseLtMatDescriptorDestroy); CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescSetAttribute( &handle, &matA, HIPSPARSELT_MAT_NUM_BATCHES, &batch_count, sizeof(batch_count))); @@ -583,75 +644,83 @@ void run(int64_t m, CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescriptorDestroy(&matD)); CHECK_HIPSPARSELT_ERROR(hipsparseLtDestroy(&handle)); } + int main(int argc, char* argv[]) { - // initialize parameters with default values - hipsparseOperation_t trans = HIPSPARSE_OPERATION_NON_TRANSPOSE; + try + { + // initialize parameters with default values + hipsparseOperation_t trans = HIPSPARSE_OPERATION_NON_TRANSPOSE; - // invalid int and float for hipsparselt spmm int and float arguments - int64_t invalid_int64 = std::numeric_limits::min() + 1; - int invalid_int = std::numeric_limits::min() + 1; - float invalid_float = std::numeric_limits::quiet_NaN(); + // invalid int and float for hipsparselt spmm int and float arguments + int64_t invalid_int64 = std::numeric_limits::min() + 1; + int invalid_int = std::numeric_limits::min() + 1; + float invalid_float = std::numeric_limits::quiet_NaN(); - // initialize to invalid value to detect if values not specified on command line - int64_t m = invalid_int64, n = invalid_int64, ld = invalid_int64, stride = invalid_int64; + // initialize to invalid value to detect if values not specified on command line + int64_t m = invalid_int64, n = invalid_int64, ld = invalid_int64, stride = invalid_int64; - int batch_count = invalid_int; - hipDataType type = HIP_R_16F; + int batch_count = invalid_int; + hipDataType type = HIP_R_16F; - bool verbose = false; - bool header = false; - bool sparse_b = false; + bool verbose = false; + bool header = false; + bool sparse_b = false; - if(parse_arguments( - argc, argv, m, n, ld, stride, batch_count, trans, type, header, sparse_b, verbose)) - { - show_usage(argv); - return EXIT_FAILURE; - } + if(parse_arguments( + argc, argv, m, n, ld, stride, batch_count, trans, type, header, sparse_b, verbose)) + { + show_usage(argv); + return EXIT_FAILURE; + } - // when arguments not specified, set to default values - if(m == invalid_int64) - m = DIM1; - if(n == invalid_int64) - n = DIM2; - if(ld == invalid_int64) - ld = trans == HIPSPARSE_OPERATION_NON_TRANSPOSE ? m : n; - if(stride == invalid_int64) - stride = trans == HIPSPARSE_OPERATION_NON_TRANSPOSE ? ld * n : ld * m; - if(batch_count == invalid_int) - batch_count = BATCH_COUNT; - - if(bad_argument(trans, m, n, ld, stride, batch_count)) - { - show_usage(argv); - return EXIT_FAILURE; - } + // when arguments not specified, set to default values + if(m == invalid_int64) + m = DIM1; + if(n == invalid_int64) + n = DIM2; + if(ld == invalid_int64) + ld = trans == HIPSPARSE_OPERATION_NON_TRANSPOSE ? m : n; + if(stride == invalid_int64) + stride = trans == HIPSPARSE_OPERATION_NON_TRANSPOSE ? ld * n : ld * m; + if(batch_count == invalid_int) + batch_count = BATCH_COUNT; + + if(bad_argument(trans, m, n, ld, stride, batch_count)) + { + show_usage(argv); + return EXIT_FAILURE; + } - if(header) - { - std::cout << "type,trans,M,N,ld,stride,batch_count," - "result,error"; - std::cout << std::endl; - } + if(header) + { + std::cout << "type,trans,M,N,ld,stride,batch_count," + "result,error"; + std::cout << std::endl; + } - switch(type) + switch(type) + { + case HIP_R_16F: + std::cout << "H"; + run<__half>(m, n, ld, stride, batch_count, trans, type, sparse_b, verbose); + break; + case HIP_R_16BF: + std::cout << "BF16"; + run(m, n, ld, stride, batch_count, trans, type, sparse_b, verbose); + break; + case HIP_R_8I: + std::cout << "I8"; + run(m, n, ld, stride, batch_count, trans, type, sparse_b, verbose); + break; + default: + break; + } + + return EXIT_SUCCESS; + } + catch(int i) { - case HIP_R_16F: - std::cout << "H"; - run<__half>(m, n, ld, stride, batch_count, trans, type, sparse_b, verbose); - break; - case HIP_R_16BF: - std::cout << "BF16"; - run(m, n, ld, stride, batch_count, trans, type, sparse_b, verbose); - break; - case HIP_R_8I: - std::cout << "I8"; - run(m, n, ld, stride, batch_count, trans, type, sparse_b, verbose); - break; - default: - break; + return EXIT_FAILURE; } - - return EXIT_SUCCESS; } diff --git a/clients/samples/example_spmm_strided_batched.cpp b/clients/samples/example_spmm_strided_batched.cpp index 7297e49d..54be9214 100644 --- a/clients/samples/example_spmm_strided_batched.cpp +++ b/clients/samples/example_spmm_strided_batched.cpp @@ -44,7 +44,7 @@ error, \ __FILE__, \ __LINE__); \ - exit(EXIT_FAILURE); \ + throw EXIT_FAILURE; \ } #endif @@ -64,10 +64,31 @@ if(error == HIPSPARSE_STATUS_ARCH_MISMATCH) \ fprintf(stderr, "HIPSPARSE_STATUS_ARCH_MISMATCH"); \ fprintf(stderr, "\n"); \ - exit(EXIT_FAILURE); \ + throw EXIT_FAILURE; \ } #endif +template +class SMART_DESTROYER +{ + typedef To (*destroy_func)(T*); + +public: + SMART_DESTROYER(T* ptr, destroy_func func) + { + _ptr = ptr; + _func = func; + } + ~SMART_DESTROYER() + { + if(_ptr != nullptr) + _func(_ptr); + } + + T* _ptr = nullptr; + destroy_func _func; +}; + // default sizes #define DIM1 127 #define DIM2 128 @@ -538,464 +559,490 @@ void initialize_a_b_c(std::vector<__half>& ha, int main(int argc, char* argv[]) { - // initialize parameters with default values - hipsparseOperation_t trans_a = HIPSPARSE_OPERATION_NON_TRANSPOSE; - hipsparseOperation_t trans_b = HIPSPARSE_OPERATION_TRANSPOSE; - - // invalid int and float for hipsparselt spmm int and float arguments - int64_t invalid_int64 = std::numeric_limits::min() + 1; - int invalid_int = std::numeric_limits::min() + 1; - float invalid_float = std::numeric_limits::quiet_NaN(); - - // initialize to invalid value to detect if values not specified on command line - int64_t m = invalid_int64, lda = invalid_int64, stride_a = invalid_int64; - int64_t n = invalid_int64, ldb = invalid_int64, stride_b = invalid_int64; - int64_t k = invalid_int64, ldc = invalid_int64, stride_c = invalid_int64; - int64_t ldd = invalid_int64, stride_d = invalid_int64; - - int batch_count = invalid_int; - - float alpha = invalid_float; - float beta = invalid_float; - - bool sparse_b = false; - bool verbose = false; - bool header = false; - - if(parse_arguments(argc, - argv, - m, - n, - k, - lda, - ldb, - ldc, - ldd, - stride_a, - stride_b, - stride_c, - stride_d, - batch_count, - alpha, - beta, - trans_a, - trans_b, - sparse_b, - header, - verbose)) + try { - show_usage(argv); - return EXIT_FAILURE; - } + // initialize parameters with default values + hipsparseOperation_t trans_a = HIPSPARSE_OPERATION_NON_TRANSPOSE; + hipsparseOperation_t trans_b = HIPSPARSE_OPERATION_TRANSPOSE; - // when arguments not specified, set to default values - if(m == invalid_int64) - m = DIM1; - if(n == invalid_int64) - n = DIM2; - if(k == invalid_int64) - k = DIM3; - if(lda == invalid_int64) - lda = trans_a == HIPSPARSE_OPERATION_NON_TRANSPOSE ? m : k; - if(ldb == invalid_int64) - ldb = trans_b == HIPSPARSE_OPERATION_NON_TRANSPOSE ? k : n; - if(ldc == invalid_int64) - ldc = m; - if(ldd == invalid_int64) - ldd = m; - if(stride_a == invalid_int64) - stride_a = trans_a == HIPSPARSE_OPERATION_NON_TRANSPOSE ? lda * k : lda * m; - if(stride_b == invalid_int64) - stride_b = trans_b == HIPSPARSE_OPERATION_NON_TRANSPOSE ? ldb * n : ldb * k; - if(stride_c == invalid_int64) - stride_c = ldc * n; - if(stride_d == invalid_int64) - stride_d = ldd * n; - if(alpha != alpha) - alpha = ALPHA; // check for alpha == invalid_float == NaN - if(beta != beta) - beta = BETA; // check for beta == invalid_float == NaN - if(batch_count == invalid_int) - batch_count = BATCH_COUNT; - - if(bad_argument(trans_a, - trans_b, - m, - n, - k, - lda, - ldb, - ldc, - ldd, - stride_a, - stride_b, - stride_c, - stride_d, - batch_count)) - { - show_usage(argv); - return EXIT_FAILURE; - } + // invalid int and float for hipsparselt spmm int and float arguments + int64_t invalid_int64 = std::numeric_limits::min() + 1; + int invalid_int = std::numeric_limits::min() + 1; + float invalid_float = std::numeric_limits::quiet_NaN(); - if(header) - { - std::cout << "transAB,M,N,K,lda,ldb,ldc,stride_a,stride_b,stride_c,batch_count,alpha,beta," - "result,error"; - std::cout << std::endl; - } + // initialize to invalid value to detect if values not specified on command line + int64_t m = invalid_int64, lda = invalid_int64, stride_a = invalid_int64; + int64_t n = invalid_int64, ldb = invalid_int64, stride_b = invalid_int64; + int64_t k = invalid_int64, ldc = invalid_int64, stride_c = invalid_int64; + int64_t ldd = invalid_int64, stride_d = invalid_int64; - int64_t a_stride_1, a_stride_2, b_stride_1, b_stride_2; - int64_t row_a, col_a, row_b, col_b, row_c, col_c; - int size_a1, size_b1, size_c1 = ldc * n, size_d1 = ldd * n; - if(trans_a == HIPSPARSE_OPERATION_NON_TRANSPOSE) - { - std::cout << "N"; - row_a = m; - col_a = k; - a_stride_1 = 1; - a_stride_2 = lda; - size_a1 = lda * k; - } - else - { - std::cout << "T"; - row_a = k; - col_a = m; - a_stride_1 = lda; - a_stride_2 = 1; - size_a1 = lda * m; - } - if(trans_b == HIPSPARSE_OPERATION_NON_TRANSPOSE) - { - std::cout << "N, "; - row_b = k; - col_b = n; - b_stride_1 = 1; - b_stride_2 = ldb; - size_b1 = ldb * n; - } - else - { - std::cout << "T, "; - row_b = n; - col_b = k; - b_stride_1 = ldb; - b_stride_2 = 1; - size_b1 = ldb * k; - } - row_c = m; - col_c = n; - - std::cout << m << ", " << n << ", " << k << ", " << lda << ", " << ldb << ", " << ldc << ", " - << ldd << ", " << stride_a << ", " << stride_b << ", " << stride_c << ", " << stride_d - << ", " << batch_count << ", " << alpha << ", " << beta << ", "; - int64_t stride_a_r = stride_a == 0 ? size_a1 : stride_a; - int64_t stride_b_r = stride_b == 0 ? size_b1 : stride_b; - int64_t stride_c_r = stride_c == 0 ? size_c1 : stride_c; - int64_t stride_d_r = stride_d == 0 ? size_d1 : stride_d; - - int64_t size_a = stride_a_r * (stride_a == 0 ? 1 : batch_count); - int64_t size_b = stride_b_r * (stride_b == 0 ? 1 : batch_count); - int64_t size_c = stride_c_r * (stride_c == 0 ? 1 : batch_count); - int64_t size_d = stride_d_r * (stride_d == 0 ? 1 : batch_count); - // Naming: da is in GPU (device) memory. ha is in CPU (host) memory - std::vector<__half> ha(size_a); - std::vector<__half> h_prune(sparse_b ? size_b : size_a); - std::vector<__half> hb(size_b); - std::vector<__half> hc(size_c); - std::vector<__half> hd(size_d); - std::vector<__half> hd_gold(size_d); - - // initial data on host - initialize_a_b_c(ha, size_a, hb, size_b, hc, size_c); - - if(verbose) - { - printf("\n"); + int batch_count = invalid_int; + + float alpha = invalid_float; + float beta = invalid_float; + + bool sparse_b = false; + bool verbose = false; + bool header = false; + + if(parse_arguments(argc, + argv, + m, + n, + k, + lda, + ldb, + ldc, + ldd, + stride_a, + stride_b, + stride_c, + stride_d, + batch_count, + alpha, + beta, + trans_a, + trans_b, + sparse_b, + header, + verbose)) + { + show_usage(argv); + return EXIT_FAILURE; + } + + // when arguments not specified, set to default values + if(m == invalid_int64) + m = DIM1; + if(n == invalid_int64) + n = DIM2; + if(k == invalid_int64) + k = DIM3; + if(lda == invalid_int64) + lda = trans_a == HIPSPARSE_OPERATION_NON_TRANSPOSE ? m : k; + if(ldb == invalid_int64) + ldb = trans_b == HIPSPARSE_OPERATION_NON_TRANSPOSE ? k : n; + if(ldc == invalid_int64) + ldc = m; + if(ldd == invalid_int64) + ldd = m; + if(stride_a == invalid_int64) + stride_a = trans_a == HIPSPARSE_OPERATION_NON_TRANSPOSE ? lda * k : lda * m; + if(stride_b == invalid_int64) + stride_b = trans_b == HIPSPARSE_OPERATION_NON_TRANSPOSE ? ldb * n : ldb * k; + if(stride_c == invalid_int64) + stride_c = ldc * n; + if(stride_d == invalid_int64) + stride_d = ldd * n; + if(alpha != alpha) + alpha = ALPHA; // check for alpha == invalid_float == NaN + if(beta != beta) + beta = BETA; // check for beta == invalid_float == NaN + if(batch_count == invalid_int) + batch_count = BATCH_COUNT; + + if(bad_argument(trans_a, + trans_b, + m, + n, + k, + lda, + ldb, + ldc, + ldd, + stride_a, + stride_b, + stride_c, + stride_d, + batch_count)) + { + show_usage(argv); + return EXIT_FAILURE; + } + + if(header) + { + std::cout + << "transAB,M,N,K,lda,ldb,ldc,stride_a,stride_b,stride_c,batch_count,alpha,beta," + "result,error"; + std::cout << std::endl; + } + + int64_t a_stride_1, a_stride_2, b_stride_1, b_stride_2; + int64_t row_a, col_a, row_b, col_b, row_c, col_c; + int size_a1, size_b1, size_c1 = ldc * n, size_d1 = ldd * n; if(trans_a == HIPSPARSE_OPERATION_NON_TRANSPOSE) { - print_strided_batched("ha initial", &ha[0], m, k, batch_count, 1, lda, stride_a_r); + std::cout << "N"; + row_a = m; + col_a = k; + a_stride_1 = 1; + a_stride_2 = lda; + size_a1 = lda * k; } else { - print_strided_batched("ha initial", &ha[0], m, k, batch_count, lda, 1, stride_a_r); + std::cout << "T"; + row_a = k; + col_a = m; + a_stride_1 = lda; + a_stride_2 = 1; + size_a1 = lda * m; } if(trans_b == HIPSPARSE_OPERATION_NON_TRANSPOSE) { - print_strided_batched("hb initial", &hb[0], k, n, batch_count, 1, ldb, stride_b_r); + std::cout << "N, "; + row_b = k; + col_b = n; + b_stride_1 = 1; + b_stride_2 = ldb; + size_b1 = ldb * n; } else { - print_strided_batched("hb initial", &hb[0], k, n, batch_count, ldb, 1, stride_b_r); + std::cout << "T, "; + row_b = n; + col_b = k; + b_stride_1 = ldb; + b_stride_2 = 1; + size_b1 = ldb * k; } - print_strided_batched("hc initial", &hc[0], m, n, batch_count, 1, ldc, stride_c_r); - } + row_c = m; + col_c = n; + + std::cout << m << ", " << n << ", " << k << ", " << lda << ", " << ldb << ", " << ldc + << ", " << ldd << ", " << stride_a << ", " << stride_b << ", " << stride_c << ", " + << stride_d << ", " << batch_count << ", " << alpha << ", " << beta << ", "; + int64_t stride_a_r = stride_a == 0 ? size_a1 : stride_a; + int64_t stride_b_r = stride_b == 0 ? size_b1 : stride_b; + int64_t stride_c_r = stride_c == 0 ? size_c1 : stride_c; + int64_t stride_d_r = stride_d == 0 ? size_d1 : stride_d; + + int64_t size_a = stride_a_r * (stride_a == 0 ? 1 : batch_count); + int64_t size_b = stride_b_r * (stride_b == 0 ? 1 : batch_count); + int64_t size_c = stride_c_r * (stride_c == 0 ? 1 : batch_count); + int64_t size_d = stride_d_r * (stride_d == 0 ? 1 : batch_count); + // Naming: da is in GPU (device) memory. ha is in CPU (host) memory + std::vector<__half> ha(size_a); + std::vector<__half> h_prune(sparse_b ? size_b : size_a); + std::vector<__half> hb(size_b); + std::vector<__half> hc(size_c); + std::vector<__half> hd(size_d); + std::vector<__half> hd_gold(size_d); + + // initial data on host + initialize_a_b_c(ha, size_a, hb, size_b, hc, size_c); + + if(verbose) + { + printf("\n"); + if(trans_a == HIPSPARSE_OPERATION_NON_TRANSPOSE) + { + print_strided_batched("ha initial", &ha[0], m, k, batch_count, 1, lda, stride_a_r); + } + else + { + print_strided_batched("ha initial", &ha[0], m, k, batch_count, lda, 1, stride_a_r); + } + if(trans_b == HIPSPARSE_OPERATION_NON_TRANSPOSE) + { + print_strided_batched("hb initial", &hb[0], k, n, batch_count, 1, ldb, stride_b_r); + } + else + { + print_strided_batched("hb initial", &hb[0], k, n, batch_count, ldb, 1, stride_b_r); + } + print_strided_batched("hc initial", &hc[0], m, n, batch_count, 1, ldc, stride_c_r); + } + + // allocate memory on device + __half * da, *dp, *db, *dc, *dd, *d_compressed, *d_compressBuffer; + void* d_workspace = nullptr; + int num_streams = 1; + hipStream_t stream = nullptr; + hipStream_t streams[1] = {stream}; + + CHECK_HIP_ERROR(hipMalloc(&da, size_a * sizeof(__half))); + SMART_DESTROYER sDA(da, hipFree); + CHECK_HIP_ERROR(hipMalloc(&dp, (sparse_b ? size_b : size_a) * sizeof(__half))); + SMART_DESTROYER sDP(dp, hipFree); + CHECK_HIP_ERROR(hipMalloc(&db, size_b * sizeof(__half))); + SMART_DESTROYER sDB(db, hipFree); + CHECK_HIP_ERROR(hipMalloc(&dc, size_c * sizeof(__half))); + SMART_DESTROYER sDC(dc, hipFree); + CHECK_HIP_ERROR(hipMalloc(&dd, size_d * sizeof(__half))); + SMART_DESTROYER sDD(dd, hipFree); + // copy matrices from host to device + CHECK_HIP_ERROR(hipMemcpy(da, ha.data(), sizeof(__half) * size_a, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(db, hb.data(), sizeof(__half) * size_b, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dc, hc.data(), sizeof(__half) * size_c, hipMemcpyHostToDevice)); + + hipsparseLtHandle_t handle; + hipsparseLtMatDescriptor_t matA, matB, matC, matD; + hipsparseLtMatmulDescriptor_t matmul; + hipsparseLtMatmulAlgSelection_t alg_sel; + hipsparseLtMatmulPlan_t plan; + + CHECK_HIPSPARSELT_ERROR(hipsparseLtInit(&handle)); + SMART_DESTROYER sH(&handle, + hipsparseLtDestroy); + + if(!sparse_b) + { + CHECK_HIPSPARSELT_ERROR( + hipsparseLtStructuredDescriptorInit(&handle, + &matA, + row_a, + col_a, + lda, + 16, + HIP_R_16F, + HIPSPARSE_ORDER_COL, + HIPSPARSELT_SPARSITY_50_PERCENT)); + } + else + { + CHECK_HIPSPARSELT_ERROR(hipsparseLtDenseDescriptorInit( + &handle, &matA, row_a, col_a, lda, 16, HIP_R_16F, HIPSPARSE_ORDER_COL)); + } + SMART_DESTROYER smA( + &matA, hipsparseLtMatDescriptorDestroy); + + if(!sparse_b) + { + CHECK_HIPSPARSELT_ERROR(hipsparseLtDenseDescriptorInit( + &handle, &matB, row_b, col_b, ldb, 16, HIP_R_16F, HIPSPARSE_ORDER_COL)); + } + else + { + CHECK_HIPSPARSELT_ERROR( + hipsparseLtStructuredDescriptorInit(&handle, + &matB, + row_b, + col_b, + ldb, + 16, + HIP_R_16F, + HIPSPARSE_ORDER_COL, + HIPSPARSELT_SPARSITY_50_PERCENT)); + } + SMART_DESTROYER smB( + &matB, hipsparseLtMatDescriptorDestroy); - // allocate memory on device - __half * da, *dp, *db, *dc, *dd, *d_compressed, *d_compressBuffer; - void* d_workspace; - int num_streams = 1; - hipStream_t stream = nullptr; - hipStream_t streams[1] = {stream}; - - CHECK_HIP_ERROR(hipMalloc(&da, size_a * sizeof(__half))); - CHECK_HIP_ERROR(hipMalloc(&dp, (sparse_b ? size_b : size_a) * sizeof(__half))); - CHECK_HIP_ERROR(hipMalloc(&db, size_b * sizeof(__half))); - CHECK_HIP_ERROR(hipMalloc(&dc, size_c * sizeof(__half))); - CHECK_HIP_ERROR(hipMalloc(&dd, size_d * sizeof(__half))); - // copy matrices from host to device - CHECK_HIP_ERROR(hipMemcpy(da, ha.data(), sizeof(__half) * size_a, hipMemcpyHostToDevice)); - CHECK_HIP_ERROR(hipMemcpy(db, hb.data(), sizeof(__half) * size_b, hipMemcpyHostToDevice)); - CHECK_HIP_ERROR(hipMemcpy(dc, hc.data(), sizeof(__half) * size_c, hipMemcpyHostToDevice)); - - hipsparseLtHandle_t handle; - hipsparseLtMatDescriptor_t matA, matB, matC, matD; - hipsparseLtMatmulDescriptor_t matmul; - hipsparseLtMatmulAlgSelection_t alg_sel; - hipsparseLtMatmulPlan_t plan; - - CHECK_HIPSPARSELT_ERROR(hipsparseLtInit(&handle)); - - if(!sparse_b) - { - CHECK_HIPSPARSELT_ERROR( - hipsparseLtStructuredDescriptorInit(&handle, - &matA, - row_a, - col_a, - lda, - 16, - HIP_R_16F, - HIPSPARSE_ORDER_COL, - HIPSPARSELT_SPARSITY_50_PERCENT)); CHECK_HIPSPARSELT_ERROR(hipsparseLtDenseDescriptorInit( - &handle, &matB, row_b, col_b, ldb, 16, HIP_R_16F, HIPSPARSE_ORDER_COL)); - } - else - { + &handle, &matC, row_c, col_c, ldc, 16, HIP_R_16F, HIPSPARSE_ORDER_COL)); + SMART_DESTROYER smC( + &matC, hipsparseLtMatDescriptorDestroy); + CHECK_HIPSPARSELT_ERROR(hipsparseLtDenseDescriptorInit( - &handle, &matA, row_a, col_a, lda, 16, HIP_R_16F, HIPSPARSE_ORDER_COL)); - CHECK_HIPSPARSELT_ERROR( - hipsparseLtStructuredDescriptorInit(&handle, - &matB, - row_b, - col_b, - ldb, - 16, - HIP_R_16F, - HIPSPARSE_ORDER_COL, - HIPSPARSELT_SPARSITY_50_PERCENT)); - } + &handle, &matD, row_c, col_c, ldd, 16, HIP_R_16F, HIPSPARSE_ORDER_COL)); + SMART_DESTROYER smD( + &matD, hipsparseLtMatDescriptorDestroy); - CHECK_HIPSPARSELT_ERROR(hipsparseLtDenseDescriptorInit( - &handle, &matC, row_c, col_c, ldc, 16, HIP_R_16F, HIPSPARSE_ORDER_COL)); - CHECK_HIPSPARSELT_ERROR(hipsparseLtDenseDescriptorInit( - &handle, &matD, row_c, col_c, ldd, 16, HIP_R_16F, HIPSPARSE_ORDER_COL)); - - CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescSetAttribute( - &handle, &matA, HIPSPARSELT_MAT_NUM_BATCHES, &batch_count, sizeof(batch_count))); - CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescSetAttribute( - &handle, &matA, HIPSPARSELT_MAT_BATCH_STRIDE, &stride_a, sizeof(stride_a))); - CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescSetAttribute( - &handle, &matB, HIPSPARSELT_MAT_NUM_BATCHES, &batch_count, sizeof(batch_count))); - CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescSetAttribute( - &handle, &matB, HIPSPARSELT_MAT_BATCH_STRIDE, &stride_b, sizeof(stride_b))); - CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescSetAttribute( - &handle, &matC, HIPSPARSELT_MAT_NUM_BATCHES, &batch_count, sizeof(batch_count))); - CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescSetAttribute( - &handle, &matC, HIPSPARSELT_MAT_BATCH_STRIDE, &stride_c, sizeof(stride_c))); - CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescSetAttribute( - &handle, &matD, HIPSPARSELT_MAT_NUM_BATCHES, &batch_count, sizeof(batch_count))); - CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescSetAttribute( - &handle, &matD, HIPSPARSELT_MAT_BATCH_STRIDE, &stride_d, sizeof(stride_d))); - - auto compute_type = + CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescSetAttribute( + &handle, &matA, HIPSPARSELT_MAT_NUM_BATCHES, &batch_count, sizeof(batch_count))); + CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescSetAttribute( + &handle, &matA, HIPSPARSELT_MAT_BATCH_STRIDE, &stride_a, sizeof(stride_a))); + CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescSetAttribute( + &handle, &matB, HIPSPARSELT_MAT_NUM_BATCHES, &batch_count, sizeof(batch_count))); + CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescSetAttribute( + &handle, &matB, HIPSPARSELT_MAT_BATCH_STRIDE, &stride_b, sizeof(stride_b))); + CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescSetAttribute( + &handle, &matC, HIPSPARSELT_MAT_NUM_BATCHES, &batch_count, sizeof(batch_count))); + CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescSetAttribute( + &handle, &matC, HIPSPARSELT_MAT_BATCH_STRIDE, &stride_c, sizeof(stride_c))); + CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescSetAttribute( + &handle, &matD, HIPSPARSELT_MAT_NUM_BATCHES, &batch_count, sizeof(batch_count))); + CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescSetAttribute( + &handle, &matD, HIPSPARSELT_MAT_BATCH_STRIDE, &stride_d, sizeof(stride_d))); + + auto compute_type = #ifdef __HIP_PLATFORM_AMD__ - HIPSPARSELT_COMPUTE_32F; + HIPSPARSELT_COMPUTE_32F; #else - HIPSPARSELT_COMPUTE_16F; + HIPSPARSELT_COMPUTE_16F; #endif - CHECK_HIPSPARSELT_ERROR(hipsparseLtMatmulDescriptorInit( - &handle, &matmul, trans_a, trans_b, &matA, &matB, &matC, &matD, compute_type)); - - CHECK_HIPSPARSELT_ERROR(hipsparseLtMatmulAlgSelectionInit( - &handle, &alg_sel, &matmul, HIPSPARSELT_MATMUL_ALG_DEFAULT)); - - CHECK_HIPSPARSELT_ERROR(hipsparseLtSpMMAPrune( - &handle, &matmul, sparse_b ? db : da, dp, HIPSPARSELT_PRUNE_SPMMA_STRIP, stream)); - - size_t workspace_size, compressed_size, compress_buffer_size; - CHECK_HIPSPARSELT_ERROR(hipsparseLtMatmulPlanInit(&handle, &plan, &matmul, &alg_sel)); - - CHECK_HIPSPARSELT_ERROR(hipsparseLtMatmulGetWorkspace(&handle, &plan, &workspace_size)); - - CHECK_HIPSPARSELT_ERROR( - hipsparseLtSpMMACompressedSize(&handle, &plan, &compressed_size, &compress_buffer_size)); - - CHECK_HIP_ERROR(hipMalloc(&d_compressed, compressed_size)); - CHECK_HIP_ERROR(hipMalloc(&d_compressBuffer, compress_buffer_size)); - - CHECK_HIPSPARSELT_ERROR( - hipsparseLtSpMMACompress(&handle, &plan, dp, d_compressed, d_compressBuffer, stream)); - if(workspace_size > 0) - CHECK_HIP_ERROR(hipMalloc(&d_workspace, workspace_size)); - - CHECK_HIPSPARSELT_ERROR(hipsparseLtMatmul(&handle, - &plan, - &alpha, - sparse_b ? da : d_compressed, - sparse_b ? d_compressed : db, - &beta, - dc, - dd, - d_workspace, - &streams[0], - num_streams)); - hipStreamSynchronize(stream); - // copy output from device to CPU - CHECK_HIP_ERROR(hipMemcpy(hd.data(), dd, sizeof(__half) * size_c, hipMemcpyDeviceToHost)); - CHECK_HIP_ERROR(hipMemcpy( - h_prune.data(), dp, sizeof(__half) * (sparse_b ? size_b : size_a), hipMemcpyDeviceToHost)); - // calculate golden or correct result - for(int i = 0; i < batch_count; i++) - { - __half* a_ptr = sparse_b ? &ha[i * stride_a] : &h_prune[i * stride_a]; - __half* b_ptr = sparse_b ? &h_prune[i * stride_b] : &hb[i * stride_b]; - __half* c_ptr = &hc[i * stride_c]; - __half* d_ptr = &hd_gold[i * stride_d]; - mat_mat_mult<__half, __half, float>(alpha, - beta, - m, - n, - k, - a_ptr, - a_stride_1, - a_stride_2, - b_ptr, - b_stride_1, - b_stride_2, - c_ptr, - 1, - ldc, - d_ptr, - 1, - ldd); - } - if(verbose) - { - std::vector<__half> h_compressed(compressed_size); - CHECK_HIP_ERROR( - hipMemcpy(&h_compressed[0], d_compressed, compressed_size, hipMemcpyDeviceToHost)); + CHECK_HIPSPARSELT_ERROR(hipsparseLtMatmulDescriptorInit( + &handle, &matmul, trans_a, trans_b, &matA, &matB, &matC, &matD, compute_type)); - auto batch_count_c = ((sparse_b ? stride_b : stride_a) == 0) ? 1 : batch_count; + CHECK_HIPSPARSELT_ERROR(hipsparseLtMatmulAlgSelectionInit( + &handle, &alg_sel, &matmul, HIPSPARSELT_MATMUL_ALG_DEFAULT)); - int64_t c_stride_1, c_stride_2, c_stride_b, c_stride_b_r; - int64_t m_stride_1, m_stride_2, m_stride_b, m_stride_b_r; - if(!sparse_b) + CHECK_HIPSPARSELT_ERROR(hipsparseLtSpMMAPrune( + &handle, &matmul, sparse_b ? db : da, dp, HIPSPARSELT_PRUNE_SPMMA_STRIP, stream)); + + size_t workspace_size, compressed_size, compress_buffer_size; + CHECK_HIPSPARSELT_ERROR(hipsparseLtMatmulPlanInit(&handle, &plan, &matmul, &alg_sel)); + SMART_DESTROYER sP( + &plan, hipsparseLtMatmulPlanDestroy); + + CHECK_HIPSPARSELT_ERROR(hipsparseLtMatmulGetWorkspace(&handle, &plan, &workspace_size)); + + CHECK_HIPSPARSELT_ERROR(hipsparseLtSpMMACompressedSize( + &handle, &plan, &compressed_size, &compress_buffer_size)); + + CHECK_HIP_ERROR(hipMalloc(&d_compressed, compressed_size)); + SMART_DESTROYER sd_compressed(d_compressed, hipFree); + + CHECK_HIP_ERROR(hipMalloc(&d_compressBuffer, compress_buffer_size)); + SMART_DESTROYER sd_compressBuffer(d_compressBuffer, hipFree); + + CHECK_HIPSPARSELT_ERROR( + hipsparseLtSpMMACompress(&handle, &plan, dp, d_compressed, d_compressBuffer, stream)); + if(workspace_size > 0) { - c_stride_1 = (trans_a == HIPSPARSE_OPERATION_NON_TRANSPOSE) ? 1 : k / 2; - c_stride_2 = (trans_a == HIPSPARSE_OPERATION_NON_TRANSPOSE) ? m : 1; - c_stride_b_r = (trans_a == HIPSPARSE_OPERATION_NON_TRANSPOSE) ? k / 2 * c_stride_2 - : m * c_stride_1; - - m_stride_1 = k / 8; - m_stride_2 = 1; - m_stride_b_r = m * m_stride_1; + CHECK_HIP_ERROR(hipMalloc(&d_workspace, workspace_size)); } - else + SMART_DESTROYER sd_workspace(d_workspace, hipFree); + + CHECK_HIPSPARSELT_ERROR(hipsparseLtMatmul(&handle, + &plan, + &alpha, + sparse_b ? da : d_compressed, + sparse_b ? d_compressed : db, + &beta, + dc, + dd, + d_workspace, + &streams[0], + num_streams)); + hipStreamSynchronize(stream); + // copy output from device to CPU + CHECK_HIP_ERROR(hipMemcpy(hd.data(), dd, sizeof(__half) * size_c, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(h_prune.data(), + dp, + sizeof(__half) * (sparse_b ? size_b : size_a), + hipMemcpyDeviceToHost)); + // calculate golden or correct result + for(int i = 0; i < batch_count; i++) { - c_stride_1 = (trans_b == HIPSPARSE_OPERATION_NON_TRANSPOSE) ? 1 : n; - c_stride_2 = (trans_b == HIPSPARSE_OPERATION_NON_TRANSPOSE) ? k / 2 : 1; - c_stride_b_r = (trans_b == HIPSPARSE_OPERATION_NON_TRANSPOSE) ? n * c_stride_2 - : k / 2 * c_stride_1; - - m_stride_1 = 1; - m_stride_2 = k / 8; - m_stride_b_r = n * m_stride_2; + __half* a_ptr = sparse_b ? &ha[i * stride_a] : &h_prune[i * stride_a]; + __half* b_ptr = sparse_b ? &h_prune[i * stride_b] : &hb[i * stride_b]; + __half* c_ptr = &hc[i * stride_c]; + __half* d_ptr = &hd_gold[i * stride_d]; + mat_mat_mult<__half, __half, float>(alpha, + beta, + m, + n, + k, + a_ptr, + a_stride_1, + a_stride_2, + b_ptr, + b_stride_1, + b_stride_2, + c_ptr, + 1, + ldc, + d_ptr, + 1, + ldd); } + if(verbose) + { + std::vector<__half> h_compressed(compressed_size); + CHECK_HIP_ERROR( + hipMemcpy(&h_compressed[0], d_compressed, compressed_size, hipMemcpyDeviceToHost)); - c_stride_b = (sparse_b ? stride_b : stride_a) == 0 ? 0 : c_stride_b_r; - m_stride_b = (sparse_b ? stride_b : stride_a) == 0 ? 0 : m_stride_b_r; + auto batch_count_c = ((sparse_b ? stride_b : stride_a) == 0) ? 1 : batch_count; - if(!sparse_b) + int64_t c_stride_1, c_stride_2, c_stride_b, c_stride_b_r; + int64_t m_stride_1, m_stride_2, m_stride_b, m_stride_b_r; + if(!sparse_b) + { + c_stride_1 = (trans_a == HIPSPARSE_OPERATION_NON_TRANSPOSE) ? 1 : k / 2; + c_stride_2 = (trans_a == HIPSPARSE_OPERATION_NON_TRANSPOSE) ? m : 1; + c_stride_b_r = (trans_a == HIPSPARSE_OPERATION_NON_TRANSPOSE) ? k / 2 * c_stride_2 + : m * c_stride_1; + + m_stride_1 = k / 8; + m_stride_2 = 1; + m_stride_b_r = m * m_stride_1; + } + else + { + c_stride_1 = (trans_b == HIPSPARSE_OPERATION_NON_TRANSPOSE) ? 1 : n; + c_stride_2 = (trans_b == HIPSPARSE_OPERATION_NON_TRANSPOSE) ? k / 2 : 1; + c_stride_b_r = (trans_b == HIPSPARSE_OPERATION_NON_TRANSPOSE) ? n * c_stride_2 + : k / 2 * c_stride_1; + + m_stride_1 = 1; + m_stride_2 = k / 8; + m_stride_b_r = n * m_stride_2; + } + + c_stride_b = (sparse_b ? stride_b : stride_a) == 0 ? 0 : c_stride_b_r; + m_stride_b = (sparse_b ? stride_b : stride_a) == 0 ? 0 : m_stride_b_r; + + if(!sparse_b) + { + print_strided_batched("device compress calculated", + &h_compressed[0], + m, + k / 2, + batch_count_c, + c_stride_1, + c_stride_2, + c_stride_b_r); + print_strided_batched_meta( + "device metadata calculated", + reinterpret_cast(&h_compressed[c_stride_b_r * batch_count_c]), + m, + k / 8, + batch_count_c, + m_stride_1, + m_stride_2, + m_stride_b_r); + } + else + { + print_strided_batched("device compress calculated", + &h_compressed[0], + k / 2, + n, + batch_count_c, + c_stride_1, + c_stride_2, + c_stride_b_r); + print_strided_batched_meta( + "device metadata calculated", + reinterpret_cast(&h_compressed[c_stride_b_r * batch_count_c]), + k / 8, + n, + batch_count_c, + m_stride_1, + m_stride_2, + m_stride_b_r); + } + print_strided_batched( + "hc_gold calculated", &hd_gold[0], m, n, batch_count, 1, ldd, stride_d_r); + print_strided_batched("hd calculated", &hd[0], m, n, batch_count, 1, ldd, stride_d_r); + } + + bool passed = true; + for(int i = 0; i < size_c; i++) { - print_strided_batched("device compress calculated", - &h_compressed[0], - m, - k / 2, - batch_count_c, - c_stride_1, - c_stride_2, - c_stride_b_r); - print_strided_batched_meta( - "device metadata calculated", - reinterpret_cast(&h_compressed[c_stride_b_r * batch_count_c]), - m, - k / 8, - batch_count_c, - m_stride_1, - m_stride_2, - m_stride_b_r); + if(!AlmostEqual(hd_gold[i], hd[i])) + { + printf( + "Err: %f vs %f\n", static_cast(hd_gold[i]), static_cast(hd[i])); + passed = false; + } } - else + if(!passed) { - print_strided_batched("device compress calculated", - &h_compressed[0], - k / 2, - n, - batch_count_c, - c_stride_1, - c_stride_2, - c_stride_b_r); - print_strided_batched_meta( - "device metadata calculated", - reinterpret_cast(&h_compressed[c_stride_b_r * batch_count_c]), - k / 8, - n, - batch_count_c, - m_stride_1, - m_stride_2, - m_stride_b_r); + std::cout << "FAIL" << std::endl; } - print_strided_batched( - "hc_gold calculated", &hd_gold[0], m, n, batch_count, 1, ldd, stride_d_r); - print_strided_batched("hd calculated", &hd[0], m, n, batch_count, 1, ldd, stride_d_r); - } - - bool passed = true; - for(int i = 0; i < size_c; i++) - { - if(!AlmostEqual(hd_gold[i], hd[i])) + else { - printf("Err: %f vs %f\n", static_cast(hd_gold[i]), static_cast(hd[i])); - passed = false; + std::cout << "PASS" << std::endl; } + + return EXIT_SUCCESS; } - if(!passed) - { - std::cout << "FAIL" << std::endl; - } - else + catch(int) { - std::cout << "PASS" << std::endl; + return EXIT_FAILURE; } - - CHECK_HIP_ERROR(hipFree(da)); - CHECK_HIP_ERROR(hipFree(dp)); - CHECK_HIP_ERROR(hipFree(db)); - CHECK_HIP_ERROR(hipFree(dc)); - CHECK_HIP_ERROR(hipFree(dd)); - CHECK_HIP_ERROR(hipFree(d_compressed)); - CHECK_HIP_ERROR(hipFree(d_compressBuffer)); - if(workspace_size > 0) - CHECK_HIP_ERROR(hipFree(d_workspace)); - CHECK_HIPSPARSELT_ERROR(hipsparseLtMatmulPlanDestroy(&plan)); - CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescriptorDestroy(&matA)); - CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescriptorDestroy(&matB)); - CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescriptorDestroy(&matC)); - CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescriptorDestroy(&matD)); - CHECK_HIPSPARSELT_ERROR(hipsparseLtDestroy(&handle)); - - return EXIT_SUCCESS; }