Skip to content

Commit c6ef075

Browse files
Use MemoryOverlap
1 parent 4439bf7 commit c6ef075

File tree

1 file changed

+11
-35
lines changed

1 file changed

+11
-35
lines changed

dpctl/tensor/libtensor/source/elementwise_functions.hpp

Lines changed: 11 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include <vector>
3535

3636
#include "simplify_iteration_space.hpp"
37+
#include "utils/memory_overlap.hpp"
3738
#include "utils/offset_utils.hpp"
3839
#include "utils/type_dispatch.hpp"
3940

@@ -122,23 +123,14 @@ py_unary_ufunc(dpctl::tensor::usm_ndarray src,
122123
}
123124

124125
// check memory overlap
125-
const char *src_data = src.get_data();
126-
char *dst_data = dst.get_data();
127-
128-
// check that arrays do not overlap, and concurrent copying is safe.
129-
auto src_offsets = src.get_minmax_offsets();
130-
int src_elem_size = src.get_elemsize();
131-
int dst_elem_size = dst.get_elemsize();
132-
133-
bool memory_overlap =
134-
((dst_data - src_data > src_offsets.second * src_elem_size -
135-
dst_offsets.first * dst_elem_size) &&
136-
(src_data - dst_data > dst_offsets.second * dst_elem_size -
137-
src_offsets.first * src_elem_size));
138-
if (memory_overlap) {
126+
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
127+
if (overlap(src, dst)) {
139128
throw py::value_error("Arrays index overlapping segments of memory");
140129
}
141130

131+
const char *src_data = src.get_data();
132+
char *dst_data = dst.get_data();
133+
142134
// handle contiguous inputs
143135
bool is_src_c_contig = src.is_c_contiguous();
144136
bool is_src_f_contig = src.is_f_contiguous();
@@ -378,32 +370,16 @@ std::pair<sycl::event, sycl::event> py_binary_ufunc(
378370
}
379371
}
380372

373+
// check memory overlap
374+
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
375+
if (overlap(src1, dst) || overlap(src2, dst)) {
376+
throw py::value_error("Arrays index overlapping segments of memory");
377+
}
381378
// check memory overlap
382379
const char *src1_data = src1.get_data();
383380
const char *src2_data = src2.get_data();
384381
char *dst_data = dst.get_data();
385382

386-
// check that arrays do not overlap, and concurrent copying is safe.
387-
auto src1_offsets = src1.get_minmax_offsets();
388-
int src1_elem_size = src1.get_elemsize();
389-
auto src2_offsets = src2.get_minmax_offsets();
390-
int src2_elem_size = src2.get_elemsize();
391-
int dst_elem_size = dst.get_elemsize();
392-
393-
bool memory_overlap_src1_dst =
394-
((dst_data - src1_data > src1_offsets.second * src1_elem_size -
395-
dst_offsets.first * dst_elem_size) &&
396-
(src1_data - dst_data > dst_offsets.second * dst_elem_size -
397-
src1_offsets.first * src1_elem_size));
398-
bool memory_overlap_src2_dst =
399-
((dst_data - src2_data > src2_offsets.second * src2_elem_size -
400-
dst_offsets.first * dst_elem_size) &&
401-
(src2_data - dst_data > dst_offsets.second * dst_elem_size -
402-
src2_offsets.first * src2_elem_size));
403-
if (memory_overlap_src1_dst || memory_overlap_src2_dst) {
404-
throw py::value_error("Arrays index overlapping segments of memory");
405-
}
406-
407383
// handle contiguous inputs
408384
bool is_src1_c_contig = src1.is_c_contiguous();
409385
bool is_src1_f_contig = src1.is_f_contiguous();

0 commit comments

Comments
 (0)