Skip to content

Commit

Permalink
Merge branch 'master' into use_dpctl_bitwise_op
Browse files Browse the repository at this point in the history
  • Loading branch information
antonwolfy authored and vtavana committed Aug 14, 2023
2 parents 7bcc426 + 1f15ccc commit 77bb04d
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 63 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ repos:
rev: 23.7.0
hooks:
- id: black
args: ["--check", "--diff", "--color"]
args: ["--check", "--diff","--color"]
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
Expand Down
96 changes: 59 additions & 37 deletions dpnp/dpnp_algo/dpnp_elementwise_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,14 @@ def dpnp_add(x1, x2, out=None, order="K"):
"""


bitwise_and_func = BinaryElementwiseFunc(
"bitwise_and",
ti._bitwise_and_result_type,
ti._bitwise_and,
_bitwise_and_docstring_,
)


def dpnp_bitwise_and(x1, x2, out=None, order="K"):
"""Invokes bitwise_and() from dpctl.tensor implementation for bitwise_and() function."""

Expand All @@ -220,13 +228,9 @@ def dpnp_bitwise_and(x1, x2, out=None, order="K"):
x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2)
out_usm = None if out is None else dpnp.get_usm_ndarray(out)

func = BinaryElementwiseFunc(
"bitwise_and",
ti._bitwise_and_result_type,
ti._bitwise_and,
_bitwise_and_docstring_,
res_usm = bitwise_and_func(
x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order
)
res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order)
return dpnp_array._create_from_usm_ndarray(res_usm)


Expand Down Expand Up @@ -256,6 +260,14 @@ def dpnp_bitwise_and(x1, x2, out=None, order="K"):
"""


bitwise_or_func = BinaryElementwiseFunc(
"bitwise_or",
ti._bitwise_or_result_type,
ti._bitwise_or,
_bitwise_or_docstring_,
)


def dpnp_bitwise_or(x1, x2, out=None, order="K"):
"""Invokes bitwise_or() from dpctl.tensor implementation for bitwise_or() function."""

Expand All @@ -264,13 +276,9 @@ def dpnp_bitwise_or(x1, x2, out=None, order="K"):
x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2)
out_usm = None if out is None else dpnp.get_usm_ndarray(out)

func = BinaryElementwiseFunc(
"bitwise_or",
ti._bitwise_or_result_type,
ti._bitwise_or,
_bitwise_or_docstring_,
res_usm = bitwise_or_func(
x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order
)
res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order)
return dpnp_array._create_from_usm_ndarray(res_usm)


Expand Down Expand Up @@ -300,6 +308,14 @@ def dpnp_bitwise_or(x1, x2, out=None, order="K"):
"""


bitwise_xor_func = BinaryElementwiseFunc(
"bitwise_xor",
ti._bitwise_xor_result_type,
ti._bitwise_xor,
_bitwise_xor_docstring_,
)


def dpnp_bitwise_xor(x1, x2, out=None, order="K"):
"""Invokes bitwise_xor() from dpctl.tensor implementation for bitwise_xor() function."""

Expand All @@ -308,13 +324,9 @@ def dpnp_bitwise_xor(x1, x2, out=None, order="K"):
x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2)
out_usm = None if out is None else dpnp.get_usm_ndarray(out)

func = BinaryElementwiseFunc(
"bitwise_xor",
ti._bitwise_xor_result_type,
ti._bitwise_xor,
_bitwise_xor_docstring_,
res_usm = bitwise_xor_func(
x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order
)
res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order)
return dpnp_array._create_from_usm_ndarray(res_usm)


Expand Down Expand Up @@ -629,20 +641,22 @@ def dpnp_greater_equal(x1, x2, out=None, order="K"):
"""


invert_func = UnaryElementwiseFunc(
"invert",
ti._bitwise_invert_result_type,
ti._bitwise_invert,
_invert_docstring,
)


def dpnp_invert(x, out=None, order="K"):
"""Invokes bitwise_invert() from dpctl.tensor implementation for invert() function."""

# dpctl.tensor only works with usm_ndarray or scalar
x_usm = dpnp.get_usm_ndarray(x)
out_usm = None if out is None else dpnp.get_usm_ndarray(out)

func = UnaryElementwiseFunc(
"invert",
ti._bitwise_invert_result_type,
ti._bitwise_invert,
_invert_docstring,
)
res_usm = func(x_usm, out=out_usm, order=order)
res_usm = invert_func(x_usm, out=out_usm, order=order)
return dpnp_array._create_from_usm_ndarray(res_usm)


Expand Down Expand Up @@ -778,6 +792,14 @@ def dpnp_isnan(x, out=None, order="K"):
"""


left_shift_func = BinaryElementwiseFunc(
"bitwise_leftt_shift",
ti._bitwise_left_shift_result_type,
ti._bitwise_left_shift,
_left_shift_docstring_,
)


def dpnp_left_shift(x1, x2, out=None, order="K"):
"""Invokes bitwise_left_shift() from dpctl.tensor implementation for left_shift() function."""

Expand All @@ -786,13 +808,9 @@ def dpnp_left_shift(x1, x2, out=None, order="K"):
x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2)
out_usm = None if out is None else dpnp.get_usm_ndarray(out)

func = BinaryElementwiseFunc(
"bitwise_leftt_shift",
ti._bitwise_left_shift_result_type,
ti._bitwise_left_shift,
_left_shift_docstring_,
res_usm = left_shift_func(
x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order
)
res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order)
return dpnp_array._create_from_usm_ndarray(res_usm)


Expand Down Expand Up @@ -1199,6 +1217,14 @@ def dpnp_not_equal(x1, x2, out=None, order="K"):
"""


right_shift_func = BinaryElementwiseFunc(
"bitwise_right_shift",
ti._bitwise_right_shift_result_type,
ti._bitwise_right_shift,
_right_shift_docstring_,
)


def dpnp_right_shift(x1, x2, out=None, order="K"):
"""Invokes bitwise_right_shift() from dpctl.tensor implementation for right_shift() function."""

Expand All @@ -1207,13 +1233,9 @@ def dpnp_right_shift(x1, x2, out=None, order="K"):
x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2)
out_usm = None if out is None else dpnp.get_usm_ndarray(out)

func = BinaryElementwiseFunc(
"bitwise_right_shift",
ti._bitwise_right_shift_result_type,
ti._bitwise_right_shift,
_right_shift_docstring_,
res_usm = right_shift_func(
x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order
)
res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order)
return dpnp_array._create_from_usm_ndarray(res_usm)


Expand Down
34 changes: 24 additions & 10 deletions tests/test_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,16 +193,16 @@ def test_mean_over_axis_0_unsupported_out_types(
input = dpt.empty((height, width), dtype=input_type, device=device)
output = dpt.empty(width, dtype=output_type, device=device)

if func(input, output):
print(output_type)
assert func(input, output) is None


@pytest.mark.parametrize(
"func, device, input_type, output_type",
product(mean_sum, all_devices, [dpt.float32], [dpt.float32]),
)
def test_mean_over_axis_0_f_contig_input(func, device, input_type, output_type):
def test_mean_sum_over_axis_0_f_contig_input(
func, device, input_type, output_type
):
skip_unsupported(device, input_type)
skip_unsupported(device, output_type)

Expand All @@ -212,16 +212,14 @@ def test_mean_over_axis_0_f_contig_input(func, device, input_type, output_type):
input = dpt.empty((height, width), dtype=input_type, device=device).T
output = dpt.empty(width, dtype=output_type, device=device)

if func(input, output):
print(output_type)
assert func(input, output) is None


@pytest.mark.parametrize(
"func, device, input_type, output_type",
product(mean_sum, all_devices, [dpt.float32], [dpt.float32]),
)
def test_mean_over_axis_0_f_contig_output(
def test_mean_sum_over_axis_0_f_contig_output(
func, device, input_type, output_type
):
skip_unsupported(device, input_type)
Expand All @@ -230,9 +228,25 @@ def test_mean_over_axis_0_f_contig_output(
height = 1
width = 10

input = dpt.empty((height, 10), dtype=input_type, device=device)
output = dpt.empty(20, dtype=output_type, device=device)[::2]
input = dpt.empty((height, width), dtype=input_type, device=device)
output = dpt.empty(width * 2, dtype=output_type, device=device)[::2]

assert func(input, output) is None


@pytest.mark.parametrize(
"func, device, input_type, output_type",
product(mean_sum, all_devices, [dpt.float32], [dpt.float32, dpt.float64]),
)
def test_mean_sum_over_axis_0_big_output(func, device, input_type, output_type):
skip_unsupported(device, input_type)
skip_unsupported(device, output_type)

local_mem_size = device.local_mem_size
height = 1
width = 1 + local_mem_size // output_type.itemsize

input = dpt.empty((height, width), dtype=input_type, device=device)
output = dpt.empty(width, dtype=output_type, device=device)

if func(input, output):
print(output_type)
assert func(input, output) is None
51 changes: 36 additions & 15 deletions tests/test_mathematical.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from itertools import permutations

import numpy
import pytest
from numpy.testing import (
Expand Down Expand Up @@ -1056,23 +1058,42 @@ def test_sum_empty_out(dtype):


@pytest.mark.parametrize(
"shape", [(), (1, 2, 3), (1, 0, 2), (10), (3, 3, 3), (5, 5), (0, 6)]
)
@pytest.mark.parametrize(
"dtype_in", get_all_dtypes(no_complex=True, no_bool=True)
)
@pytest.mark.parametrize(
"dtype_out", get_all_dtypes(no_complex=True, no_bool=True)
"shape",
[
(),
(1, 2, 3),
(1, 0, 2),
(10,),
(3, 3, 3),
(5, 5),
(0, 6),
(10, 1),
(1, 10),
],
)
def test_sum(shape, dtype_in, dtype_out):
a_np = numpy.ones(shape, dtype=dtype_in)
a = dpnp.ones(shape, dtype=dtype_in)
axes = [None, 0, 1, 2]
@pytest.mark.parametrize("dtype_in", get_all_dtypes())
@pytest.mark.parametrize("dtype_out", get_all_dtypes())
@pytest.mark.parametrize("transpose", [True, False])
@pytest.mark.parametrize("keepdims", [True, False])
def test_sum(shape, dtype_in, dtype_out, transpose, keepdims):
size = numpy.prod(shape)
a_np = numpy.arange(size).astype(dtype_in).reshape(shape)
a = dpnp.asarray(a_np)

if transpose:
a_np = a_np.T
a = a.T

axes_range = list(numpy.arange(len(shape)))
axes = [None]
axes += axes_range
axes += permutations(axes_range, 2)
axes.append(tuple(axes_range))

for axis in axes:
if axis is None or axis < a.ndim:
numpy_res = a_np.sum(axis=axis, dtype=dtype_out)
dpnp_res = a.sum(axis=axis, dtype=dtype_out)
assert_array_equal(numpy_res, dpnp_res.asnumpy())
numpy_res = a_np.sum(axis=axis, dtype=dtype_out, keepdims=keepdims)
dpnp_res = a.sum(axis=axis, dtype=dtype_out, keepdims=keepdims)
assert_array_equal(numpy_res, dpnp_res.asnumpy())


class TestMean:
Expand Down

0 comments on commit 77bb04d

Please sign in to comment.