Skip to content

Commit f9ddea7

Browse files
Renamed source/reductions.*pp to sum_reductions.*pp
Added MemoryOverap check, and the array range check per FIXME note and PR review feedback. Also consolidated transfer of iteration/reduction metadata into a single operation to improve test stability on CPU and improve overall host submission overhead time.
1 parent 0290954 commit f9ddea7

File tree

4 files changed

+42
-33
lines changed

4 files changed

+42
-33
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ pybind11_add_module(${python_module_name} MODULE
4747
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/boolean_reductions.cpp
4848
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/device_support_queries.cpp
4949
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions.cpp
50-
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions.cpp
50+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sum_reductions.cpp
5151
)
5252
set(_clang_prefix "")
5353
if (WIN32)

dpctl/tensor/libtensor/source/reductions.cpp renamed to dpctl/tensor/libtensor/source/sum_reductions.cpp

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,10 @@
3535
#include <pybind11/stl.h>
3636

3737
#include "kernels/reductions.hpp"
38-
#include "reductions.hpp"
38+
#include "sum_reductions.hpp"
3939

4040
#include "simplify_iteration_space.hpp"
41+
#include "utils/memory_overlap.hpp"
4142
#include "utils/offset_utils.hpp"
4243
#include "utils/type_dispatch.hpp"
4344

@@ -135,9 +136,23 @@ std::pair<sycl::event, sycl::event> py_sum_over_axis(
135136
reduction_nelems *= static_cast<size_t>(src_shape_ptr[i]);
136137
}
137138

138-
// FIXME: check that dst and src do not overlap
139-
// check that dst is ample enough (memory span is sufficient
140-
// to accommodate all elements)
139+
// check that dst and src do not overlap
140+
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
141+
if (overlap(src, dst)) {
142+
throw py::value_error("Arrays index overlapping segments of memory");
143+
}
144+
145+
// destination must be ample enough to accomodate all elements
146+
{
147+
auto dst_offsets = dst.get_minmax_offsets();
148+
size_t range =
149+
static_cast<size_t>(dst_offsets.second - dst_offsets.first);
150+
if (range + 1 < dst_nelems) {
151+
throw py::value_error(
152+
"Destination array can not accomodate all the "
153+
"elements of source array.");
154+
}
155+
}
141156

142157
int src_typenum = src.get_typenum();
143158
int dst_typenum = dst.get_typenum();
@@ -297,38 +312,33 @@ std::pair<sycl::event, sycl::event> py_sum_over_axis(
297312
}
298313

299314
std::vector<sycl::event> host_task_events{};
300-
const auto &iter_src_dst_metadata_packing_triple_ =
301-
dpctl::tensor::offset_utils::device_allocate_and_pack<py::ssize_t>(
302-
exec_q, host_task_events, simplified_iteration_shape,
303-
simplified_iteration_src_strides, simplified_iteration_dst_strides);
304-
py::ssize_t *iter_shape_and_strides =
305-
std::get<0>(iter_src_dst_metadata_packing_triple_);
306-
if (iter_shape_and_strides == nullptr) {
315+
316+
using dpctl::tensor::offset_utils::device_allocate_and_pack;
317+
318+
const auto &arrays_metainfo_packing_triple_ =
319+
device_allocate_and_pack<py::ssize_t>(
320+
exec_q, host_task_events,
321+
// iteration metadata
322+
simplified_iteration_shape, simplified_iteration_src_strides,
323+
simplified_iteration_dst_strides,
324+
// reduction metadata
325+
simplified_reduction_shape, simplified_reduction_src_strides);
326+
py::ssize_t *temp_allocation_ptr =
327+
std::get<0>(arrays_metainfo_packing_triple_);
328+
if (temp_allocation_ptr == nullptr) {
307329
throw std::runtime_error("Unable to allocate memory on device");
308330
}
309-
const auto &copy_iter_metadata_ev =
310-
std::get<2>(iter_src_dst_metadata_packing_triple_);
331+
const auto &copy_metadata_ev = std::get<2>(arrays_metainfo_packing_triple_);
311332

312-
const auto &reduction_metadata_packing_triple_ =
313-
dpctl::tensor::offset_utils::device_allocate_and_pack<py::ssize_t>(
314-
exec_q, host_task_events, simplified_reduction_shape,
315-
simplified_reduction_src_strides);
333+
py::ssize_t *iter_shape_and_strides = temp_allocation_ptr;
316334
py::ssize_t *reduction_shape_stride =
317-
std::get<0>(reduction_metadata_packing_triple_);
318-
if (reduction_shape_stride == nullptr) {
319-
sycl::event::wait(host_task_events);
320-
sycl::free(iter_shape_and_strides, exec_q);
321-
throw std::runtime_error("Unable to allocate memory on device");
322-
}
323-
const auto &copy_reduction_metadata_ev =
324-
std::get<2>(reduction_metadata_packing_triple_);
335+
temp_allocation_ptr + 3 * simplified_iteration_shape.size();
325336

326337
std::vector<sycl::event> all_deps;
327-
all_deps.reserve(depends.size() + 2);
338+
all_deps.reserve(depends.size() + 1);
328339
all_deps.resize(depends.size());
329340
std::copy(depends.begin(), depends.end(), all_deps.begin());
330-
all_deps.push_back(copy_iter_metadata_ev);
331-
all_deps.push_back(copy_reduction_metadata_ev);
341+
all_deps.push_back(copy_metadata_ev);
332342

333343
auto comp_ev = fn(exec_q, dst_nelems, reduction_nelems, src.get_data(),
334344
dst.get_data(), iteration_nd, iter_shape_and_strides,
@@ -339,9 +349,8 @@ std::pair<sycl::event, sycl::event> py_sum_over_axis(
339349
sycl::event temp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) {
340350
cgh.depends_on(comp_ev);
341351
auto ctx = exec_q.get_context();
342-
cgh.host_task([ctx, iter_shape_and_strides, reduction_shape_stride] {
343-
sycl::free(iter_shape_and_strides, ctx);
344-
sycl::free(reduction_shape_stride, ctx);
352+
cgh.host_task([ctx, temp_allocation_ptr] {
353+
sycl::free(temp_allocation_ptr, ctx);
345354
});
346355
});
347356
host_task_events.push_back(temp_cleanup_ev);

dpctl/tensor/libtensor/source/tensor_py.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@
4444
#include "full_ctor.hpp"
4545
#include "integer_advanced_indexing.hpp"
4646
#include "linear_sequences.hpp"
47-
#include "reductions.hpp"
4847
#include "simplify_iteration_space.hpp"
48+
#include "sum_reductions.hpp"
4949
#include "triul_ctor.hpp"
5050
#include "utils/memory_overlap.hpp"
5151
#include "utils/strided_iters.hpp"

0 commit comments

Comments
 (0)