Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed extra copy for transpose arrays in dot() #1477

Merged
merged 5 commits into from
Jul 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/conda-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ env:
test_special.py
test_umath.py
test_usm_type.py
third_party/cupy/linalg_tests/test_product.py
third_party/cupy/math_tests/test_explog.py
third_party/cupy/math_tests/test_misc.py
third_party/cupy/math_tests/test_trigonometric.py
Expand Down
159 changes: 87 additions & 72 deletions dpnp/backend/kernels/dpnp_krnl_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <dpnp_iface.hpp>

namespace mkl_blas = oneapi::mkl::blas;
namespace mkl_blas_cm = oneapi::mkl::blas::column_major;
namespace mkl_blas_rm = oneapi::mkl::blas::row_major;
namespace mkl_lapack = oneapi::mkl::lapack;

Expand Down Expand Up @@ -227,12 +228,10 @@ DPCTLSyclEventRef dpnp_dot_c(DPCTLSyclQueueRef q_ref,
DPCTLSyclEventRef event_ref = nullptr;
sycl::queue q = *(reinterpret_cast<sycl::queue *>(q_ref));

DPNPC_ptr_adapter<_DataType_input1> input1_ptr(q_ref, input1_in,
input1_size);
DPNPC_ptr_adapter<_DataType_input2> input2_ptr(q_ref, input2_in,
input2_size);
_DataType_input1 *input1 = input1_ptr.get_ptr();
_DataType_input2 *input2 = input2_ptr.get_ptr();
_DataType_input1 *input1 =
static_cast<_DataType_input1 *>(const_cast<void *>(input1_in));
_DataType_input2 *input2 =
static_cast<_DataType_input2 *>(const_cast<void *>(input2_in));
_DataType_output *result = reinterpret_cast<_DataType_output *>(result_out);

if (!input1_size || !input2_size) {
Expand All @@ -257,10 +256,12 @@ DPCTLSyclEventRef dpnp_dot_c(DPCTLSyclQueueRef q_ref,
// if both arrays are vectors
if ((input1_ndim == 1) && (input2_ndim == 1)) {
assert(input1_size == input2_size);

sycl::event event = dot(q, result, input1, input2, input1_strides[0],
input2_strides[0], input1_size);
event.wait();
return event_ref;

event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event);
return DPCTLEvent_Copy(event_ref);
}

// 1D vector
Expand Down Expand Up @@ -297,13 +298,17 @@ DPCTLSyclEventRef dpnp_dot_c(DPCTLSyclQueueRef q_ref,
size_t ext_result_ndim =
((input1_ndim == 1) || (input2_ndim == 1)) ? 2 : result_ndim;
shape_elem_type *ext_result_shape = new shape_elem_type[ext_result_ndim];
shape_elem_type *ext_result_strides = new shape_elem_type[ext_result_ndim];
if ((input1_ndim == 1) || (input2_ndim == 1)) {
ext_result_shape[0] = ext_input1_shape[0];
ext_result_shape[1] = ext_input2_shape[1];
ext_result_strides[0] = 0;
ext_result_strides[1] = result_strides[0];
}
else {
for (size_t i = 0; i < ext_result_ndim; ++i) {
ext_result_shape[i] = result_shape[i];
ext_result_strides[i] = result_strides[i];
}
}

Expand All @@ -316,80 +321,89 @@ DPCTLSyclEventRef dpnp_dot_c(DPCTLSyclQueueRef q_ref,
// check if GEMM can be executed (strides)
// TODO: rewrite the condition in general case for ndims > 2
// (looks like there are such another cases)

if (ext_input1_ndim == 2 && ext_input2_ndim == 2) {
// there is a difference of behavior with trans and sizes params in previous
// version of GEMM only new version is supported, in case of old version
// computation goes in common way
#if INTEL_MKL_VERSION >= 20210004
// is mat1 F-contiguous, C-contiguous
bool mat1_f_contig =
(((ext_input1_shape[0] == 1) || (ext_input1_strides[0] == 1)) &&
((ext_input1_shape[1] == 1) ||
(ext_input1_strides[1] == ext_input1_shape[0])));
bool mat1_c_contig =
(((ext_input1_shape[1] == 1) || (ext_input1_strides[1] == 1)) &&
((ext_input1_shape[0] == 1) ||
(ext_input1_strides[0] == ext_input1_shape[1])));
// is mat2 F-contiguous, C-contiguous
bool mat2_f_contig =
(((ext_input2_shape[0] == 1) || (ext_input2_strides[0] == 1)) &&
((ext_input2_shape[1] == 1) ||
(ext_input2_strides[1] == ext_input2_shape[0])));
bool mat2_c_contig =
(((ext_input2_shape[1] == 1) || (ext_input2_strides[1] == 1)) &&
((ext_input2_shape[0] == 1) ||
(ext_input2_strides[0] == ext_input2_shape[1])));

if ((mat1_f_contig || mat1_c_contig) &&
(mat2_f_contig || mat2_c_contig)) {
oneapi::mkl::transpose trans1 =
(mat1_f_contig && !mat1_c_contig)
? oneapi::mkl::transpose::trans
: oneapi::mkl::transpose::nontrans;
oneapi::mkl::transpose trans2 =
(mat2_f_contig && !mat2_c_contig)
? oneapi::mkl::transpose::trans
: oneapi::mkl::transpose::nontrans;
// OneMKL gemm suports only arrays contiguous on inner dimension,
// so stride for at least one dimension should be equal to 1
if ((ext_input1_strides[0] == 1 || ext_input1_strides[1] == 1) &&
(ext_input2_strides[0] == 1 || ext_input2_strides[1] == 1) &&
(ext_result_strides[0] == 1 || ext_result_strides[1] == 1))
{
const bool isRowmA =
(ext_input1_strides[1] == 1 || ext_input1_strides[0] == 0);
const bool isRowmB =
(ext_input2_strides[1] == 1 || ext_input2_strides[1] == 0);
const bool isRowmC =
(ext_result_strides[1] == 1 || ext_result_strides[0] == 0);

oneapi::mkl::transpose transA =
(isRowmA != isRowmC) ? oneapi::mkl::transpose::trans
: oneapi::mkl::transpose::nontrans;
oneapi::mkl::transpose transB =
(isRowmB != isRowmC) ? oneapi::mkl::transpose::trans
: oneapi::mkl::transpose::nontrans;

const size_t size_m = ext_input1_shape[0];
const size_t size_n = ext_input2_shape[1];
const size_t size_k = ext_input1_shape[1];

const std::int64_t lda =
trans1 == oneapi::mkl::transpose::nontrans
? ext_input1_strides[0]
: ext_input1_strides[1];
const std::int64_t ldb =
trans2 == oneapi::mkl::transpose::nontrans
? ext_input2_strides[0]
: ext_input2_strides[1];

// definition of ldc will be another for result with
// non-standard (c-contiguous) strides const std::int64_t ldc =
// result_strides[0] == 1 ? result_strides[1] :
// result_strides[0];
const std::int64_t ldc = size_n;
auto getLdaLdc = [](const bool isRown, shape_elem_type *strides,
shape_elem_type *shapes) {
if (isRown) {
return (strides[0] != 0) ? strides[0] : shapes[1];
}
return strides[1];
};

const std::int64_t lda = static_cast<std::int64_t>(
getLdaLdc(isRowmA, ext_input1_strides, ext_input1_shape));
const std::int64_t ldb = static_cast<std::int64_t>(
isRowmB ? ext_input2_strides[0] : ext_input2_strides[1]);
const std::int64_t ldc = static_cast<std::int64_t>(
getLdaLdc(isRowmC, ext_result_strides, ext_result_shape));

constexpr _DataType_output alpha = 1;
constexpr _DataType_output beta = 0;

std::stringstream error_msg;
std::int64_t info = 0;

try {
sycl::event event = mkl_blas_rm::gemm(
q, trans1, trans2, size_m, size_n, size_k,
_DataType_output(1), // alpha
input1, lda, input2, ldb,
_DataType_output(0), // beta
result, ldc);
event.wait();
delete[] ext_input1_shape;
delete[] ext_input1_strides;
delete[] ext_input2_shape;
delete[] ext_input2_strides;
delete[] ext_result_shape;

return event_ref;
if (isRowmC) {
mkl_blas_rm::gemm(q, transA, transB, size_m, size_n,
size_k, alpha, input1, lda, input2,
ldb, beta, result, ldc)
.wait();
}
else {
mkl_blas_cm::gemm(q, transA, transB, size_m, size_n,
size_k, alpha, input1, lda, input2,
ldb, beta, result, ldc)
.wait();
}
} catch (mkl_lapack::exception const &e) {
error_msg << "Unexpected MKL exception caught during "
"gemm() call:\nreason: "
<< e.what() << "\ninfo: " << e.info();
info = e.info();
} catch (const std::exception &e) {
// do nothing, proceed to general case
error_msg << "Unexpected SYCL exception caught during "
"gemm() call:\n"
<< e.what();
info = -1;
}
#endif

if (info != 0) // an unexected error occurs
{
throw std::runtime_error(error_msg.str());
}

delete[] ext_input1_shape;
delete[] ext_input1_strides;
delete[] ext_input2_shape;
delete[] ext_input2_strides;
delete[] ext_result_shape;
delete[] ext_result_strides;
return event_ref;
}
}
}
Expand Down Expand Up @@ -437,6 +451,7 @@ DPCTLSyclEventRef dpnp_dot_c(DPCTLSyclQueueRef q_ref,
delete[] ext_input2_shape;
delete[] ext_input2_strides;
delete[] ext_result_shape;
delete[] ext_result_strides;

return event_ref;
}
Expand Down
2 changes: 1 addition & 1 deletion dpnp/dpnp_algo/dpnp_algo_linearalgebra.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_2in_1out_dot_t)(c_dpctl.DPCTLSyclQueueR
const shape_elem_type *, const shape_elem_type * ,
void * , const size_t, const size_t,
const shape_elem_type *, const shape_elem_type * ,
const c_dpctl.DPCTLEventVectorRef)
const c_dpctl.DPCTLEventVectorRef) except +
ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_2in_1out_matmul_t)(c_dpctl.DPCTLSyclQueueRef,
void * , const size_t, const size_t,
const shape_elem_type *, const shape_elem_type * ,
Expand Down
9 changes: 5 additions & 4 deletions dpnp/dpnp_iface_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,17 +108,16 @@ def dot(x1, x2, out=None, **kwargs):
else (None, None)
)

# TODO: copy_when_strides=False (now it's done for faster implementation with transpose arrays)
x1_desc = dpnp.get_dpnp_descriptor(
x1,
copy_when_strides=True,
copy_when_strides=False,
copy_when_nondefault_queue=False,
alloc_usm_type=usm_type,
alloc_queue=queue,
)
x2_desc = dpnp.get_dpnp_descriptor(
x2,
copy_when_strides=True,
copy_when_strides=False,
copy_when_nondefault_queue=False,
alloc_usm_type=usm_type,
alloc_queue=queue,
Expand All @@ -131,7 +130,9 @@ def dot(x1, x2, out=None, **kwargs):
)
out_desc = (
dpnp.get_dpnp_descriptor(
out, copy_when_nondefault_queue=False
out,
copy_when_strides=False,
copy_when_nondefault_queue=False,
)
or None
)
Expand Down
Loading