Skip to content

Commit

Permalink
Backward support hipsparseLtDatatype_t
Browse files Browse the repository at this point in the history
 *TODO remove this support in the later version
  • Loading branch information
vin-huang committed Jul 29, 2024
1 parent 105a25e commit a6bf6f8
Show file tree
Hide file tree
Showing 7 changed files with 240 additions and 15 deletions.
21 changes: 21 additions & 0 deletions library/include/hipsparselt.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,27 @@ typedef struct {uint8_t data[11024];} hipsparseLtMatmulAlgSelection_t;
typedef struct {uint8_t data[11024];} hipsparseLtMatmulPlan_t;
#endif


/* Types definitions */
/*! @deprecated use hipDataType instead
* @TODO Keep this enum for backward supporting, will be deprecated in the later version.
* \ingroup types_module
* \brief List of hipsparselt data types.
*
* \details
* Indicates the precision width of data stored in a hipsparselt type.
* Should use hipDatatype_t instead in the furture.
*/
typedef enum
{
HIPSPARSELT_R_16F = 150, /**< 16 bit floating point, real */
HIPSPARSELT_R_32F = 151, /**< 32 bit floating point, real */
HIPSPARSELT_R_8I = 160, /**< 8 bit signed integer, real */
HIPSPARSELT_R_16BF = 168, /**< 16 bit bfloat, real */
HIPSPARSELT_R_8F = 170, /**< 8 bit floating point, real */
HIPSPARSELT_R_8BF = 171, /**< 8 bit bfloat, real */
} hipsparseLtDatatype_t;

/*! \ingroup types_module
* \brief Specify the sparsity of the structured matrix.
*
Expand Down
13 changes: 13 additions & 0 deletions library/src/auxiliary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,19 @@
#include "activation.hpp"

// clang-format off
/* @deprecated */
const hipsparseLtDatatype_t string_to_hipsparselt_datatype(const std::string& value)
{
return
value == "f32_r" || value == "s" ? HIPSPARSELT_R_32F :
value == "f16_r" || value == "h" ? HIPSPARSELT_R_16F :
value == "bf16_r" ? HIPSPARSELT_R_16BF :
value == "i8_r" ? HIPSPARSELT_R_8I :
value == "f8_r" ? HIPSPARSELT_R_8F :
value == "bf8_r" ? HIPSPARSELT_R_8BF :
static_cast<hipsparseLtDatatype_t>(-1);
}

const hipDataType string_to_hip_datatype(const std::string& value)
{
return
Expand Down
8 changes: 8 additions & 0 deletions library/src/hcc_detail/rocsparselt/src/include/handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ struct _rocsparselt_mat_descr
, c_k(rhs.c_k)
, c_ld(rhs.c_ld)
, c_n(rhs.c_n)
, is_hipsparselt_datatype(rhs.is_hipsparselt_datatype)
{
is_init = (uintptr_t)handle;
};
Expand Down Expand Up @@ -179,6 +180,9 @@ struct _rocsparselt_mat_descr
int64_t c_ld = -1;

int64_t c_n = -1;

//@TODO This is used to backward support hipsparseLtDatatype_t, should remove in the later version.
bool is_hipsparselt_datatype = false;
};

/********************************************************************************
Expand Down Expand Up @@ -223,6 +227,7 @@ struct _rocsparselt_matmul_descr
, _ldb(rhs._ldb)
, _is_sparse_a(rhs._is_sparse_a)
, _swap_ab(rhs._swap_ab)
, bias_is_hipsparselt_datatype(rhs.bias_is_hipsparselt_datatype)
{
matrix_A = rhs.matrix_A->clone();
matrix_B = rhs.matrix_B->clone();
Expand Down Expand Up @@ -294,6 +299,9 @@ struct _rocsparselt_matmul_descr
bool _is_sparse_a = true;
bool _swap_ab = false;

//@TODO This is used to backward support hipsparseLtDatatype_t, should remove in the later version.
bool bias_is_hipsparselt_datatype = false;

private:
bool is_reference = true;
uintptr_t is_init = 0;
Expand Down
152 changes: 137 additions & 15 deletions library/src/hcc_detail/rocsparselt/src/rocsparselt_auxiliary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,62 @@
extern "C" {
#endif

//@TODO This is used to backward support hipsparseLtDatatype_t, should be deprecated in the later version.
hipDataType HIPSparseLtDatatypeToHipDatatype(hipsparseLtDatatype_t type)
{
switch(type)
{
case HIPSPARSELT_R_32F:
return HIP_R_32F;

case HIPSPARSELT_R_16BF:
return HIP_R_16BF;

case HIPSPARSELT_R_16F:
return HIP_R_16F;

case HIPSPARSELT_R_8I:
return HIP_R_8I;

case HIPSPARSELT_R_8F:
return HIP_R_8F_E4M3_FNUZ;

case HIPSPARSELT_R_8BF:
return HIP_R_8F_E5M2_FNUZ;

default:
throw HIPSPARSE_STATUS_NOT_SUPPORTED;
}
}

//@TODO This is used to backward support hipsparseLtDatatype_t, should be deprecated in the later version.
hipsparseLtDatatype_t HipDatatypeToHIPSparseLtDatatype(hipDataType type)
{
switch(type)
{
case HIP_R_32F:
return HIPSPARSELT_R_32F;

case HIP_R_16BF:
return HIPSPARSELT_R_16BF;

case HIP_R_16F:
return HIPSPARSELT_R_16F;

case HIP_R_8I:
return HIPSPARSELT_R_8I;

case HIP_R_8F_E4M3_FNUZ:
return HIPSPARSELT_R_8F;

case HIP_R_8F_E5M2_FNUZ:
return HIPSPARSELT_R_8BF;

default:
throw HIPSPARSE_STATUS_NOT_SUPPORTED;
}
}

/********************************************************************************
* \brief rocsparselt_handle is a structure holding the rocsparselt library context.
* It must be initialized using rocsparselt_init()
Expand Down Expand Up @@ -148,15 +204,25 @@ rocsparselt_status rocsparselt_dense_descr_init(const rocsparselt_handle* handle
// Allocate
try
{
/*
* @TODO
* This is used to backward support hipsparseLtDatatype_t, will be deprecated in the later version.
*/
hipDataType vtype_;
bool is_hipsparselt_datatype = false;
try
{
vtype_ = HIPSparseLtDatatypeToHipDatatype(
static_cast<hipsparseLtDatatype_t>(valueType));
is_hipsparselt_datatype = true;
}
catch(...)
{
vtype_ = valueType;
}

auto status = validateMatrixArgs(_handle,
rows,
cols,
ld,
alignment,
valueType,
order,
rocsparselt_matrix_type_dense);
auto status = validateMatrixArgs(
_handle, rows, cols, ld, alignment, vtype_, order, rocsparselt_matrix_type_dense);
if(status != rocsparselt_status_success)
throw status;

Expand All @@ -168,10 +234,11 @@ rocsparselt_status rocsparselt_dense_descr_init(const rocsparselt_handle* handle
_matDescr->n = cols;
_matDescr->ld = ld;
_matDescr->alignment = alignment;
_matDescr->type = valueType;
_matDescr->type = vtype_;
_matDescr->order = order;
_matDescr->num_batches = 1;
_matDescr->batch_stride = order == rocsparselt_order_column ? cols * ld : rows * ld;
_matDescr->is_hipsparselt_datatype = is_hipsparselt_datatype;
log_api(_handle,
__func__,
"_matDescr[out]",
Expand All @@ -185,7 +252,7 @@ rocsparselt_status rocsparselt_dense_descr_init(const rocsparselt_handle* handle
"alignment[in]",
alignment,
"valueType[in]",
hipDataType_to_string(valueType),
hipDataType_to_string(vtype_),
"order[in]",
rocsparselt_order_to_string(order));
}
Expand Down Expand Up @@ -239,12 +306,30 @@ rocsparselt_status rocsparselt_structured_descr_init(const rocsparselt_handle* h
// Allocate
try
{

/*
* @TODO
* This is used to backward support hipsparseLtDatatype_t, will be deprecated in the later version.
*/
hipDataType vtype_;
bool is_hipsparselt_datatype = false;
try
{
vtype_ = HIPSparseLtDatatypeToHipDatatype(
static_cast<hipsparseLtDatatype_t>(valueType));
is_hipsparselt_datatype = true;
}
catch(...)
{
vtype_ = valueType;
}

auto status = validateMatrixArgs(_handle,
rows,
cols,
ld,
alignment,
valueType,
vtype_,
order,
rocsparselt_matrix_type_structured);
if(status != rocsparselt_status_success)
Expand All @@ -258,11 +343,12 @@ rocsparselt_status rocsparselt_structured_descr_init(const rocsparselt_handle* h
_matDescr->n = cols;
_matDescr->ld = ld;
_matDescr->alignment = alignment;
_matDescr->type = valueType;
_matDescr->type = vtype_;
_matDescr->order = order;
_matDescr->sparsity = sparsity;
_matDescr->num_batches = 1;
_matDescr->batch_stride = order == rocsparselt_order_column ? cols * ld : rows * ld;
_matDescr->is_hipsparselt_datatype = is_hipsparselt_datatype;
log_api(_handle,
__func__,
"_matDescr[out]",
Expand All @@ -276,7 +362,7 @@ rocsparselt_status rocsparselt_structured_descr_init(const rocsparselt_handle* h
"alignment[in]",
alignment,
"valueType[in]",
hipDataType_to_string(valueType),
hipDataType_to_string(vtype_),
"order[in]",
rocsparselt_order_to_string(order),
"sparsity[in]",
Expand Down Expand Up @@ -726,6 +812,8 @@ rocsparselt_status rocsparselt_matmul_descr_init(const rocsparselt_handle* ha
break;
}

_matmulDescr->bias_is_hipsparselt_datatype = _matA->is_hipsparselt_datatype;

_matmulDescr->_op_A = _matmulDescr->op_A;
if(_matA->order != _matC->order)
_matmulDescr->_op_A = _matmulDescr->op_A == rocsparselt_operation_none
Expand Down Expand Up @@ -931,7 +1019,23 @@ rocsparselt_status
}
case rocsparselt_matmul_bias_type:
{
assign_data(&_matmulDescr->bias_type);
/*
* @TODO
* This is used to backward support hipsparseLtDatatype_t, will be deprecated in the later version.
*/
hipDataType vtype_;
assign_data(&vtype_);
try
{
_matmulDescr->bias_type = HIPSparseLtDatatypeToHipDatatype(
static_cast<hipsparseLtDatatype_t>(vtype_));
_matmulDescr->bias_is_hipsparselt_datatype = true;
}
catch(...)
{
_matmulDescr->bias_type = vtype_;
_matmulDescr->bias_is_hipsparselt_datatype = false;
}
break;
}
default:
Expand Down Expand Up @@ -1070,7 +1174,25 @@ rocsparselt_status
break;
case rocsparselt_matmul_bias_type:
{
retrive_data(_matmulDescr->bias_type);
/*
* @TODO
* This is used to backward support hipsparseLtDatatype_t, will be deprecated in the later version.
*/
if(_matmulDescr->bias_is_hipsparselt_datatype)
{
try
{
hipsparseLtDatatype_t bias_type_
= HipDatatypeToHIPSparseLtDatatype(_matmulDescr->bias_type);
retrive_data(bias_type_);
}
catch(...)
{
retrive_data(_matmulDescr->bias_type);
}
}
else
retrive_data(_matmulDescr->bias_type);
break;
}
default:
Expand Down
28 changes: 28 additions & 0 deletions library/src/include/auxiliary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,29 @@ constexpr hipsparseOrder_t char_to_hipsparselt_order(char value)
}
}

// return precision string for hipsparseLtDatatype_t
/* @deprecated */
HIPSPARSELT_EXPORT
constexpr const char* hipsparselt_datatype_to_string(hipsparseLtDatatype_t type)
{
switch(type)
{
case HIPSPARSELT_R_32F:
return "f32_r";
case HIPSPARSELT_R_16F:
return "f16_r";
case HIPSPARSELT_R_16BF:
return "bf16_r";
case HIPSPARSELT_R_8I:
return "i8_r";
case HIPSPARSELT_R_8F:
return "f8_r";
case HIPSPARSELT_R_8BF:
return "bf8_r";
}
return "invalid";
}

// return precision string for hipDataType
HIPSPARSELT_EXPORT
constexpr const char* hip_datatype_to_string(hipDataType type)
Expand Down Expand Up @@ -151,6 +174,11 @@ constexpr const char* hipsparselt_computetype_to_string(hipsparseLtComputetype_t
}

// clang-format off

/* @deprecated */
HIPSPARSELT_EXPORT
const hipsparseLtDatatype_t string_to_hipsparselt_datatype(const std::string& value);

HIPSPARSELT_EXPORT
const hipDataType string_to_hip_datatype(const std::string& value);

Expand Down
11 changes: 11 additions & 0 deletions library/src/include/hipsparselt_ostream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,17 @@ class HIPSPARSELT_EXPORT hipsparselt_internal_ostream
return os << s.c_str();
}

/*
* @deprecated
* hipsparseLtDatatype_t output
*/
friend hipsparselt_internal_ostream& operator<<(hipsparselt_internal_ostream& os,
hipsparseLtDatatype_t d)
{
os.m_os << hipsparselt_datatype_to_string(d);
return os;
}

// hipDataType output
friend hipsparselt_internal_ostream& operator<<(hipsparselt_internal_ostream& os, hipDataType d)
{
Expand Down
Loading

0 comments on commit a6bf6f8

Please sign in to comment.