diff --git a/dpnp/CMakeLists.txt b/dpnp/CMakeLists.txt index ebaf1d7b0ef..dadfb9d476e 100644 --- a/dpnp/CMakeLists.txt +++ b/dpnp/CMakeLists.txt @@ -56,6 +56,7 @@ endfunction() build_dpnp_cython_ext_with_backend(dparray ${CMAKE_CURRENT_SOURCE_DIR}/dparray.pyx dpnp) add_subdirectory(backend) +add_subdirectory(backend/extensions/blas) add_subdirectory(backend/extensions/lapack) add_subdirectory(backend/extensions/vm) add_subdirectory(backend/extensions/sycl_ext) diff --git a/dpnp/backend/extensions/blas/CMakeLists.txt b/dpnp/backend/extensions/blas/CMakeLists.txt new file mode 100644 index 00000000000..2f1c2857f2f --- /dev/null +++ b/dpnp/backend/extensions/blas/CMakeLists.txt @@ -0,0 +1,82 @@ +# ***************************************************************************** +# Copyright (c) 2016-2023, Intel Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# - Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +# ***************************************************************************** + + +set(python_module_name _blas_impl) +set(_module_src + ${CMAKE_CURRENT_SOURCE_DIR}/blas_py.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/gemm.cpp +) + +pybind11_add_module(${python_module_name} MODULE ${_module_src}) +add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_module_src}) + +if (WIN32) + if (${CMAKE_VERSION} VERSION_LESS "3.27") + # this is a work-around for target_link_options inserting option after -link option, cause + # linker to ignore it. + set(CMAKE_CXX_LINK_FLAGS "${CMAKE_CXX_LINK_FLAGS} -fsycl-device-code-split=per_kernel") + endif() +endif() + +set_target_properties(${python_module_name} PROPERTIES CMAKE_POSITION_INDEPENDENT_CODE ON) + +target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../include) +target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../src) + +target_include_directories(${python_module_name} PUBLIC ${Dpctl_INCLUDE_DIRS}) +target_include_directories(${python_module_name} PUBLIC ${Dpctl_TENSOR_INCLUDE_DIR}) + +if (WIN32) + target_compile_options(${python_module_name} PRIVATE + /clang:-fno-approx-func + /clang:-fno-finite-math-only + ) +else() + target_compile_options(${python_module_name} PRIVATE + -fno-approx-func + -fno-finite-math-only + ) +endif() + +target_link_options(${python_module_name} PUBLIC -fsycl-device-code-split=per_kernel) +if (UNIX) + # this option is support on Linux only + target_link_options(${python_module_name} PUBLIC -fsycl-link-huge-device-code) +endif() + +if (DPNP_GENERATE_COVERAGE) + target_link_options(${python_module_name} PRIVATE -fprofile-instr-generate -fcoverage-mapping) +endif() + +if (MKL_VERSION_2024) + target_link_libraries(${python_module_name} PUBLIC MKL::MKL_SYCL::BLAS) +else() + target_link_libraries(${python_module_name} PUBLIC MKL::MKL_DPCPP) +endif() + +install(TARGETS ${python_module_name} + DESTINATION "dpnp/backend/extensions/blas" +) diff --git a/dpnp/backend/extensions/blas/blas_py.cpp b/dpnp/backend/extensions/blas/blas_py.cpp new file mode 100644 index 00000000000..f3a48946019 --- /dev/null +++ b/dpnp/backend/extensions/blas/blas_py.cpp @@ -0,0 +1,53 @@ +//***************************************************************************** +// Copyright (c) 2023, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +// This file defines functions of dpnp.backend._lapack_impl extensions +// +//***************************************************************************** + +#include +#include + +#include "gemm.hpp" + +namespace blas_ext = dpnp::backend::ext::blas; +namespace py = pybind11; + +// populate dispatch tables +void init_dispatch_tables(void) +{ + blas_ext::init_gemm_dispatch_table(); +} + +PYBIND11_MODULE(_blas_impl, m) +{ + init_dispatch_tables(); + + m.def("_gemm", &blas_ext::gemm, + "Call `gemm` from OneMKL LAPACK library to return " + "the matrix-matrix product with general matrices.", + py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"), + py::arg("matrixC"), py::arg("depends") = py::list()); +} diff --git a/dpnp/backend/extensions/blas/gemm.cpp b/dpnp/backend/extensions/blas/gemm.cpp new file mode 100644 index 00000000000..b0fe2d5c5e2 --- /dev/null +++ b/dpnp/backend/extensions/blas/gemm.cpp @@ -0,0 +1,268 @@ +//***************************************************************************** +// Copyright (c) 2023, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include + +// dpctl tensor headers +#include "utils/memory_overlap.hpp" +#include "utils/type_utils.hpp" + +#include "gemm.hpp" +#include "types_matrix.hpp" + +#include "dpnp_utils.hpp" + +namespace dpnp +{ +namespace backend +{ +namespace ext +{ +namespace blas +{ +namespace mkl_blas = oneapi::mkl::blas; +namespace py = pybind11; +namespace type_utils = dpctl::tensor::type_utils; + +typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue, + oneapi::mkl::transpose, + oneapi::mkl::transpose, + const std::int64_t, + const std::int64_t, + const std::int64_t, + char *, + const std::int64_t, + char *, + const std::int64_t, + char *, + const std::int64_t, + const std::vector &); + +static gemm_impl_fn_ptr_t gemm_dispatch_table[dpctl_td_ns::num_types] + [dpctl_td_ns::num_types]; + +template +static sycl::event gemm_impl(sycl::queue exec_q, + oneapi::mkl::transpose transA, + oneapi::mkl::transpose transB, + const std::int64_t m, + const std::int64_t n, + const std::int64_t k, + char *matrixA, + const std::int64_t ld_array_1, + char *matrixB, + const std::int64_t ld_array_2, + char *resultC, + const std::int64_t ld_result, + const std::vector &depends) +{ + type_utils::validate_type_for_device(exec_q); + type_utils::validate_type_for_device(exec_q); + + Tab *a = reinterpret_cast(matrixA); + Tab *b = reinterpret_cast(matrixB); + Tc *res = reinterpret_cast(resultC); + + std::stringstream error_msg; + std::int64_t info = 0; + bool mkl_exception_caught = false; + + sycl::event gemm_event; + try { + // Need to add logic to call column_major::gemm + gemm_event = mkl_blas::row_major::gemm( + exec_q, + transA, // Parameter indicating whether matrix A is not transposed + // ('N'), transposed ('T'), or conjugate transposed ('C'). + transB, // Same as transA but for matrix B. + m, // Number of rows in matrices A and C. + n, // Number of columns in matrices B and C. + k, // Number of columns in matrix A and rows in matrix B. + Tab(1), // Scaling factor for the product of matrices A and B. + a, // Pointer to matrix A. + ld_array_1, // Leading dimension of matrix A, which is the stride + // between successive rows (for row major layout). + b, // Pointer to matrix B. + ld_array_2, // Leading dimension of matrix B, similar to ld_array_1. + Tab(0), // Scaling factor for matrix C. + res, // Pointer to matrix C, where the result is stored. + ld_result, // Leading dimension of matrix C. + depends); + } catch (oneapi::mkl::exception const &e) { + error_msg + << "Unexpected MKL exception caught during gemm() call:\nreason: " + << e.what(); + mkl_exception_caught = true; + } catch (sycl::exception const &e) { + error_msg << "Unexpected SYCL exception caught during gemm() call:\n" + << e.what(); + info = -1; + } + + if (info != 0 || mkl_exception_caught) // an unexpected error occurs + { + throw std::runtime_error(error_msg.str()); + } + + return gemm_event; +} + +std::pair + gemm(sycl::queue exec_q, + dpctl::tensor::usm_ndarray matrixA, + dpctl::tensor::usm_ndarray matrixB, + dpctl::tensor::usm_ndarray resultC, + const std::vector &depends) +{ + const int matrixA_nd = matrixA.get_ndim(); + const int matrixB_nd = matrixB.get_ndim(); + const int resultC_nd = resultC.get_ndim(); + + // TODO: Add support for more two-dimensional arrays + if ((matrixA_nd != 2) || (matrixB_nd != 2) || (resultC_nd != 2)) { + throw py::value_error("The input matrices must be of 2 dimensions."); + } + + // check compatibility of execution queue and allocation queue + if (!dpctl::utils::queues_are_compatible( + exec_q, + {matrixA.get_queue(), matrixB.get_queue(), resultC.get_queue()})) + { + throw std::runtime_error( + "USM allocations are not compatible with the execution queue."); + } + + // bool is_matrixA_c_contig = matrixA.is_c_contiguous(); + // bool is_matrixB_c_contig = matrixB.is_c_contiguous(); + + bool is_matrixA_f_contig = matrixA.is_f_contiguous(); + bool is_matrixB_f_contig = matrixB.is_f_contiguous(); + + const py::ssize_t *a_shape = matrixA.get_shape_raw(); + const py::ssize_t *b_shape = matrixB.get_shape_raw(); + const py::ssize_t *res_shape = resultC.get_shape_raw(); + + if (a_shape[1] != b_shape[0]) { + throw std::runtime_error("The number of columns in A must be equal to " + "the number of rows in B."); + } + + oneapi::mkl::transpose transA = is_matrixA_f_contig + ? oneapi::mkl::transpose::T + : oneapi::mkl::transpose::N; + oneapi::mkl::transpose transB = is_matrixB_f_contig + ? oneapi::mkl::transpose::T + : oneapi::mkl::transpose::N; + + // // support only 2d matrices + // auto isRowm = [](const dpctl::tensor::usm_ndarray m) { + // const py::ssize_t *m_s = m.get_strides_raw(); + // return m_s[1] == 1; + // }; + + // if (!is_matrixA_c_contig && !is_matrixA_f_contig){ + // transA = isRowm(matrixA) ? oneapi::mkl::transpose::N + // : oneapi::mkl::transpose::T; + + // } + + // if (!is_matrixB_c_contig && !is_matrixB_f_contig){ + // transB = isRowm(matrixB) ? oneapi::mkl::transpose::N + // : oneapi::mkl::transpose::T; + // } + + const std::int64_t m = a_shape[0]; + const std::int64_t n = b_shape[1]; + const std::int64_t k = a_shape[1]; + + // const std::int64_t ld_array_1 = + // (transA == oneapi::mkl::transpose::nontrans) ? k : m; + // const std::int64_t ld_array_2 = + // (transB == oneapi::mkl::transpose::nontrans) ? n : k; + // const std::int64_t ld_result = res_shape[1]; + + const std::int64_t ld_array_1 = + (transA == oneapi::mkl::transpose::N) ? a_shape[1] : a_shape[0]; + const std::int64_t ld_array_2 = + (transB == oneapi::mkl::transpose::N) ? b_shape[1] : b_shape[0]; + const std::int64_t ld_result = res_shape[1]; + + int matrixA_typenum = matrixA.get_typenum(); + int matrixB_typenum = matrixB.get_typenum(); + int resultC_typenum = resultC.get_typenum(); + + if (matrixA_typenum != matrixB_typenum) { + throw py::value_error("matrixA and matrixB must be of the same type."); + } + + auto array_types = dpctl_td_ns::usm_ndarray_types(); + int matrixAB_type_id = array_types.typenum_to_lookup_id(matrixA_typenum); + int resultC_type_id = array_types.typenum_to_lookup_id(resultC_typenum); + + gemm_impl_fn_ptr_t gemm_fn = + gemm_dispatch_table[matrixAB_type_id][resultC_type_id]; + if (gemm_fn == nullptr) { + throw py::value_error("Type dispatch ran into trouble."); + } + + char *a_typeless_ptr = matrixA.get_data(); + char *b_typeless_ptr = matrixB.get_data(); + char *r_typeless_ptr = resultC.get_data(); + + std::vector host_task_events; + sycl::event gemm_ev = + gemm_fn(exec_q, transA, transB, m, n, k, a_typeless_ptr, ld_array_1, + b_typeless_ptr, ld_array_2, r_typeless_ptr, ld_result, depends); + + sycl::event args_ev = dpctl::utils::keep_args_alive( + exec_q, {matrixA, matrixB, resultC}, host_task_events); + return std::make_pair(args_ev, gemm_ev); +} + +template +struct GemmContigFactory +{ + fnT get() + { + if constexpr (types::GemmTypePairSupportFactory::is_defined) { + return gemm_impl; + } + else { + return nullptr; + } + } +}; + +void init_gemm_dispatch_table(void) +{ + dpctl_td_ns::DispatchTableBuilder + contig; + contig.populate_dispatch_table(gemm_dispatch_table); +} +} // namespace blas +} // namespace ext +} // namespace backend +} // namespace dpnp diff --git a/dpnp/backend/extensions/blas/gemm.hpp b/dpnp/backend/extensions/blas/gemm.hpp new file mode 100644 index 00000000000..bff43a0d819 --- /dev/null +++ b/dpnp/backend/extensions/blas/gemm.hpp @@ -0,0 +1,52 @@ +//***************************************************************************** +// Copyright (c) 2023, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include +#include + +#include + +namespace dpnp +{ +namespace backend +{ +namespace ext +{ +namespace blas +{ +extern std::pair + gemm(sycl::queue exec_q, + dpctl::tensor::usm_ndarray matrixA, + dpctl::tensor::usm_ndarray matrixB, + dpctl::tensor::usm_ndarray resultC, + const std::vector &depends); + +extern void init_gemm_dispatch_table(void); +} // namespace blas +} // namespace ext +} // namespace backend +} // namespace dpnp diff --git a/dpnp/backend/extensions/blas/types_matrix.hpp b/dpnp/backend/extensions/blas/types_matrix.hpp new file mode 100644 index 00000000000..7f2482cae85 --- /dev/null +++ b/dpnp/backend/extensions/blas/types_matrix.hpp @@ -0,0 +1,78 @@ +//***************************************************************************** +// Copyright (c) 2023, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include + +// dpctl tensor headers +#include "utils/type_dispatch.hpp" + +// dpctl namespace for operations with types +namespace dpctl_td_ns = dpctl::tensor::type_dispatch; + +namespace dpnp +{ +namespace backend +{ +namespace ext +{ +namespace blas +{ +namespace types +{ +/** + * @brief A factory to define pairs of supported types for which + * MKL BLAS library provides support in oneapi::mkl::blas::gemm + * function. + * + * @tparam Tab Type of arrays containing input matrices A and B. + * @tparam Tc Type of array containing output matrix C. + */ +template +struct GemmTypePairSupportFactory +{ + static constexpr bool is_defined = std::disjunction< + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + Tc, + std::complex>, + dpctl_td_ns::TypePairDefinedEntry, + Tc, + std::complex>, + // fall-through + dpctl_td_ns::NotDefinedEntry>::is_defined; +}; +} // namespace types +} // namespace blas +} // namespace ext +} // namespace backend +} // namespace dpnp diff --git a/dpnp/dpnp_iface_linearalgebra.py b/dpnp/dpnp_iface_linearalgebra.py index 30b6134da17..4ee98e92c56 100644 --- a/dpnp/dpnp_iface_linearalgebra.py +++ b/dpnp/dpnp_iface_linearalgebra.py @@ -40,10 +40,12 @@ """ +import dpctl import dpctl.tensor as dpt import numpy import dpnp +import dpnp.backend.extensions.blas._blas_impl as bi from dpnp.dpnp_algo import * from dpnp.dpnp_utils import * @@ -282,50 +284,110 @@ def matmul(x1, x2, out=None, **kwargs): """ - x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False) - x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_nondefault_queue=False) - if x1_desc and x2_desc and not kwargs: - if x1_desc.ndim != 2 or x2_desc.ndim != 2: - pass - elif not x1_desc.ndim: - pass - elif not x2_desc.ndim: - pass - elif not x1_desc.size: - pass - elif not x2_desc.size: - pass - else: - if 0: - """ - Cost model checks - """ - - array1_size = x1_desc.size - array2_size = x2_desc.size - cost_size = 4096 # 2D array shape(64, 64) - - if (x1_desc.dtype == dpnp.float64) or ( - x1_desc.dtype == dpnp.float32 - ): - """ - Floating point types are handled via original math library better than SYCL math library - """ - cost_size = 262144 # 2D array shape(512, 512) - - if (array1_size > cost_size) and (array2_size > cost_size): - return dpnp_matmul(x1_desc, x2_desc, out) - else: - out_desc = ( - dpnp.get_dpnp_descriptor( - out, copy_when_nondefault_queue=False - ) - if out is not None - else None - ) - return dpnp_matmul(x1_desc, x2_desc, out_desc).get_pyobj() + # x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False) + # x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_nondefault_queue=False) + # if x1_desc and x2_desc and not kwargs: + # if x1_desc.ndim != 2 or x2_desc.ndim != 2: + # pass + # elif not x1_desc.ndim: + # pass + # elif not x2_desc.ndim: + # pass + # elif not x1_desc.size: + # pass + # elif not x2_desc.size: + # pass + # else: + # if 0: + # """ + # Cost model checks + # """ + + # array1_size = x1_desc.size + # array2_size = x2_desc.size + # cost_size = 4096 # 2D array shape(64, 64) + + # if (x1_desc.dtype == dpnp.float64) or ( + # x1_desc.dtype == dpnp.float32 + # ): + # """ + # Floating point types are handled via original math library better than SYCL math library + # """ + # cost_size = 262144 # 2D array shape(512, 512) + + # if (array1_size > cost_size) and (array2_size > cost_size): + # return dpnp_matmul(x1_desc, x2_desc, out) + # else: + # out_desc = ( + # dpnp.get_dpnp_descriptor( + # out, copy_when_nondefault_queue=False + # ) + # if out is not None + # else None + # ) + # return dpnp_matmul(x1_desc, x2_desc, out_desc).get_pyobj() + + # return call_origin(numpy.matmul, x1, x2, out=out, **kwargs) + + if not dpnp.is_supported_array_type(x1): + raise TypeError( + "An array must be any of supported type, but got {}".format( + type(x1) + ) + ) + + if not dpnp.is_supported_array_type(x2): + raise TypeError( + "An array must be any of supported type, but got {}".format( + type(x2) + ) + ) + + if x1.ndim != 2: + raise ValueError( + f"{x1.ndim}-dimensional array given. The input " + "array must be two-dimensional" + ) + + if x2.ndim != 2: + raise ValueError( + f"{x2.ndim}-dimensional array given. The input " + "array must be two-dimensional" + ) + + if x1.shape[1] != x2.shape[0]: + raise ValueError( + "Input operand 1 has a mismatch in its core dimension 0, " + "with gufunc signature (n?,k),(k,m?)->(n?,m?) " + f"(size {x1.shape[1]} is different from {x2.shape[0]})" + ) + + exec_q = dpctl.utils.get_execution_queue((x1.sycl_queue, x2.sycl_queue)) + if exec_q is None: + raise ValueError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + + # Determine the resulting type + # Now supports input arrays of float type + result = dpnp.empty( + (x1.shape[0], x2.shape[1]), dtype="float32", sycl_queue=exec_q + ) + + # x1_usm_arr = dpnp.get_usm_ndarray(x1) + # x2_usm_arr = dpnp.get_usm_ndarray(x2) + # res_usm_arr = dpnp.get_usm_ndarray(result) + + # Is it necessary to do a copy of the input arrays?! + + ht_blas_ev, _ = bi._gemm( + exec_q, x1.get_array(), x2.get_array(), result.get_array(), [] + ) + + ht_blas_ev.wait() - return call_origin(numpy.matmul, x1, x2, out=out, **kwargs) + return result def outer(x1, x2, out=None):