Skip to content

Commit c49ab5b

Browse files
committed
Merge branch 'add_by_scalar' of https://github.com/antonwolfy/dpnp into add_by_scalar
2 parents 430eca3 + 398225d commit c49ab5b

File tree

8 files changed

+87
-29
lines changed

8 files changed

+87
-29
lines changed

dpnp/backend/include/dpnp_gen_2arg_3type_tbl.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@
111111

112112
MACRO_2ARG_3TYPES_OP(dpnp_add_c,
113113
input1_elem + input2_elem,
114-
sycl::add_sat(x1, x2),
115-
MACRO_UNPACK_TYPES(int, long),
114+
nullptr,
115+
std::false_type,
116116
oneapi::mkl::vm::add,
117117
MACRO_UNPACK_TYPES(float, double, std::complex<float>, std::complex<double>))
118118

dpnp/dpnp_array.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,10 @@ def __bool__(self):
140140
return self._array_obj.__bool__()
141141

142142
# '__class__',
143-
# '__complex__',
143+
144+
def __complex__(self):
145+
return self._array_obj.__complex__()
146+
144147
# '__contains__',
145148
# '__copy__',
146149
# '__deepcopy__',
@@ -187,7 +190,10 @@ def __gt__(self, other):
187190
# '__imatmul__',
188191
# '__imod__',
189192
# '__imul__',
190-
# '__index__',
193+
194+
def __index__(self):
195+
return self._array_obj.__index__()
196+
191197
# '__init__',
192198
# '__init_subclass__',
193199

dpnp/dpnp_iface_arraycreation.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from dpnp.dpnp_utils import *
4949

5050
import dpnp.dpnp_container as dpnp_container
51+
import dpctl.tensor as dpt
5152

5253

5354
__all__ = [
@@ -530,7 +531,7 @@ def empty_like(x1,
530531
531532
Limitations
532533
-----------
533-
Parameters ``x1`` is supported only as :class:`dpnp.dpnp_array`.
534+
Parameter ``x1`` is supported as :class:`dpnp.dpnp_array` or :class:`dpctl.tensor.usm_ndarray`
534535
Parameter ``order`` is supported with values ``"C"`` or ``"F"``.
535536
Parameter ``subok`` is supported only with default value ``False``.
536537
Otherwise the function will be executed sequentially on CPU.
@@ -552,7 +553,7 @@ def empty_like(x1,
552553
553554
"""
554555

555-
if not isinstance(x1, dpnp.ndarray):
556+
if not isinstance(x1, (dpnp.ndarray, dpt.usm_ndarray)):
556557
pass
557558
elif order not in ('C', 'c', 'F', 'f', None):
558559
pass
@@ -762,7 +763,7 @@ def full_like(x1,
762763
763764
Limitations
764765
-----------
765-
Parameters ``x1`` is supported only as :class:`dpnp.dpnp_array`.
766+
Parameter ``x1`` is supported as :class:`dpnp.dpnp_array` or :class:`dpctl.tensor.usm_ndarray`
766767
Parameter ``order`` is supported only with values ``"C"`` and ``"F"``.
767768
Parameter ``subok`` is supported only with default value ``False``.
768769
Otherwise the function will be executed sequentially on CPU.
@@ -783,7 +784,7 @@ def full_like(x1,
783784
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
784785
785786
"""
786-
if not isinstance(x1, dpnp.ndarray):
787+
if not isinstance(x1, (dpnp.ndarray, dpt.usm_ndarray)):
787788
pass
788789
elif order not in ('C', 'c', 'F', 'f', None):
789790
pass
@@ -1189,7 +1190,7 @@ def ones_like(x1,
11891190
11901191
Limitations
11911192
-----------
1192-
Parameters ``x1`` is supported only as :class:`dpnp.dpnp_array`.
1193+
Parameter ``x1`` is supported as :class:`dpnp.dpnp_array` or :class:`dpctl.tensor.usm_ndarray`
11931194
Parameter ``order`` is supported with values ``"C"`` or ``"F"``.
11941195
Parameter ``subok`` is supported only with default value ``False``.
11951196
Otherwise the function will be executed sequentially on CPU.
@@ -1211,7 +1212,7 @@ def ones_like(x1,
12111212
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
12121213
12131214
"""
1214-
if not isinstance(x1, dpnp.ndarray):
1215+
if not isinstance(x1, (dpnp.ndarray, dpt.usm_ndarray)):
12151216
pass
12161217
elif order not in ('C', 'c', 'F', 'f', None):
12171218
pass
@@ -1502,7 +1503,7 @@ def zeros_like(x1,
15021503
15031504
Limitations
15041505
-----------
1505-
Parameters ``x1`` is supported only as :class:`dpnp.dpnp_array`.
1506+
Parameter ``x1`` is supported as :class:`dpnp.dpnp_array` or :class:`dpctl.tensor.usm_ndarray`
15061507
Parameter ``order`` is supported with values ``"C"`` or ``"F"``.
15071508
Parameter ``subok`` is supported only with default value ``False``.
15081509
Otherwise the function will be executed sequentially on CPU.
@@ -1523,8 +1524,8 @@ def zeros_like(x1,
15231524
>>> [i for i in np.zeros_like(x)]
15241525
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
15251526
1526-
"""
1527-
if not isinstance(x1, dpnp.ndarray):
1527+
"""
1528+
if not isinstance(x1, (dpnp.ndarray, dpt.usm_ndarray)):
15281529
pass
15291530
elif order not in ('C', 'c', 'F', 'f', None):
15301531
pass

dpnp/random/dpnp_random_state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def randint(self, low, high=None, size=None, dtype=int, usm_type="device"):
337337
if not use_origin_backend(low):
338338
if not dpnp.isscalar(low):
339339
pass
340-
elif not dpnp.isscalar(high):
340+
elif not (high is None or dpnp.isscalar(high)):
341341
pass
342342
else:
343343
_dtype = dpnp.int32 if dtype is int else dpnp.dtype(dtype)

tests/skipped_tests.tbl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,10 @@ tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNu
769769
tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNumpyInput_param_10_{name='remainder', nargs=2}::test_raises_with_numpy_input
770770
tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNumpyInput_param_11_{name='mod', nargs=2}::test_raises_with_numpy_input
771771
tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNumpyInput_param_1_{name='angle', nargs=1}::test_raises_with_numpy_input
772+
tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNumpyInput_param_4_{name='divide', nargs=2}::test_raises_with_numpy_input
773+
tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNumpyInput_param_5_{name='power', nargs=2}::test_raises_with_numpy_input
774+
tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNumpyInput_param_6_{name='subtract', nargs=2}::test_raises_with_numpy_input
775+
tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNumpyInput_param_7_{name='true_divide', nargs=2}::test_raises_with_numpy_input
772776
tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNumpyInput_param_8_{name='floor_divide', nargs=2}::test_raises_with_numpy_input
773777
tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNumpyInput_param_9_{name='fmod', nargs=2}::test_raises_with_numpy_input
774778
tests/third_party/cupy/math_tests/test_arithmetic.py::TestBoolSubtract_param_3_{shape=(), xp=dpnp}::test_bool_subtract

tests/test_arraycreation.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,3 +485,26 @@ def test_ones_like(array, dtype, order):
485485
a = numpy.array(array)
486486
ia = dpnp.array(array)
487487
assert_array_equal(func(numpy, a), func(dpnp, ia))
488+
489+
490+
@pytest.mark.parametrize(
491+
"func, args",
492+
[
493+
pytest.param("full_like",
494+
['x0', '4']),
495+
pytest.param("zeros_like",
496+
['x0']),
497+
pytest.param("ones_like",
498+
['x0']),
499+
pytest.param("empty_like",
500+
['x0']),
501+
])
502+
def test_dpctl_tensor_input(func, args):
503+
x0 = dpt.reshape(dpt.arange(9), (3,3))
504+
new_args = [eval(val, {'x0' : x0}) for val in args]
505+
X = getattr(dpt, func)(*new_args)
506+
Y = getattr(dpnp, func)(*new_args)
507+
if func is 'empty_like':
508+
assert X.shape == Y.shape
509+
else:
510+
assert_array_equal(X, Y)

tests/test_dparray.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
1-
import dpnp
2-
import numpy
31
import pytest
2+
from .helper import get_all_dtypes
3+
4+
import dpnp
45
import dpctl.tensor as dpt
56

7+
import numpy
8+
from numpy.testing import (
9+
assert_array_equal
10+
)
11+
612

7-
@pytest.mark.parametrize("res_dtype",
8-
[numpy.float64, numpy.float32, numpy.int64, numpy.int32, numpy.bool_, numpy.complex_],
9-
ids=['float64', 'float32', 'int64', 'int32', 'bool', 'complex'])
10-
@pytest.mark.parametrize("arr_dtype",
11-
[numpy.float64, numpy.float32, numpy.int64, numpy.int32, numpy.bool_, numpy.complex_],
12-
ids=['float64', 'float32', 'int64', 'int32', 'bool', 'complex'])
13+
@pytest.mark.parametrize("res_dtype", get_all_dtypes())
14+
@pytest.mark.parametrize("arr_dtype", get_all_dtypes())
1315
@pytest.mark.parametrize("arr",
1416
[[-2, -1, 0, 1, 2], [[-2, -1], [1, 2]], []],
1517
ids=['[-2, -1, 0, 1, 2]', '[[-2, -1], [1, 2]]', '[]'])
@@ -18,12 +20,10 @@ def test_astype(arr, arr_dtype, res_dtype):
1820
dpnp_array = dpnp.array(numpy_array)
1921
expected = numpy_array.astype(res_dtype)
2022
result = dpnp_array.astype(res_dtype)
21-
numpy.testing.assert_array_equal(expected, result)
23+
assert_array_equal(expected, result)
2224

2325

24-
@pytest.mark.parametrize("arr_dtype",
25-
[numpy.float64, numpy.float32, numpy.int64, numpy.int32, numpy.bool_, numpy.complex_],
26-
ids=['float64', 'float32', 'int64', 'int32', 'bool', 'complex'])
26+
@pytest.mark.parametrize("arr_dtype", get_all_dtypes())
2727
@pytest.mark.parametrize("arr",
2828
[[-2, -1, 0, 1, 2], [[-2, -1], [1, 2]], []],
2929
ids=['[-2, -1, 0, 1, 2]', '[[-2, -1], [1, 2]]', '[]'])
@@ -32,7 +32,7 @@ def test_flatten(arr, arr_dtype):
3232
dpnp_array = dpnp.array(arr, dtype=arr_dtype)
3333
expected = numpy_array.flatten()
3434
result = dpnp_array.flatten()
35-
numpy.testing.assert_array_equal(expected, result)
35+
assert_array_equal(expected, result)
3636

3737

3838
@pytest.mark.parametrize("shape",
@@ -68,3 +68,29 @@ def test_flags_strides(dtype, order, strides):
6868
assert usm_array.flags == dpnp_array.flags
6969
assert numpy_array.flags.c_contiguous == dpnp_array.flags.c_contiguous
7070
assert numpy_array.flags.f_contiguous == dpnp_array.flags.f_contiguous
71+
72+
73+
@pytest.mark.parametrize("func", [bool, float, int, complex])
74+
@pytest.mark.parametrize("shape", [tuple(), (1,), (1, 1), (1, 1, 1)])
75+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_float16=False, no_complex=True))
76+
def test_scalar_type_casting(func, shape, dtype):
77+
numpy_array = numpy.full(shape, 5, dtype=dtype)
78+
dpnp_array = dpnp.full(shape, 5, dtype=dtype)
79+
assert func(numpy_array) == func(dpnp_array)
80+
81+
82+
@pytest.mark.parametrize("method", ["__bool__", "__float__", "__int__", "__complex__"])
83+
@pytest.mark.parametrize("shape", [tuple(), (1,), (1, 1), (1, 1, 1)])
84+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_float16=False, no_complex=True, no_none=True))
85+
def test_scalar_type_casting_by_method(method, shape, dtype):
86+
numpy_array = numpy.full(shape, 4.7, dtype=dtype)
87+
dpnp_array = dpnp.full(shape, 4.7, dtype=dtype)
88+
assert getattr(numpy_array, method)() == getattr(dpnp_array, method)()
89+
90+
91+
@pytest.mark.parametrize("shape", [(1,), (1, 1), (1, 1, 1)])
92+
@pytest.mark.parametrize("index_dtype", [dpnp.int32, dpnp.int64])
93+
def test_array_as_index(shape, index_dtype):
94+
ind_arr = dpnp.ones(shape, dtype=index_dtype)
95+
a = numpy.arange(ind_arr.size + 1)
96+
assert a[tuple(ind_arr)] == a[1]

tests/third_party/cupy/random_tests/test_sample.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ def test_lo_hi_nonrandom(self):
3333
a = random.randint(-1.1, -0.9, size=(2, 2))
3434
numpy.testing.assert_array_equal(a, cupy.full((2, 2), -1))
3535

36-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
3736
def test_zero_sizes(self):
3837
a = random.randint(10, size=(0,))
3938
numpy.testing.assert_array_equal(a, cupy.array(()))
@@ -112,7 +111,6 @@ def test_goodness_of_fit_2(self):
112111
self.assertTrue(hypothesis.chi_square_test(counts, expected))
113112

114113

115-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
116114
@testing.gpu
117115
class TestRandintDtype(unittest.TestCase):
118116

0 commit comments

Comments
 (0)