Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix memory leakage #153

Merged
merged 2 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 145 additions & 83 deletions clients/samples/example_compress.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
error, \
__FILE__, \
__LINE__); \
exit(EXIT_FAILURE); \
throw EXIT_FAILURE; \
}
#endif

Expand All @@ -65,10 +65,52 @@
if(error == HIPSPARSE_STATUS_ARCH_MISMATCH) \
fprintf(stderr, "HIPSPARSE_STATUS_ARCH_MISMATCH"); \
fprintf(stderr, "\n"); \
exit(EXIT_FAILURE); \
throw EXIT_FAILURE; \
}
#endif

template <typename T, typename To>
class SMART_DESTROYER
{
typedef To (*destroy_func)(T*);

public:
SMART_DESTROYER(T* ptr, destroy_func func)
{
_ptr = ptr;
_func = func;
}
~SMART_DESTROYER()
{
if(_ptr != nullptr)
_func(_ptr);
}

T* _ptr = nullptr;
destroy_func _func;
};

template <typename T, typename To>
class SMART_DESTROYER_NON_PTR
{
typedef To (*destroy_func)(T);

public:
SMART_DESTROYER_NON_PTR(T* ptr, destroy_func func)
{
_ptr = ptr;
_func = func;
}
~SMART_DESTROYER_NON_PTR()
{
if(_ptr != nullptr)
_func(*_ptr);
}

T* _ptr = nullptr;
destroy_func _func;
};

inline unsigned char generate_metadata(int a, int b, int c, int d)
{
unsigned char metadata = (a)&0x03;
Expand Down Expand Up @@ -709,9 +751,13 @@ void run(int64_t m,
int num_streams = 0;
hipStream_t stream = nullptr;
hipStreamCreate(&stream);
SMART_DESTROYER_NON_PTR<hipStream_t, hipError_t> sS(&stream, hipStreamDestroy);

CHECK_HIP_ERROR(hipMalloc(&d, size * sizeof(T)));
SMART_DESTROYER<void, hipError_t> sD(d, hipFree);

CHECK_HIP_ERROR(hipMalloc(&d_test, size * sizeof(T)));
SMART_DESTROYER<void, hipError_t> sDt(d_test, hipFree);
// copy matrices from host to device
CHECK_HIP_ERROR(hipMemcpy(d, hp.data(), sizeof(T) * size, hipMemcpyHostToDevice));

Expand All @@ -722,6 +768,7 @@ void run(int64_t m,
hipsparseLtMatmulPlan_t plan;

CHECK_HIPSPARSELT_ERROR(hipsparseLtInit(&handle));
SMART_DESTROYER<const hipsparseLtHandle_t, hipsparseStatus_t> sH(&handle, hipsparseLtDestroy);

if(!sparse_b)
{
Expand All @@ -735,18 +782,22 @@ void run(int64_t m,
type,
HIPSPARSE_ORDER_COL,
HIPSPARSELT_SPARSITY_50_PERCENT));
CHECK_HIPSPARSELT_ERROR(
hipsparseLtDenseDescriptorInit(&handle, &matB, n, m, n, 16, type, HIPSPARSE_ORDER_COL));
CHECK_HIPSPARSELT_ERROR(
hipsparseLtDenseDescriptorInit(&handle, &matC, m, m, m, 16, type, HIPSPARSE_ORDER_COL));
CHECK_HIPSPARSELT_ERROR(
hipsparseLtDenseDescriptorInit(&handle, &matD, m, m, m, 16, type, HIPSPARSE_ORDER_COL));
}
else
{
CHECK_HIPSPARSELT_ERROR(
hipsparseLtDenseDescriptorInit(&handle, &matA, n, m, n, 16, type, HIPSPARSE_ORDER_COL));
}
SMART_DESTROYER<const hipsparseLtMatDescriptor_t, hipsparseStatus_t> smA(
&matA, hipsparseLtMatDescriptorDestroy);

if(!sparse_b)
{
CHECK_HIPSPARSELT_ERROR(
hipsparseLtDenseDescriptorInit(&handle, &matB, n, m, n, 16, type, HIPSPARSE_ORDER_COL));
}
else
{
CHECK_HIPSPARSELT_ERROR(
hipsparseLtStructuredDescriptorInit(&handle,
&matB,
Expand All @@ -757,11 +808,21 @@ void run(int64_t m,
type,
HIPSPARSE_ORDER_COL,
HIPSPARSELT_SPARSITY_50_PERCENT));
CHECK_HIPSPARSELT_ERROR(
hipsparseLtDenseDescriptorInit(&handle, &matC, n, n, n, 16, type, HIPSPARSE_ORDER_COL));
CHECK_HIPSPARSELT_ERROR(
hipsparseLtDenseDescriptorInit(&handle, &matD, n, n, n, 16, type, HIPSPARSE_ORDER_COL));
}
SMART_DESTROYER<const hipsparseLtMatDescriptor_t, hipsparseStatus_t> smB(
&matB, hipsparseLtMatDescriptorDestroy);

auto tmp_m = (!sparse_b) ? m : n;

CHECK_HIPSPARSELT_ERROR(hipsparseLtDenseDescriptorInit(
&handle, &matC, tmp_m, tmp_m, tmp_m, 16, type, HIPSPARSE_ORDER_COL));
SMART_DESTROYER<const hipsparseLtMatDescriptor_t, hipsparseStatus_t> smC(
&matC, hipsparseLtMatDescriptorDestroy);

CHECK_HIPSPARSELT_ERROR(hipsparseLtDenseDescriptorInit(
&handle, &matD, tmp_m, tmp_m, tmp_m, 16, type, HIPSPARSE_ORDER_COL));
SMART_DESTROYER<const hipsparseLtMatDescriptor_t, hipsparseStatus_t> smD(
&matD, hipsparseLtMatDescriptorDestroy);

CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescSetAttribute(
&handle, &matA, HIPSPARSELT_MAT_NUM_BATCHES, &batch_count, sizeof(batch_count)));
Expand Down Expand Up @@ -811,6 +872,8 @@ void run(int64_t m,

size_t workspace_size, compressed_size, compress_buffer_size;
CHECK_HIPSPARSELT_ERROR(hipsparseLtMatmulPlanInit(&handle, &plan, &matmul, &alg_sel));
SMART_DESTROYER<const hipsparseLtMatmulPlan_t, hipsparseStatus_t> sP(
&plan, hipsparseLtMatmulPlanDestroy);

CHECK_HIPSPARSELT_ERROR(hipsparseLtMatmulGetWorkspace(&handle, &plan, &workspace_size));

Expand All @@ -824,7 +887,11 @@ void run(int64_t m,
std::vector<T> hp_compressed(compressed_size / sizeof(T));
std::vector<T> hp_compressBuffer(compress_buffer_size / sizeof(T));
CHECK_HIP_ERROR(hipMalloc(&d_compressed, compressed_size));
SMART_DESTROYER<void, hipError_t> sDc(d_compressed, hipFree);

CHECK_HIP_ERROR(hipMalloc(&d_compressBuffer, compress_buffer_size));
SMART_DESTROYER<void, hipError_t> sDcB(d_compressBuffer, hipFree);

CHECK_HIPSPARSELT_ERROR(
hipsparseLtSpMMACompress(&handle, &plan, d_test, d_compressed, d_compressBuffer, stream));
hipStreamSynchronize(stream);
Expand Down Expand Up @@ -1028,89 +1095,84 @@ void run(int64_t m,
//validate(&hp_test[0], &hp_compressed[0], reinterpret_cast<unsigned char*>(&hp_compressed[c_stride_b * batch_count]), m, n, batch_count, stride_1, stride_2, stride, m, n/2, batch_count, c_stride_1, c_stride_2, c_stride_b, m, n/8, batch_count, m_stride_1, m_stride_2, m_stride_b);

//CHECK_HIP_ERROR(hipMalloc(&d_compressed, compressed_size));

CHECK_HIP_ERROR(hipFree(d));
CHECK_HIP_ERROR(hipFree(d_test));
CHECK_HIP_ERROR(hipFree(d_compressed));
CHECK_HIP_ERROR(hipFree(d_compressBuffer));

CHECK_HIPSPARSELT_ERROR(hipsparseLtMatmulPlanDestroy(&plan));
CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescriptorDestroy(&matA));
CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescriptorDestroy(&matB));
CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescriptorDestroy(&matC));
CHECK_HIPSPARSELT_ERROR(hipsparseLtMatDescriptorDestroy(&matD));
CHECK_HIPSPARSELT_ERROR(hipsparseLtDestroy(&handle));
}

int main(int argc, char* argv[])
{
// initialize parameters with default values
hipsparseOperation_t trans = HIPSPARSE_OPERATION_NON_TRANSPOSE;
try
{
// initialize parameters with default values
hipsparseOperation_t trans = HIPSPARSE_OPERATION_NON_TRANSPOSE;

// invalid int and float for hipsparselt spmm int and float arguments
int64_t invalid_int64 = std::numeric_limits<int64_t>::min() + 1;
int invalid_int = std::numeric_limits<int>::min() + 1;
float invalid_float = std::numeric_limits<float>::quiet_NaN();
// invalid int and float for hipsparselt spmm int and float arguments
int64_t invalid_int64 = std::numeric_limits<int64_t>::min() + 1;
int invalid_int = std::numeric_limits<int>::min() + 1;
float invalid_float = std::numeric_limits<float>::quiet_NaN();

// initialize to invalid value to detect if values not specified on command line
int64_t m = invalid_int64, n = invalid_int64, ld = invalid_int64, stride = invalid_int64;
// initialize to invalid value to detect if values not specified on command line
int64_t m = invalid_int64, n = invalid_int64, ld = invalid_int64, stride = invalid_int64;

int batch_count = invalid_int;
hipDataType type = HIP_R_16F;
int batch_count = invalid_int;
hipDataType type = HIP_R_16F;

bool verbose = false;
bool header = false;
bool sparse_b = false;
bool verbose = false;
bool header = false;
bool sparse_b = false;

if(parse_arguments(
argc, argv, m, n, ld, stride, batch_count, trans, type, header, sparse_b, verbose))
{
show_usage(argv);
return EXIT_FAILURE;
}
if(parse_arguments(
argc, argv, m, n, ld, stride, batch_count, trans, type, header, sparse_b, verbose))
{
show_usage(argv);
return EXIT_FAILURE;
}

// when arguments not specified, set to default values
if(m == invalid_int64)
m = DIM1;
if(n == invalid_int64)
n = DIM2;
if(ld == invalid_int64)
ld = trans == HIPSPARSE_OPERATION_NON_TRANSPOSE ? m : n;
if(stride == invalid_int64)
stride = trans == HIPSPARSE_OPERATION_NON_TRANSPOSE ? ld * n : ld * m;
if(batch_count == invalid_int)
batch_count = BATCH_COUNT;

if(bad_argument(trans, m, n, ld, stride, batch_count))
{
show_usage(argv);
return EXIT_FAILURE;
}
// when arguments not specified, set to default values
if(m == invalid_int64)
m = DIM1;
if(n == invalid_int64)
n = DIM2;
if(ld == invalid_int64)
ld = trans == HIPSPARSE_OPERATION_NON_TRANSPOSE ? m : n;
if(stride == invalid_int64)
stride = trans == HIPSPARSE_OPERATION_NON_TRANSPOSE ? ld * n : ld * m;
if(batch_count == invalid_int)
batch_count = BATCH_COUNT;

if(bad_argument(trans, m, n, ld, stride, batch_count))
{
show_usage(argv);
return EXIT_FAILURE;
}

if(header)
{
std::cout << "type,trans,M,N,K,ld,stride,batch_count,"
"result,error";
std::cout << std::endl;
}
if(header)
{
std::cout << "type,trans,M,N,K,ld,stride,batch_count,"
"result,error";
std::cout << std::endl;
}

switch(type)
switch(type)
{
case HIP_R_16F:
std::cout << "H_";
run<__half>(m, n, ld, stride, batch_count, trans, type, sparse_b, verbose);
break;
case HIP_R_16BF:
std::cout << "BF16_";
run<hip_bfloat16>(m, n, ld, stride, batch_count, trans, type, sparse_b, verbose);
break;
case HIP_R_8I:
std::cout << "I8_";
run<int8_t>(m, n, ld, stride, batch_count, trans, type, sparse_b, verbose);
break;
default:
break;
}

return EXIT_SUCCESS;
}
catch(int i)
{
case HIP_R_16F:
std::cout << "H_";
run<__half>(m, n, ld, stride, batch_count, trans, type, sparse_b, verbose);
break;
case HIP_R_16BF:
std::cout << "BF16_";
run<hip_bfloat16>(m, n, ld, stride, batch_count, trans, type, sparse_b, verbose);
break;
case HIP_R_8I:
std::cout << "I8_";
run<int8_t>(m, n, ld, stride, batch_count, trans, type, sparse_b, verbose);
break;
default:
break;
return EXIT_FAILURE;
}

return EXIT_SUCCESS;
}
Loading
Loading