Skip to content

Commit

Permalink
Auto free allocated memory and objects in the sample code.
Browse files Browse the repository at this point in the history
  • Loading branch information
vin-huang committed Oct 1, 2024
1 parent ebe1d14 commit 56b4c00
Show file tree
Hide file tree
Showing 3 changed files with 753 additions and 575 deletions.
228 changes: 145 additions & 83 deletions clients/samples/example_compress.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
error, \
__FILE__, \
__LINE__); \
exit(EXIT_FAILURE); \
throw EXIT_FAILURE; \
}
#endif

Expand All @@ -65,10 +65,52 @@
if(error == HIPSPARSE_STATUS_ARCH_MISMATCH) \
fprintf(stderr, "HIPSPARSE_STATUS_ARCH_MISMATCH"); \
fprintf(stderr, "\n"); \
exit(EXIT_FAILURE); \
throw EXIT_FAILURE; \
}
#endif

template <typename T, typename To>
class SMART_DESTROYER
{
typedef To (*destroy_func)(T*);

public:
SMART_DESTROYER(T* ptr, destroy_func func)
{
_ptr = ptr;
_func = func;
}
~SMART_DESTROYER()
{
if(_ptr != nullptr)
_func(_ptr);
}

T* _ptr = nullptr;
destroy_func _func;
};

template <typename T, typename To>
class SMART_DESTROYER_NON_PTR
{
typedef To (*destroy_func)(T);

public:
SMART_DESTROYER_NON_PTR(T* ptr, destroy_func func)
{
_ptr = ptr;
_func = func;
}
~SMART_DESTROYER_NON_PTR()
{
if(_ptr != nullptr)
_func(*_ptr);
}

T* _ptr = nullptr;
destroy_func _func;
};

inline unsigned char generate_metadata(int a, int b, int c, int d)
{
unsigned char metadata = (a)&0x03;
Expand Down Expand Up @@ -709,9 +751,13 @@ void run(int64_t m,
int num_streams = 0;
hipStream_t stream = nullptr;
hipStreamCreate(&stream);
SMART_DESTROYER_NON_PTR<hipStream_t, hipError_t> sS(&stream, hipStreamDestroy);

CHECK_HIP_ERROR(hipMalloc(&d, size * sizeof(T)));
SMART_DESTROYER<void, hipError_t> sD(d, hipFree);

CHECK_HIP_ERROR(hipMalloc(&d_test, size * sizeof(T)));
SMART_DESTROYER<void, hipError_t> sDt(d_test, hipFree);
// copy matrices from host to device
CHECK_HIP_ERROR(hipMemcpy(d, hp.data(), sizeof(T) * size, hipMemcpyHostToDevice));

Expand All @@ -722,6 +768,7 @@ void run(int64_t m,
hipsparseLtMatmulPlan_t plan;

CHECK_HIPSPARSELT_ERROR(hipsparseLtInit(&handle));
SMART_DESTROYER<const hipsparseLtHandle_t, hipsparseStatus_t> sH(&handle, hipsparseLtDestroy);

if(!sparse_b)
{
Expand All @@ -735,18 +782,22 @@ void run(int64_t m,
type,
HIPSPARSE_ORDER_COL,
HIPSPARSELT_SPARSITY_50_PERCENT));
CHECK_HIPSPARSELT_ERROR(
hipsparseLtDenseDescriptorInit(&handle, &matB, n, m, n, 16, type, HIPSPARSE_ORDER_COL));
CHECK_HIPSPARSELT_ERROR(
hipsparseLtDenseDescriptorInit(&handle, &matC, m, m, m, 16, type, HIPSPARSE_ORDER_COL));
CHECK_HIPSPARSELT_ERROR(
hipsparseLtDenseDescriptorInit(&handle, &matD, m, m, m, 16, type, HIPSPARSE_ORDER_COL));
}
else
{
CHECK_HIPSPARSELT_ERROR(
hipsparseLtDenseDescriptorInit(&handle, &matA, n, m, n, 16, type, HIPSPARSE_ORDER_COL));
}
SMART_DESTROYER<const hipsparseLtMatDescriptor_t, hipsparseStatus_t> smA(
&matA, hipsparseLtMatDescriptorDestroy);

if(!sparse_b)
{
CHECK_HIPSPARSELT_ERROR(
hipsparseLtDenseDescriptorInit(&handle, &matB, n, m, n, 16, type, HIPSPARSE_ORDER_COL));
}
else
{
CHECK_HIPSPARSELT_ERROR(
hipsparseLtStructuredDescriptorInit(&handle,
&matB,
Expand All @@ -757,11 +808,21 @@ void run(int64_t m,
type,
HIPSPARSE_ORDER_COL,
HIPSPARSELT_SPARSITY_50_PERCENT));
CHECK_HIPSPARSELT_ERROR(
hipsparseLtDenseDescriptorInit(&handle, &matC, n, n, n, 16, type, HIPSPARSE_ORDER_COL));
CHECK_HIPSPARSELT_ERROR(
hipsparseLtDenseDescriptorInit(&handle, &matD, n, n, n, 16, type, HIPSPARSE_ORDER_COL));
}
SMART_DESTROYER<const hipsparseLtMatDescriptor_t, hipsparseStatus_t> smB(
&matB, hipsparseLtMatDescriptorDestroy);

auto tmp_m = (!sparse_b) ? m : n;

CHECK_HIPSPARSELT_ERROR(hipsparseLtDenseDescriptorInit(
&handle, &matC, tmp_m, tmp_m, tmp_m, 16, type, HIPSPARSE_ORDER_COL));
SMART_DESTROYER<const hipsparseLtMatDescriptor_t, hipsparseStatus_t> smC(
&matC, hipsparseLtMatDescriptorDestroy);

CHECK_HIPSPARSELT_ERROR(hipsparseLtDenseDescriptorInit(
&handle, &matD, tmp_m, tmp_m, tmp_m, 16, type, HIPSPARSE_ORDER_COL));
SMART_DESTROYER<const hipsparseLtMatDescriptor_t, hipsparseStatus_t> smD(
&matD, hipsparseLtMatDescriptorDestroy);

CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescSetAttribute(
&handle, &matA, HIPSPARSELT_MAT_NUM_BATCHES, &batch_count, sizeof(batch_count)));
Expand Down Expand Up @@ -811,6 +872,8 @@ void run(int64_t m,

size_t workspace_size, compressed_size, compress_buffer_size;
CHECK_HIPSPARSELT_ERROR(hipsparseLtMatmulPlanInit(&handle, &plan, &matmul, &alg_sel));
SMART_DESTROYER<const hipsparseLtMatmulPlan_t, hipsparseStatus_t> sP(
&plan, hipsparseLtMatmulPlanDestroy);

CHECK_HIPSPARSELT_ERROR(hipsparseLtMatmulGetWorkspace(&handle, &plan, &workspace_size));

Expand All @@ -824,7 +887,11 @@ void run(int64_t m,
std::vector<T> hp_compressed(compressed_size / sizeof(T));
std::vector<T> hp_compressBuffer(compress_buffer_size / sizeof(T));
CHECK_HIP_ERROR(hipMalloc(&d_compressed, compressed_size));
SMART_DESTROYER<void, hipError_t> sDc(d_compressed, hipFree);

CHECK_HIP_ERROR(hipMalloc(&d_compressBuffer, compress_buffer_size));
SMART_DESTROYER<void, hipError_t> sDcB(d_compressBuffer, hipFree);

CHECK_HIPSPARSELT_ERROR(
hipsparseLtSpMMACompress(&handle, &plan, d_test, d_compressed, d_compressBuffer, stream));
hipStreamSynchronize(stream);
Expand Down Expand Up @@ -1028,89 +1095,84 @@ void run(int64_t m,
//validate(&hp_test[0], &hp_compressed[0], reinterpret_cast<unsigned char*>(&hp_compressed[c_stride_b * batch_count]), m, n, batch_count, stride_1, stride_2, stride, m, n/2, batch_count, c_stride_1, c_stride_2, c_stride_b, m, n/8, batch_count, m_stride_1, m_stride_2, m_stride_b);

//CHECK_HIP_ERROR(hipMalloc(&d_compressed, compressed_size));

CHECK_HIP_ERROR(hipFree(d));
CHECK_HIP_ERROR(hipFree(d_test));
CHECK_HIP_ERROR(hipFree(d_compressed));
CHECK_HIP_ERROR(hipFree(d_compressBuffer));

CHECK_HIPSPARSELT_ERROR(hipsparseLtMatmulPlanDestroy(&plan));
CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescriptorDestroy(&matA));
CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescriptorDestroy(&matB));
CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescriptorDestroy(&matC));
CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescriptorDestroy(&matD));
CHECK_HIPSPARSELT_ERROR(hipsparseLtDestroy(&handle));
}

int main(int argc, char* argv[])
{
// initialize parameters with default values
hipsparseOperation_t trans = HIPSPARSE_OPERATION_NON_TRANSPOSE;
try
{
// initialize parameters with default values
hipsparseOperation_t trans = HIPSPARSE_OPERATION_NON_TRANSPOSE;

// invalid int and float for hipsparselt spmm int and float arguments
int64_t invalid_int64 = std::numeric_limits<int64_t>::min() + 1;
int invalid_int = std::numeric_limits<int>::min() + 1;
float invalid_float = std::numeric_limits<float>::quiet_NaN();
// invalid int and float for hipsparselt spmm int and float arguments
int64_t invalid_int64 = std::numeric_limits<int64_t>::min() + 1;
int invalid_int = std::numeric_limits<int>::min() + 1;
float invalid_float = std::numeric_limits<float>::quiet_NaN();

// initialize to invalid value to detect if values not specified on command line
int64_t m = invalid_int64, n = invalid_int64, ld = invalid_int64, stride = invalid_int64;
// initialize to invalid value to detect if values not specified on command line
int64_t m = invalid_int64, n = invalid_int64, ld = invalid_int64, stride = invalid_int64;

int batch_count = invalid_int;
hipDataType type = HIP_R_16F;
int batch_count = invalid_int;
hipDataType type = HIP_R_16F;

bool verbose = false;
bool header = false;
bool sparse_b = false;
bool verbose = false;
bool header = false;
bool sparse_b = false;

if(parse_arguments(
argc, argv, m, n, ld, stride, batch_count, trans, type, header, sparse_b, verbose))
{
show_usage(argv);
return EXIT_FAILURE;
}
if(parse_arguments(
argc, argv, m, n, ld, stride, batch_count, trans, type, header, sparse_b, verbose))
{
show_usage(argv);
return EXIT_FAILURE;
}

// when arguments not specified, set to default values
if(m == invalid_int64)
m = DIM1;
if(n == invalid_int64)
n = DIM2;
if(ld == invalid_int64)
ld = trans == HIPSPARSE_OPERATION_NON_TRANSPOSE ? m : n;
if(stride == invalid_int64)
stride = trans == HIPSPARSE_OPERATION_NON_TRANSPOSE ? ld * n : ld * m;
if(batch_count == invalid_int)
batch_count = BATCH_COUNT;

if(bad_argument(trans, m, n, ld, stride, batch_count))
{
show_usage(argv);
return EXIT_FAILURE;
}
// when arguments not specified, set to default values
if(m == invalid_int64)
m = DIM1;
if(n == invalid_int64)
n = DIM2;
if(ld == invalid_int64)
ld = trans == HIPSPARSE_OPERATION_NON_TRANSPOSE ? m : n;
if(stride == invalid_int64)
stride = trans == HIPSPARSE_OPERATION_NON_TRANSPOSE ? ld * n : ld * m;
if(batch_count == invalid_int)
batch_count = BATCH_COUNT;

if(bad_argument(trans, m, n, ld, stride, batch_count))
{
show_usage(argv);
return EXIT_FAILURE;
}

if(header)
{
std::cout << "type,trans,M,N,K,ld,stride,batch_count,"
"result,error";
std::cout << std::endl;
}
if(header)
{
std::cout << "type,trans,M,N,K,ld,stride,batch_count,"
"result,error";
std::cout << std::endl;
}

switch(type)
switch(type)
{
case HIP_R_16F:
std::cout << "H_";
run<__half>(m, n, ld, stride, batch_count, trans, type, sparse_b, verbose);
break;
case HIP_R_16BF:
std::cout << "BF16_";
run<hip_bfloat16>(m, n, ld, stride, batch_count, trans, type, sparse_b, verbose);
break;
case HIP_R_8I:
std::cout << "I8_";
run<int8_t>(m, n, ld, stride, batch_count, trans, type, sparse_b, verbose);
break;
default:
break;
}

return EXIT_SUCCESS;
}
catch(int i)
{
case HIP_R_16F:
std::cout << "H_";
run<__half>(m, n, ld, stride, batch_count, trans, type, sparse_b, verbose);
break;
case HIP_R_16BF:
std::cout << "BF16_";
run<hip_bfloat16>(m, n, ld, stride, batch_count, trans, type, sparse_b, verbose);
break;
case HIP_R_8I:
std::cout << "I8_";
run<int8_t>(m, n, ld, stride, batch_count, trans, type, sparse_b, verbose);
break;
default:
break;
return EXIT_FAILURE;
}

return EXIT_SUCCESS;
}
Loading

0 comments on commit 56b4c00

Please sign in to comment.