Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCLomatic] Enable migration for CUBLASLT_EPILOGUE_DGELU & EPILOGUE_BGRADB #2449

Open
wants to merge 9 commits into
base: SYCLomatic
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions clang/lib/DPCT/MapNames.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2026,6 +2026,9 @@ void MapNames::setExplicitNamespaceMap(
{"CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER",
getLibraryHelperNamespace() + "blas_gemm::experimental::matmul_desc_t::"
"attribute::epilogue_aux_pointer"},
{"CUBLASLT_EPILOGUE_DGELU",
getLibraryHelperNamespace() + "blas_gemm::experimental::matmul_desc_t::"
"attribute::dgelu_epilogue"},
{"CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET",
getLibraryHelperNamespace() +
"blas_gemm::experimental::matmul_desc_t::attribute::unsupport"},
Expand Down
38 changes: 19 additions & 19 deletions clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,19 @@ using matmul_desc_ptr = matmul_desc_t *;
class transform_desc_t;
using transform_desc_ptr = transform_desc_t *;

template <typename primitive_type, typename... args_type>
abhilash1910 marked this conversation as resolved.
Show resolved Hide resolved
inline
typename primitive_type::primitive_desc dgelu_epilogue_(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function name looks like a type name, please refine it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

float alpha, float beta) {

auto alg = ::dnnl::algorithm::eltwise_gelu_erf;
const memory_desc_ext &dst_desc = new memory_desc_ext();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where to "delete" dst_desc and src_desc?

const memory_desc_ext &src_desc = new memory_desc_ext();
return create_primitive_desc<primitive_type>(
::dnnl::prop_kind::backward, alg, src_desc.get_desc(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For src_desc and dst_desc, is it OK to pass to the create_primitive_desc directly?
No need to fill data?

dst_desc.get_desc(), alpha, beta);
}

class descriptor {
public:
descriptor() {}
Expand All @@ -62,15 +75,7 @@ class descriptor {

class matrix_layout_t {
public:
enum class attribute {
type,
order,
rows,
cols,
ld,
batch_count,
abhilash1910 marked this conversation as resolved.
Show resolved Hide resolved
strided_batch_offset
};
enum class attribute { type, order, rows, cols, ld };

matrix_layout_t(library_data_t type, std::uint64_t rows, std::uint64_t cols,
std::int64_t ld)
Expand Down Expand Up @@ -99,8 +104,6 @@ class matrix_layout_t {
CASE(rows)
CASE(cols)
CASE(ld)
CASE(batch_count)
abhilash1910 marked this conversation as resolved.
Show resolved Hide resolved
CASE(strided_batch_offset)
}
#undef CASE
}
Expand All @@ -110,8 +113,6 @@ class matrix_layout_t {
std::uint64_t _rows;
std::uint64_t _cols;
std::int64_t _ld;
std::uint64_t _batch_count = 1;
abhilash1910 marked this conversation as resolved.
Show resolved Hide resolved
std::uint64_t _strided_batch_offset = 0;

friend sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr computeDesc,
const void *alpha, const void *a,
Expand All @@ -137,8 +138,7 @@ class matmul_desc_t {
trans_b,
trans_c,
epilogue,
epilogue_aux_ld,
abhilash1910 marked this conversation as resolved.
Show resolved Hide resolved
epilogue_aux_pointer,
dgelu_epilogue
a_scale_pointer,
b_scale_pointer,
d_scale_pointer,
Expand Down Expand Up @@ -180,8 +180,7 @@ class matmul_desc_t {
CASE(b_scale_pointer)
CASE(d_scale_pointer)
CASE(absmax_d_pointer)
CASE(epilogue_aux_ld)
abhilash1910 marked this conversation as resolved.
Show resolved Hide resolved
CASE(epilogue_aux_pointer)
CASE(dgelu_epilogue)
default:
break;
}
Expand All @@ -195,12 +194,12 @@ class matmul_desc_t {
oneapi::mkl::transpose _trans_b = oneapi::mkl::transpose::nontrans;
oneapi::mkl::transpose _trans_c = oneapi::mkl::transpose::nontrans;
epilogue_t _epilogue = epilogue_t::nop;
size_t _epilogue_aux_ld = 0;
abhilash1910 marked this conversation as resolved.
Show resolved Hide resolved
void *_a_scale_pointer = nullptr;
void *_b_scale_pointer = nullptr;
void *_d_scale_pointer = nullptr;
void *_absmax_d_pointer = nullptr;
void *_epilogue_aux_pointer = nullptr;
abhilash1910 marked this conversation as resolved.
Show resolved Hide resolved
auto *_dgelu_epilogue = dgelu_epilogue_<::dnnl::eltwise_backward>(0.f, 0.f);


friend sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr computeDesc,
const void *alpha, const void *a,
Expand Down Expand Up @@ -732,6 +731,7 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
const void *c, matrix_layout_ptr c_desc, void *d,
matrix_layout_ptr d_desc,
::dpct::cs::queue_ptr q_ptr) {

const size_t m = compute_desc->_trans_a == oneapi::mkl::transpose::nontrans
? a_desc->_rows
: a_desc->_cols;
Expand Down
2 changes: 2 additions & 0 deletions clang/test/dpct/cublaslt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ void foo3() {
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::epilogue;
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::epilogue_aux_ld;
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::epilogue_aux_pointer;
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::dgelu_epilogue;
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::unsupport;
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::unsupport;
// CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::a_scale_pointer;
Expand All @@ -214,6 +215,7 @@ void foo3() {
d = CUBLASLT_MATMUL_DESC_EPILOGUE;
d = CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD;
d = CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER;
d = CUBLASLT_EPILOGUE_DGELU;
d = CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET;
d = CUBLASLT_MATMUL_DESC_FAST_ACCUM;
d = CUBLASLT_MATMUL_DESC_A_SCALE_POINTER;
Expand Down
Loading