Skip to content

Commit

Permalink
following NEP-50 for dpnp.einsum (#2120)
Browse files Browse the repository at this point in the history
  • Loading branch information
vtavana authored Oct 24, 2024
1 parent 7c45c10 commit 0fbf329
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 22 deletions.
2 changes: 1 addition & 1 deletion dpnp/dpnp_iface_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def correlate(x1, x2, mode="valid"):
-----------
Input arrays are supported as :obj:`dpnp.ndarray`.
Size and shape of input arrays are supported to be equal.
Parameter `mode` is supported only with default value ``"valid``.
Parameter `mode` is supported only with default value ``"valid"``.
Otherwise the function will be executed sequentially on CPU.
Input array data types are limited by supported DPNP :ref:`Data types`.
Expand Down
16 changes: 7 additions & 9 deletions dpnp/dpnp_utils/dpnp_utils_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@
from dpctl.utils import ExecutionPlacementError

import dpnp
from dpnp.dpnp_utils import get_usm_allocations

from ..dpnp_array import dpnp_array
from dpnp.dpnp_array import dpnp_array
from dpnp.dpnp_utils import get_usm_allocations, map_dtype_to_device

_einsum_symbols = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"

Expand Down Expand Up @@ -1027,17 +1026,16 @@ def dpnp_einsum(
"Input and output allocation queues are not compatible"
)

result_dtype = dpnp.result_type(*arrays) if dtype is None else dtype
for id, a in enumerate(operands):
if dpnp.isscalar(a):
scalar_dtype = map_dtype_to_device(type(a), exec_q.sycl_device)
operands[id] = dpnp.array(
a, dtype=result_dtype, usm_type=res_usm_type, sycl_queue=exec_q
a, dtype=scalar_dtype, usm_type=res_usm_type, sycl_queue=exec_q
)
arrays.append(operands[id])
result_dtype = dpnp.result_type(*arrays) if dtype is None else dtype
if order in ["a", "A"]:
order = (
"F" if not any(arr.flags.c_contiguous for arr in arrays) else "C"
)
if order in "aA":
order = "F" if all(arr.flags.fnc for arr in arrays) else "C"

input_subscripts = [
_parse_ellipsis_subscript(sub, idx, ndim=arr.ndim)
Expand Down
12 changes: 5 additions & 7 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,14 +1139,12 @@ def check_einsum_sums(self, dtype, do_opt=False):
result = inp.einsum(*args, dtype="?", casting="unsafe", optimize=do_opt)
assert_dtype_allclose(result, expected)

# with an scalar, NumPy < 2.0.0 uses the other input arrays to
# determine the output type while for NumPy > 2.0.0 the scalar
# with default machine dtype is used to determine the output
# data type
# NumPy >= 2.0 follows NEP-50 to determine the output dtype when one of
# the inputs is a scalar while NumPy < 2.0 does not
if numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0":
check_type = True
else:
check_type = False
else:
check_type = True
a = numpy.arange(9, dtype=dtype)
a_dp = inp.array(a)
expected = numpy.einsum(",i->", 3, a)
Expand Down Expand Up @@ -1712,7 +1710,7 @@ def test_broadcasting_dot_cases(self):

def test_output_order(self):
# Ensure output order is respected for optimize cases, the below
# conraction should yield a reshaped tensor view
# contraction should yield a reshaped tensor view
a = inp.ones((2, 3, 5), order="F")
b = inp.ones((4, 3), order="F")

Expand Down
9 changes: 4 additions & 5 deletions tests/third_party/cupy/linalg_tests/test_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,13 +475,12 @@ def test_einsum_binary(self, xp, dtype_a, dtype_b):


class TestEinSumBinaryOperationWithScalar:
# with an scalar, NumPy < 2.0.0 uses the other input arrays to determine
# the output type while for NumPy > 2.0.0 the scalar with default machine
# dtype is used to determine the output type
# NumPy >= 2.0 follows NEP-50 to determine the output dtype when one of
# the inputs is a scalar while NumPy < 2.0 does not
if numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0":
type_check = has_support_aspect64()
else:
type_check = False
else:
type_check = has_support_aspect64()

@testing.for_all_dtypes()
@testing.numpy_cupy_allclose(contiguous_check=False, type_check=type_check)
Expand Down

0 comments on commit 0fbf329

Please sign in to comment.