Skip to content

Commit 215ef2f

Browse files
Simplified tests to leverage support for __eq__
Added tests to covert methods of Flags class
1 parent d628bed commit 215ef2f

File tree

1 file changed

+26
-12
lines changed

1 file changed

+26
-12
lines changed

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,13 @@ def test_allocate_usm_ndarray(shape, usm_type):
5959

6060

6161
def test_usm_ndarray_flags():
62-
assert dpt.usm_ndarray((5,)).flags.flags == 3
63-
assert dpt.usm_ndarray((5, 2)).flags.flags == 1
64-
assert dpt.usm_ndarray((5, 2), order="F").flags.flags == 2
65-
assert dpt.usm_ndarray((5, 1, 2), order="F").flags.flags == 2
66-
assert dpt.usm_ndarray((5, 1, 2), strides=(2, 0, 1)).flags.flags == 1
67-
assert dpt.usm_ndarray((5, 1, 2), strides=(1, 0, 5)).flags.flags == 2
68-
assert dpt.usm_ndarray((5, 1, 1), strides=(1, 0, 1)).flags.flags == 3
62+
assert dpt.usm_ndarray((5,)).flags.fnc
63+
assert dpt.usm_ndarray((5, 2)).flags.c_contiguous
64+
assert dpt.usm_ndarray((5, 2), order="F").flags.f_contiguous
65+
assert dpt.usm_ndarray((5, 1, 2), order="F").flags.f_contiguous
66+
assert dpt.usm_ndarray((5, 1, 2), strides=(2, 0, 1)).flags.c_contiguous
67+
assert dpt.usm_ndarray((5, 1, 2), strides=(1, 0, 5)).flags.f_contiguous
68+
assert dpt.usm_ndarray((5, 1, 1), strides=(1, 0, 1)).flags.fnc
6969

7070

7171
@pytest.mark.parametrize(
@@ -326,7 +326,7 @@ def test_usm_ndarray_props():
326326
Xusm = dpt.usm_ndarray((10, 5), dtype="c16", order="F")
327327
Xusm.ndim
328328
repr(Xusm)
329-
Xusm.flags.flags
329+
Xusm.flags
330330
Xusm.__sycl_usm_array_interface__
331331
Xusm.device
332332
Xusm.strides
@@ -465,7 +465,7 @@ def test_pyx_capi_get_flags():
465465
fn_restype=ctypes.c_int,
466466
)
467467
flags = get_flags_fn(X)
468-
assert type(flags) is int and flags == X.flags.flags
468+
assert type(flags) is int and X.flags == flags
469469

470470

471471
def test_pyx_capi_get_offset():
@@ -919,7 +919,7 @@ def test_reshape():
919919

920920
X = dpt.usm_ndarray((1,))
921921
Y = dpt.reshape(X, X.shape)
922-
assert Y.flags.flags == X.flags.flags
922+
assert Y.flags == X.flags
923923

924924
A = dpt.usm_ndarray((0,), "i4")
925925
A1 = dpt.reshape(A, (0,))
@@ -1402,7 +1402,7 @@ def test_triu_order_k(order, k):
14021402
Xnp = np.arange(np.prod(shape), dtype="int").reshape(shape, order=order)
14031403
Ynp = np.triu(Xnp, k)
14041404
assert Y.dtype == Ynp.dtype
1405-
assert X.flags.flags == Y.flags.flags
1405+
assert X.flags == Y.flags
14061406
assert np.array_equal(Ynp, dpt.asnumpy(Y))
14071407

14081408

@@ -1423,7 +1423,7 @@ def test_tril_order_k(order, k):
14231423
Xnp = np.arange(np.prod(shape), dtype="int").reshape(shape, order=order)
14241424
Ynp = np.tril(Xnp, k)
14251425
assert Y.dtype == Ynp.dtype
1426-
assert X.flags.flags == Y.flags.flags
1426+
assert X.flags == Y.flags
14271427
assert np.array_equal(Ynp, dpt.asnumpy(Y))
14281428

14291429

@@ -1463,3 +1463,17 @@ def test_common_arg_validation():
14631463
dpt.tril(X)
14641464
with pytest.raises(TypeError):
14651465
dpt.triu(X)
1466+
1467+
1468+
def test_flags():
1469+
x = dpt.empty(tuple(), "i4")
1470+
f = x.flags
1471+
f.__repr__()
1472+
f.c_contiguous
1473+
f.f_contiguous
1474+
f.contiguous
1475+
f.fnc
1476+
f.forc
1477+
f.writable
1478+
# check comparison with generic types
1479+
f == Ellipsis

0 commit comments

Comments
 (0)