Skip to content

Commit 7a62edb

Browse files
committed
Moved shape and stride assertions for empty orthogonal directions
1 parent 4c753e2 commit 7a62edb

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,9 @@ py_extract(dpctl::tensor::usm_ndarray src,
385385
auto fn =
386386
masked_extract_all_slices_strided_impl_dispatch_vector[src_typeid];
387387

388+
assert(dst_shape_vec.size() == 1);
389+
assert(dst_strides_vec.size() == 1);
390+
388391
using dpctl::tensor::offset_utils::device_allocate_and_pack;
389392
const auto &ptr_size_event_tuple1 =
390393
device_allocate_and_pack<py::ssize_t>(
@@ -397,9 +400,6 @@ py_extract(dpctl::tensor::usm_ndarray src,
397400
sycl::event copy_src_shape_strides_ev =
398401
std::get<2>(ptr_size_event_tuple1);
399402

400-
assert(dst_shape_vec.size() == 1);
401-
assert(dst_strides_vec.size() == 1);
402-
403403
std::vector<sycl::event> all_deps;
404404
all_deps.reserve(depends.size() + 1);
405405
all_deps.insert(all_deps.end(), depends.begin(), depends.end());
@@ -680,6 +680,9 @@ py_place(dpctl::tensor::usm_ndarray dst,
680680
auto fn =
681681
masked_place_all_slices_strided_impl_dispatch_vector[dst_typeid];
682682

683+
assert(rhs_shape_vec.size() == 1);
684+
assert(rhs_strides_vec.size() == 1);
685+
683686
using dpctl::tensor::offset_utils::device_allocate_and_pack;
684687
const auto &ptr_size_event_tuple1 =
685688
device_allocate_and_pack<py::ssize_t>(
@@ -692,9 +695,6 @@ py_place(dpctl::tensor::usm_ndarray dst,
692695
sycl::event copy_dst_shape_strides_ev =
693696
std::get<2>(ptr_size_event_tuple1);
694697

695-
assert(rhs_shape_vec.size() == 1);
696-
assert(rhs_strides_vec.size() == 1);
697-
698698
std::vector<sycl::event> all_deps;
699699
all_deps.reserve(depends.size() + 1);
700700
all_deps.insert(all_deps.end(), depends.begin(), depends.end());

0 commit comments

Comments
 (0)