|
34 | 34 | #include <vector>
|
35 | 35 |
|
36 | 36 | #include "simplify_iteration_space.hpp"
|
| 37 | +#include "utils/memory_overlap.hpp" |
37 | 38 | #include "utils/offset_utils.hpp"
|
38 | 39 | #include "utils/type_dispatch.hpp"
|
39 | 40 |
|
@@ -122,23 +123,14 @@ py_unary_ufunc(dpctl::tensor::usm_ndarray src,
|
122 | 123 | }
|
123 | 124 |
|
124 | 125 | // check memory overlap
|
125 |
| - const char *src_data = src.get_data(); |
126 |
| - char *dst_data = dst.get_data(); |
127 |
| - |
128 |
| - // check that arrays do not overlap, and concurrent copying is safe. |
129 |
| - auto src_offsets = src.get_minmax_offsets(); |
130 |
| - int src_elem_size = src.get_elemsize(); |
131 |
| - int dst_elem_size = dst.get_elemsize(); |
132 |
| - |
133 |
| - bool memory_overlap = |
134 |
| - ((dst_data - src_data > src_offsets.second * src_elem_size - |
135 |
| - dst_offsets.first * dst_elem_size) && |
136 |
| - (src_data - dst_data > dst_offsets.second * dst_elem_size - |
137 |
| - src_offsets.first * src_elem_size)); |
138 |
| - if (memory_overlap) { |
| 126 | + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); |
| 127 | + if (overlap(src, dst)) { |
139 | 128 | throw py::value_error("Arrays index overlapping segments of memory");
|
140 | 129 | }
|
141 | 130 |
|
| 131 | + const char *src_data = src.get_data(); |
| 132 | + char *dst_data = dst.get_data(); |
| 133 | + |
142 | 134 | // handle contiguous inputs
|
143 | 135 | bool is_src_c_contig = src.is_c_contiguous();
|
144 | 136 | bool is_src_f_contig = src.is_f_contiguous();
|
@@ -378,32 +370,16 @@ std::pair<sycl::event, sycl::event> py_binary_ufunc(
|
378 | 370 | }
|
379 | 371 | }
|
380 | 372 |
|
| 373 | + // check memory overlap |
| 374 | + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); |
| 375 | + if (overlap(src1, dst) || overlap(src2, dst)) { |
| 376 | + throw py::value_error("Arrays index overlapping segments of memory"); |
| 377 | + } |
381 | 378 | // check memory overlap
|
382 | 379 | const char *src1_data = src1.get_data();
|
383 | 380 | const char *src2_data = src2.get_data();
|
384 | 381 | char *dst_data = dst.get_data();
|
385 | 382 |
|
386 |
| - // check that arrays do not overlap, and concurrent copying is safe. |
387 |
| - auto src1_offsets = src1.get_minmax_offsets(); |
388 |
| - int src1_elem_size = src1.get_elemsize(); |
389 |
| - auto src2_offsets = src2.get_minmax_offsets(); |
390 |
| - int src2_elem_size = src2.get_elemsize(); |
391 |
| - int dst_elem_size = dst.get_elemsize(); |
392 |
| - |
393 |
| - bool memory_overlap_src1_dst = |
394 |
| - ((dst_data - src1_data > src1_offsets.second * src1_elem_size - |
395 |
| - dst_offsets.first * dst_elem_size) && |
396 |
| - (src1_data - dst_data > dst_offsets.second * dst_elem_size - |
397 |
| - src1_offsets.first * src1_elem_size)); |
398 |
| - bool memory_overlap_src2_dst = |
399 |
| - ((dst_data - src2_data > src2_offsets.second * src2_elem_size - |
400 |
| - dst_offsets.first * dst_elem_size) && |
401 |
| - (src2_data - dst_data > dst_offsets.second * dst_elem_size - |
402 |
| - src2_offsets.first * src2_elem_size)); |
403 |
| - if (memory_overlap_src1_dst || memory_overlap_src2_dst) { |
404 |
| - throw py::value_error("Arrays index overlapping segments of memory"); |
405 |
| - } |
406 |
| - |
407 | 383 | // handle contiguous inputs
|
408 | 384 | bool is_src1_c_contig = src1.is_c_contiguous();
|
409 | 385 | bool is_src1_f_contig = src1.is_f_contiguous();
|
|
0 commit comments