Skip to content

Commit bd9e465

Browse files
Merge pull request #1141 from IntelPython/flags-gh-1138
Allow for mutation of usm_ndarray.flags.writable flag
2 parents 26b2eaa + b63907a commit bd9e465

File tree

4 files changed

+36
-4
lines changed

4 files changed

+36
-4
lines changed

dpctl/tensor/_flags.pyx

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@ cdef class Flags:
7575
"""
7676
return _check_bit(self.flags_, USM_ARRAY_WRITABLE)
7777

78+
@writable.setter
79+
def writable(self, new_val):
80+
if not isinstance(new_val, bool):
81+
raise TypeError("Expecting a boolean value")
82+
self.arr_._set_writable_flag(new_val)
83+
7884
@property
7985
def fc(self):
8086
"""
@@ -129,6 +135,14 @@ cdef class Flags:
129135
elif name == "CONTIGUOUS":
130136
return self.forc
131137

138+
def __setitem__(self, name, val):
139+
if name in ["WRITABLE", "W"]:
140+
self.writable = val
141+
else:
142+
raise ValueError(
143+
"Only writable ('W' or 'WRITABLE') flag can be set"
144+
)
145+
132146
def __repr__(self):
133147
out = []
134148
for name in "C_CONTIGUOUS", "F_CONTIGUOUS", "WRITABLE":

dpctl/tensor/_usmarray.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,6 @@ cdef api class usm_ndarray [object PyUSMArrayObject, type PyUSMArrayType]:
7272
cdef dpctl.DPCTLSyclQueueRef get_queue_ref(self) except *
7373
cdef dpctl.SyclQueue get_sycl_queue(self)
7474

75+
cdef _set_writable_flag(self, int)
76+
7577
cdef __cythonbufferdefaults__ = {"mode": "strided"}

dpctl/tensor/_usmarray.pyx

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,12 @@ cdef class usm_ndarray:
532532
"""
533533
return _flags.Flags(self, self.flags_)
534534

535+
cdef _set_writable_flag(self, int flag):
536+
cdef int arr_fl = self.flags_
537+
arr_fl ^= (arr_fl & USM_ARRAY_WRITABLE) # unset WRITABLE flag
538+
arr_fl |= (USM_ARRAY_WRITABLE if flag else 0)
539+
self.flags_ = arr_fl
540+
535541
@property
536542
def usm_type(self):
537543
"""
@@ -1390,12 +1396,10 @@ cdef api Py_ssize_t UsmNDArray_GetOffset(usm_ndarray arr):
13901396
allocation"""
13911397
return arr.get_offset()
13921398

1399+
13931400
cdef api void UsmNDArray_SetWritableFlag(usm_ndarray arr, int flag):
13941401
"""Set/unset USM_ARRAY_WRITABLE in the given array `arr`."""
1395-
cdef int arr_fl = arr.flags_
1396-
arr_fl ^= (arr_fl & USM_ARRAY_WRITABLE) # unset WRITABLE flag
1397-
arr_fl |= (USM_ARRAY_WRITABLE if flag else 0)
1398-
arr.flags_ = arr_fl
1402+
arr._set_writable_flag(flag)
13991403

14001404
cdef api object UsmNDArray_MakeSimpleFromMemory(
14011405
int nd, const Py_ssize_t *shape, int typenum,

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def test_allocate_usm_ndarray(shape, usm_type):
5757

5858

5959
def test_usm_ndarray_flags():
60+
get_queue_or_skip()
6061
assert dpt.usm_ndarray((5,), dtype="i4").flags.fc
6162
assert dpt.usm_ndarray((5, 2), dtype="i4").flags.c_contiguous
6263
assert dpt.usm_ndarray((5, 2), dtype="i4", order="F").flags.f_contiguous
@@ -68,6 +69,17 @@ def test_usm_ndarray_flags():
6869
(5, 1, 2), dtype="i4", strides=(1, 0, 5)
6970
).flags.f_contiguous
7071
assert dpt.usm_ndarray((5, 1, 1), dtype="i4", strides=(1, 0, 1)).flags.fc
72+
x = dpt.empty(5, dtype="u2")
73+
assert x.flags.writable is True
74+
x.flags.writable = False
75+
assert x.flags.writable is False
76+
with pytest.raises(ValueError):
77+
x[:] = 0
78+
x.flags["W"] = True
79+
assert x.flags.writable is True
80+
x.flags["WRITABLE"] = True
81+
assert x.flags.writable is True
82+
x[:] = 0
7183

7284

7385
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)