-
Notifications
You must be signed in to change notification settings - Fork 23
Add new blas extension and update dpnp.matmul func #1616
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 6 commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
3444816
Add new blas extension and update matmul impl
vlad-perevezentsev b8f7f00
Add support for N-D array
c74e884
support more special cases + add new tests
31dbb36
fix random behavior on cpu
ae15bd0
Merge branch 'master' into update_matmul
bf3ba86
correct dtypes + support more keywords
cf89931
add strided support
2f61f94
Merge branch 'master' into update_matmul
d45cb4a
check input arrays
7b8d29b
address comments - first round
57bb9a3
Merge branch 'master' into update_matmul
vtavana df89d23
address comments - second round
8f2ea97
Merge branch 'master' into update_matmul
vtavana bea7f60
address comments - third round
5633c33
Merge branch 'master' into update_matmul
vtavana ff2b0d4
fix pre-commit
f522647
improve test coverage
07a5e52
address comments
28f9c27
Merge branch 'master' into update_matmul
vtavana 0955e31
update _gemm_res_dtype func
f1a30aa
fix a test for result_type
1a45c68
fix minor issues
566f154
Merge branch 'master' into update_matmul
vtavana 0bc7172
skip tests for matmul
1e6fe9e
Merge branch 'master' into update_matmul
antonwolfy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# ***************************************************************************** | ||
# Copyright (c) 2016-2023, Intel Corporation | ||
# All rights reserved. | ||
# | ||
# Redistribution and use in source and binary forms, with or without | ||
# modification, are permitted provided that the following conditions are met: | ||
# - Redistributions of source code must retain the above copyright notice, | ||
# this list of conditions and the following disclaimer. | ||
# - Redistributions in binary form must reproduce the above copyright notice, | ||
# this list of conditions and the following disclaimer in the documentation | ||
# and/or other materials provided with the distribution. | ||
# | ||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | ||
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE | ||
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | ||
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | ||
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | ||
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | ||
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | ||
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF | ||
# THE POSSIBILITY OF SUCH DAMAGE. | ||
# ***************************************************************************** | ||
|
||
|
||
set(python_module_name _blas_impl) | ||
set(_module_src | ||
${CMAKE_CURRENT_SOURCE_DIR}/blas_py.cpp | ||
${CMAKE_CURRENT_SOURCE_DIR}/gemm.cpp | ||
${CMAKE_CURRENT_SOURCE_DIR}/gemm_batch.cpp | ||
) | ||
|
||
pybind11_add_module(${python_module_name} MODULE ${_module_src}) | ||
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_module_src}) | ||
|
||
if (WIN32) | ||
if (${CMAKE_VERSION} VERSION_LESS "3.27") | ||
# this is a work-around for target_link_options inserting option after -link option, cause | ||
# linker to ignore it. | ||
set(CMAKE_CXX_LINK_FLAGS "${CMAKE_CXX_LINK_FLAGS} -fsycl-device-code-split=per_kernel") | ||
endif() | ||
endif() | ||
|
||
set_target_properties(${python_module_name} PROPERTIES CMAKE_POSITION_INDEPENDENT_CODE ON) | ||
|
||
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../include) | ||
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../src) | ||
|
||
target_include_directories(${python_module_name} PUBLIC ${Dpctl_INCLUDE_DIRS}) | ||
target_include_directories(${python_module_name} PUBLIC ${Dpctl_TENSOR_INCLUDE_DIR}) | ||
|
||
if (WIN32) | ||
target_compile_options(${python_module_name} PRIVATE | ||
/clang:-fno-approx-func | ||
/clang:-fno-finite-math-only | ||
) | ||
else() | ||
target_compile_options(${python_module_name} PRIVATE | ||
-fno-approx-func | ||
-fno-finite-math-only | ||
) | ||
endif() | ||
|
||
target_link_options(${python_module_name} PUBLIC -fsycl-device-code-split=per_kernel) | ||
if (UNIX) | ||
# this option is support on Linux only | ||
target_link_options(${python_module_name} PUBLIC -fsycl-link-huge-device-code) | ||
endif() | ||
|
||
if (DPNP_GENERATE_COVERAGE) | ||
target_link_options(${python_module_name} PRIVATE -fprofile-instr-generate -fcoverage-mapping) | ||
endif() | ||
|
||
if (MKL_VERSION_2024) | ||
target_link_libraries(${python_module_name} PUBLIC MKL::MKL_SYCL::BLAS) | ||
else() | ||
target_link_libraries(${python_module_name} PUBLIC MKL::MKL_DPCPP) | ||
endif() | ||
|
||
install(TARGETS ${python_module_name} | ||
DESTINATION "dpnp/backend/extensions/blas" | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
//***************************************************************************** | ||
// Copyright (c) 2023, Intel Corporation | ||
// All rights reserved. | ||
// | ||
// Redistribution and use in source and binary forms, with or without | ||
// modification, are permitted provided that the following conditions are met: | ||
// - Redistributions of source code must retain the above copyright notice, | ||
// this list of conditions and the following disclaimer. | ||
// - Redistributions in binary form must reproduce the above copyright notice, | ||
// this list of conditions and the following disclaimer in the documentation | ||
// and/or other materials provided with the distribution. | ||
// | ||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | ||
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE | ||
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | ||
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | ||
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | ||
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | ||
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | ||
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF | ||
// THE POSSIBILITY OF SUCH DAMAGE. | ||
//***************************************************************************** | ||
// | ||
// This file defines functions of dpnp.backend._lapack_impl extensions | ||
// | ||
//***************************************************************************** | ||
|
||
#include <pybind11/pybind11.h> | ||
#include <pybind11/stl.h> | ||
|
||
#include "gemm.hpp" | ||
|
||
namespace blas_ext = dpnp::backend::ext::blas; | ||
namespace py = pybind11; | ||
|
||
// populate dispatch tables | ||
void init_dispatch_tables(void) | ||
{ | ||
blas_ext::init_gemm_dispatch_table(); | ||
blas_ext::init_gemm_batch_dispatch_table(); | ||
} | ||
|
||
PYBIND11_MODULE(_blas_impl, m) | ||
{ | ||
init_dispatch_tables(); | ||
|
||
{ | ||
m.def("_gemm", &blas_ext::gemm, | ||
"Call `gemm` from OneMKL LAPACK library to return " | ||
"the matrix-matrix product with 2-D matrices.", | ||
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"), | ||
py::arg("matrixC"), py::arg("depends") = py::list()); | ||
vtavana marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
} | ||
|
||
{ | ||
m.def("_gemm_batch", &blas_ext::gemm_batch, | ||
"Call `gemm_batch` from OneMKL LAPACK library to return " | ||
"the matrix-matrix product with general matrices."); | ||
vtavana marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,244 @@ | ||
//***************************************************************************** | ||
// Copyright (c) 2023, Intel Corporation | ||
// All rights reserved. | ||
// | ||
// Redistribution and use in source and binary forms, with or without | ||
// modification, are permitted provided that the following conditions are met: | ||
// - Redistributions of source code must retain the above copyright notice, | ||
// this list of conditions and the following disclaimer. | ||
// - Redistributions in binary form must reproduce the above copyright notice, | ||
// this list of conditions and the following disclaimer in the documentation | ||
// and/or other materials provided with the distribution. | ||
// | ||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | ||
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE | ||
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | ||
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | ||
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | ||
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | ||
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | ||
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF | ||
// THE POSSIBILITY OF SUCH DAMAGE. | ||
//***************************************************************************** | ||
|
||
#include <pybind11/pybind11.h> | ||
|
||
// dpctl tensor headers | ||
#include "utils/memory_overlap.hpp" | ||
#include "utils/type_utils.hpp" | ||
|
||
#include "gemm.hpp" | ||
#include "types_matrix.hpp" | ||
|
||
#include "dpnp_utils.hpp" | ||
|
||
namespace dpnp | ||
{ | ||
namespace backend | ||
{ | ||
namespace ext | ||
{ | ||
namespace blas | ||
{ | ||
namespace mkl_blas = oneapi::mkl::blas; | ||
namespace py = pybind11; | ||
namespace type_utils = dpctl::tensor::type_utils; | ||
|
||
typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue, | ||
oneapi::mkl::transpose, | ||
oneapi::mkl::transpose, | ||
const std::int64_t, | ||
const std::int64_t, | ||
const std::int64_t, | ||
char *, | ||
const std::int64_t, | ||
char *, | ||
const std::int64_t, | ||
char *, | ||
const std::int64_t, | ||
const std::vector<sycl::event> &); | ||
|
||
static gemm_impl_fn_ptr_t gemm_dispatch_table[dpctl_td_ns::num_types] | ||
[dpctl_td_ns::num_types]; | ||
|
||
template <typename Tab, typename Tc> | ||
static sycl::event gemm_impl(sycl::queue exec_q, | ||
oneapi::mkl::transpose transA, | ||
oneapi::mkl::transpose transB, | ||
const std::int64_t m, | ||
const std::int64_t n, | ||
const std::int64_t k, | ||
char *matrixA, | ||
const std::int64_t lda, | ||
char *matrixB, | ||
const std::int64_t ldb, | ||
char *resultC, | ||
const std::int64_t ldc, | ||
const std::vector<sycl::event> &depends) | ||
{ | ||
type_utils::validate_type_for_device<Tab>(exec_q); | ||
type_utils::validate_type_for_device<Tc>(exec_q); | ||
|
||
Tab *a = reinterpret_cast<Tab *>(matrixA); | ||
Tab *b = reinterpret_cast<Tab *>(matrixB); | ||
Tc *res = reinterpret_cast<Tc *>(resultC); | ||
|
||
std::stringstream error_msg; | ||
std::int64_t info = 0; | ||
bool mkl_exception_caught = false; | ||
|
||
sycl::event gemm_event; | ||
try { | ||
gemm_event = mkl_blas::row_major::gemm( | ||
exec_q, | ||
transA, // Parameter indicating whether matrix A is not | ||
// transposed ('N'), transposed ('T'), | ||
// or conjugate transposed ('C'). | ||
vtavana marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
transB, // Same as transA but for matrix B. | ||
m, // Number of rows in matrices A and C. | ||
n, // Number of columns in matrices B and C. | ||
k, // Number of columns in matrix A and rows in matrix B. | ||
Tab(1), // Scaling factor for the product of matrices A and B. | ||
a, // Pointer to matrix A. | ||
lda, // Leading dimension of matrix A, which is the | ||
// stride between successive rows (for row major | ||
// layout). | ||
b, // Pointer to matrix B. | ||
ldb, // Leading dimension of matrix B, similar to lda | ||
vtavana marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
Tab(0), // Scaling factor for matrix C. | ||
res, // Pointer to matrix C, where the result is stored. | ||
ldc, // Leading dimension of matrix C. | ||
depends); | ||
} catch (oneapi::mkl::exception const &e) { | ||
error_msg | ||
<< "Unexpected MKL exception caught during gemm() call:\nreason: " | ||
<< e.what(); | ||
mkl_exception_caught = true; | ||
vtavana marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
} catch (sycl::exception const &e) { | ||
error_msg << "Unexpected SYCL exception caught during gemm() call:\n" | ||
<< e.what(); | ||
info = -1; | ||
} | ||
|
||
if (info != 0 || mkl_exception_caught) // an unexpected error occurs | ||
{ | ||
throw std::runtime_error(error_msg.str()); | ||
} | ||
|
||
return gemm_event; | ||
} | ||
|
||
std::pair<sycl::event, sycl::event> | ||
gemm(sycl::queue exec_q, | ||
dpctl::tensor::usm_ndarray matrixA, | ||
dpctl::tensor::usm_ndarray matrixB, | ||
dpctl::tensor::usm_ndarray resultC, | ||
const std::vector<sycl::event> &depends) | ||
{ | ||
const int matrixA_nd = matrixA.get_ndim(); | ||
const int matrixB_nd = matrixB.get_ndim(); | ||
const int resultC_nd = resultC.get_ndim(); | ||
|
||
if ((matrixA_nd != 2) || (matrixB_nd != 2) || (resultC_nd != 2)) { | ||
throw py::value_error("The input matrices must be of 2 dimensions."); | ||
} | ||
|
||
// check compatibility of execution queue and allocation queue | ||
if (!dpctl::utils::queues_are_compatible( | ||
exec_q, | ||
{matrixA.get_queue(), matrixB.get_queue(), resultC.get_queue()})) | ||
{ | ||
throw std::runtime_error( | ||
"USM allocations are not compatible with the execution queue."); | ||
} | ||
|
||
bool is_matrixA_f_contig = matrixA.is_f_contiguous(); | ||
bool is_matrixB_f_contig = matrixB.is_f_contiguous(); | ||
|
||
const py::ssize_t *a_shape = matrixA.get_shape_raw(); | ||
const py::ssize_t *b_shape = matrixB.get_shape_raw(); | ||
const py::ssize_t *res_shape = resultC.get_shape_raw(); | ||
|
||
if (a_shape[1] != b_shape[0]) { | ||
throw std::runtime_error("The number of columns in A must be equal to " | ||
"the number of rows in B."); | ||
} | ||
|
||
vtavana marked this conversation as resolved.
Show resolved
Hide resolved
|
||
oneapi::mkl::transpose transA = is_matrixA_f_contig | ||
vtavana marked this conversation as resolved.
Show resolved
Hide resolved
|
||
? oneapi::mkl::transpose::T | ||
: oneapi::mkl::transpose::N; | ||
oneapi::mkl::transpose transB = is_matrixB_f_contig | ||
? oneapi::mkl::transpose::T | ||
: oneapi::mkl::transpose::N; | ||
|
||
const std::int64_t m = a_shape[0]; | ||
const std::int64_t n = b_shape[1]; | ||
const std::int64_t k = a_shape[1]; | ||
|
||
const std::int64_t lda = | ||
(transA == oneapi::mkl::transpose::N) ? a_shape[1] : a_shape[0]; | ||
const std::int64_t ldb = | ||
(transB == oneapi::mkl::transpose::N) ? b_shape[1] : b_shape[0]; | ||
const std::int64_t ldc = res_shape[1]; | ||
|
||
int matrixA_typenum = matrixA.get_typenum(); | ||
int matrixB_typenum = matrixB.get_typenum(); | ||
int resultC_typenum = resultC.get_typenum(); | ||
|
||
if (matrixA_typenum != matrixB_typenum) { | ||
throw py::value_error("matrixA and matrixB must be of the same type."); | ||
} | ||
|
||
auto array_types = dpctl_td_ns::usm_ndarray_types(); | ||
int matrixAB_type_id = array_types.typenum_to_lookup_id(matrixA_typenum); | ||
int resultC_type_id = array_types.typenum_to_lookup_id(resultC_typenum); | ||
|
||
gemm_impl_fn_ptr_t gemm_fn = | ||
gemm_dispatch_table[matrixAB_type_id][resultC_type_id]; | ||
if (gemm_fn == nullptr) { | ||
throw py::value_error("Type dispatch ran into trouble."); | ||
vtavana marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
} | ||
|
||
char *a_typeless_ptr = matrixA.get_data(); | ||
char *b_typeless_ptr = matrixB.get_data(); | ||
char *r_typeless_ptr = resultC.get_data(); | ||
|
||
std::vector<sycl::event> host_task_events; | ||
sycl::event gemm_ev = | ||
gemm_fn(exec_q, transA, transB, m, n, k, a_typeless_ptr, lda, | ||
b_typeless_ptr, ldb, r_typeless_ptr, ldc, depends); | ||
|
||
host_task_events.push_back(gemm_ev); | ||
sycl::event args_ev = dpctl::utils::keep_args_alive( | ||
exec_q, {matrixA, matrixB, resultC}, host_task_events); | ||
vtavana marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
return std::make_pair(args_ev, gemm_ev); | ||
} | ||
|
||
template <typename fnT, typename Tab, typename Tc> | ||
struct GemmContigFactory | ||
{ | ||
fnT get() | ||
{ | ||
if constexpr (types::GemmTypePairSupportFactory<Tab, Tc>::is_defined) { | ||
return gemm_impl<Tab, Tc>; | ||
} | ||
else { | ||
return nullptr; | ||
} | ||
} | ||
}; | ||
|
||
void init_gemm_dispatch_table(void) | ||
{ | ||
dpctl_td_ns::DispatchTableBuilder<gemm_impl_fn_ptr_t, GemmContigFactory, | ||
dpctl_td_ns::num_types> | ||
contig; | ||
contig.populate_dispatch_table(gemm_dispatch_table); | ||
} | ||
} // namespace blas | ||
} // namespace ext | ||
} // namespace backend | ||
} // namespace dpnp |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.