diff --git a/dpnp/backend/extensions/blas/CMakeLists.txt b/dpnp/backend/extensions/blas/CMakeLists.txt index 2f1c2857f2f..d19f60c9792 100644 --- a/dpnp/backend/extensions/blas/CMakeLists.txt +++ b/dpnp/backend/extensions/blas/CMakeLists.txt @@ -28,6 +28,7 @@ set(python_module_name _blas_impl) set(_module_src ${CMAKE_CURRENT_SOURCE_DIR}/blas_py.cpp ${CMAKE_CURRENT_SOURCE_DIR}/gemm.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/gemm_batch.cpp ) pybind11_add_module(${python_module_name} MODULE ${_module_src}) diff --git a/dpnp/backend/extensions/blas/blas_py.cpp b/dpnp/backend/extensions/blas/blas_py.cpp index f3a48946019..8232cda85c7 100644 --- a/dpnp/backend/extensions/blas/blas_py.cpp +++ b/dpnp/backend/extensions/blas/blas_py.cpp @@ -39,15 +39,25 @@ namespace py = pybind11; void init_dispatch_tables(void) { blas_ext::init_gemm_dispatch_table(); + blas_ext::init_gemm_batch_dispatch_table(); } PYBIND11_MODULE(_blas_impl, m) { init_dispatch_tables(); - m.def("_gemm", &blas_ext::gemm, - "Call `gemm` from OneMKL LAPACK library to return " - "the matrix-matrix product with general matrices.", - py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"), - py::arg("matrixC"), py::arg("depends") = py::list()); + { + m.def("_gemm", &blas_ext::gemm, + "Call `gemm` from OneMKL LAPACK library to return " + "the matrix-matrix product with 2-D matrices.", + py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"), + py::arg("matrixC"), py::arg("isRowMajor"), + py::arg("depends") = py::list()); + } + + { + m.def("_gemm_batch", &blas_ext::gemm_batch, + "Call `gemm_batch` from OneMKL LAPACK library to return " + "the matrix-matrix product with general matrices."); + } } diff --git a/dpnp/backend/extensions/blas/gemm.cpp b/dpnp/backend/extensions/blas/gemm.cpp index b0fe2d5c5e2..b217cdd9025 100644 --- a/dpnp/backend/extensions/blas/gemm.cpp +++ b/dpnp/backend/extensions/blas/gemm.cpp @@ -58,6 +58,7 @@ typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue, const std::int64_t, char *, const std::int64_t, + const bool, const std::vector &); static gemm_impl_fn_ptr_t gemm_dispatch_table[dpctl_td_ns::num_types] @@ -76,6 +77,7 @@ static sycl::event gemm_impl(sycl::queue exec_q, const std::int64_t ld_array_2, char *resultC, const std::int64_t ld_result, + const bool isRowMajor, const std::vector &depends) { type_utils::validate_type_for_device(exec_q); @@ -92,24 +94,54 @@ static sycl::event gemm_impl(sycl::queue exec_q, sycl::event gemm_event; try { // Need to add logic to call column_major::gemm - gemm_event = mkl_blas::row_major::gemm( - exec_q, - transA, // Parameter indicating whether matrix A is not transposed - // ('N'), transposed ('T'), or conjugate transposed ('C'). - transB, // Same as transA but for matrix B. - m, // Number of rows in matrices A and C. - n, // Number of columns in matrices B and C. - k, // Number of columns in matrix A and rows in matrix B. - Tab(1), // Scaling factor for the product of matrices A and B. - a, // Pointer to matrix A. - ld_array_1, // Leading dimension of matrix A, which is the stride - // between successive rows (for row major layout). - b, // Pointer to matrix B. - ld_array_2, // Leading dimension of matrix B, similar to ld_array_1. - Tab(0), // Scaling factor for matrix C. - res, // Pointer to matrix C, where the result is stored. - ld_result, // Leading dimension of matrix C. - depends); + if (isRowMajor) { + gemm_event = mkl_blas::row_major::gemm( + exec_q, + transA, // Parameter indicating whether matrix A is not + // transposed + // ('N'), transposed ('T'), or conjugate transposed + // ('C'). + transB, // Same as transA but for matrix B. + m, // Number of rows in matrices A and C. + n, // Number of columns in matrices B and C. + k, // Number of columns in matrix A and rows in matrix B. + Tab(1), // Scaling factor for the product of matrices A and B. + a, // Pointer to matrix A. + ld_array_1, // Leading dimension of matrix A, which is the + // stride between successive rows (for row major + // layout). + b, // Pointer to matrix B. + ld_array_2, // Leading dimension of matrix B, similar to + // ld_array_1. + Tab(0), // Scaling factor for matrix C. + res, // Pointer to matrix C, where the result is stored. + ld_result, // Leading dimension of matrix C. + depends); + } + else { + gemm_event = mkl_blas::column_major::gemm( + exec_q, + transA, // Parameter indicating whether matrix A is not + // transposed + // ('N'), transposed ('T'), or conjugate transposed + // ('C'). + transB, // Same as transA but for matrix B. + m, // Number of rows in matrices A and C. + n, // Number of columns in matrices B and C. + k, // Number of columns in matrix A and rows in matrix B. + Tab(1), // Scaling factor for the product of matrices A and B. + a, // Pointer to matrix A. + ld_array_1, // Leading dimension of matrix A, which is the + // stride between successive rows (for row major + // layout). + b, // Pointer to matrix B. + ld_array_2, // Leading dimension of matrix B, similar to + // ld_array_1. + Tab(0), // Scaling factor for matrix C. + res, // Pointer to matrix C, where the result is stored. + ld_result, // Leading dimension of matrix C. + depends); + } } catch (oneapi::mkl::exception const &e) { error_msg << "Unexpected MKL exception caught during gemm() call:\nreason: " @@ -134,6 +166,7 @@ std::pair dpctl::tensor::usm_ndarray matrixA, dpctl::tensor::usm_ndarray matrixB, dpctl::tensor::usm_ndarray resultC, + const bool isRowMajor, const std::vector &depends) { const int matrixA_nd = matrixA.get_ndim(); @@ -234,7 +267,8 @@ std::pair std::vector host_task_events; sycl::event gemm_ev = gemm_fn(exec_q, transA, transB, m, n, k, a_typeless_ptr, ld_array_1, - b_typeless_ptr, ld_array_2, r_typeless_ptr, ld_result, depends); + b_typeless_ptr, ld_array_2, r_typeless_ptr, ld_result, + isRowMajor, depends); sycl::event args_ev = dpctl::utils::keep_args_alive( exec_q, {matrixA, matrixB, resultC}, host_task_events); diff --git a/dpnp/backend/extensions/blas/gemm.hpp b/dpnp/backend/extensions/blas/gemm.hpp index bff43a0d819..9eaa2f4ea22 100644 --- a/dpnp/backend/extensions/blas/gemm.hpp +++ b/dpnp/backend/extensions/blas/gemm.hpp @@ -43,9 +43,31 @@ extern std::pair dpctl::tensor::usm_ndarray matrixA, dpctl::tensor::usm_ndarray matrixB, dpctl::tensor::usm_ndarray resultC, + const bool isRowMajor, const std::vector &depends); +// extern sycl::event +extern std::pair + gemm_batch(sycl::queue q, + dpctl::tensor::usm_ndarray matrixA, + dpctl::tensor::usm_ndarray matrixB, + dpctl::tensor::usm_ndarray resultC, + const std::int64_t m, + const std::int64_t n, + const std::int64_t k, + const std::int64_t batch_size, + const std::int64_t ld_array_1, + const std::int64_t ld_array_2, + const std::int64_t ld_result, + size_t stridea, + size_t strideb, + size_t stridec, + const std::int64_t transA_int, + const std::int64_t transB_int, + const std::vector &depends); + extern void init_gemm_dispatch_table(void); +extern void init_gemm_batch_dispatch_table(void); } // namespace blas } // namespace ext } // namespace backend diff --git a/dpnp/backend/extensions/blas/gemm_batch.cpp b/dpnp/backend/extensions/blas/gemm_batch.cpp new file mode 100644 index 00000000000..b6cf878f74f --- /dev/null +++ b/dpnp/backend/extensions/blas/gemm_batch.cpp @@ -0,0 +1,225 @@ +//***************************************************************************** +// Copyright (c) 2023, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include + +// dpctl tensor headers +#include "utils/memory_overlap.hpp" +#include "utils/type_utils.hpp" + +#include "gemm.hpp" +#include "types_matrix.hpp" + +#include "dpnp_utils.hpp" + +namespace dpnp +{ +namespace backend +{ +namespace ext +{ +namespace blas +{ +namespace mkl_blas = oneapi::mkl::blas; +namespace py = pybind11; +namespace type_utils = dpctl::tensor::type_utils; + +typedef sycl::event (*gemm_batch_impl_fn_ptr_t)( + sycl::queue, + const std::int64_t, + const std::int64_t, + const std::int64_t, + const std::int64_t, + const std::int64_t, + const std::int64_t, + const std::int64_t, + size_t, + size_t, + size_t, + oneapi::mkl::transpose, + oneapi::mkl::transpose, + char *, + char *, + char *, + const std::vector &); + +static gemm_batch_impl_fn_ptr_t + gemm_batch_dispatch_table[dpctl_td_ns::num_types][dpctl_td_ns::num_types]; + +template +static sycl::event gemm_batch_impl(sycl::queue exec_q, + const std::int64_t m, + const std::int64_t n, + const std::int64_t k, + const std::int64_t batch_size, + const std::int64_t ld_array_1, + const std::int64_t ld_array_2, + const std::int64_t ld_result, + size_t stridea, + size_t strideb, + size_t stridec, + oneapi::mkl::transpose transA, + oneapi::mkl::transpose transB, + char *matrixA, + char *matrixB, + char *resultC, + const std::vector &depends) +{ + type_utils::validate_type_for_device(exec_q); + type_utils::validate_type_for_device(exec_q); + + Tab *a = reinterpret_cast(matrixA); + Tab *b = reinterpret_cast(matrixB); + Tc *res = reinterpret_cast(resultC); + + std::stringstream error_msg; + std::int64_t info = 0; + bool mkl_exception_caught = false; + + sycl::event gemm_batch_event; + try { + // Need to add logic to call column_major::gemm + gemm_batch_event = oneapi::mkl::blas::row_major::gemm_batch( + exec_q, transA, transB, m, n, k, Tab(1), a, ld_array_1, stridea, b, + ld_array_2, strideb, Tab(0), res, ld_result, stridec, batch_size, + depends); + } catch (oneapi::mkl::exception const &e) { + error_msg + << "Unexpected MKL exception caught during gemm() call:\nreason: " + << e.what(); + mkl_exception_caught = true; + } catch (sycl::exception const &e) { + error_msg << "Unexpected SYCL exception caught during gemm() call:\n" + << e.what(); + info = -1; + } + + if (info != 0 || mkl_exception_caught) // an unexpected error occurs + { + throw std::runtime_error(error_msg.str()); + } + + return gemm_batch_event; +} + +// std::pair +// sycl::event +std::pair + gemm_batch(sycl::queue exec_q, + dpctl::tensor::usm_ndarray matrixA, + dpctl::tensor::usm_ndarray matrixB, + dpctl::tensor::usm_ndarray resultC, + const std::int64_t m, + const std::int64_t n, + const std::int64_t k, + const std::int64_t batch_size, + const std::int64_t ld_array_1, + const std::int64_t ld_array_2, + const std::int64_t ld_result, + size_t stridea, + size_t strideb, + size_t stridec, + const std::int64_t transA_int, + const std::int64_t transB_int, + const std::vector &depends = {}) +{ + if (!dpctl::utils::queues_are_compatible( + exec_q, + {matrixA.get_queue(), matrixB.get_queue(), resultC.get_queue()})) + { + throw std::runtime_error( + "USM allocations are not compatible with the execution queue."); + } + + oneapi::mkl::transpose transA = (transA_int == 1) + ? oneapi::mkl::transpose::N + : oneapi::mkl::transpose::T; + oneapi::mkl::transpose transB = (transB_int == 1) + ? oneapi::mkl::transpose::N + : oneapi::mkl::transpose::T; + + int matrixA_typenum = matrixA.get_typenum(); + int matrixB_typenum = matrixB.get_typenum(); + int resultC_typenum = resultC.get_typenum(); + + if (matrixA_typenum != matrixB_typenum) { + throw py::value_error("matrixA and matrixB must be of the same type."); + } + auto array_types = dpctl_td_ns::usm_ndarray_types(); + int matrixAB_type_id = array_types.typenum_to_lookup_id(matrixA_typenum); + int resultC_type_id = array_types.typenum_to_lookup_id(resultC_typenum); + + gemm_batch_impl_fn_ptr_t gemm_batch_fn = + gemm_batch_dispatch_table[matrixAB_type_id][resultC_type_id]; + if (gemm_batch_fn == nullptr) { + throw py::value_error("Type dispatch ran into trouble."); + } + + char *a_typeless_ptr = matrixA.get_data(); + char *b_typeless_ptr = matrixB.get_data(); + char *r_typeless_ptr = resultC.get_data(); + + std::vector host_task_events; + // sycl::event res_ev; + sycl::event gemm_batch_ev = + gemm_batch_fn(exec_q, m, n, k, batch_size, ld_array_1, ld_array_2, + ld_result, stridea, strideb, stridec, transA, transB, + a_typeless_ptr, b_typeless_ptr, r_typeless_ptr, depends); + + // res_ev = gemm_batch_ev; + + // return res_ev; + sycl::event args_batch_ev = dpctl::utils::keep_args_alive( + exec_q, {matrixA, matrixB, resultC}, host_task_events); + return std::make_pair(args_batch_ev, gemm_batch_ev); +} + +template +struct GemmBatchContigFactory +{ + fnT get() + { + if constexpr (types::GemmBatchTypePairSupportFactory::is_defined) { + return gemm_batch_impl; + } + else { + return nullptr; + } + } +}; + +void init_gemm_batch_dispatch_table(void) +{ + dpctl_td_ns::DispatchTableBuilder + contig; + contig.populate_dispatch_table(gemm_batch_dispatch_table); +} +} // namespace blas +} // namespace ext +} // namespace backend +} // namespace dpnp diff --git a/dpnp/backend/extensions/blas/types_matrix.hpp b/dpnp/backend/extensions/blas/types_matrix.hpp index 7f2482cae85..f2191fc6e2b 100644 --- a/dpnp/backend/extensions/blas/types_matrix.hpp +++ b/dpnp/backend/extensions/blas/types_matrix.hpp @@ -60,6 +60,29 @@ struct GemmTypePairSupportFactory dpctl_td_ns::TypePairDefinedEntry, dpctl_td_ns::TypePairDefinedEntry, dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + Tc, + std::complex>, + dpctl_td_ns::TypePairDefinedEntry, + Tc, + std::complex>, + // fall-through + dpctl_td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct GemmBatchTypePairSupportFactory +{ + static constexpr bool is_defined = std::disjunction< + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, dpctl_td_ns::TypePairDefinedEntry, Tc, diff --git a/dpnp/dpnp_iface_linearalgebra.py b/dpnp/dpnp_iface_linearalgebra.py index 4ee98e92c56..1fcfe8ef0ee 100644 --- a/dpnp/dpnp_iface_linearalgebra.py +++ b/dpnp/dpnp_iface_linearalgebra.py @@ -42,11 +42,13 @@ import dpctl import dpctl.tensor as dpt +import dpctl.tensor._tensor_impl as ti import numpy import dpnp import dpnp.backend.extensions.blas._blas_impl as bi from dpnp.dpnp_algo import * +from dpnp.dpnp_array import dpnp_array from dpnp.dpnp_utils import * __all__ = [ @@ -284,110 +286,214 @@ def matmul(x1, x2, out=None, **kwargs): """ - # x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False) - # x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_nondefault_queue=False) - # if x1_desc and x2_desc and not kwargs: - # if x1_desc.ndim != 2 or x2_desc.ndim != 2: - # pass - # elif not x1_desc.ndim: - # pass - # elif not x2_desc.ndim: - # pass - # elif not x1_desc.size: - # pass - # elif not x2_desc.size: - # pass - # else: - # if 0: - # """ - # Cost model checks - # """ - - # array1_size = x1_desc.size - # array2_size = x2_desc.size - # cost_size = 4096 # 2D array shape(64, 64) - - # if (x1_desc.dtype == dpnp.float64) or ( - # x1_desc.dtype == dpnp.float32 - # ): - # """ - # Floating point types are handled via original math library better than SYCL math library - # """ - # cost_size = 262144 # 2D array shape(512, 512) - - # if (array1_size > cost_size) and (array2_size > cost_size): - # return dpnp_matmul(x1_desc, x2_desc, out) - # else: - # out_desc = ( - # dpnp.get_dpnp_descriptor( - # out, copy_when_nondefault_queue=False - # ) - # if out is not None - # else None - # ) - # return dpnp_matmul(x1_desc, x2_desc, out_desc).get_pyobj() - - # return call_origin(numpy.matmul, x1, x2, out=out, **kwargs) - - if not dpnp.is_supported_array_type(x1): - raise TypeError( - "An array must be any of supported type, but got {}".format( - type(x1) - ) - ) - - if not dpnp.is_supported_array_type(x2): - raise TypeError( - "An array must be any of supported type, but got {}".format( - type(x2) - ) - ) + x1_ndim = x1.ndim + x2_ndim = x2.ndim - if x1.ndim != 2: + if x1_ndim == 0 or x2_ndim == 0: raise ValueError( - f"{x1.ndim}-dimensional array given. The input " - "array must be two-dimensional" + "matmul: Input operand does not have enough dimensions" ) - if x2.ndim != 2: + exec_q = dpctl.utils.get_execution_queue((x1.sycl_queue, x2.sycl_queue)) + if exec_q is None: raise ValueError( - f"{x2.ndim}-dimensional array given. The input " - "array must be two-dimensional" + "Execution placement can not be unambiguously inferred " + "from input arguments." ) - if x1.shape[1] != x2.shape[0]: + squeeze_flag = x1_ndim == 1 or x2_ndim == 1 + if x1_ndim == 1: + x1 = x1[dpnp.newaxis, :] + x1_ndim = x1.ndim + + if x2_ndim == 1: + x2 = x2[:, dpnp.newaxis] + x2_ndim = x2.ndim + + x1_shape = x1.shape + x2_shape = x2.shape + if x1_shape[-1] != x2_shape[-2]: raise ValueError( "Input operand 1 has a mismatch in its core dimension 0, " "with gufunc signature (n?,k),(k,m?)->(n?,m?) " - f"(size {x1.shape[1]} is different from {x2.shape[0]})" + f"(size {x1_shape[1]} is different from {x2_shape[0]})" ) - exec_q = dpctl.utils.get_execution_queue((x1.sycl_queue, x2.sycl_queue)) - if exec_q is None: + # Determine the result data type # should be corrected for integer data type # VAHID + res_dtype = _common_type(x1, x2) + if x1.dtype != res_dtype: + x1 = dpnp.astype(x1, res_dtype) + if x2.dtype != res_dtype: + x2 = dpnp.astype(x2, res_dtype) + + if x1_ndim == 2 and x2_ndim == 2: + res_shape = (x1.shape[0], x2.shape[1]) + else: + if x1_ndim != x2_ndim: + diff = abs(x1_ndim - x2_ndim) + + if x1_ndim < x2_ndim: + x1 = x1.reshape((1,) * diff + x1.shape) + x1_ndim = x1.ndim + x1_shape = x1.shape + res_shape = x2_shape[:-2] + (x1_shape[-2], x2_shape[-1]) + else: + x2 = x2.reshape((1,) * diff + x2.shape) + x2_ndim = x2.ndim + x2_shape = x2.shape + res_shape = x1_shape[:-2] + (x1_shape[-2], x2_shape[-1]) + else: + for i in range(x1_ndim - 2): + if x1_shape[i] != x2_shape[i]: + if x1_shape[i] == 1: + x1 = dpnp.repeat(x1, x2_shape[i], axis=i) + elif x2_shape[i] == 1: + x2 = dpnp.repeat(x2, x1_shape[i], axis=i) + else: + raise ValueError( + "operands could not be broadcast together with remapped shapes." + ) + x1_shape = x1.shape + x2_shape = x2.shape + res_shape = x1_shape[:-1] + (x2_shape[-1],) + + result = dpnp.empty(res_shape, dtype=res_dtype, sycl_queue=exec_q) + # Is it necessary to do a copy of the input arrays?! + isRowMajor = True + if result.size == 0: + pass + else: + if x1.size == 0 or x2.size == 0: + result = dpnp.zeros(res_shape, dtype=res_dtype, sycl_queue=exec_q) + else: + if x1_ndim == 2 and x2_ndim == 2: + ht_blas_ev, _ = bi._gemm( + exec_q, + dpnp.get_usm_ndarray(x1), + dpnp.get_usm_ndarray(x2), + dpnp.get_usm_ndarray(result), + isRowMajor, + [], + ) + else: + # if_a_f_contig = a.flags["F_CONTIGUOUS"] + # if_b_f_contig = b.flags["F_CONTIGUOUS"] + # if_out_f_contig = out.flags["F_CONTIGUOUS"] + + # x1_strides = a.strides if not if_a_f_contig else a.strides[::-1] + # x2_strides = b.strides if not if_b_f_contig else b.strides[::-1] + # res_strides = out.strides if not if_out_f_contig else out.strides[::-1] + + x1_strides = x1.strides + x2_strides = x2.strides + res_strides = result.strides + + is_support_gemm(x1_strides, x1_ndim) + is_support_gemm(x2_strides, x2_ndim) + + transa = is_row(x1_strides, x1_ndim) + transb = is_row(x2_strides, x2_ndim) + + batch_size = res_shape[:-2][0] # VAHID + m = x1_shape[-2] + n = x2_shape[-1] + k = x1_shape[-1] + + # lda = max(x1_shape[-2:]) + # ldb = max(x2_shape[-2:]) + # ldc = max(res_shape[-2:]) + lda = k if transa else m + ldb = n if transb else k + ldc = n # column major m, row major n # VAHID + + stridea = x1_strides[0] + strideb = x2_strides[0] + stridec = res_strides[-3] + + if x1_ndim > 3: + iter = ti._contract_iter2( + res_shape[:-2], x1_strides[:-2], x2_strides[:-2] + ) + if len(iter[0]) != 1: + raise ValueError( + "Input arrays cannot be used in gemm_batch" + ) + batch_size = iter[0][0] + stridea = iter[1][0] + strideb = iter[3][0] + + ht_blas_ev, _ = bi._gemm_batch( + exec_q, + dpnp.get_usm_ndarray(x1), + dpnp.get_usm_ndarray(x2), + dpnp.get_usm_ndarray(result), + m, + n, + k, + batch_size, + lda, + ldb, + ldc, + stridea, + strideb, + stridec, + transa, + transb, + [], + ) + + ht_blas_ev.wait() + + if squeeze_flag: + result = dpnp.squeeze(result) + + if out is None: + return result + else: + if out.shape != result.shape: + raise ValueError( + f"Output array of shape {result.shape} is needed, got {out.shape}." + ) + elif not isinstance(out, dpnp_array): + if isinstance(out, dpt.usm_ndarray): + out = dpnp_array._create_from_usm_ndarray(out) + else: + raise TypeError( + "Output array must be any of supported type, but got {}".format( + type(out) + ) + ) + + dpnp.copyto(out, result, casting="safe") + + return out + + +def is_support_gemm(strides, ndim): + if strides[ndim - 1] != 1 and strides[ndim - 2] != 1: raise ValueError( - "Execution placement can not be unambiguously inferred " - "from input arguments." + "The input matrices must be contiguous on inner dimension." ) - # Determine the resulting type - # Now supports input arrays of float type - result = dpnp.empty( - (x1.shape[0], x2.shape[1]), dtype="float32", sycl_queue=exec_q - ) - # x1_usm_arr = dpnp.get_usm_ndarray(x1) - # x2_usm_arr = dpnp.get_usm_ndarray(x2) - # res_usm_arr = dpnp.get_usm_ndarray(result) +def is_row(strides, ndim): + return strides[ndim - 1] == 1 - # Is it necessary to do a copy of the input arrays?! - ht_blas_ev, _ = bi._gemm( - exec_q, x1.get_array(), x2.get_array(), result.get_array(), [] - ) +def _common_type(*arrays): + dtypes = [arr.dtype for arr in arrays] + + default = dpnp.default_float_type().name + dtype_common = _common_type_internal(default, *dtypes) + + return dtype_common - ht_blas_ev.wait() - return result +def _common_type_internal(default_dtype, *dtypes): + inexact_dtypes = [ + dtype if dtype.kind in "fc" else default_dtype for dtype in dtypes + ] + return dpnp.result_type(*inexact_dtypes) def outer(x1, x2, out=None): diff --git a/tests/test_mathematical.py b/tests/test_mathematical.py index 89a09a7dc29..f2dab08346e 100644 --- a/tests/test_mathematical.py +++ b/tests/test_mathematical.py @@ -2044,3 +2044,52 @@ def test_inplace_remainder(dtype): dp_a %= 4 assert_allclose(dp_a, np_a) + + +@pytest.mark.parametrize( + "dtype", get_all_dtypes(no_bool=True, no_none=True, no_complex=True) +) +def test_inplace_floor_divide(dtype): + size = 21 + np_a = numpy.arange(size, dtype=dtype) + dp_a = dpnp.arange(size, dtype=dtype) + + np_a //= 4 + dp_a //= 4 + + assert_allclose(dp_a, np_a) + + +@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) +@pytest.mark.parametrize( + "shape_pair", + [ + ((4,), (4, 4)), + ((4, 4), (4,)), + ((4, 4, 4, 4), (4, 4)), + ((4, 4), (4, 4, 4, 4)), + ((4,), (4,)), + ((4, 4, 4), (4, 4, 4)), + ], + ids=[ + "((4,), (4, 4))", + "((4, 4), (4,))", + "((4, 4, 4, 4), (4, 4))", + "((4, 4), (4, 4, 4, 4))", + "((4,), (4,))", + "((4, 4, 4), (4, 4, 4))", + ], +) +def test_matmul(dtype, shape_pair): + shape1, shape2 = shape_pair + size1 = numpy.prod(shape1) + size2 = numpy.prod(shape2) + a1 = numpy.arange(size1, dtype=dtype).reshape(shape1) + a2 = numpy.arange(size2, dtype=dtype).reshape(shape2) + + b1 = dpnp.asarray(a1) + b2 = dpnp.asarray(a2) + + result = dpnp.matmul(b1, b2) + expected = numpy.matmul(a1, a2) + assert_allclose(expected, result) diff --git a/tests/third_party/cupy/math_tests/test_matmul.py b/tests/third_party/cupy/math_tests/test_matmul.py index d0f3555373a..d4aef7e46d7 100644 --- a/tests/third_party/cupy/math_tests/test_matmul.py +++ b/tests/third_party/cupy/math_tests/test_matmul.py @@ -25,33 +25,33 @@ ((0,), (0,)), # matmul test ((5, 3, 2), (5, 2, 4)), - # ((0, 3, 2), (0, 2, 4)), - # ((5, 3, 2), (2, 4)), - # ((0, 3, 2), (2, 4)), - # ((3, 2), (5, 2, 4)), - # ((3, 2), (0, 2, 4)), - # ((5, 3, 2), (1, 2, 4)), - # ((0, 3, 2), (1, 2, 4)), - # ((1, 3, 2), (5, 2, 4)), - # ((1, 3, 2), (0, 2, 4)), - # ((5, 3, 2), (2,)), - # ((5, 3, 0), (0,)), - # ((2,), (5, 2, 4)), - # ((0,), (5, 0, 4)), - # ((2, 2, 3, 2), (2, 2, 2, 4)), - # ((5, 0, 3, 2), (5, 0, 2, 4)), - # ((6, 5, 3, 2), (2, 4)), - # ((5, 0, 3, 2), (2, 4)), - # ((3, 2), (6, 5, 2, 4)), - # ((3, 2), (5, 0, 2, 4)), - # ((1, 5, 3, 2), (6, 1, 2, 4)), - # ((1, 0, 3, 2), (6, 1, 2, 4)), - # ((6, 1, 3, 2), (1, 5, 2, 4)), - # ((6, 1, 3, 2), (1, 0, 2, 4)), - # ((6, 5, 3, 2), (2,)), - # ((6, 5, 3, 0), (0,)), - # ((2,), (6, 5, 2, 4)), - # ((0,), (6, 5, 0, 4)), + ((0, 3, 2), (0, 2, 4)), + ((5, 3, 2), (2, 4)), + ((0, 3, 2), (2, 4)), + ((3, 2), (5, 2, 4)), + ((3, 2), (0, 2, 4)), + ((5, 3, 2), (1, 2, 4)), + ((0, 3, 2), (1, 2, 4)), + ((1, 3, 2), (5, 2, 4)), + ((1, 3, 2), (0, 2, 4)), + ((5, 3, 2), (2,)), + ((5, 3, 0), (0,)), + ((2,), (5, 2, 4)), + ((0,), (5, 0, 4)), + ((2, 2, 3, 2), (2, 2, 2, 4)), + ((5, 0, 3, 2), (5, 0, 2, 4)), + ((6, 5, 3, 2), (2, 4)), + ((5, 0, 3, 2), (2, 4)), + ((3, 2), (6, 5, 2, 4)), + ((3, 2), (5, 0, 2, 4)), + ((1, 5, 3, 2), (6, 1, 2, 4)), + ((1, 0, 3, 2), (6, 1, 2, 4)), + ((6, 1, 3, 2), (1, 5, 2, 4)), + ((6, 1, 3, 2), (1, 0, 2, 4)), + ((6, 5, 3, 2), (2,)), + ((6, 5, 3, 0), (0,)), + ((2,), (6, 5, 2, 4)), + ((0,), (6, 5, 0, 4)), ((1, 3, 3), (10, 1, 3, 1)), ], } @@ -61,14 +61,18 @@ @testing.gpu class TestMatmul(unittest.TestCase): @testing.for_all_dtypes(name="dtype1") - @testing.numpy_cupy_allclose(rtol=1e-3, atol=1e-3) # required for uint8 + @testing.numpy_cupy_allclose( + rtol=1e-3, atol=1e-3, type_check=False + ) # required for uint8 def test_operator_matmul(self, xp, dtype1): x1 = testing.shaped_arange(self.shape_pair[0], xp, dtype1) x2 = testing.shaped_arange(self.shape_pair[1], xp, dtype1) return operator.matmul(x1, x2) @testing.for_all_dtypes(name="dtype1") - @testing.numpy_cupy_allclose(rtol=1e-3, atol=1e-3) # required for uint8 + @testing.numpy_cupy_allclose( + rtol=1e-3, atol=1e-3, type_check=False + ) # required for uint8 def test_cupy_matmul(self, xp, dtype1): x1 = testing.shaped_arange(self.shape_pair[0], xp, dtype1) x2 = testing.shaped_arange(self.shape_pair[1], xp, dtype1) @@ -110,7 +114,9 @@ class TestMatmulLarge(unittest.TestCase): } @testing.for_all_dtypes(name="dtype1") - @testing.numpy_cupy_allclose(rtol=1e-3, atol=1e-3) # required for uint8 + @testing.numpy_cupy_allclose( + rtol=1e-3, atol=1e-3, type_check=False + ) # required for uint8 def test_operator_matmul(self, xp, dtype1): if (dtype1, dtype1) in self.skip_dtypes or ( dtype1, @@ -122,7 +128,9 @@ def test_operator_matmul(self, xp, dtype1): return operator.matmul(x1, x2) @testing.for_all_dtypes(name="dtype1") - @testing.numpy_cupy_allclose(rtol=1e-3, atol=1e-3) # required for uint8 + @testing.numpy_cupy_allclose( + rtol=1e-3, atol=1e-3, type_check=False + ) # required for uint8 def test_cupy_matmul(self, xp, dtype1): if (dtype1, dtype1) in self.skip_dtypes or ( dtype1, @@ -151,7 +159,6 @@ def test_cupy_matmul(self, xp, dtype1): } ) ) -@pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.gpu class TestMatmulInvalidShape(unittest.TestCase): def test_invalid_shape(self):