Skip to content

Commit 85f12d0

Browse files
committed
Test the writable flag of a usm_ndarray from read-only DLPack capsule
Also adds a test checking for the writable flag of a usm_ndarray constructed with a read-only usm_ndarray as the buffer Removes some commented out `flags` checks from `test_meshgrid2`. These checks were malformed, as `dpt.meshgrid` returns a non-contiguous view in the test.
1 parent aba58c4 commit 85f12d0

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,15 @@ def test_usm_ndarray_writable_flag_views():
148148
assert not a.imag.flags.writable
149149

150150

151+
def test_usm_ndarray_from_usm_ndarray_readonly():
152+
get_queue_or_skip()
153+
154+
x1 = dpt.arange(10, dtype="f4")
155+
x1.flags["W"] = False
156+
x2 = dpt.usm_ndarray(x1.shape, dtype="f4", buffer=x1)
157+
assert not x2.flags.writable
158+
159+
151160
@pytest.mark.parametrize(
152161
"dtype",
153162
[
@@ -2159,9 +2168,6 @@ def test_meshgrid2():
21592168
assert z1.shape == z2.shape and z2.shape == z3.shape
21602169
assert y1.shape == (len(x2), len(x1), len(x3))
21612170
assert z1.shape == (len(x1), len(x2), len(x3))
2162-
# FIXME: uncomment out once gh-921 is merged
2163-
# assert all(z.flags["C"] for z in (z1, z2, z3))
2164-
# assert all(y.flags["C"] for y in (y1, y2, y3))
21652171

21662172

21672173
def test_common_arg_validation():

dpctl/tests/test_usm_ndarray_dlpack.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,8 +314,10 @@ def test_versioned_dlpack_capsule():
314314
cap = x.__dlpack__(max_version=max_supported_ver)
315315
y = _dlp.from_dlpack_versioned_capsule(cap)
316316
assert x._pointer == y._pointer
317+
assert not y.flags.writable
317318

318319
# read-only array, and copy
319320
cap = x.__dlpack__(max_version=max_supported_ver, copy=True)
320321
y = _dlp.from_dlpack_versioned_capsule(cap)
321322
assert x._pointer != y._pointer
323+
assert not y.flags.writable

0 commit comments

Comments
 (0)