Skip to content

Commit f35b34d

Browse files
authored
Merge pull request #1238 from vlad-perevezentsev/fix_dpctl_sum
Fix dpctl.tensor.sum for support zero-dimensional array
2 parents 47466ee + 9579d8a commit f35b34d

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

dpctl/tensor/_reduction.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,12 @@ def sum(arr, axis=None, dtype=None, keepdims=False):
122122
res_dt = _to_device_supported_dtype(res_dt, q.sycl_device)
123123

124124
res_usm_type = arr.usm_type
125-
if red_nd == 0:
125+
if arr.size == 0:
126126
return dpt.zeros(
127127
res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
128128
)
129+
if red_nd == 0:
130+
return dpt.astype(arr, res_dt, copy=False)
129131

130132
host_tasks_list = []
131133
if ti._sum_over_axis_dtype_supported(inp_dt, res_dt, res_usm_type, q):

dpctl/tests/test_tensor_sum.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,30 @@ def test_sum_keepdims():
106106
assert isinstance(s, dpt.usm_ndarray)
107107
assert s.shape == (3, 1, 1, 6, 1)
108108
assert (dpt.asnumpy(s) == np.full(s.shape, 4 * 5 * 7)).all()
109+
110+
111+
def test_sum_scalar():
112+
get_queue_or_skip()
113+
114+
m = dpt.ones(())
115+
s = dpt.sum(m)
116+
117+
assert isinstance(s, dpt.usm_ndarray)
118+
assert m.sycl_queue == s.sycl_queue
119+
assert s.shape == ()
120+
assert dpt.asnumpy(s) == np.full((), 1)
121+
122+
123+
@pytest.mark.parametrize("arg_dtype", _all_dtypes)
124+
@pytest.mark.parametrize("out_dtype", _all_dtypes[1:])
125+
def test_sum_arg_out_dtype_scalar(arg_dtype, out_dtype):
126+
q = get_queue_or_skip()
127+
skip_if_dtype_not_supported(arg_dtype, q)
128+
skip_if_dtype_not_supported(out_dtype, q)
129+
130+
m = dpt.ones((), dtype=arg_dtype)
131+
r = dpt.sum(m, dtype=out_dtype)
132+
133+
assert isinstance(r, dpt.usm_ndarray)
134+
assert r.dtype == dpt.dtype(out_dtype)
135+
assert dpt.asnumpy(r) == 1

0 commit comments

Comments
 (0)