Skip to content

Boolean indexing performance #1300

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions dpctl/tensor/_copy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
":class:`dpctl.tensor.usm_ndarray`."
)

int32_t_max = 2147483648


def _copy_to_numpy(ary):
if not isinstance(ary, dpt.usm_ndarray):
Expand Down Expand Up @@ -482,7 +484,8 @@ def _extract_impl(ary, ary_mask, axis=0):
"Parameter p is inconsistent with input array dimensions"
)
mask_nelems = ary_mask.size
cumsum = dpt.empty(mask_nelems, dtype=dpt.int64, device=ary_mask.device)
cumsum_dt = dpt.int32 if mask_nelems < int32_t_max else dpt.int64
cumsum = dpt.empty(mask_nelems, dtype=cumsum_dt, device=ary_mask.device)
exec_q = cumsum.sycl_queue
mask_count = ti.mask_positions(ary_mask, cumsum, sycl_queue=exec_q)
dst_shape = ary.shape[:pp] + (mask_count,) + ary.shape[pp + mask_nd :]
Expand All @@ -509,8 +512,9 @@ def _nonzero_impl(ary):
exec_q = ary.sycl_queue
usm_type = ary.usm_type
mask_nelems = ary.size
cumsum_dt = dpt.int32 if mask_nelems < int32_t_max else dpt.int64
cumsum = dpt.empty(
mask_nelems, dtype=dpt.int64, sycl_queue=exec_q, order="C"
mask_nelems, dtype=cumsum_dt, sycl_queue=exec_q, order="C"
)
mask_count = ti.mask_positions(ary, cumsum, sycl_queue=exec_q)
indexes = dpt.empty(
Expand Down Expand Up @@ -604,7 +608,8 @@ def _place_impl(ary, ary_mask, vals, axis=0):
"Parameter p is inconsistent with input array dimensions"
)
mask_nelems = ary_mask.size
cumsum = dpt.empty(mask_nelems, dtype=dpt.int64, device=ary_mask.device)
cumsum_dt = dpt.int32 if mask_nelems < int32_t_max else dpt.int64
cumsum = dpt.empty(mask_nelems, dtype=cumsum_dt, device=ary_mask.device)
exec_q = cumsum.sycl_queue
mask_count = ti.mask_positions(ary_mask, cumsum, sycl_queue=exec_q)
expected_vals_shape = (
Expand Down
131 changes: 102 additions & 29 deletions dpctl/tensor/libtensor/include/kernels/boolean_advanced_indexing.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,15 +393,24 @@ size_t mask_positions_contig_impl(sycl::queue q,
throw std::bad_alloc();
}
sycl::event copy_e =
q.copy<std::int64_t>(last_elem, last_elem_host_usm, 1, {comp_ev});
q.copy<cumsumT>(last_elem, last_elem_host_usm, 1, {comp_ev});
copy_e.wait();
size_t return_val = static_cast<size_t>(*last_elem_host_usm);
sycl::free(last_elem_host_usm, q);

return return_val;
}

template <typename fnT, typename T> struct MaskPositionsContigFactory
template <typename fnT, typename T> struct MaskPositionsContigFactoryForInt32
{
fnT get()
{
fnT fn = mask_positions_contig_impl<T, std::int32_t>;
return fn;
}
};

template <typename fnT, typename T> struct MaskPositionsContigFactoryForInt64
{
fnT get()
{
Expand Down Expand Up @@ -452,15 +461,24 @@ size_t mask_positions_strided_impl(sycl::queue q,
throw std::bad_alloc();
}
sycl::event copy_e =
q.copy<std::int64_t>(last_elem, last_elem_host_usm, 1, {comp_ev});
q.copy<cumsumT>(last_elem, last_elem_host_usm, 1, {comp_ev});
copy_e.wait();
size_t return_val = static_cast<size_t>(*last_elem_host_usm);
sycl::free(last_elem_host_usm, q);

return return_val;
}

template <typename fnT, typename T> struct MaskPositionsStridedFactory
template <typename fnT, typename T> struct MaskPositionsStridedFactoryForInt32
{
fnT get()
{
fnT fn = mask_positions_strided_impl<T, std::int32_t>;
return fn;
}
};

template <typename fnT, typename T> struct MaskPositionsStridedFactoryForInt64
{
fnT get()
{
Expand Down Expand Up @@ -611,7 +629,18 @@ sycl::event masked_extract_some_slices_strided_impl(
return comp_ev;
}

template <typename fnT, typename T> struct MaskExtractAllSlicesStridedFactory
template <typename fnT, typename T>
struct MaskExtractAllSlicesStridedFactoryForInt32
{
fnT get()
{
fnT fn = masked_extract_all_slices_strided_impl<T, std::int32_t>;
return fn;
}
};

template <typename fnT, typename T>
struct MaskExtractAllSlicesStridedFactoryForInt64
{
fnT get()
{
Expand All @@ -620,7 +649,18 @@ template <typename fnT, typename T> struct MaskExtractAllSlicesStridedFactory
}
};

template <typename fnT, typename T> struct MaskExtractSomeSlicesStridedFactory
template <typename fnT, typename T>
struct MaskExtractSomeSlicesStridedFactoryForInt32
{
fnT get()
{
fnT fn = masked_extract_some_slices_strided_impl<T, std::int32_t>;
return fn;
}
};

template <typename fnT, typename T>
struct MaskExtractSomeSlicesStridedFactoryForInt64
{
fnT get()
{
Expand Down Expand Up @@ -763,7 +803,18 @@ sycl::event masked_place_some_slices_strided_impl(
return comp_ev;
}

template <typename fnT, typename T> struct MaskPlaceAllSlicesStridedFactory
template <typename fnT, typename T>
struct MaskPlaceAllSlicesStridedFactoryForInt32
{
fnT get()
{
fnT fn = masked_place_all_slices_strided_impl<T, std::int32_t>;
return fn;
}
};

template <typename fnT, typename T>
struct MaskPlaceAllSlicesStridedFactoryForInt64
{
fnT get()
{
Expand All @@ -772,7 +823,18 @@ template <typename fnT, typename T> struct MaskPlaceAllSlicesStridedFactory
}
};

template <typename fnT, typename T> struct MaskPlaceSomeSlicesStridedFactory
template <typename fnT, typename T>
struct MaskPlaceSomeSlicesStridedFactoryForInt32
{
fnT get()
{
fnT fn = masked_place_some_slices_strided_impl<T, std::int32_t>;
return fn;
}
};

template <typename fnT, typename T>
struct MaskPlaceSomeSlicesStridedFactoryForInt64
{
fnT get()
{
Expand All @@ -783,7 +845,17 @@ template <typename fnT, typename T> struct MaskPlaceSomeSlicesStridedFactory

// Non-zero

class non_zero_indexes_krn;
template <typename T1, typename T2> class non_zero_indexes_krn;

typedef sycl::event (*non_zero_indexes_fn_ptr_t)(
sycl::queue,
py::ssize_t,
py::ssize_t,
int,
const char *,
char *,
const py::ssize_t *,
std::vector<sycl::event> const &);

template <typename indT1, typename indT2>
sycl::event non_zero_indexes_impl(sycl::queue exec_q,
Expand All @@ -800,28 +872,29 @@ sycl::event non_zero_indexes_impl(sycl::queue exec_q,

sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);
cgh.parallel_for<class non_zero_indexes_krn>(
sycl::range<1>(iter_size), [=](sycl::id<1> idx) {
auto i = idx[0];

auto cs_curr_val = cumsum_data[i] - 1;
auto cs_prev_val = (i > 0) ? cumsum_data[i - 1] : indT1(0);
bool cond = (cs_curr_val == cs_prev_val);

py::ssize_t i_ = static_cast<py::ssize_t>(i);
for (int dim = nd; --dim > 0;) {
auto sd = mask_shape[dim];
py::ssize_t q = i_ / sd;
py::ssize_t r = (i_ - q * sd);
if (cond) {
indexes_data[cs_curr_val + dim * nz_elems] =
static_cast<indT2>(r);
}
i_ = q;
}
cgh.parallel_for<class non_zero_indexes_krn<indT1, indT2>>(
sycl::range<1>(iter_size), [=](sycl::id<1> idx)
{
auto i = idx[0];

auto cs_curr_val = cumsum_data[i] - 1;
auto cs_prev_val = (i > 0) ? cumsum_data[i - 1] : indT1(0);
bool cond = (cs_curr_val == cs_prev_val);

py::ssize_t i_ = static_cast<py::ssize_t>(i);
for (int dim = nd; --dim > 0;) {
auto sd = mask_shape[dim];
py::ssize_t q = i_ / sd;
py::ssize_t r = (i_ - q * sd);
if (cond) {
indexes_data[cs_curr_val] = static_cast<indT2>(i_);
indexes_data[cs_curr_val + dim * nz_elems] =
static_cast<indT2>(r);
}
i_ = q;
}
if (cond) {
indexes_data[cs_curr_val] = static_cast<indT2>(i_);
}
});
});

Expand Down
Loading