Skip to content

Commit

Permalink
Update error message
Browse files Browse the repository at this point in the history
Update error message.
  • Loading branch information
abuccts committed Apr 4, 2023
1 parent 857a8ba commit 0719c2a
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,18 @@ template <typename T> cudaDataType_t get_datatype() {
}

template <typename Ta, typename Tb, typename Tout>
float timing_matmul_tn(size_t m, size_t n, size_t k, int batch, int warmup, int iter) {
float timing_matmul_tn(size_t m, size_t n, size_t k, size_t batch, int warmup, int iter) {
// init matrix
Ta *matrix_a = nullptr;
Tb *matrix_b = nullptr;
Tout *matrix_out = nullptr;
cudaMalloc(&matrix_a, m * k * std::max(batch, 1) * sizeof(Ta));
cudaMalloc(&matrix_b, k * n * std::max(batch, 1) * sizeof(Tb));
cudaMalloc(&matrix_out, m * n * std::max(batch, 1) * sizeof(Tout));
batch = std::max<size_t>(batch, 1);
cudaMalloc(&matrix_a, m * k * batch * sizeof(Ta));
cudaMalloc(&matrix_b, k * n * batch * sizeof(Tb));
cudaMalloc(&matrix_out, m * n * batch * sizeof(Tout));

init_matrix<Ta><<<216, 1024>>>(matrix_a, 1.f, m * k * std::max(batch, 1));
init_matrix<Tb><<<216, 1024>>>(matrix_b, 2.f, k * n * std::max(batch, 1));
init_matrix<Ta><<<216, 1024>>>(matrix_a, 1.f, m * k * batch);
init_matrix<Tb><<<216, 1024>>>(matrix_b, 2.f, k * n * batch);

// init gemm
size_t lda = k, ldb = k, ldd = m;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@

void cublasLtGemm::Init() {
cublasLtHandle_t handle;
checkCublasStatus(cublasLtCreate(&handle));
CUBLAS_CHECK(cublasLtCreate(&handle));
handle_.reset(handle);

/* preference can be initialized without arguments */
cublasLtMatmulPreference_t preference;
checkCublasStatus(cublasLtMatmulPreferenceCreate(&preference));
CUBLAS_CHECK(cublasLtMatmulPreferenceCreate(&preference));
preference_.reset(preference);
}

Expand All @@ -24,32 +24,32 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l
// force c_type
cudaDataType_t c_type = d_type;
// Create matrix descriptors.
checkCublasStatus(
CUBLAS_CHECK(
cublasLtMatrixLayoutCreate(&a_desc, a_type, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda));
checkCublasStatus(
CUBLAS_CHECK(
cublasLtMatrixLayoutCreate(&b_desc, b_type, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb));
checkCublasStatus(cublasLtMatrixLayoutCreate(&c_desc, c_type, m, n, ldd));
checkCublasStatus(cublasLtMatrixLayoutCreate(&d_desc, d_type, m, n, ldd));
CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&c_desc, c_type, m, n, ldd));
CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&d_desc, d_type, m, n, ldd));

// strided batch gemm
if (batch > 0) {
int64_t stridea = m * k, strideb = k * n, stridec = m * n, strided = m * n;
checkCublasStatus(
CUBLAS_CHECK(
cublasLtMatrixLayoutSetAttribute(a_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
checkCublasStatus(cublasLtMatrixLayoutSetAttribute(a_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
&stridea, sizeof(stridea)));
checkCublasStatus(
CUBLAS_CHECK(cublasLtMatrixLayoutSetAttribute(a_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea,
sizeof(stridea)));
CUBLAS_CHECK(
cublasLtMatrixLayoutSetAttribute(b_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
checkCublasStatus(cublasLtMatrixLayoutSetAttribute(b_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
&strideb, sizeof(strideb)));
checkCublasStatus(
CUBLAS_CHECK(cublasLtMatrixLayoutSetAttribute(b_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideb,
sizeof(strideb)));
CUBLAS_CHECK(
cublasLtMatrixLayoutSetAttribute(c_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
checkCublasStatus(cublasLtMatrixLayoutSetAttribute(c_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
&stridec, sizeof(stridec)));
checkCublasStatus(
CUBLAS_CHECK(cublasLtMatrixLayoutSetAttribute(c_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec,
sizeof(stridec)));
CUBLAS_CHECK(
cublasLtMatrixLayoutSetAttribute(d_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
checkCublasStatus(cublasLtMatrixLayoutSetAttribute(d_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
&strided, sizeof(strided)));
CUBLAS_CHECK(cublasLtMatrixLayoutSetAttribute(d_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strided,
sizeof(strided)));
}
a_desc_.reset(a_desc);
b_desc_.reset(b_desc);
Expand All @@ -64,7 +64,7 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l
gemm_compute_type = CUBLAS_COMPUTE_64F;

cublasLtMatmulDesc_t op_desc = nullptr;
checkCublasStatus(cublasLtMatmulDescCreate(&op_desc, gemm_compute_type, CUDA_R_32F));
CUBLAS_CHECK(cublasLtMatmulDescCreate(&op_desc, gemm_compute_type, CUDA_R_32F));
op_desc_.reset(op_desc);

if (a_type == CUDA_R_8F_E5M2 || b_type == CUDA_R_8F_E5M2 || a_type == CUDA_R_8F_E4M3 || b_type == CUDA_R_8F_E4M3) {
Expand All @@ -73,33 +73,31 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l
cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode, sizeof(fastAccuMode));
}

checkCublasStatus(
cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)));
checkCublasStatus(
cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transb)));
CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)));
CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transb)));

if (a_scale_inverse != nullptr) {
checkCublasStatus(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&a_scale_inverse, sizeof(a_scale_inverse)));
CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&a_scale_inverse, sizeof(a_scale_inverse)));
}
if (b_scale_inverse != nullptr) {
checkCublasStatus(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&b_scale_inverse, sizeof(b_scale_inverse)));
CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&b_scale_inverse, sizeof(b_scale_inverse)));
}
checkCublasStatus(
CUBLAS_CHECK(
cublasLtMatmulDescSetAttribute(op_desc_.get(), CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)));
}

size_t cublasLtGemm::GetAlgorithm(int max_algorithm_count, size_t max_workspace_size) {
checkCublasStatus(cublasLtMatmulPreferenceSetAttribute(preference_.get(), CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&max_workspace_size, sizeof(max_workspace_size)));
CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(preference_.get(), CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&max_workspace_size, sizeof(max_workspace_size)));

int found_algorithm_count = 0;
std::vector<cublasLtMatmulHeuristicResult_t> results(max_algorithm_count);
// Though we query all of possible algorithm, we will use the first later
checkCublasStatus(cublasLtMatmulAlgoGetHeuristic(handle_.get(), op_desc_.get(), a_desc_.get(), b_desc_.get(),
c_desc_.get(), d_desc_.get(), preference_.get(),
max_algorithm_count, results.data(), &found_algorithm_count));
CUBLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(handle_.get(), op_desc_.get(), a_desc_.get(), b_desc_.get(),
c_desc_.get(), d_desc_.get(), preference_.get(), max_algorithm_count,
results.data(), &found_algorithm_count));
if (found_algorithm_count == 0) {
throw std::runtime_error("Unable to find any suitable algorithms");
}
Expand All @@ -111,13 +109,13 @@ size_t cublasLtGemm::GetAlgorithm(int max_algorithm_count, size_t max_workspace_

void cublasLtGemm::Execute(void *matrix_a, void *matrix_b, void *matrix_c, void *matrix_d, float alpha, float beta,
void *workspace, size_t workspace_size, cudaStream_t stream) {
checkCublasStatus(cublasLtMatmul(handle_.get(), op_desc_.get(), static_cast<const void *>(&alpha), /* alpha */
matrix_a, /* A */
a_desc_.get(), matrix_b, /* B */
b_desc_.get(), static_cast<const void *>(&beta), /* beta */
matrix_c, /* C */
c_desc_.get(), matrix_d, /* D */
d_desc_.get(), &heuristic_results_.front().algo, /* algo */
workspace, /* workspace */
workspace_size, stream)); /* stream */
CUBLAS_CHECK(cublasLtMatmul(handle_.get(), op_desc_.get(), static_cast<const void *>(&alpha), /* alpha */
matrix_a, /* A */
a_desc_.get(), matrix_b, /* B */
b_desc_.get(), static_cast<const void *>(&beta), /* beta */
matrix_c, /* C */
c_desc_.get(), matrix_d, /* D */
d_desc_.get(), &heuristic_results_.front().algo, /* algo */
workspace, /* workspace */
workspace_size, stream)); /* stream */
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@

#include <cublasLt.h>

inline void checkCublasStatus(cublasStatus_t status) {
if (status != CUBLAS_STATUS_SUCCESS) {
printf("cuBLAS API failed with status %s\n", cublasGetStatusString(status));
throw std::logic_error("cuBLAS API failed");
}
}
#define CUBLAS_CHECK(func) \
do { \
cublasStatus_t status = func; \
if (status != CUBLAS_STATUS_SUCCESS) { \
printf("cuBLAS failure: %s:%d '%s'\n", __FILE__, __LINE__, cublasGetStatusString(status)); \
exit(EXIT_FAILURE); \
} \
} while (0)

class cublasLtGemm {
public:
Expand Down

0 comments on commit 0719c2a

Please sign in to comment.