Skip to content

Commit

Permalink
Add support for N-D array
Browse files Browse the repository at this point in the history
add N-dimension
  • Loading branch information
vtavana committed Nov 19, 2023
1 parent 3444816 commit b8f7f00
Show file tree
Hide file tree
Showing 9 changed files with 617 additions and 140 deletions.
1 change: 1 addition & 0 deletions dpnp/backend/extensions/blas/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ set(python_module_name _blas_impl)
set(_module_src
${CMAKE_CURRENT_SOURCE_DIR}/blas_py.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemm_batch.cpp
)

pybind11_add_module(${python_module_name} MODULE ${_module_src})
Expand Down
20 changes: 15 additions & 5 deletions dpnp/backend/extensions/blas/blas_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,25 @@ namespace py = pybind11;
void init_dispatch_tables(void)
{
blas_ext::init_gemm_dispatch_table();
blas_ext::init_gemm_batch_dispatch_table();
}

PYBIND11_MODULE(_blas_impl, m)
{
init_dispatch_tables();

m.def("_gemm", &blas_ext::gemm,
"Call `gemm` from OneMKL LAPACK library to return "
"the matrix-matrix product with general matrices.",
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
py::arg("matrixC"), py::arg("depends") = py::list());
{
m.def("_gemm", &blas_ext::gemm,
"Call `gemm` from OneMKL LAPACK library to return "
"the matrix-matrix product with 2-D matrices.",
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
py::arg("matrixC"), py::arg("isRowMajor"),
py::arg("depends") = py::list());
}

{
m.def("_gemm_batch", &blas_ext::gemm_batch,
"Call `gemm_batch` from OneMKL LAPACK library to return "
"the matrix-matrix product with general matrices.");
}
}
72 changes: 53 additions & 19 deletions dpnp/backend/extensions/blas/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue,
const std::int64_t,
char *,
const std::int64_t,
const bool,
const std::vector<sycl::event> &);

static gemm_impl_fn_ptr_t gemm_dispatch_table[dpctl_td_ns::num_types]
Expand All @@ -76,6 +77,7 @@ static sycl::event gemm_impl(sycl::queue exec_q,
const std::int64_t ld_array_2,
char *resultC,
const std::int64_t ld_result,
const bool isRowMajor,
const std::vector<sycl::event> &depends)
{
type_utils::validate_type_for_device<Tab>(exec_q);
Expand All @@ -92,24 +94,54 @@ static sycl::event gemm_impl(sycl::queue exec_q,
sycl::event gemm_event;
try {
// Need to add logic to call column_major::gemm
gemm_event = mkl_blas::row_major::gemm(
exec_q,
transA, // Parameter indicating whether matrix A is not transposed
// ('N'), transposed ('T'), or conjugate transposed ('C').
transB, // Same as transA but for matrix B.
m, // Number of rows in matrices A and C.
n, // Number of columns in matrices B and C.
k, // Number of columns in matrix A and rows in matrix B.
Tab(1), // Scaling factor for the product of matrices A and B.
a, // Pointer to matrix A.
ld_array_1, // Leading dimension of matrix A, which is the stride
// between successive rows (for row major layout).
b, // Pointer to matrix B.
ld_array_2, // Leading dimension of matrix B, similar to ld_array_1.
Tab(0), // Scaling factor for matrix C.
res, // Pointer to matrix C, where the result is stored.
ld_result, // Leading dimension of matrix C.
depends);
if (isRowMajor) {
gemm_event = mkl_blas::row_major::gemm(
exec_q,
transA, // Parameter indicating whether matrix A is not
// transposed
// ('N'), transposed ('T'), or conjugate transposed
// ('C').
transB, // Same as transA but for matrix B.
m, // Number of rows in matrices A and C.
n, // Number of columns in matrices B and C.
k, // Number of columns in matrix A and rows in matrix B.
Tab(1), // Scaling factor for the product of matrices A and B.
a, // Pointer to matrix A.
ld_array_1, // Leading dimension of matrix A, which is the
// stride between successive rows (for row major
// layout).
b, // Pointer to matrix B.
ld_array_2, // Leading dimension of matrix B, similar to
// ld_array_1.
Tab(0), // Scaling factor for matrix C.
res, // Pointer to matrix C, where the result is stored.
ld_result, // Leading dimension of matrix C.
depends);
}
else {
gemm_event = mkl_blas::column_major::gemm(
exec_q,
transA, // Parameter indicating whether matrix A is not
// transposed
// ('N'), transposed ('T'), or conjugate transposed
// ('C').
transB, // Same as transA but for matrix B.
m, // Number of rows in matrices A and C.
n, // Number of columns in matrices B and C.
k, // Number of columns in matrix A and rows in matrix B.
Tab(1), // Scaling factor for the product of matrices A and B.
a, // Pointer to matrix A.
ld_array_1, // Leading dimension of matrix A, which is the
// stride between successive rows (for row major
// layout).
b, // Pointer to matrix B.
ld_array_2, // Leading dimension of matrix B, similar to
// ld_array_1.
Tab(0), // Scaling factor for matrix C.
res, // Pointer to matrix C, where the result is stored.
ld_result, // Leading dimension of matrix C.
depends);
}
} catch (oneapi::mkl::exception const &e) {
error_msg
<< "Unexpected MKL exception caught during gemm() call:\nreason: "
Expand All @@ -134,6 +166,7 @@ std::pair<sycl::event, sycl::event>
dpctl::tensor::usm_ndarray matrixA,
dpctl::tensor::usm_ndarray matrixB,
dpctl::tensor::usm_ndarray resultC,
const bool isRowMajor,
const std::vector<sycl::event> &depends)
{
const int matrixA_nd = matrixA.get_ndim();
Expand Down Expand Up @@ -234,7 +267,8 @@ std::pair<sycl::event, sycl::event>
std::vector<sycl::event> host_task_events;
sycl::event gemm_ev =
gemm_fn(exec_q, transA, transB, m, n, k, a_typeless_ptr, ld_array_1,
b_typeless_ptr, ld_array_2, r_typeless_ptr, ld_result, depends);
b_typeless_ptr, ld_array_2, r_typeless_ptr, ld_result,
isRowMajor, depends);

sycl::event args_ev = dpctl::utils::keep_args_alive(
exec_q, {matrixA, matrixB, resultC}, host_task_events);
Expand Down
22 changes: 22 additions & 0 deletions dpnp/backend/extensions/blas/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,31 @@ extern std::pair<sycl::event, sycl::event>
dpctl::tensor::usm_ndarray matrixA,
dpctl::tensor::usm_ndarray matrixB,
dpctl::tensor::usm_ndarray resultC,
const bool isRowMajor,
const std::vector<sycl::event> &depends);

// extern sycl::event
extern std::pair<sycl::event, sycl::event>
gemm_batch(sycl::queue q,
dpctl::tensor::usm_ndarray matrixA,
dpctl::tensor::usm_ndarray matrixB,
dpctl::tensor::usm_ndarray resultC,
const std::int64_t m,
const std::int64_t n,
const std::int64_t k,
const std::int64_t batch_size,
const std::int64_t ld_array_1,
const std::int64_t ld_array_2,
const std::int64_t ld_result,
size_t stridea,
size_t strideb,
size_t stridec,
const std::int64_t transA_int,
const std::int64_t transB_int,
const std::vector<sycl::event> &depends);

extern void init_gemm_dispatch_table(void);
extern void init_gemm_batch_dispatch_table(void);
} // namespace blas
} // namespace ext
} // namespace backend
Expand Down
Loading

0 comments on commit b8f7f00

Please sign in to comment.