From a6bf6f8e0915a7817a945500ed0d3924fb5021af Mon Sep 17 00:00:00 2001 From: Vin Huang Date: Sat, 27 Jul 2024 03:42:02 +0800 Subject: [PATCH] Backward support hipsparseLtDatatype_t *TODO remove this support in the later version --- library/include/hipsparselt.h | 21 +++ library/src/auxiliary.cpp | 13 ++ .../rocsparselt/src/include/handle.h | 8 + .../rocsparselt/src/rocsparselt_auxiliary.cpp | 152 ++++++++++++++++-- library/src/include/auxiliary.hpp | 28 ++++ library/src/include/hipsparselt_ostream.hpp | 11 ++ library/src/nvcc_detail/hipsparselt.cpp | 22 +++ 7 files changed, 240 insertions(+), 15 deletions(-) diff --git a/library/include/hipsparselt.h b/library/include/hipsparselt.h index 480bbd89..6becfb93 100644 --- a/library/include/hipsparselt.h +++ b/library/include/hipsparselt.h @@ -147,6 +147,27 @@ typedef struct {uint8_t data[11024];} hipsparseLtMatmulAlgSelection_t; typedef struct {uint8_t data[11024];} hipsparseLtMatmulPlan_t; #endif + +/* Types definitions */ +/*! @deprecated use hipDataType instead + * @TODO Keep this enum for backward supporting, will be deprecated in the later version. + * \ingroup types_module + * \brief List of hipsparselt data types. + * + * \details + * Indicates the precision width of data stored in a hipsparselt type. + * Should use hipDatatype_t instead in the furture. + */ +typedef enum +{ + HIPSPARSELT_R_16F = 150, /**< 16 bit floating point, real */ + HIPSPARSELT_R_32F = 151, /**< 32 bit floating point, real */ + HIPSPARSELT_R_8I = 160, /**< 8 bit signed integer, real */ + HIPSPARSELT_R_16BF = 168, /**< 16 bit bfloat, real */ + HIPSPARSELT_R_8F = 170, /**< 8 bit floating point, real */ + HIPSPARSELT_R_8BF = 171, /**< 8 bit bfloat, real */ +} hipsparseLtDatatype_t; + /*! \ingroup types_module * \brief Specify the sparsity of the structured matrix. * diff --git a/library/src/auxiliary.cpp b/library/src/auxiliary.cpp index 43446825..4a9f9ea9 100644 --- a/library/src/auxiliary.cpp +++ b/library/src/auxiliary.cpp @@ -24,6 +24,19 @@ #include "activation.hpp" // clang-format off +/* @deprecated */ +const hipsparseLtDatatype_t string_to_hipsparselt_datatype(const std::string& value) +{ + return + value == "f32_r" || value == "s" ? HIPSPARSELT_R_32F : + value == "f16_r" || value == "h" ? HIPSPARSELT_R_16F : + value == "bf16_r" ? HIPSPARSELT_R_16BF : + value == "i8_r" ? HIPSPARSELT_R_8I : + value == "f8_r" ? HIPSPARSELT_R_8F : + value == "bf8_r" ? HIPSPARSELT_R_8BF : + static_cast(-1); +} + const hipDataType string_to_hip_datatype(const std::string& value) { return diff --git a/library/src/hcc_detail/rocsparselt/src/include/handle.h b/library/src/hcc_detail/rocsparselt/src/include/handle.h index 097d7041..adfafc11 100644 --- a/library/src/hcc_detail/rocsparselt/src/include/handle.h +++ b/library/src/hcc_detail/rocsparselt/src/include/handle.h @@ -116,6 +116,7 @@ struct _rocsparselt_mat_descr , c_k(rhs.c_k) , c_ld(rhs.c_ld) , c_n(rhs.c_n) + , is_hipsparselt_datatype(rhs.is_hipsparselt_datatype) { is_init = (uintptr_t)handle; }; @@ -179,6 +180,9 @@ struct _rocsparselt_mat_descr int64_t c_ld = -1; int64_t c_n = -1; + + //@TODO This is used to backward support hipsparseLtDatatype_t, should remove in the later version. + bool is_hipsparselt_datatype = false; }; /******************************************************************************** @@ -223,6 +227,7 @@ struct _rocsparselt_matmul_descr , _ldb(rhs._ldb) , _is_sparse_a(rhs._is_sparse_a) , _swap_ab(rhs._swap_ab) + , bias_is_hipsparselt_datatype(rhs.bias_is_hipsparselt_datatype) { matrix_A = rhs.matrix_A->clone(); matrix_B = rhs.matrix_B->clone(); @@ -294,6 +299,9 @@ struct _rocsparselt_matmul_descr bool _is_sparse_a = true; bool _swap_ab = false; + //@TODO This is used to backward support hipsparseLtDatatype_t, should remove in the later version. + bool bias_is_hipsparselt_datatype = false; + private: bool is_reference = true; uintptr_t is_init = 0; diff --git a/library/src/hcc_detail/rocsparselt/src/rocsparselt_auxiliary.cpp b/library/src/hcc_detail/rocsparselt/src/rocsparselt_auxiliary.cpp index eba633bf..1de767e2 100644 --- a/library/src/hcc_detail/rocsparselt/src/rocsparselt_auxiliary.cpp +++ b/library/src/hcc_detail/rocsparselt/src/rocsparselt_auxiliary.cpp @@ -42,6 +42,62 @@ extern "C" { #endif +//@TODO This is used to backward support hipsparseLtDatatype_t, should be deprecated in the later version. +hipDataType HIPSparseLtDatatypeToHipDatatype(hipsparseLtDatatype_t type) +{ + switch(type) + { + case HIPSPARSELT_R_32F: + return HIP_R_32F; + + case HIPSPARSELT_R_16BF: + return HIP_R_16BF; + + case HIPSPARSELT_R_16F: + return HIP_R_16F; + + case HIPSPARSELT_R_8I: + return HIP_R_8I; + + case HIPSPARSELT_R_8F: + return HIP_R_8F_E4M3_FNUZ; + + case HIPSPARSELT_R_8BF: + return HIP_R_8F_E5M2_FNUZ; + + default: + throw HIPSPARSE_STATUS_NOT_SUPPORTED; + } +} + +//@TODO This is used to backward support hipsparseLtDatatype_t, should be deprecated in the later version. +hipsparseLtDatatype_t HipDatatypeToHIPSparseLtDatatype(hipDataType type) +{ + switch(type) + { + case HIP_R_32F: + return HIPSPARSELT_R_32F; + + case HIP_R_16BF: + return HIPSPARSELT_R_16BF; + + case HIP_R_16F: + return HIPSPARSELT_R_16F; + + case HIP_R_8I: + return HIPSPARSELT_R_8I; + + case HIP_R_8F_E4M3_FNUZ: + return HIPSPARSELT_R_8F; + + case HIP_R_8F_E5M2_FNUZ: + return HIPSPARSELT_R_8BF; + + default: + throw HIPSPARSE_STATUS_NOT_SUPPORTED; + } +} + /******************************************************************************** * \brief rocsparselt_handle is a structure holding the rocsparselt library context. * It must be initialized using rocsparselt_init() @@ -148,15 +204,25 @@ rocsparselt_status rocsparselt_dense_descr_init(const rocsparselt_handle* handle // Allocate try { + /* + * @TODO + * This is used to backward support hipsparseLtDatatype_t, will be deprecated in the later version. + */ + hipDataType vtype_; + bool is_hipsparselt_datatype = false; + try + { + vtype_ = HIPSparseLtDatatypeToHipDatatype( + static_cast(valueType)); + is_hipsparselt_datatype = true; + } + catch(...) + { + vtype_ = valueType; + } - auto status = validateMatrixArgs(_handle, - rows, - cols, - ld, - alignment, - valueType, - order, - rocsparselt_matrix_type_dense); + auto status = validateMatrixArgs( + _handle, rows, cols, ld, alignment, vtype_, order, rocsparselt_matrix_type_dense); if(status != rocsparselt_status_success) throw status; @@ -168,10 +234,11 @@ rocsparselt_status rocsparselt_dense_descr_init(const rocsparselt_handle* handle _matDescr->n = cols; _matDescr->ld = ld; _matDescr->alignment = alignment; - _matDescr->type = valueType; + _matDescr->type = vtype_; _matDescr->order = order; _matDescr->num_batches = 1; _matDescr->batch_stride = order == rocsparselt_order_column ? cols * ld : rows * ld; + _matDescr->is_hipsparselt_datatype = is_hipsparselt_datatype; log_api(_handle, __func__, "_matDescr[out]", @@ -185,7 +252,7 @@ rocsparselt_status rocsparselt_dense_descr_init(const rocsparselt_handle* handle "alignment[in]", alignment, "valueType[in]", - hipDataType_to_string(valueType), + hipDataType_to_string(vtype_), "order[in]", rocsparselt_order_to_string(order)); } @@ -239,12 +306,30 @@ rocsparselt_status rocsparselt_structured_descr_init(const rocsparselt_handle* h // Allocate try { + + /* + * @TODO + * This is used to backward support hipsparseLtDatatype_t, will be deprecated in the later version. + */ + hipDataType vtype_; + bool is_hipsparselt_datatype = false; + try + { + vtype_ = HIPSparseLtDatatypeToHipDatatype( + static_cast(valueType)); + is_hipsparselt_datatype = true; + } + catch(...) + { + vtype_ = valueType; + } + auto status = validateMatrixArgs(_handle, rows, cols, ld, alignment, - valueType, + vtype_, order, rocsparselt_matrix_type_structured); if(status != rocsparselt_status_success) @@ -258,11 +343,12 @@ rocsparselt_status rocsparselt_structured_descr_init(const rocsparselt_handle* h _matDescr->n = cols; _matDescr->ld = ld; _matDescr->alignment = alignment; - _matDescr->type = valueType; + _matDescr->type = vtype_; _matDescr->order = order; _matDescr->sparsity = sparsity; _matDescr->num_batches = 1; _matDescr->batch_stride = order == rocsparselt_order_column ? cols * ld : rows * ld; + _matDescr->is_hipsparselt_datatype = is_hipsparselt_datatype; log_api(_handle, __func__, "_matDescr[out]", @@ -276,7 +362,7 @@ rocsparselt_status rocsparselt_structured_descr_init(const rocsparselt_handle* h "alignment[in]", alignment, "valueType[in]", - hipDataType_to_string(valueType), + hipDataType_to_string(vtype_), "order[in]", rocsparselt_order_to_string(order), "sparsity[in]", @@ -726,6 +812,8 @@ rocsparselt_status rocsparselt_matmul_descr_init(const rocsparselt_handle* ha break; } + _matmulDescr->bias_is_hipsparselt_datatype = _matA->is_hipsparselt_datatype; + _matmulDescr->_op_A = _matmulDescr->op_A; if(_matA->order != _matC->order) _matmulDescr->_op_A = _matmulDescr->op_A == rocsparselt_operation_none @@ -931,7 +1019,23 @@ rocsparselt_status } case rocsparselt_matmul_bias_type: { - assign_data(&_matmulDescr->bias_type); + /* + * @TODO + * This is used to backward support hipsparseLtDatatype_t, will be deprecated in the later version. + */ + hipDataType vtype_; + assign_data(&vtype_); + try + { + _matmulDescr->bias_type = HIPSparseLtDatatypeToHipDatatype( + static_cast(vtype_)); + _matmulDescr->bias_is_hipsparselt_datatype = true; + } + catch(...) + { + _matmulDescr->bias_type = vtype_; + _matmulDescr->bias_is_hipsparselt_datatype = false; + } break; } default: @@ -1070,7 +1174,25 @@ rocsparselt_status break; case rocsparselt_matmul_bias_type: { - retrive_data(_matmulDescr->bias_type); + /* + * @TODO + * This is used to backward support hipsparseLtDatatype_t, will be deprecated in the later version. + */ + if(_matmulDescr->bias_is_hipsparselt_datatype) + { + try + { + hipsparseLtDatatype_t bias_type_ + = HipDatatypeToHIPSparseLtDatatype(_matmulDescr->bias_type); + retrive_data(bias_type_); + } + catch(...) + { + retrive_data(_matmulDescr->bias_type); + } + } + else + retrive_data(_matmulDescr->bias_type); break; } default: diff --git a/library/src/include/auxiliary.hpp b/library/src/include/auxiliary.hpp index 66d73cbf..698e526d 100644 --- a/library/src/include/auxiliary.hpp +++ b/library/src/include/auxiliary.hpp @@ -107,6 +107,29 @@ constexpr hipsparseOrder_t char_to_hipsparselt_order(char value) } } +// return precision string for hipsparseLtDatatype_t +/* @deprecated */ +HIPSPARSELT_EXPORT +constexpr const char* hipsparselt_datatype_to_string(hipsparseLtDatatype_t type) +{ + switch(type) + { + case HIPSPARSELT_R_32F: + return "f32_r"; + case HIPSPARSELT_R_16F: + return "f16_r"; + case HIPSPARSELT_R_16BF: + return "bf16_r"; + case HIPSPARSELT_R_8I: + return "i8_r"; + case HIPSPARSELT_R_8F: + return "f8_r"; + case HIPSPARSELT_R_8BF: + return "bf8_r"; + } + return "invalid"; +} + // return precision string for hipDataType HIPSPARSELT_EXPORT constexpr const char* hip_datatype_to_string(hipDataType type) @@ -151,6 +174,11 @@ constexpr const char* hipsparselt_computetype_to_string(hipsparseLtComputetype_t } // clang-format off + +/* @deprecated */ +HIPSPARSELT_EXPORT +const hipsparseLtDatatype_t string_to_hipsparselt_datatype(const std::string& value); + HIPSPARSELT_EXPORT const hipDataType string_to_hip_datatype(const std::string& value); diff --git a/library/src/include/hipsparselt_ostream.hpp b/library/src/include/hipsparselt_ostream.hpp index c79a9d5c..4230133a 100644 --- a/library/src/include/hipsparselt_ostream.hpp +++ b/library/src/include/hipsparselt_ostream.hpp @@ -393,6 +393,17 @@ class HIPSPARSELT_EXPORT hipsparselt_internal_ostream return os << s.c_str(); } + /* + * @deprecated + * hipsparseLtDatatype_t output + */ + friend hipsparselt_internal_ostream& operator<<(hipsparselt_internal_ostream& os, + hipsparseLtDatatype_t d) + { + os.m_os << hipsparselt_datatype_to_string(d); + return os; + } + // hipDataType output friend hipsparselt_internal_ostream& operator<<(hipsparselt_internal_ostream& os, hipDataType d) { diff --git a/library/src/nvcc_detail/hipsparselt.cpp b/library/src/nvcc_detail/hipsparselt.cpp index d37f6919..b158c8b3 100644 --- a/library/src/nvcc_detail/hipsparselt.cpp +++ b/library/src/nvcc_detail/hipsparselt.cpp @@ -113,6 +113,28 @@ hipsparseStatus_t hipCUSPARSEStatusToHIPStatus(cusparseStatus_t cuStatus) #endif } +/* @deprecated */ +cudaDataType HIPDatatypeToCuSparseLtDatatype(hipsparseLtDatatype_t type) +{ + switch(type) + { + case HIPSPARSELT_R_32F: + return CUDA_R_32F; + + case HIPSPARSELT_R_16BF: + return CUDA_R_16BF; + + case HIPSPARSELT_R_16F: + return CUDA_R_16F; + + case HIPSPARSELT_R_8I: + return CUDA_R_8I; + + default: + throw HIPSPARSE_STATUS_NOT_SUPPORTED; + } +} + cudaDataType HIPDatatypeToCuSparseLtDatatype(hipDataType type) { switch(type)