Skip to content

Commit 4c753e2

Browse files
committed
Masked shape and stride assertions moved to before device_allocate_and_pack call
1 parent 4f16b93 commit 4c753e2

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,9 @@ py_extract(dpctl::tensor::usm_ndarray src,
469469
simplified_ortho_shape, simplified_ortho_src_strides,
470470
simplified_ortho_dst_strides, ortho_src_offset, ortho_dst_offset);
471471

472+
assert(masked_dst_shape.size() == 1);
473+
assert(masked_dst_strides.size() == 1);
474+
472475
using dpctl::tensor::offset_utils::device_allocate_and_pack;
473476
const auto &ptr_size_event_tuple1 =
474477
device_allocate_and_pack<py::ssize_t>(
@@ -485,9 +488,6 @@ py_extract(dpctl::tensor::usm_ndarray src,
485488
py::ssize_t *packed_masked_src_shape_strides =
486489
packed_shapes_strides + (3 * ortho_nd);
487490

488-
assert(masked_dst_shape.size() == 1);
489-
assert(masked_dst_strides.size() == 1);
490-
491491
std::vector<sycl::event> all_deps;
492492
all_deps.reserve(depends.size() + 1);
493493
all_deps.insert(all_deps.end(), depends.begin(), depends.end());
@@ -763,6 +763,9 @@ py_place(dpctl::tensor::usm_ndarray dst,
763763
simplified_ortho_shape, simplified_ortho_dst_strides,
764764
simplified_ortho_rhs_strides, ortho_dst_offset, ortho_rhs_offset);
765765

766+
assert(masked_rhs_shape.size() == 1);
767+
assert(masked_rhs_strides.size() == 1);
768+
766769
using dpctl::tensor::offset_utils::device_allocate_and_pack;
767770
const auto &ptr_size_event_tuple1 =
768771
device_allocate_and_pack<py::ssize_t>(
@@ -779,9 +782,6 @@ py_place(dpctl::tensor::usm_ndarray dst,
779782
py::ssize_t *packed_masked_dst_shape_strides =
780783
packed_shapes_strides + (3 * ortho_nd);
781784

782-
assert(masked_rhs_shape.size() == 1);
783-
assert(masked_rhs_strides.size() == 1);
784-
785785
std::vector<sycl::event> all_deps;
786786
all_deps.reserve(depends.size() + 1);
787787
all_deps.insert(all_deps.end(), depends.begin(), depends.end());

0 commit comments

Comments
 (0)