Skip to content

Commit 438ded5

Browse files
Merge pull request #1323 from IntelPython/nonzero-regression
Nonzero regression
2 parents 79994c1 + 3fe2718 commit 438ded5

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -517,9 +517,10 @@ def _nonzero_impl(ary):
517517
mask_nelems, dtype=cumsum_dt, sycl_queue=exec_q, order="C"
518518
)
519519
mask_count = ti.mask_positions(ary, cumsum, sycl_queue=exec_q)
520+
indexes_dt = ti.default_device_int_type(exec_q.sycl_device)
520521
indexes = dpt.empty(
521522
(ary.ndim, mask_count),
522-
dtype=cumsum.dtype,
523+
dtype=indexes_dt,
523524
usm_type=usm_type,
524525
sycl_queue=exec_q,
525526
order="C",

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,3 +1345,15 @@ def test_nonzero_arg_validation():
13451345
dpt.nonzero(list())
13461346
with pytest.raises(ValueError):
13471347
dpt.nonzero(dpt.asarray(1))
1348+
1349+
1350+
def test_nonzero_dtype():
1351+
"See gh-1322"
1352+
get_queue_or_skip()
1353+
x = dpt.ones((3, 4))
1354+
idx, idy = dpt.nonzero(x)
1355+
# create array using device's
1356+
# default integral data type
1357+
ref = dpt.arange(8)
1358+
assert idx.dtype == ref.dtype
1359+
assert idy.dtype == ref.dtype

0 commit comments

Comments
 (0)