Skip to content

Commit 06533eb

Browse files
authored
Add operation __index__ and __complex__ (#1285)
* Add operation __index__ and __complex__ * Add tests
1 parent 9308f64 commit 06533eb

File tree

2 files changed

+47
-15
lines changed

2 files changed

+47
-15
lines changed

dpnp/dpnp_array.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,10 @@ def __bool__(self):
140140
return self._array_obj.__bool__()
141141

142142
# '__class__',
143-
# '__complex__',
143+
144+
def __complex__(self):
145+
return self._array_obj.__complex__()
146+
144147
# '__contains__',
145148
# '__copy__',
146149
# '__deepcopy__',
@@ -187,7 +190,10 @@ def __gt__(self, other):
187190
# '__imatmul__',
188191
# '__imod__',
189192
# '__imul__',
190-
# '__index__',
193+
194+
def __index__(self):
195+
return self._array_obj.__index__()
196+
191197
# '__init__',
192198
# '__init_subclass__',
193199

tests/test_dparray.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
1-
import dpnp
2-
import numpy
31
import pytest
2+
from .helper import get_all_dtypes
3+
4+
import dpnp
45
import dpctl.tensor as dpt
56

7+
import numpy
8+
from numpy.testing import (
9+
assert_array_equal
10+
)
11+
612

7-
@pytest.mark.parametrize("res_dtype",
8-
[numpy.float64, numpy.float32, numpy.int64, numpy.int32, numpy.bool_, numpy.complex_],
9-
ids=['float64', 'float32', 'int64', 'int32', 'bool', 'complex'])
10-
@pytest.mark.parametrize("arr_dtype",
11-
[numpy.float64, numpy.float32, numpy.int64, numpy.int32, numpy.bool_, numpy.complex_],
12-
ids=['float64', 'float32', 'int64', 'int32', 'bool', 'complex'])
13+
@pytest.mark.parametrize("res_dtype", get_all_dtypes())
14+
@pytest.mark.parametrize("arr_dtype", get_all_dtypes())
1315
@pytest.mark.parametrize("arr",
1416
[[-2, -1, 0, 1, 2], [[-2, -1], [1, 2]], []],
1517
ids=['[-2, -1, 0, 1, 2]', '[[-2, -1], [1, 2]]', '[]'])
@@ -18,12 +20,10 @@ def test_astype(arr, arr_dtype, res_dtype):
1820
dpnp_array = dpnp.array(numpy_array)
1921
expected = numpy_array.astype(res_dtype)
2022
result = dpnp_array.astype(res_dtype)
21-
numpy.testing.assert_array_equal(expected, result)
23+
assert_array_equal(expected, result)
2224

2325

24-
@pytest.mark.parametrize("arr_dtype",
25-
[numpy.float64, numpy.float32, numpy.int64, numpy.int32, numpy.bool_, numpy.complex_],
26-
ids=['float64', 'float32', 'int64', 'int32', 'bool', 'complex'])
26+
@pytest.mark.parametrize("arr_dtype", get_all_dtypes())
2727
@pytest.mark.parametrize("arr",
2828
[[-2, -1, 0, 1, 2], [[-2, -1], [1, 2]], []],
2929
ids=['[-2, -1, 0, 1, 2]', '[[-2, -1], [1, 2]]', '[]'])
@@ -32,7 +32,7 @@ def test_flatten(arr, arr_dtype):
3232
dpnp_array = dpnp.array(arr, dtype=arr_dtype)
3333
expected = numpy_array.flatten()
3434
result = dpnp_array.flatten()
35-
numpy.testing.assert_array_equal(expected, result)
35+
assert_array_equal(expected, result)
3636

3737

3838
@pytest.mark.parametrize("shape",
@@ -68,3 +68,29 @@ def test_flags_strides(dtype, order, strides):
6868
assert usm_array.flags == dpnp_array.flags
6969
assert numpy_array.flags.c_contiguous == dpnp_array.flags.c_contiguous
7070
assert numpy_array.flags.f_contiguous == dpnp_array.flags.f_contiguous
71+
72+
73+
@pytest.mark.parametrize("func", [bool, float, int, complex])
74+
@pytest.mark.parametrize("shape", [tuple(), (1,), (1, 1), (1, 1, 1)])
75+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_float16=False, no_complex=True))
76+
def test_scalar_type_casting(func, shape, dtype):
77+
numpy_array = numpy.full(shape, 5, dtype=dtype)
78+
dpnp_array = dpnp.full(shape, 5, dtype=dtype)
79+
assert func(numpy_array) == func(dpnp_array)
80+
81+
82+
@pytest.mark.parametrize("method", ["__bool__", "__float__", "__int__", "__complex__"])
83+
@pytest.mark.parametrize("shape", [tuple(), (1,), (1, 1), (1, 1, 1)])
84+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_float16=False, no_complex=True, no_none=True))
85+
def test_scalar_type_casting_by_method(method, shape, dtype):
86+
numpy_array = numpy.full(shape, 4.7, dtype=dtype)
87+
dpnp_array = dpnp.full(shape, 4.7, dtype=dtype)
88+
assert getattr(numpy_array, method)() == getattr(dpnp_array, method)()
89+
90+
91+
@pytest.mark.parametrize("shape", [(1,), (1, 1), (1, 1, 1)])
92+
@pytest.mark.parametrize("index_dtype", [dpnp.int32, dpnp.int64])
93+
def test_array_as_index(shape, index_dtype):
94+
ind_arr = dpnp.ones(shape, dtype=index_dtype)
95+
a = numpy.arange(ind_arr.size + 1)
96+
assert a[tuple(ind_arr)] == a[1]

0 commit comments

Comments
 (0)