Skip to content

Commit 5768442

Browse files
Example of calling MKL on usm_ndarray inputs
1 parent 335fb23 commit 5768442

File tree

8 files changed

+324
-0
lines changed

8 files changed

+324
-0
lines changed

cmake/FindDpctl.cmake

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#.rst:
2+
#
3+
# Find the include directory for ``dpctl_capi.h``, ``dpctl4pybind11.hpp``.
4+
#
5+
# This module sets the following variables:
6+
#
7+
# ``Dpctl_FOUND``
8+
# True if DPCTL was found.
9+
# ``Dpctl_INCLUDE_DIRS``
10+
# The include directories needed to use Dpctl.
11+
# ``Dpctl_VERSION``
12+
# The version of DPCTL found.
13+
#
14+
# The module will also explicitly define one cache variable:
15+
#
16+
# ``Dpctl_INCLUDE_DIR``
17+
#
18+
19+
if(NOT Dpctl_FOUND)
20+
set(_find_extra_args)
21+
if(Dpctl_FIND_REQUIRED)
22+
list(APPEND _find_extra_args REQUIRED)
23+
endif()
24+
if(Dpctl_FIND_QUIET)
25+
list(APPEND _find_extra_args QUIET)
26+
endif()
27+
find_package(PythonInterp ${_find_extra_args})
28+
find_package(PythonLibs ${_find_extra_args})
29+
30+
if(PYTHON_EXECUTABLE)
31+
execute_process(COMMAND "${PYTHON_EXECUTABLE}"
32+
-c "import dpctl; print(dpctl.get_include())"
33+
OUTPUT_VARIABLE _dpctl_include_dir
34+
OUTPUT_STRIP_TRAILING_WHITESPACE
35+
ERROR_QUIET
36+
)
37+
execute_process(COMMAND "${PYTHON_EXECUTABLE}"
38+
-c "import dpctl; print(dpctl.__version__)"
39+
OUTPUT_VARIABLE Dpctl_VERSION
40+
OUTPUT_STRIP_TRAILING_WHITESPACE
41+
ERROR_QUIET
42+
)
43+
44+
endif()
45+
endif()
46+
47+
find_path(Dpctl_INCLUDE_DIR
48+
dpctl_capi.h dpctl4pybind11.hpp dpctl_sycl_interface.h
49+
PATHS "${_dpctl_include_dir}" "${PYTHON_INCLUDE_DIR}"
50+
PATH_SUFFIXES dpctl/include
51+
)
52+
53+
set(Dpctl_INCLUDE_DIRS ${Dpctl_INCLUDE_DIR})
54+
55+
# handle the QUIETLY and REQUIRED arguments and set Dpctl_FOUND to TRUE if
56+
# all listed variables are TRUE
57+
include(FindPackageHandleStandardArgs)
58+
find_package_handle_standard_args(Dpctl
59+
REQUIRED_VARS
60+
Dpctl_INCLUDE_DIR
61+
VERSION_VAR Dpctl_VERSION
62+
)
63+
64+
mark_as_advanced(Dpctl_INCLUDE_DIR)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
cmake_minimum_required(VERSION 3.22 FATAL_ERROR)
2+
3+
project(example_use_mkl_gemm LANGUAGES CXX)
4+
set(DPCTL_CMAKE_MODULES_PATH "${CMAKE_SOURCE_DIR}/../../../cmake")
5+
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${DPCTL_CMAKE_MODULES_PATH})
6+
find_package(IntelDPCPP REQUIRED PATHS ${DPCTL_CMAKE_MODULES_PATH} NO_DEFAULT_PATH)
7+
8+
9+
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}")
10+
set(CMAKE_CXX_STANDARD 17)
11+
set(CMAKE_CXX_STANDARD_REQUIRED True)
12+
13+
# Define CMAKE_INSTALL_xxx: LIBDIR, INCLUDEDIR
14+
include(GNUInstallDirs)
15+
16+
# Fetch pybind11
17+
include(FetchContent)
18+
FetchContent_Declare(
19+
pybind11
20+
URL https://github.com/pybind/pybind11/archive/refs/tags/v2.9.0.tar.gz
21+
URL_HASH SHA256=057fb68dafd972bc13afb855f3b0d8cf0fa1a78ef053e815d9af79be7ff567cb
22+
)
23+
FetchContent_MakeAvailable(pybind11)
24+
25+
find_package(PythonExtensions REQUIRED)
26+
find_package(Dpctl REQUIRED)
27+
28+
find_library(mkl_core NAMES mkl_core PATH ${MKL_LIBRARY_DIR})
29+
find_library(mkl_sycl NAMES mkl_sycl PATH ${MKL_LIBRARY_DIR})
30+
find_library(mkl_intel_ilp64 NAMES mkl_intel_ilp64 PATH ${MKL_LIBRARY_DIR})
31+
find_library(mkl_tbb_thread NAMES mkl_tbb_thread PATH ${MKL_LIBRARY_DIR})
32+
find_library(tbb NAMES tbb PATH ${TBB_LIBRARY_DIR})
33+
34+
set(py_module_name _onemkl)
35+
36+
pybind11_add_module(${py_module_name}
37+
MODULE
38+
sycl_gemm/_onemkl.cpp
39+
)
40+
target_include_directories(${py_module_name}
41+
PUBLIC ${MKL_INCLUDE_DIR} ${TBB_INCLUDE_DIR}
42+
)
43+
target_link_libraries(${py_module_name}
44+
PUBLIC mkl_sycl mkl_intel_ilp64 mkl_tbb_thread mkl_core tbb
45+
)
46+
47+
install(TARGETS ${py_module_name} DESTINATION sycl_gemm)
48+
target_include_directories(${py_module_name} PUBLIC ${Dpctl_INCLUDE_DIRS})
49+
50+
get_target_property(_sycl_gemm_sources ${py_module_name} SOURCES)
51+
set_source_files_properties(${_sycl_gemm_sources}
52+
PROPERTIES
53+
COMPILE_OPTIONS "-O3;-Wno-deprecated-declarations"
54+
)
55+
56+
set(ignoreMe "${SKBUILD}")
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
Example of SYCL built pybind11 extension
2+
3+
To build, use (assumes scikit-build and dpcpp) is installed
4+
5+
```sh
6+
python setup.py develop -- -G "Ninja" -DCMAKE_C_COMPILER:PATH=icx -DCMAKE_CXX_COMPILER:PATH=icpx -DTBB_LIBRARY_DIR=$CONDA_PREFIX/lib -DMKL_LIBRARY_DIR=${CONDA_PREFIX}/lib -DMKL_INCLUDE_DIR=${CONDA_PREFIX}/include -DTBB_INCLUDE_DIR=${CONDA_PREFIX}/include
7+
```
8+
9+
To run test suite
10+
11+
```sh
12+
python -m pytest tests
13+
```
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from skbuild import setup
2+
3+
setup(
4+
name="sycl_gemm",
5+
version="0.0.1",
6+
description="an example of SYCL-powered Python package (with pybind11)",
7+
author="Intel Scripting",
8+
license="Apache 2.0",
9+
packages=["sycl_gemm"],
10+
)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from ._onemkl import gemv
2+
3+
__all__ = ["gemv"]
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
#include "dpctl4pybind11.hpp"
2+
#include <CL/sycl.hpp>
3+
#include <oneapi/mkl.hpp>
4+
#include <pybind11/pybind11.h>
5+
#include <pybind11/stl.h>
6+
7+
namespace py = pybind11;
8+
9+
/* DPCTL C-API for usm_ndarray
10+
UsmNDArray_GetData
11+
UsmNDArray_GetNDim
12+
UsmNDArray_GetShape
13+
UsmNDArray_GetStrides
14+
UsmNDArray_GetTypenum
15+
UsmNDArray_GetFlags
16+
UsmNDArray_GetQueueRef
17+
*/
18+
19+
sycl::event gemv(sycl::queue q,
20+
py::object matrix,
21+
py::object vector,
22+
py::object result,
23+
const std::vector<sycl::event> &depends = {})
24+
{
25+
PyObject *m_src = matrix.ptr();
26+
if (!PyObject_TypeCheck(m_src, &PyUSMArrayType)) {
27+
throw std::runtime_error("Matrix is not a dpctl.tensor.usm_ndarray");
28+
}
29+
30+
PyObject *v_src = vector.ptr();
31+
if (!PyObject_TypeCheck(v_src, &PyUSMArrayType)) {
32+
throw std::runtime_error("Vector is not a dpctl.tensor.usm_ndarray");
33+
}
34+
35+
PyObject *r_src = result.ptr();
36+
if (!PyObject_TypeCheck(r_src, &PyUSMArrayType)) {
37+
throw std::runtime_error("Result is not a dpctl.tensor.usm_ndarray");
38+
}
39+
40+
PyUSMArrayObject *m_usm_ary = reinterpret_cast<PyUSMArrayObject *>(m_src);
41+
PyUSMArrayObject *v_usm_ary = reinterpret_cast<PyUSMArrayObject *>(v_src);
42+
PyUSMArrayObject *r_usm_ary = reinterpret_cast<PyUSMArrayObject *>(r_src);
43+
44+
if (UsmNDArray_GetNDim(m_usm_ary) != 2 ||
45+
UsmNDArray_GetNDim(v_usm_ary) != 1 ||
46+
UsmNDArray_GetNDim(r_usm_ary) != 1)
47+
{
48+
throw std::runtime_error(
49+
"Inconsistent dimensions, expecting matrix and a vector");
50+
}
51+
52+
py::ssize_t *m_sh = UsmNDArray_GetShape(m_usm_ary);
53+
py::ssize_t n = m_sh[0];
54+
py::ssize_t m = m_sh[1];
55+
56+
py::ssize_t *v_sh = UsmNDArray_GetShape(v_usm_ary);
57+
py::ssize_t *r_sh = UsmNDArray_GetShape(r_usm_ary);
58+
if (v_sh[0] != m || r_sh[0] != n) {
59+
throw std::runtime_error("Inconsistent shapes.");
60+
}
61+
62+
int mat_flags = UsmNDArray_GetFlags(m_usm_ary);
63+
int v_flags = UsmNDArray_GetFlags(v_usm_ary);
64+
int r_flags = UsmNDArray_GetFlags(r_usm_ary);
65+
66+
if (!((mat_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS)) &&
67+
(v_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS)) &&
68+
(r_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS))))
69+
{
70+
throw std::runtime_error("Arrays must be contiguous.");
71+
}
72+
73+
int mat_typenum = UsmNDArray_GetTypenum(m_usm_ary);
74+
int v_typenum = UsmNDArray_GetTypenum(v_usm_ary);
75+
int r_typenum = UsmNDArray_GetTypenum(r_usm_ary);
76+
77+
if ((mat_typenum != v_typenum) || (r_typenum != v_typenum) ||
78+
!((v_typenum == UAR_DOUBLE) || (v_typenum == UAR_FLOAT) ||
79+
(v_typenum == UAR_CDOUBLE) || (v_typenum == UAR_CFLOAT)))
80+
{
81+
std::cout << "Found: [" << mat_typenum << ", " << v_typenum << ", "
82+
<< r_typenum << "]" << std::endl;
83+
std::cout << "Expected: [" << UAR_DOUBLE << ", " << UAR_FLOAT << ", "
84+
<< UAR_CDOUBLE << ", " << UAR_CFLOAT << "]" << std::endl;
85+
throw std::runtime_error(
86+
"Only real and complex floating point arrays are supported.");
87+
}
88+
89+
char *mat_typeless_ptr = UsmNDArray_GetData(m_usm_ary);
90+
char *v_typeless_ptr = UsmNDArray_GetData(v_usm_ary);
91+
char *r_typeless_ptr = UsmNDArray_GetData(r_usm_ary);
92+
93+
sycl::event res_ev;
94+
if (v_typenum == UAR_DOUBLE) {
95+
using T = double;
96+
sycl::event gemv_ev = oneapi::mkl::blas::row_major::gemv(
97+
q, oneapi::mkl::transpose::nontrans, n, m, T(1),
98+
reinterpret_cast<T *>(mat_typeless_ptr), m,
99+
reinterpret_cast<T *>(v_typeless_ptr), 1, T(0),
100+
reinterpret_cast<T *>(r_typeless_ptr), 1, depends);
101+
res_ev = gemv_ev;
102+
}
103+
else if (v_typenum == UAR_FLOAT) {
104+
using T = float;
105+
sycl::event gemv_ev = oneapi::mkl::blas::row_major::gemv(
106+
q, oneapi::mkl::transpose::nontrans, n, m, T(1),
107+
reinterpret_cast<T *>(mat_typeless_ptr), m,
108+
reinterpret_cast<T *>(v_typeless_ptr), 1, T(0),
109+
reinterpret_cast<T *>(r_typeless_ptr), 1, depends);
110+
res_ev = gemv_ev;
111+
}
112+
else if (v_typenum == UAR_CDOUBLE) {
113+
using T = std::complex<double>;
114+
sycl::event gemv_ev = oneapi::mkl::blas::row_major::gemv(
115+
q, oneapi::mkl::transpose::nontrans, n, m, T(1),
116+
reinterpret_cast<T *>(mat_typeless_ptr), m,
117+
reinterpret_cast<T *>(v_typeless_ptr), 1, T(0),
118+
reinterpret_cast<T *>(r_typeless_ptr), 1, depends);
119+
res_ev = gemv_ev;
120+
}
121+
else if (v_typenum == UAR_CFLOAT) {
122+
using T = std::complex<float>;
123+
sycl::event gemv_ev = oneapi::mkl::blas::row_major::gemv(
124+
q, oneapi::mkl::transpose::nontrans, n, m, T(1),
125+
reinterpret_cast<T *>(mat_typeless_ptr), m,
126+
reinterpret_cast<T *>(v_typeless_ptr), 1, T(0),
127+
reinterpret_cast<T *>(r_typeless_ptr), 1, depends);
128+
res_ev = gemv_ev;
129+
}
130+
else {
131+
throw std::runtime_error("Type dispatch ran into trouble.");
132+
}
133+
134+
return res_ev;
135+
}
136+
137+
PYBIND11_MODULE(_onemkl, m)
138+
{
139+
// Import the dpctl extensions
140+
import_dpctl();
141+
m.def("gemv", &gemv, "Uses oneMKL to compute dot(matrix, vector)");
142+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#!/bin/bash -x
2+
3+
export PYBIND11_INCLUDES=$(python3 -m pybind11 --includes)
4+
export DPCTL_INCLUDE_DIR=$(python -c "import dpctl; print(dpctl.get_include())")
5+
export DPCTL_LIB_DIR=${DPCTL_INCLUDE_DIR}/..
6+
export PY_EXT_SUFFIX=$(python3-config --extension-suffix)
7+
export HOST_COMPILER_FLAGS="-g -std=c++2a -O3 -Wno-return-type -Wno-deprecated-declarations -fPIC ${PYBIND11_INCLUDES} -I${DPCTL_INCLUDE_DIR}"
8+
9+
# -fsycl-host-compiler=g++ \
10+
# -fsycl-host-compiler-options="${HOST_COMPILER_FLAGS}" \
11+
12+
dpcpp -O3 -fsycl -Wno-deprecated-declarations \
13+
-fpic -fPIC -shared \
14+
${PYBIND11_INCLUDES} -I${DPCTL_INCLUDE_DIR} \
15+
sycl_gemm.cpp -o _sycl_gemm${PY_EXT_SUFFIX}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import numpy as np
2+
import pytest
3+
from sycl_gemm import gemv
4+
5+
import dpctl
6+
import dpctl.tensor as dpt
7+
8+
9+
def test_gemv():
10+
try:
11+
q = dpctl.SyclQueue()
12+
except dpctl.SyclQueueCreationError:
13+
pytest.skip("Queue could not be created")
14+
Mnp, vnp = np.random.randn(5, 3), np.random.randn(3)
15+
r = dpt.empty((5,), dtype="d", sycl_queue=q)
16+
M = dpt.asarray(Mnp, sycl_queue=q)
17+
v = dpt.asarray(vnp, sycl_queue=q)
18+
ev = gemv(M.sycl_queue, M, v, r, [])
19+
ev.wait()
20+
rnp = dpt.asnumpy(r)
21+
assert np.allclose(rnp, Mnp @ vnp)

0 commit comments

Comments
 (0)