Skip to content

Commit 2cd3fb0

Browse files
antonwolfyvtavana
authored andcommitted
Merge branch 'master' into use_dpctl_bitwise_op
2 parents 7bcc426 + 1f15ccc commit 2cd3fb0

File tree

3 files changed

+119
-62
lines changed

3 files changed

+119
-62
lines changed

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 59 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,14 @@ def dpnp_add(x1, x2, out=None, order="K"):
212212
"""
213213

214214

215+
bitwise_and_func = BinaryElementwiseFunc(
216+
"bitwise_and",
217+
ti._bitwise_and_result_type,
218+
ti._bitwise_and,
219+
_bitwise_and_docstring_,
220+
)
221+
222+
215223
def dpnp_bitwise_and(x1, x2, out=None, order="K"):
216224
"""Invokes bitwise_and() from dpctl.tensor implementation for bitwise_and() function."""
217225

@@ -220,13 +228,9 @@ def dpnp_bitwise_and(x1, x2, out=None, order="K"):
220228
x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2)
221229
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
222230

223-
func = BinaryElementwiseFunc(
224-
"bitwise_and",
225-
ti._bitwise_and_result_type,
226-
ti._bitwise_and,
227-
_bitwise_and_docstring_,
231+
res_usm = bitwise_and_func(
232+
x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order
228233
)
229-
res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order)
230234
return dpnp_array._create_from_usm_ndarray(res_usm)
231235

232236

@@ -256,6 +260,14 @@ def dpnp_bitwise_and(x1, x2, out=None, order="K"):
256260
"""
257261

258262

263+
bitwise_or_func = BinaryElementwiseFunc(
264+
"bitwise_or",
265+
ti._bitwise_or_result_type,
266+
ti._bitwise_or,
267+
_bitwise_or_docstring_,
268+
)
269+
270+
259271
def dpnp_bitwise_or(x1, x2, out=None, order="K"):
260272
"""Invokes bitwise_or() from dpctl.tensor implementation for bitwise_or() function."""
261273

@@ -264,13 +276,9 @@ def dpnp_bitwise_or(x1, x2, out=None, order="K"):
264276
x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2)
265277
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
266278

267-
func = BinaryElementwiseFunc(
268-
"bitwise_or",
269-
ti._bitwise_or_result_type,
270-
ti._bitwise_or,
271-
_bitwise_or_docstring_,
279+
res_usm = bitwise_or_func(
280+
x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order
272281
)
273-
res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order)
274282
return dpnp_array._create_from_usm_ndarray(res_usm)
275283

276284

@@ -300,6 +308,14 @@ def dpnp_bitwise_or(x1, x2, out=None, order="K"):
300308
"""
301309

302310

311+
bitwise_xor_func = BinaryElementwiseFunc(
312+
"bitwise_xor",
313+
ti._bitwise_xor_result_type,
314+
ti._bitwise_xor,
315+
_bitwise_xor_docstring_,
316+
)
317+
318+
303319
def dpnp_bitwise_xor(x1, x2, out=None, order="K"):
304320
"""Invokes bitwise_xor() from dpctl.tensor implementation for bitwise_xor() function."""
305321

@@ -308,13 +324,9 @@ def dpnp_bitwise_xor(x1, x2, out=None, order="K"):
308324
x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2)
309325
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
310326

311-
func = BinaryElementwiseFunc(
312-
"bitwise_xor",
313-
ti._bitwise_xor_result_type,
314-
ti._bitwise_xor,
315-
_bitwise_xor_docstring_,
327+
res_usm = bitwise_xor_func(
328+
x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order
316329
)
317-
res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order)
318330
return dpnp_array._create_from_usm_ndarray(res_usm)
319331

320332

@@ -629,20 +641,22 @@ def dpnp_greater_equal(x1, x2, out=None, order="K"):
629641
"""
630642

631643

644+
invert_func = UnaryElementwiseFunc(
645+
"invert",
646+
ti._bitwise_invert_result_type,
647+
ti._bitwise_invert,
648+
_invert_docstring,
649+
)
650+
651+
632652
def dpnp_invert(x, out=None, order="K"):
633653
"""Invokes bitwise_invert() from dpctl.tensor implementation for invert() function."""
634654

635655
# dpctl.tensor only works with usm_ndarray or scalar
636656
x_usm = dpnp.get_usm_ndarray(x)
637657
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
638658

639-
func = UnaryElementwiseFunc(
640-
"invert",
641-
ti._bitwise_invert_result_type,
642-
ti._bitwise_invert,
643-
_invert_docstring,
644-
)
645-
res_usm = func(x_usm, out=out_usm, order=order)
659+
res_usm = invert_func(x_usm, out=out_usm, order=order)
646660
return dpnp_array._create_from_usm_ndarray(res_usm)
647661

648662

@@ -778,6 +792,14 @@ def dpnp_isnan(x, out=None, order="K"):
778792
"""
779793

780794

795+
left_shift_func = BinaryElementwiseFunc(
796+
"bitwise_leftt_shift",
797+
ti._bitwise_left_shift_result_type,
798+
ti._bitwise_left_shift,
799+
_left_shift_docstring_,
800+
)
801+
802+
781803
def dpnp_left_shift(x1, x2, out=None, order="K"):
782804
"""Invokes bitwise_left_shift() from dpctl.tensor implementation for left_shift() function."""
783805

@@ -786,13 +808,9 @@ def dpnp_left_shift(x1, x2, out=None, order="K"):
786808
x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2)
787809
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
788810

789-
func = BinaryElementwiseFunc(
790-
"bitwise_leftt_shift",
791-
ti._bitwise_left_shift_result_type,
792-
ti._bitwise_left_shift,
793-
_left_shift_docstring_,
811+
res_usm = left_shift_func(
812+
x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order
794813
)
795-
res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order)
796814
return dpnp_array._create_from_usm_ndarray(res_usm)
797815

798816

@@ -1199,6 +1217,14 @@ def dpnp_not_equal(x1, x2, out=None, order="K"):
11991217
"""
12001218

12011219

1220+
right_shift_func = BinaryElementwiseFunc(
1221+
"bitwise_right_shift",
1222+
ti._bitwise_right_shift_result_type,
1223+
ti._bitwise_right_shift,
1224+
_right_shift_docstring_,
1225+
)
1226+
1227+
12021228
def dpnp_right_shift(x1, x2, out=None, order="K"):
12031229
"""Invokes bitwise_right_shift() from dpctl.tensor implementation for right_shift() function."""
12041230

@@ -1207,13 +1233,9 @@ def dpnp_right_shift(x1, x2, out=None, order="K"):
12071233
x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2)
12081234
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
12091235

1210-
func = BinaryElementwiseFunc(
1211-
"bitwise_right_shift",
1212-
ti._bitwise_right_shift_result_type,
1213-
ti._bitwise_right_shift,
1214-
_right_shift_docstring_,
1236+
res_usm = right_shift_func(
1237+
x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order
12151238
)
1216-
res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order)
12171239
return dpnp_array._create_from_usm_ndarray(res_usm)
12181240

12191241

tests/test_extensions.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -193,16 +193,16 @@ def test_mean_over_axis_0_unsupported_out_types(
193193
input = dpt.empty((height, width), dtype=input_type, device=device)
194194
output = dpt.empty(width, dtype=output_type, device=device)
195195

196-
if func(input, output):
197-
print(output_type)
198196
assert func(input, output) is None
199197

200198

201199
@pytest.mark.parametrize(
202200
"func, device, input_type, output_type",
203201
product(mean_sum, all_devices, [dpt.float32], [dpt.float32]),
204202
)
205-
def test_mean_over_axis_0_f_contig_input(func, device, input_type, output_type):
203+
def test_mean_sum_over_axis_0_f_contig_input(
204+
func, device, input_type, output_type
205+
):
206206
skip_unsupported(device, input_type)
207207
skip_unsupported(device, output_type)
208208

@@ -212,16 +212,14 @@ def test_mean_over_axis_0_f_contig_input(func, device, input_type, output_type):
212212
input = dpt.empty((height, width), dtype=input_type, device=device).T
213213
output = dpt.empty(width, dtype=output_type, device=device)
214214

215-
if func(input, output):
216-
print(output_type)
217215
assert func(input, output) is None
218216

219217

220218
@pytest.mark.parametrize(
221219
"func, device, input_type, output_type",
222220
product(mean_sum, all_devices, [dpt.float32], [dpt.float32]),
223221
)
224-
def test_mean_over_axis_0_f_contig_output(
222+
def test_mean_sum_over_axis_0_f_contig_output(
225223
func, device, input_type, output_type
226224
):
227225
skip_unsupported(device, input_type)
@@ -230,9 +228,25 @@ def test_mean_over_axis_0_f_contig_output(
230228
height = 1
231229
width = 10
232230

233-
input = dpt.empty((height, 10), dtype=input_type, device=device)
234-
output = dpt.empty(20, dtype=output_type, device=device)[::2]
231+
input = dpt.empty((height, width), dtype=input_type, device=device)
232+
output = dpt.empty(width * 2, dtype=output_type, device=device)[::2]
233+
234+
assert func(input, output) is None
235+
236+
237+
@pytest.mark.parametrize(
238+
"func, device, input_type, output_type",
239+
product(mean_sum, all_devices, [dpt.float32], [dpt.float32, dpt.float64]),
240+
)
241+
def test_mean_sum_over_axis_0_big_output(func, device, input_type, output_type):
242+
skip_unsupported(device, input_type)
243+
skip_unsupported(device, output_type)
244+
245+
local_mem_size = device.local_mem_size
246+
height = 1
247+
width = 1 + local_mem_size // output_type.itemsize
248+
249+
input = dpt.empty((height, width), dtype=input_type, device=device)
250+
output = dpt.empty(width, dtype=output_type, device=device)
235251

236-
if func(input, output):
237-
print(output_type)
238252
assert func(input, output) is None

tests/test_mathematical.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from itertools import permutations
2+
13
import numpy
24
import pytest
35
from numpy.testing import (
@@ -1056,23 +1058,42 @@ def test_sum_empty_out(dtype):
10561058

10571059

10581060
@pytest.mark.parametrize(
1059-
"shape", [(), (1, 2, 3), (1, 0, 2), (10), (3, 3, 3), (5, 5), (0, 6)]
1060-
)
1061-
@pytest.mark.parametrize(
1062-
"dtype_in", get_all_dtypes(no_complex=True, no_bool=True)
1063-
)
1064-
@pytest.mark.parametrize(
1065-
"dtype_out", get_all_dtypes(no_complex=True, no_bool=True)
1061+
"shape",
1062+
[
1063+
(),
1064+
(1, 2, 3),
1065+
(1, 0, 2),
1066+
(10,),
1067+
(3, 3, 3),
1068+
(5, 5),
1069+
(0, 6),
1070+
(10, 1),
1071+
(1, 10),
1072+
],
10661073
)
1067-
def test_sum(shape, dtype_in, dtype_out):
1068-
a_np = numpy.ones(shape, dtype=dtype_in)
1069-
a = dpnp.ones(shape, dtype=dtype_in)
1070-
axes = [None, 0, 1, 2]
1074+
@pytest.mark.parametrize("dtype_in", get_all_dtypes())
1075+
@pytest.mark.parametrize("dtype_out", get_all_dtypes())
1076+
@pytest.mark.parametrize("transpose", [True, False])
1077+
@pytest.mark.parametrize("keepdims", [True, False])
1078+
def test_sum(shape, dtype_in, dtype_out, transpose, keepdims):
1079+
size = numpy.prod(shape)
1080+
a_np = numpy.arange(size).astype(dtype_in).reshape(shape)
1081+
a = dpnp.asarray(a_np)
1082+
1083+
if transpose:
1084+
a_np = a_np.T
1085+
a = a.T
1086+
1087+
axes_range = list(numpy.arange(len(shape)))
1088+
axes = [None]
1089+
axes += axes_range
1090+
axes += permutations(axes_range, 2)
1091+
axes.append(tuple(axes_range))
1092+
10711093
for axis in axes:
1072-
if axis is None or axis < a.ndim:
1073-
numpy_res = a_np.sum(axis=axis, dtype=dtype_out)
1074-
dpnp_res = a.sum(axis=axis, dtype=dtype_out)
1075-
assert_array_equal(numpy_res, dpnp_res.asnumpy())
1094+
numpy_res = a_np.sum(axis=axis, dtype=dtype_out, keepdims=keepdims)
1095+
dpnp_res = a.sum(axis=axis, dtype=dtype_out, keepdims=keepdims)
1096+
assert_array_equal(numpy_res, dpnp_res.asnumpy())
10761097

10771098

10781099
class TestMean:

0 commit comments

Comments
 (0)