Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dpnp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ endfunction()

build_dpnp_cython_ext_with_backend(dparray ${CMAKE_CURRENT_SOURCE_DIR}/dparray.pyx dpnp)
add_subdirectory(backend)
add_subdirectory(backend/extensions/blas)
add_subdirectory(backend/extensions/lapack)
add_subdirectory(backend/extensions/vm)
add_subdirectory(backend/extensions/sycl_ext)
Expand Down
83 changes: 83 additions & 0 deletions dpnp/backend/extensions/blas/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# *****************************************************************************
# Copyright (c) 2016-2023, Intel Corporation
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# - Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# - Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
# THE POSSIBILITY OF SUCH DAMAGE.
# *****************************************************************************


set(python_module_name _blas_impl)
set(_module_src
${CMAKE_CURRENT_SOURCE_DIR}/blas_py.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemm_batch.cpp
)

pybind11_add_module(${python_module_name} MODULE ${_module_src})
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_module_src})

if (WIN32)
if (${CMAKE_VERSION} VERSION_LESS "3.27")
# this is a work-around for target_link_options inserting option after -link option, cause
# linker to ignore it.
set(CMAKE_CXX_LINK_FLAGS "${CMAKE_CXX_LINK_FLAGS} -fsycl-device-code-split=per_kernel")
endif()
endif()

set_target_properties(${python_module_name} PROPERTIES CMAKE_POSITION_INDEPENDENT_CODE ON)

target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../include)
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../src)

target_include_directories(${python_module_name} PUBLIC ${Dpctl_INCLUDE_DIRS})
target_include_directories(${python_module_name} PUBLIC ${Dpctl_TENSOR_INCLUDE_DIR})

if (WIN32)
target_compile_options(${python_module_name} PRIVATE
/clang:-fno-approx-func
/clang:-fno-finite-math-only
)
else()
target_compile_options(${python_module_name} PRIVATE
-fno-approx-func
-fno-finite-math-only
)
endif()

target_link_options(${python_module_name} PUBLIC -fsycl-device-code-split=per_kernel)
if (UNIX)
# this option is support on Linux only
target_link_options(${python_module_name} PUBLIC -fsycl-link-huge-device-code)
endif()

if (DPNP_GENERATE_COVERAGE)
target_link_options(${python_module_name} PRIVATE -fprofile-instr-generate -fcoverage-mapping)
endif()

if (MKL_VERSION_2024)
target_link_libraries(${python_module_name} PUBLIC MKL::MKL_SYCL::BLAS)
else()
target_link_libraries(${python_module_name} PUBLIC MKL::MKL_DPCPP)
endif()

install(TARGETS ${python_module_name}
DESTINATION "dpnp/backend/extensions/blas"
)
62 changes: 62 additions & 0 deletions dpnp/backend/extensions/blas/blas_py.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
//*****************************************************************************
// Copyright (c) 2023, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// - Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// - Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************
//
// This file defines functions of dpnp.backend._lapack_impl extensions
//
//*****************************************************************************

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "gemm.hpp"

namespace blas_ext = dpnp::backend::ext::blas;
namespace py = pybind11;

// populate dispatch tables
void init_dispatch_tables(void)
{
blas_ext::init_gemm_dispatch_table();
blas_ext::init_gemm_batch_dispatch_table();
}

PYBIND11_MODULE(_blas_impl, m)
{
init_dispatch_tables();

{
m.def("_gemm", &blas_ext::gemm,
"Call `gemm` from OneMKL LAPACK library to return "
"the matrix-matrix product with 2-D matrices.",
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
py::arg("matrixC"), py::arg("depends") = py::list());
}

{
m.def("_gemm_batch", &blas_ext::gemm_batch,
"Call `gemm_batch` from OneMKL LAPACK library to return "
"the matrix-matrix product with general matrices.");
}
}
244 changes: 244 additions & 0 deletions dpnp/backend/extensions/blas/gemm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
//*****************************************************************************
// Copyright (c) 2023, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// - Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// - Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************

#include <pybind11/pybind11.h>

// dpctl tensor headers
#include "utils/memory_overlap.hpp"
#include "utils/type_utils.hpp"

#include "gemm.hpp"
#include "types_matrix.hpp"

#include "dpnp_utils.hpp"

namespace dpnp
{
namespace backend
{
namespace ext
{
namespace blas
{
namespace mkl_blas = oneapi::mkl::blas;
namespace py = pybind11;
namespace type_utils = dpctl::tensor::type_utils;

typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue,
oneapi::mkl::transpose,
oneapi::mkl::transpose,
const std::int64_t,
const std::int64_t,
const std::int64_t,
char *,
const std::int64_t,
char *,
const std::int64_t,
char *,
const std::int64_t,
const std::vector<sycl::event> &);

static gemm_impl_fn_ptr_t gemm_dispatch_table[dpctl_td_ns::num_types]
[dpctl_td_ns::num_types];

template <typename Tab, typename Tc>
static sycl::event gemm_impl(sycl::queue exec_q,
oneapi::mkl::transpose transA,
oneapi::mkl::transpose transB,
const std::int64_t m,
const std::int64_t n,
const std::int64_t k,
char *matrixA,
const std::int64_t lda,
char *matrixB,
const std::int64_t ldb,
char *resultC,
const std::int64_t ldc,
const std::vector<sycl::event> &depends)
{
type_utils::validate_type_for_device<Tab>(exec_q);
type_utils::validate_type_for_device<Tc>(exec_q);

Tab *a = reinterpret_cast<Tab *>(matrixA);
Tab *b = reinterpret_cast<Tab *>(matrixB);
Tc *res = reinterpret_cast<Tc *>(resultC);

std::stringstream error_msg;
std::int64_t info = 0;
bool mkl_exception_caught = false;

sycl::event gemm_event;
try {
gemm_event = mkl_blas::row_major::gemm(
exec_q,
transA, // Parameter indicating whether matrix A is not
// transposed ('N'), transposed ('T'),
// or conjugate transposed ('C').
transB, // Same as transA but for matrix B.
m, // Number of rows in matrices A and C.
n, // Number of columns in matrices B and C.
k, // Number of columns in matrix A and rows in matrix B.
Tab(1), // Scaling factor for the product of matrices A and B.
a, // Pointer to matrix A.
lda, // Leading dimension of matrix A, which is the
// stride between successive rows (for row major
// layout).
b, // Pointer to matrix B.
ldb, // Leading dimension of matrix B, similar to lda
Tab(0), // Scaling factor for matrix C.
res, // Pointer to matrix C, where the result is stored.
ldc, // Leading dimension of matrix C.
depends);
} catch (oneapi::mkl::exception const &e) {
error_msg
<< "Unexpected MKL exception caught during gemm() call:\nreason: "
<< e.what();
mkl_exception_caught = true;
} catch (sycl::exception const &e) {
error_msg << "Unexpected SYCL exception caught during gemm() call:\n"
<< e.what();
info = -1;
}

if (info != 0 || mkl_exception_caught) // an unexpected error occurs
{
throw std::runtime_error(error_msg.str());
}

return gemm_event;
}

std::pair<sycl::event, sycl::event>
gemm(sycl::queue exec_q,
dpctl::tensor::usm_ndarray matrixA,
dpctl::tensor::usm_ndarray matrixB,
dpctl::tensor::usm_ndarray resultC,
const std::vector<sycl::event> &depends)
{
const int matrixA_nd = matrixA.get_ndim();
const int matrixB_nd = matrixB.get_ndim();
const int resultC_nd = resultC.get_ndim();

if ((matrixA_nd != 2) || (matrixB_nd != 2) || (resultC_nd != 2)) {
throw py::value_error("The input matrices must be of 2 dimensions.");
}

// check compatibility of execution queue and allocation queue
if (!dpctl::utils::queues_are_compatible(
exec_q,
{matrixA.get_queue(), matrixB.get_queue(), resultC.get_queue()}))
{
throw std::runtime_error(
"USM allocations are not compatible with the execution queue.");
}

bool is_matrixA_f_contig = matrixA.is_f_contiguous();
bool is_matrixB_f_contig = matrixB.is_f_contiguous();

const py::ssize_t *a_shape = matrixA.get_shape_raw();
const py::ssize_t *b_shape = matrixB.get_shape_raw();
const py::ssize_t *res_shape = resultC.get_shape_raw();

if (a_shape[1] != b_shape[0]) {
throw std::runtime_error("The number of columns in A must be equal to "
"the number of rows in B.");
}

oneapi::mkl::transpose transA = is_matrixA_f_contig
? oneapi::mkl::transpose::T
: oneapi::mkl::transpose::N;
oneapi::mkl::transpose transB = is_matrixB_f_contig
? oneapi::mkl::transpose::T
: oneapi::mkl::transpose::N;

const std::int64_t m = a_shape[0];
const std::int64_t n = b_shape[1];
const std::int64_t k = a_shape[1];

const std::int64_t lda =
(transA == oneapi::mkl::transpose::N) ? a_shape[1] : a_shape[0];
const std::int64_t ldb =
(transB == oneapi::mkl::transpose::N) ? b_shape[1] : b_shape[0];
const std::int64_t ldc = res_shape[1];

int matrixA_typenum = matrixA.get_typenum();
int matrixB_typenum = matrixB.get_typenum();
int resultC_typenum = resultC.get_typenum();

if (matrixA_typenum != matrixB_typenum) {
throw py::value_error("matrixA and matrixB must be of the same type.");
}

auto array_types = dpctl_td_ns::usm_ndarray_types();
int matrixAB_type_id = array_types.typenum_to_lookup_id(matrixA_typenum);
int resultC_type_id = array_types.typenum_to_lookup_id(resultC_typenum);

gemm_impl_fn_ptr_t gemm_fn =
gemm_dispatch_table[matrixAB_type_id][resultC_type_id];
if (gemm_fn == nullptr) {
throw py::value_error("Type dispatch ran into trouble.");
}

char *a_typeless_ptr = matrixA.get_data();
char *b_typeless_ptr = matrixB.get_data();
char *r_typeless_ptr = resultC.get_data();

std::vector<sycl::event> host_task_events;
sycl::event gemm_ev =
gemm_fn(exec_q, transA, transB, m, n, k, a_typeless_ptr, lda,
b_typeless_ptr, ldb, r_typeless_ptr, ldc, depends);

host_task_events.push_back(gemm_ev);
sycl::event args_ev = dpctl::utils::keep_args_alive(
exec_q, {matrixA, matrixB, resultC}, host_task_events);

return std::make_pair(args_ev, gemm_ev);
}

template <typename fnT, typename Tab, typename Tc>
struct GemmContigFactory
{
fnT get()
{
if constexpr (types::GemmTypePairSupportFactory<Tab, Tc>::is_defined) {
return gemm_impl<Tab, Tc>;
}
else {
return nullptr;
}
}
};

void init_gemm_dispatch_table(void)
{
dpctl_td_ns::DispatchTableBuilder<gemm_impl_fn_ptr_t, GemmContigFactory,
dpctl_td_ns::num_types>
contig;
contig.populate_dispatch_table(gemm_dispatch_table);
}
} // namespace blas
} // namespace ext
} // namespace backend
} // namespace dpnp
Loading