Skip to content

Commit

Permalink
gpu: nvidia: Add support for cublaslt matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
ShanoToni authored and mgouicem committed Sep 30, 2024
1 parent ec24f24 commit eb146c4
Show file tree
Hide file tree
Showing 38 changed files with 2,971 additions and 435 deletions.
48 changes: 48 additions & 0 deletions cmake/FindcublasLt.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# ===============================================================================
# Copyright 2024 Intel Corporation
# Copyright 2024 Codeplay Software Limited
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
# ===============================================================================

find_package(CUDA 10.0 REQUIRED)
find_package(Threads REQUIRED)

find_path(
CUBLASLT_INCLUDE_DIR "cublasLt.h"
HINTS ${CUDA_TOOLKIT_ROOT_DIR}
PATH_SUFFIXES include)

find_library(CUDA_DRIVER_LIBRARY cuda)

find_library(
CUBLAS_LIBRARY cublasLt
HINTS ${CUDA_TOOLKIT_ROOT_DIR}
PATH_SUFFIXES lib lib64 bin)

include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(
cublasLt REQUIRED_VARS CUBLASLT_INCLUDE_DIR CUDA_INCLUDE_DIRS CUBLAS_LIBRARY
CUDA_LIBRARIES CUDA_DRIVER_LIBRARY)

if(NOT TARGET cublasLt::cublasLt)
add_library(cublasLt::cublasLt SHARED IMPORTED)
set_target_properties(
cublasLt::cublasLt
PROPERTIES IMPORTED_LOCATION ${CUBLAS_LIBRARY}
INTERFACE_INCLUDE_DIRECTORIES
"${CUBLASLT_INCLUDE_DIR};${CUDA_INCLUDE_DIRS}"
INTERFACE_LINK_LIBRARIES
"Threads::Threads;${CUDA_DRIVER_LIBRARY};${CUDA_LIBRARIES}"
INTERFACE_COMPILE_DEFINITIONS CUDA_NO_HALF)
endif()
3 changes: 2 additions & 1 deletion cmake/SYCL.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,10 @@ endmacro()
if(DNNL_SYCL_CUDA)
suppress_warnings_for_nvidia_target()
find_package(cuBLAS REQUIRED)
find_package(cublasLt REQUIRED)
find_package(cuDNN REQUIRED)

adjust_headers_priority("cuBLAS::cuBLAS;cuDNN::cuDNN")
adjust_headers_priority("cuBLAS::cuBLAS;cuDNN::cuDNN;cublasLt::cublasLt")
add_definitions_with_host_compiler("-DCUDA_NO_HALF")

list(APPEND EXTRA_SHARED_LIBS cuBLAS::cuBLAS cuDNN::cuDNN)
Expand Down
1 change: 1 addition & 0 deletions include/oneapi/dnnl/dnnl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1215,6 +1215,7 @@ struct memory : public handle<dnnl_memory_t> {
AB16b64a2b = dnnl_AB16b64a2b,
Ab4a = dnnl_Ab4a,
Ab8a = dnnl_Ab8a,
Ab32a = dnnl_Ab32a,
Abc16a = dnnl_Abc16a,
ABc16a16b = dnnl_ABc16a16b,
ABc4a4b = dnnl_ABc4a4b,
Expand Down
1 change: 1 addition & 0 deletions include/oneapi/dnnl/dnnl_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,7 @@ typedef enum {
dnnl_bcad,
dnnl_cabd,
dnnl_dabc,
dnnl_Ab32a,

/// Just a sentinel, not real memory format tag. Must be changed after new
/// format tag is added.
Expand Down
2 changes: 2 additions & 0 deletions src/common/c_types_map.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ const format_kind_t sparse = static_cast<format_kind_t>(4);
const format_kind_t internal_only_start = (format_kind_t)(1 << 8);
const format_kind_t wino = internal_only_start;
const format_kind_t rnn_packed = (format_kind_t)(internal_only_start + 1);
const format_kind_t cublaslt_blocked = (format_kind_t)(internal_only_start + 2);
} // namespace format_kind

#ifdef DNNL_EXPERIMENTAL_PROFILING
Expand Down Expand Up @@ -371,6 +372,7 @@ const format_tag_t aCB16b64c4b = dnnl_aCB16b64c4b;

const format_tag_t Ab4a = dnnl_Ab4a;
const format_tag_t Ab8a = dnnl_Ab8a;
const format_tag_t Ab32a = dnnl_Ab32a;
const format_tag_t Abc16a = dnnl_Abc16a;
const format_tag_t ABc16a16b = dnnl_ABc16a16b;
const format_tag_t ABc4a2b = dnnl_ABc4a2b;
Expand Down
4 changes: 3 additions & 1 deletion src/common/dnnl_debug.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ const char *dnnl_fmt_kind2str(dnnl_format_kind_t v) {
#ifdef DNNL_EXPERIMENTAL_SPARSE
if (v == dnnl_format_kind_sparse) return "sparse";
#endif
if (v == format_kind::wino || v == format_kind::rnn_packed) return "opaque";
if (v == format_kind::wino || v == format_kind::rnn_packed
|| v == format_kind::cublaslt_blocked)
return "opaque";
if (v == dnnl_format_kind_max) return "max";
assert(!"unknown fmt_kind");
return "unknown fmt_kind";
Expand Down
1 change: 1 addition & 0 deletions src/common/dnnl_debug_autogenerated.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,7 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) {
if (v == dnnl_AcdeB4b8a4b) return "AcdeB4b8a4b";
if (v == dnnl_Ab4a) return "Ab4a";
if (v == dnnl_Ab8a) return "Ab8a";
if (v == dnnl_Ab32a) return "Ab32a";
if (v == dnnl_BA4b4a) return "BA4b4a";
if (v == dnnl_BA8b4a) return "BA8b4a";
if (v == dnnl_BA2a24b) return "BA2a24b";
Expand Down
3 changes: 2 additions & 1 deletion src/common/memory_desc.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2022-2024 Intel Corporation
* Copyright 2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -722,6 +722,7 @@ status_t dnnl_memory_desc_query(
case query::format_kind:
switch ((int)md->format_kind) {
case format_kind::rnn_packed:
case format_kind::cublaslt_blocked:
case format_kind::wino:
*(format_kind_t *)result = format_kind::opaque;
break;
Expand Down
11 changes: 10 additions & 1 deletion src/common/memory_desc.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2022-2023 Intel Corporation
* Copyright 2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -23,6 +23,8 @@
namespace dnnl {
namespace impl {

enum class cublaslt_memory_format_t { col32_2r_4r4 };

// Winograd-specific formats
enum class wino_memory_format_t {
// Undefined memory format, used for empty memory descriptors.
Expand Down Expand Up @@ -135,6 +137,11 @@ struct rnn_packed_desc_t {
size_t size;
};

struct cublaslt_blocked_desc_t {
cublaslt_memory_format_t cublaslt_format;
size_t size;
};

struct sparse_desc_t {
static constexpr int max_metadata_types = 2;
// Each encoding defines the number of handles it requires and their
Expand Down Expand Up @@ -289,6 +296,8 @@ struct dnnl_memory_desc : public dnnl::impl::c_compatible {
dnnl::impl::wino_desc_t wino_desc;
// Tensor of packed weights for RNN.
dnnl::impl::rnn_packed_desc_t rnn_packed_desc;
// Description of the data layout for memory formats used in cublasLt IMMA kernels.
dnnl::impl::cublaslt_blocked_desc_t cublaslt_blocked_desc;
// Description of the sparse encodings.
dnnl::impl::sparse_desc_t sparse_desc;
// ... other descriptions possible
Expand Down
1 change: 1 addition & 0 deletions src/common/memory_desc_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ status_t memory_desc_wrapper::compute_blocking(

C(Ab4a, {0, 1}, {4}, {0});
C(Ab8a, {0, 1}, {8}, {0});
C(Ab32a, {0, 1}, {32}, {0});

C(BA4b4a, {1, 0}, {4, 4}, {1, 0});
C(BA8b4a, {1, 0}, {8, 4}, {1, 0});
Expand Down
15 changes: 13 additions & 2 deletions src/common/memory_desc_wrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ struct memory_desc_wrapper : public c_compatible {
bool is_rnn_packed_desc() const {
return format_kind() == format_kind::rnn_packed;
}
bool is_cublaslt_blocked_desc() const {
return format_kind() == format_kind::cublaslt_blocked;
}
bool is_sparse_desc() const { return format_kind() == format_kind::sparse; }

const blocking_desc_t &blocking_desc() const {
Expand All @@ -82,6 +85,10 @@ struct memory_desc_wrapper : public c_compatible {
assert(is_rnn_packed_desc());
return md_->format_desc.rnn_packed_desc;
}
const cublaslt_blocked_desc_t &cublaslt_blocked_desc() const {
assert(is_cublaslt_blocked_desc());
return md_->format_desc.cublaslt_blocked_desc;
}

const sparse_desc_t &sparse_desc() const {
assert(is_sparse_desc());
Expand Down Expand Up @@ -224,7 +231,8 @@ struct memory_desc_wrapper : public c_compatible {
return 0;

if (utils::one_of(format_kind(), format_kind::blocked,
format_kind::wino, format_kind::rnn_packed)
format_kind::wino, format_kind::rnn_packed,
format_kind::cublaslt_blocked)
&& index != 0) {
return 0;
}
Expand All @@ -235,6 +243,8 @@ struct memory_desc_wrapper : public c_compatible {
return wino_desc().size;
} else if (is_rnn_packed_desc()) {
return rnn_packed_desc().size;
} else if (is_cublaslt_blocked_desc()) {
return cublaslt_blocked_desc().size;
} else if (is_blocking_desc()) {
if (offset0() != 0) return 0;

Expand Down Expand Up @@ -581,7 +591,8 @@ inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs,

if (one_of(format_kind(), format_kind::undef, format_kind::any))
return false;
if (is_wino_desc() || is_rnn_packed_desc()) return false;
if (is_wino_desc() || is_rnn_packed_desc() || is_cublaslt_blocked_desc())
return false;

const int ds = dim_start;
const auto &blk = blocking_desc();
Expand Down
7 changes: 7 additions & 0 deletions src/common/memory_tracking.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,10 +253,14 @@ enum {
key_lnorm_tmp_diff_ss,
key_lnorm_reduction,
key_matmul_dst_in_acc_dt,
key_matmul_lt_algo_scratch,
key_matmul_lt_block_c,
key_matmul_src_trans,
key_matmul_wei_trans,
key_matmul_dst_trans,
key_matmul_dst_cast_acc,
key_matmul_lt_src_scale,
key_matmul_lt_wei_scale,
key_matmul_sparse_tmp_ptr,
key_pool_dst_bf16cvt,
key_pool_dst_plain2blocked_cvt,
Expand All @@ -282,6 +286,9 @@ enum {
key_reorder_rnn_weights_reduction,
key_reorder_rnn_weights_transposition,
key_reorder_rnn_weights_xf16_cvt,
key_reorder_cublaslt_src_float,
key_reorder_cublaslt_dst_float,
key_reorder_cublaslt_generic,
key_rnn_space,
key_rnn_bf32_attention_trans,
key_rnn_bf32_wei_layer_trans,
Expand Down
7 changes: 7 additions & 0 deletions src/common/primitive_hashing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,13 @@ size_t get_md_hash(const memory_desc_t &md) {
seed = hash_combine(seed, md.format_desc.wino_desc.adj_scale);
seed = hash_combine(seed, md.format_desc.wino_desc.size);
break;
case format_kind::cublaslt_blocked:
seed = hash_combine(seed,
static_cast<size_t>(md.format_desc.cublaslt_blocked_desc
.cublaslt_format));
seed = hash_combine(
seed, (md.format_desc.cublaslt_blocked_desc.size));
break;
case format_kind::rnn_packed:
seed = hash_combine(seed,
static_cast<size_t>(md.format_desc.rnn_packed_desc.format));
Expand Down
4 changes: 4 additions & 0 deletions src/common/reorder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ status_t reorder_primitive_desc_create(std::shared_ptr<primitive_desc_t> &pd,
engine_t *engine, const memory_desc_t *src_md,
const memory_desc_t *dst_md, const primitive_attr_t *attr = nullptr);

status_t reorder_primitive_desc_create(std::shared_ptr<primitive_desc_t> &pd,
engine_t *engine, const memory_desc_t *src_md, engine_t *src_engine,
const memory_desc_t *dst_md, engine_t *dst_engine,
const primitive_attr_t *attr = nullptr);
} // namespace impl
} // namespace dnnl

Expand Down
5 changes: 5 additions & 0 deletions src/common/serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ void serialize_md(serialization_stream_t &sstream, const memory_desc_t &md) {
sstream.write(&md.format_desc.wino_desc.adj_scale);
sstream.write(&md.format_desc.wino_desc.size);
break;
case format_kind::cublaslt_blocked:
sstream.write(
&md.format_desc.cublaslt_blocked_desc.cublaslt_format);
sstream.write(&md.format_desc.cublaslt_blocked_desc.size);
break;
case format_kind::rnn_packed:
sstream.write(&md.format_desc.rnn_packed_desc.format);
sstream.write(&md.format_desc.rnn_packed_desc.n_parts);
Expand Down
4 changes: 4 additions & 0 deletions src/common/type_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,10 @@ inline bool wino_desc_is_equal(const wino_desc_t &lhs, const wino_desc_t &rhs) {
&& lhs.ic2_block == rhs.ic2_block && lhs.oc2_block == rhs.oc2_block
&& lhs.r == rhs.r;
}
inline bool cublaslt_blocked_desc_is_equal(const cublaslt_blocked_desc_t &lhs,
const cublaslt_blocked_desc_t &rhs) {
return lhs.cublaslt_format == rhs.cublaslt_format && lhs.size == rhs.size;
}

inline bool rnn_packed_desc_is_equal(
const rnn_packed_desc_t &lhs, const rnn_packed_desc_t &rhs) {
Expand Down
9 changes: 9 additions & 0 deletions src/common/verbose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,14 @@ std::string rnn_flags2str(unsigned flags) {
return s;
}

std::string cublasltfmt2str(const memory_desc_t *md) {
if (md->format_desc.cublaslt_blocked_desc.cublaslt_format
== cublaslt_memory_format_t::col32_2r_4r4) {
return ":col32_2r_4r4";
}
return "";
}

std::ostream &operator<<(std::ostream &ss, const memory_extra_desc_t &extra) {
using namespace memory_extra_flags;

Expand Down Expand Up @@ -512,6 +520,7 @@ std::string md2fmt_str(
case format_kind::blocked:
ss << ":" << md2fmt_tag_str(md) << ":" << md2fmt_strides_str(md);
break;
case format_kind::cublaslt_blocked: ss << cublasltfmt2str(md); break;
case format_kind::wino:
case format_kind::rnn_packed:
case format_kind::opaque: ss << "::"; break;
Expand Down
18 changes: 14 additions & 4 deletions src/gpu/generic/sycl/binary_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,13 @@ struct binary_kernel_vec_t {
== src1_mem.md().strides()[i]);
}
}
if (!any_broadcast && conf_.post_ops.get_post_op() == 0

const bool is_blocked_fmt = conf_.src0_md.inner_nblks() > 0
|| conf_.src1_md.inner_nblks() > 0
|| conf_.dst_md.inner_nblks() > 0;

if (!any_broadcast && !is_blocked_fmt
&& conf_.post_ops.get_post_op() == 0
&& sg_base_idx + (sg.get_local_range()[0] * conf_.block_size)
< conf_.wk_size
&& is_same_tag) {
Expand All @@ -114,8 +120,12 @@ struct binary_kernel_vec_t {
for (int i = 0; i < conf_.block_size; i++) {
int idx = base_idx + i;
if (idx < conf_.wk_size) {
for (int i = 0; i < max_supported_ndims; i++) {
off_dst[i] = idx / strides[i] % dims[i];
auto l_offset = idx;
for (int i = 0; i < conf_.ndims; i++) {
const int d = conf_.ndims - 1 - i;
const dim_t cur_dim = conf_.dst_md.dims()[d];
off_dst[d] = l_offset % cur_dim;
l_offset = l_offset / cur_dim;
}

for (int i = 0; i < max_supported_ndims; i++) {
Expand All @@ -133,7 +143,7 @@ struct binary_kernel_vec_t {

acc = conf_.post_ops.apply(
acc, dst_, idx, po_args_, off_dst);
dst_mem.store(acc, idx);
dst_mem.store_md(acc, off_dst);
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/gpu/generic/sycl/ref_binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ status_t ref_binary_t::init(impl::engine_t *engine) {

status_t ref_binary_t::execute(const exec_ctx_t &ctx) const {

ctx.zero_pad_output(DNNL_ARG_TO);

parallel_for(ctx, kernel_, [&](::sycl::handler &cgh) {
binary_kernel_vec_t binary_kernel(pd()->conf_, cgh, ctx);

Expand Down
5 changes: 4 additions & 1 deletion src/gpu/generic/sycl/ref_binary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,11 @@ struct ref_binary_t : public gpu::generic::sycl::primitive_t {
using namespace format_tag;

for (const auto &mdw : {src0, src1, dst}) {
if (!mdw.is_plain()) { return false; }
if (!(mdw.is_plain() || mdw.matches_tag(format_tag::Ab32a)
|| mdw.matches_tag(format_tag::aBc32b)))
return false;
}

return true;
}
};
Expand Down
Loading

0 comments on commit eb146c4

Please sign in to comment.