Skip to content

Commit 7d1904c

Browse files
Merge pull request #1016 from IntelPython/convenience-signature-for-util-function
Add queues_are_compatible signature for list of usm_ndarray instances
2 parents da56dce + b02745d commit 7d1904c

File tree

8 files changed

+26
-24
lines changed

8 files changed

+26
-24
lines changed

dpctl/apis/include/dpctl4pybind11.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -987,6 +987,8 @@ sycl::event keep_args_alive(sycl::queue q,
987987
return host_task_ev;
988988
}
989989

990+
/*! @brief Check if all allocation queues are the same as the
991+
execution queue */
990992
template <std::size_t num>
991993
bool queues_are_compatible(sycl::queue exec_q,
992994
const sycl::queue (&alloc_qs)[num])
@@ -1000,6 +1002,21 @@ bool queues_are_compatible(sycl::queue exec_q,
10001002
return true;
10011003
}
10021004

1005+
/*! @brief Check if all allocation queues of usm_ndarays are the same as
1006+
the execution queue */
1007+
template <std::size_t num>
1008+
bool queues_are_compatible(sycl::queue exec_q,
1009+
const ::dpctl::tensor::usm_ndarray (&arrs)[num])
1010+
{
1011+
for (std::size_t i = 0; i < num; ++i) {
1012+
1013+
if (exec_q != arrs[i].get_queue()) {
1014+
return false;
1015+
}
1016+
}
1017+
return true;
1018+
}
1019+
10031020
} // end namespace utils
10041021

10051022
} // end namespace dpctl

dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,7 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
160160
}
161161

162162
// check compatibility of execution queue and allocation queue
163-
sycl::queue src_q = src.get_queue();
164-
sycl::queue dst_q = dst.get_queue();
165-
166-
if (!dpctl::utils::queues_are_compatible(exec_q, {src_q, dst_q})) {
163+
if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) {
167164
throw py::value_error(
168165
"Execution queue is not compatible with allocation queues");
169166
}

dpctl/tensor/libtensor/source/copy_for_reshape.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,7 @@ copy_usm_ndarray_for_reshape(dpctl::tensor::usm_ndarray src,
101101
}
102102

103103
// check same contexts
104-
sycl::queue src_q = src.get_queue();
105-
sycl::queue dst_q = dst.get_queue();
106-
107-
if (!dpctl::utils::queues_are_compatible(exec_q, {src_q, dst_q})) {
104+
if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) {
108105
throw py::value_error(
109106
"Execution queue is not compatible with allocation queues");
110107
}

dpctl/tensor/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,7 @@ void copy_numpy_ndarray_into_usm_ndarray(
101101
}
102102
}
103103

104-
sycl::queue dst_q = dst.get_queue();
105-
106-
if (!dpctl::utils::queues_are_compatible(exec_q, {dst_q})) {
104+
if (!dpctl::utils::queues_are_compatible(exec_q, {dst})) {
107105
throw py::value_error("Execution queue is not compatible with the "
108106
"allocation queue");
109107
}

dpctl/tensor/libtensor/source/eye_ctor.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,7 @@ usm_ndarray_eye(py::ssize_t k,
6161
"usm_ndarray_eye: Expecting 2D array to populate");
6262
}
6363

64-
sycl::queue dst_q = dst.get_queue();
65-
if (!dpctl::utils::queues_are_compatible(exec_q, {dst_q})) {
64+
if (!dpctl::utils::queues_are_compatible(exec_q, {dst})) {
6665
throw py::value_error("Execution queue is not compatible with the "
6766
"allocation queue");
6867
}

dpctl/tensor/libtensor/source/full_ctor.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,7 @@ usm_ndarray_full(py::object py_value,
6969
return std::make_pair(sycl::event(), sycl::event());
7070
}
7171

72-
sycl::queue dst_q = dst.get_queue();
73-
if (!dpctl::utils::queues_are_compatible(exec_q, {dst_q})) {
72+
if (!dpctl::utils::queues_are_compatible(exec_q, {dst})) {
7473
throw py::value_error(
7574
"Execution queue is not compatible with the allocation queue");
7675
}

dpctl/tensor/libtensor/source/linear_sequences.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,7 @@ usm_ndarray_linear_sequence_step(py::object start,
7878
"usm_ndarray_linspace: Non-contiguous arrays are not supported");
7979
}
8080

81-
sycl::queue dst_q = dst.get_queue();
82-
if (!dpctl::utils::queues_are_compatible(exec_q, {dst_q})) {
81+
if (!dpctl::utils::queues_are_compatible(exec_q, {dst})) {
8382
throw py::value_error(
8483
"Execution queue is not compatible with the allocation queue");
8584
}
@@ -127,8 +126,7 @@ usm_ndarray_linear_sequence_affine(py::object start,
127126
"usm_ndarray_linspace: Non-contiguous arrays are not supported");
128127
}
129128

130-
sycl::queue dst_q = dst.get_queue();
131-
if (!dpctl::utils::queues_are_compatible(exec_q, {dst_q})) {
129+
if (!dpctl::utils::queues_are_compatible(exec_q, {dst})) {
132130
throw py::value_error(
133131
"Execution queue context is not the same as allocation context");
134132
}

dpctl/tensor/libtensor/source/triul_ctor.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,8 @@ usm_ndarray_triul(sycl::queue exec_q,
121121
throw py::value_error("Array dtype are not the same.");
122122
}
123123

124-
// check same contexts
125-
sycl::queue src_q = src.get_queue();
126-
sycl::queue dst_q = dst.get_queue();
127-
128-
if (!dpctl::utils::queues_are_compatible(exec_q, {src_q, dst_q})) {
124+
// check same queues
125+
if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) {
129126
throw py::value_error(
130127
"Execution queue context is not the same as allocation contexts");
131128
}

0 commit comments

Comments
 (0)