Skip to content

Commit

Permalink
Fix the memory overflow bug of the tune_cublaslt_gemm operator (#9076)
Browse files Browse the repository at this point in the history
* fix bug

* add cudacheck
  • Loading branch information
Hanyonggong authored Sep 4, 2024
1 parent 9939f84 commit fbbc0a2
Showing 1 changed file with 33 additions and 29 deletions.
62 changes: 33 additions & 29 deletions csrc/gpu/tune_cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License. */
#include <limits>
#include <list>
#include <vector>
#include <iomanip>

#include "helper.h"

Expand Down Expand Up @@ -123,11 +124,8 @@ static void TestMatmulRun(cublasLtHandle_t ltHandle,
ltHandle, matmulDesc, A_desc, B_desc, C_desc, C_desc, &algo, &heurResult);
if (algoStatus == CUBLAS_STATUS_SUCCESS) {
ScaleT alpha = static_cast<ScaleT>(1), beta = static_cast<ScaleT>(0);
paddle::Tensor workspace =
paddle::empty({static_cast<int64_t>(heurResult.workspaceSize)},
paddle::DataType::UINT8,
paddle::GPUPlace());
void* workspace_ptr = workspace.data<uint8_t>();
void* workSpace;
CUDA_CHECK(cudaMalloc(&workSpace, heurResult.workspaceSize));
CUDA_CHECK(cudaEventRecord(startEvent, stream));
int repeats = 100;
for (int loop = 0; loop < repeats; loop++) {
Expand All @@ -144,7 +142,7 @@ static void TestMatmulRun(cublasLtHandle_t ltHandle,
C,
C_desc,
&algo,
workspace_ptr,
workSpace,
heurResult.workspaceSize,
stream);
if (currStatus != CUBLAS_STATUS_SUCCESS) {
Expand All @@ -166,6 +164,7 @@ static void TestMatmulRun(cublasLtHandle_t ltHandle,
perfResults.workspaceSize = heurResult.workspaceSize;
perfResults.wavesCount = heurResult.wavesCount;
}
CUDA_CHECK(cudaFree(workSpace));
} else {
std::cerr << "not enough workspace! current workspace is "
<< heurResult.workspaceSize;
Expand Down Expand Up @@ -468,7 +467,7 @@ class DevContext {};
class CPUContext : public DevContext {};

class CUBLASLTContext : public DevContext {
public:
public:
CUBLASLTContext() { CUDA_CHECK(cublasLtCreate(&handle)); }

cublasLtHandle_t handle;
Expand Down Expand Up @@ -513,22 +512,20 @@ void GEMMInt8<int8_t, int32_t, CUBLASLTContext>(const CUBLASLTContext& dev_ctx,
bool is_test,
bool is_read_from_file,
const std::string& path) {
paddle::Tensor A = paddle::empty({static_cast<int64_t>(AVec.size())},
paddle::DataType::INT8,
paddle::GPUPlace());
paddle::Tensor B = paddle::empty({static_cast<int64_t>(BVec.size())},
paddle::DataType::INT8,
paddle::GPUPlace());
paddle::Tensor C = paddle::empty({static_cast<int64_t>(CVec.size())},
paddle::DataType::INT32,
paddle::GPUPlace());

int8_t* A_dev = A.data<int8_t>();
int8_t* B_dev = B.data<int8_t>();
int32_t* C_dev = C.data<int32_t>();
int8_t* A_dev;
int8_t* B_dev;
int32_t* C_dev;
char* workSpace;

CUDA_CHECK(cudaMalloc((void**)&A_dev, AVec.size() * sizeof(int8_t)));
CUDA_CHECK(cudaMalloc((void**)&B_dev, BVec.size() * sizeof(int8_t)));
CUDA_CHECK(cudaMalloc((void**)&C_dev, m * n * sizeof(int32_t)));
CUDA_CHECK(
cudaMemcpy(A_dev, AVec.data(), AVec.size(), cudaMemcpyHostToDevice));
CUDA_CHECK(
cudaMemcpy(B_dev, BVec.data(), BVec.size(), cudaMemcpyHostToDevice));

// init data structure

cublasLtMatmulDesc_t matmul_desc;
cublasLtMatrixLayout_t A_desc;
cublasLtMatrixLayout_t B_desc;
Expand Down Expand Up @@ -639,11 +636,8 @@ void GEMMInt8<int8_t, int32_t, CUBLASLTContext>(const CUBLASLTContext& dev_ctx,
using_default_config();
}

paddle::Tensor workspace =
paddle::empty({static_cast<int64_t>(work_space_size)},
paddle::DataType::UINT8,
paddle::GPUPlace());
void* workspace_ptr = workspace.data<uint8_t>();
CUDA_CHECK(cudaMalloc((void**)&workSpace, work_space_size));

CUDA_CHECK(cublasLtMatmulAlgoInit(dev_ctx.handle,
cudaComputeType,
CUDA_R_32I,
Expand Down Expand Up @@ -692,7 +686,7 @@ void GEMMInt8<int8_t, int32_t, CUBLASLTContext>(const CUBLASLTContext& dev_ctx,
C_desc,
&algo,
// nullptr,
workspace_ptr,
workSpace,
// 0,
work_space_size,
0));
Expand All @@ -701,8 +695,18 @@ void GEMMInt8<int8_t, int32_t, CUBLASLTContext>(const CUBLASLTContext& dev_ctx,
CUDA_CHECK(cudaDeviceSynchronize());
auto end = std::chrono::high_resolution_clock::now();
double time = diffTime(start, end);
auto now = std::chrono::system_clock::now();
std::time_t now_time_t = std::chrono::system_clock::to_time_t(now);
std::tm now_tm = *std::localtime(&now_time_t);

std::cout << "GEMM with cublaslt imma1 int8 spend " << time / repeats
<< " ms in " << m << ", " << k << ", " << n << std::endl;
<< " ms in " << m << ", " << k << ", " << n
<< ", current time: " << std::put_time(&now_tm, "%H:%M:%S")
<< std::endl;
CUDA_CHECK(cudaFree(A_dev));
CUDA_CHECK(cudaFree(B_dev));
CUDA_CHECK(cudaFree(C_dev));
CUDA_CHECK(cudaFree(workSpace));
}

void TuneCublasltGemm(const paddle::Tensor& M,
Expand Down Expand Up @@ -730,7 +734,7 @@ void TuneCublasltGemm(const paddle::Tensor& M,
assert(K_size == N_size);

int m_data = (int)M_data[0];
assert(m_data > 0 && 4 <= 8192);
assert(m_data > 0);

std::vector<int> mm;

Expand Down

0 comments on commit fbbc0a2

Please sign in to comment.