35
35
#include < pybind11/stl.h>
36
36
37
37
#include " kernels/reductions.hpp"
38
- #include " reductions .hpp"
38
+ #include " sum_reductions .hpp"
39
39
40
40
#include " simplify_iteration_space.hpp"
41
+ #include " utils/memory_overlap.hpp"
41
42
#include " utils/offset_utils.hpp"
42
43
#include " utils/type_dispatch.hpp"
43
44
@@ -135,9 +136,23 @@ std::pair<sycl::event, sycl::event> py_sum_over_axis(
135
136
reduction_nelems *= static_cast <size_t >(src_shape_ptr[i]);
136
137
}
137
138
138
- // FIXME: check that dst and src do not overlap
139
- // check that dst is ample enough (memory span is sufficient
140
- // to accommodate all elements)
139
+ // check that dst and src do not overlap
140
+ auto const &overlap = dpctl::tensor::overlap::MemoryOverlap ();
141
+ if (overlap (src, dst)) {
142
+ throw py::value_error (" Arrays index overlapping segments of memory" );
143
+ }
144
+
145
+ // destination must be ample enough to accomodate all elements
146
+ {
147
+ auto dst_offsets = dst.get_minmax_offsets ();
148
+ size_t range =
149
+ static_cast <size_t >(dst_offsets.second - dst_offsets.first );
150
+ if (range + 1 < dst_nelems) {
151
+ throw py::value_error (
152
+ " Destination array can not accomodate all the "
153
+ " elements of source array." );
154
+ }
155
+ }
141
156
142
157
int src_typenum = src.get_typenum ();
143
158
int dst_typenum = dst.get_typenum ();
@@ -297,38 +312,33 @@ std::pair<sycl::event, sycl::event> py_sum_over_axis(
297
312
}
298
313
299
314
std::vector<sycl::event> host_task_events{};
300
- const auto &iter_src_dst_metadata_packing_triple_ =
301
- dpctl::tensor::offset_utils::device_allocate_and_pack<py::ssize_t >(
302
- exec_q, host_task_events, simplified_iteration_shape,
303
- simplified_iteration_src_strides, simplified_iteration_dst_strides);
304
- py::ssize_t *iter_shape_and_strides =
305
- std::get<0 >(iter_src_dst_metadata_packing_triple_);
306
- if (iter_shape_and_strides == nullptr ) {
315
+
316
+ using dpctl::tensor::offset_utils::device_allocate_and_pack;
317
+
318
+ const auto &arrays_metainfo_packing_triple_ =
319
+ device_allocate_and_pack<py::ssize_t >(
320
+ exec_q, host_task_events,
321
+ // iteration metadata
322
+ simplified_iteration_shape, simplified_iteration_src_strides,
323
+ simplified_iteration_dst_strides,
324
+ // reduction metadata
325
+ simplified_reduction_shape, simplified_reduction_src_strides);
326
+ py::ssize_t *temp_allocation_ptr =
327
+ std::get<0 >(arrays_metainfo_packing_triple_);
328
+ if (temp_allocation_ptr == nullptr ) {
307
329
throw std::runtime_error (" Unable to allocate memory on device" );
308
330
}
309
- const auto ©_iter_metadata_ev =
310
- std::get<2 >(iter_src_dst_metadata_packing_triple_);
331
+ const auto ©_metadata_ev = std::get<2 >(arrays_metainfo_packing_triple_);
311
332
312
- const auto &reduction_metadata_packing_triple_ =
313
- dpctl::tensor::offset_utils::device_allocate_and_pack<py::ssize_t >(
314
- exec_q, host_task_events, simplified_reduction_shape,
315
- simplified_reduction_src_strides);
333
+ py::ssize_t *iter_shape_and_strides = temp_allocation_ptr;
316
334
py::ssize_t *reduction_shape_stride =
317
- std::get<0 >(reduction_metadata_packing_triple_);
318
- if (reduction_shape_stride == nullptr ) {
319
- sycl::event::wait (host_task_events);
320
- sycl::free (iter_shape_and_strides, exec_q);
321
- throw std::runtime_error (" Unable to allocate memory on device" );
322
- }
323
- const auto ©_reduction_metadata_ev =
324
- std::get<2 >(reduction_metadata_packing_triple_);
335
+ temp_allocation_ptr + 3 * simplified_iteration_shape.size ();
325
336
326
337
std::vector<sycl::event> all_deps;
327
- all_deps.reserve (depends.size () + 2 );
338
+ all_deps.reserve (depends.size () + 1 );
328
339
all_deps.resize (depends.size ());
329
340
std::copy (depends.begin (), depends.end (), all_deps.begin ());
330
- all_deps.push_back (copy_iter_metadata_ev);
331
- all_deps.push_back (copy_reduction_metadata_ev);
341
+ all_deps.push_back (copy_metadata_ev);
332
342
333
343
auto comp_ev = fn (exec_q, dst_nelems, reduction_nelems, src.get_data (),
334
344
dst.get_data (), iteration_nd, iter_shape_and_strides,
@@ -339,9 +349,8 @@ std::pair<sycl::event, sycl::event> py_sum_over_axis(
339
349
sycl::event temp_cleanup_ev = exec_q.submit ([&](sycl::handler &cgh) {
340
350
cgh.depends_on (comp_ev);
341
351
auto ctx = exec_q.get_context ();
342
- cgh.host_task ([ctx, iter_shape_and_strides, reduction_shape_stride] {
343
- sycl::free (iter_shape_and_strides, ctx);
344
- sycl::free (reduction_shape_stride, ctx);
352
+ cgh.host_task ([ctx, temp_allocation_ptr] {
353
+ sycl::free (temp_allocation_ptr, ctx);
345
354
});
346
355
});
347
356
host_task_events.push_back (temp_cleanup_ev);
0 commit comments