Skip to content

Commit d9c0d08

Browse files
committed
Implemented ndarray flag helper class
1 parent 7b368a1 commit d9c0d08

File tree

5 files changed

+116
-30
lines changed

5 files changed

+116
-30
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -261,18 +261,18 @@ def copy(usm_ary, order="K"):
261261
elif order == "F":
262262
copy_order = order
263263
elif order == "A":
264-
if usm_ary.flags & 2:
264+
if usm_ary.flags.f_contiguous:
265265
copy_order = "F"
266266
elif order == "K":
267-
if usm_ary.flags & 2:
267+
if usm_ary.flags.f_contiguous:
268268
copy_order = "F"
269269
else:
270270
raise ValueError(
271271
"Unrecognized value of the order keyword. "
272272
"Recognized values are 'A', 'C', 'F', or 'K'"
273273
)
274-
c_contig = usm_ary.flags & 1
275-
f_contig = usm_ary.flags & 2
274+
c_contig = usm_ary.flags.c_contiguous
275+
f_contig = usm_ary.flags.f_contiguous
276276
R = dpt.usm_ndarray(
277277
usm_ary.shape,
278278
dtype=usm_ary.dtype,
@@ -325,8 +325,8 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
325325
ary_dtype, newdtype, casting
326326
)
327327
)
328-
c_contig = usm_ary.flags & 1
329-
f_contig = usm_ary.flags & 2
328+
c_contig = usm_ary.flags.c_contiguous
329+
f_contig = usm_ary.flags.f_contiguous
330330
needs_copy = copy or not (ary_dtype == target_dtype)
331331
if not needs_copy and (order != "K"):
332332
needs_copy = (c_contig and order not in ["A", "C"]) or (
@@ -339,10 +339,10 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
339339
elif order == "F":
340340
copy_order = order
341341
elif order == "A":
342-
if usm_ary.flags & 2:
342+
if usm_ary.flags.f_contiguous:
343343
copy_order = "F"
344344
elif order == "K":
345-
if usm_ary.flags & 2:
345+
if usm_ary.flags.f_contiguous:
346346
copy_order = "F"
347347
else:
348348
raise ValueError(

dpctl/tensor/_ctors.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,9 @@ def _asarray_from_usm_ndarray(
133133
# sycl_queue is unchanged
134134
can_zero_copy = can_zero_copy and copy_q is usm_ndary.sycl_queue
135135
# order is unchanged
136-
c_contig = usm_ndary.flags & 1
137-
f_contig = usm_ndary.flags & 2
138-
fc_contig = usm_ndary.flags & 3
136+
c_contig = usm_ndary.flags.c_contiguous
137+
f_contig = usm_ndary.flags.f_contiguous
138+
fc_contig = usm_ndary.flags.forc
139139
if can_zero_copy:
140140
if order == "C" and c_contig:
141141
pass
@@ -1130,7 +1130,7 @@ def tril(X, k=0):
11301130
k = operator.index(k)
11311131

11321132
# F_CONTIGUOUS = 2
1133-
order = "F" if (X.flags & 2) else "C"
1133+
order = "F" if (X.flags.f_contiguous) else "C"
11341134

11351135
shape = X.shape
11361136
nd = X.ndim
@@ -1171,7 +1171,7 @@ def triu(X, k=0):
11711171
k = operator.index(k)
11721172

11731173
# F_CONTIGUOUS = 2
1174-
order = "F" if (X.flags & 2) else "C"
1174+
order = "F" if (X.flags.f_contiguous) else "C"
11751175

11761176
shape = X.shape
11771177
nd = X.ndim

dpctl/tensor/_flags.pyx

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2022 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
# distutils: language = c++
18+
# cython: language_level=3
19+
# cython: linetrace=True
20+
21+
from dpctl.tensor._usmarray cimport (
22+
USM_ARRAY_C_CONTIGUOUS,
23+
USM_ARRAY_F_CONTIGUOUS,
24+
USM_ARRAY_WRITEABLE,
25+
)
26+
27+
28+
class Flags:
29+
30+
def __init__(self, arr, flags):
31+
self.arr_ = arr
32+
self.flags_ = flags
33+
34+
@property
35+
def flags(self):
36+
return self.flags_
37+
38+
@property
39+
def c_contiguous(self):
40+
return ((self.flags_ & USM_ARRAY_C_CONTIGUOUS)
41+
== USM_ARRAY_C_CONTIGUOUS)
42+
43+
@property
44+
def f_contiguous(self):
45+
return ((self.flags_ & USM_ARRAY_F_CONTIGUOUS)
46+
== USM_ARRAY_F_CONTIGUOUS)
47+
48+
@property
49+
def writable(self):
50+
return False if ((self.flags & USM_ARRAY_WRITEABLE)
51+
== USM_ARRAY_WRITEABLE) else True
52+
53+
@property
54+
def forc(self):
55+
return True if (((self.flags_ & USM_ARRAY_F_CONTIGUOUS)
56+
== USM_ARRAY_F_CONTIGUOUS)
57+
or ((self.flags_ & USM_ARRAY_C_CONTIGUOUS)
58+
== USM_ARRAY_C_CONTIGUOUS)) else False
59+
60+
@property
61+
def fnc(self):
62+
return True if (((self.flags_ & USM_ARRAY_F_CONTIGUOUS)
63+
== USM_ARRAY_F_CONTIGUOUS)
64+
and not ((self.flags_ & USM_ARRAY_C_CONTIGUOUS)
65+
== USM_ARRAY_C_CONTIGUOUS)) else False
66+
67+
@property
68+
def contiguous(self):
69+
return self.forc
70+
71+
def __getitem__(self, name):
72+
if name in ["C_CONTIGUOUS", "C"]:
73+
return self.c_contiguous
74+
elif name in ["F_CONTIGUOUS", "F"]:
75+
return self.f_contiguous
76+
elif name == "WRITABLE":
77+
return self.writable
78+
elif name == "CONTIGUOUS":
79+
return self.forc
80+
81+
def __repr__(self):
82+
out = []
83+
for name in "C_CONTIGUOUS", "F_CONTIGUOUS", "WRITABLE":
84+
out.append(" {} : {}".format(name, self[name]))
85+
return '\n'.join(out)

dpctl/tensor/_usmarray.pyx

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ from cpython.tuple cimport PyTuple_New, PyTuple_SetItem
3333
cimport dpctl as c_dpctl
3434
cimport dpctl.memory as c_dpmem
3535
cimport dpctl.tensor._dlpack as c_dlpack
36+
import dpctl.tensor._flags as _flags
3637

3738
include "_stride_utils.pxi"
3839
include "_types.pxi"
@@ -503,9 +504,9 @@ cdef class usm_ndarray:
503504
@property
504505
def flags(self):
505506
"""
506-
Currently returns integer whose bits correspond to the flags.
507+
Returns dpctl.tensor._flags object.
507508
"""
508-
return self.flags_
509+
return _flags.Flags(self, self.flags_)
509510

510511
@property
511512
def usm_type(self):
@@ -663,7 +664,7 @@ cdef class usm_ndarray:
663664
strides=self.strides,
664665
offset=self.get_offset()
665666
)
666-
res.flags_ = self.flags
667+
res.flags_ = self.flags.flags
667668
return res
668669
else:
669670
nbytes = self.usm_data.nbytes
@@ -678,7 +679,7 @@ cdef class usm_ndarray:
678679
strides=self.strides,
679680
offset=self.get_offset()
680681
)
681-
res.flags_ = self.flags
682+
res.flags_ = self.flags.flags
682683
return res
683684

684685
def _set_namespace(self, mod):

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,13 @@ def test_allocate_usm_ndarray(shape, usm_type):
5959

6060

6161
def test_usm_ndarray_flags():
62-
assert dpt.usm_ndarray((5,)).flags == 3
63-
assert dpt.usm_ndarray((5, 2)).flags == 1
64-
assert dpt.usm_ndarray((5, 2), order="F").flags == 2
65-
assert dpt.usm_ndarray((5, 1, 2), order="F").flags == 2
66-
assert dpt.usm_ndarray((5, 1, 2), strides=(2, 0, 1)).flags == 1
67-
assert dpt.usm_ndarray((5, 1, 2), strides=(1, 0, 5)).flags == 2
68-
assert dpt.usm_ndarray((5, 1, 1), strides=(1, 0, 1)).flags == 3
62+
assert dpt.usm_ndarray((5,)).flags.flags == 3
63+
assert dpt.usm_ndarray((5, 2)).flags.flags == 1
64+
assert dpt.usm_ndarray((5, 2), order="F").flags.flags == 2
65+
assert dpt.usm_ndarray((5, 1, 2), order="F").flags.flags == 2
66+
assert dpt.usm_ndarray((5, 1, 2), strides=(2, 0, 1)).flags.flags == 1
67+
assert dpt.usm_ndarray((5, 1, 2), strides=(1, 0, 5)).flags.flags == 2
68+
assert dpt.usm_ndarray((5, 1, 1), strides=(1, 0, 1)).flags.flags == 3
6969

7070

7171
@pytest.mark.parametrize(
@@ -326,7 +326,7 @@ def test_usm_ndarray_props():
326326
Xusm = dpt.usm_ndarray((10, 5), dtype="c16", order="F")
327327
Xusm.ndim
328328
repr(Xusm)
329-
Xusm.flags
329+
Xusm.flags.flags
330330
Xusm.__sycl_usm_array_interface__
331331
Xusm.device
332332
Xusm.strides
@@ -465,7 +465,7 @@ def test_pyx_capi_get_flags():
465465
fn_restype=ctypes.c_int,
466466
)
467467
flags = get_flags_fn(X)
468-
assert type(flags) is int and flags == X.flags
468+
assert type(flags) is int and flags == X.flags.flags
469469

470470

471471
def test_pyx_capi_get_offset():
@@ -753,7 +753,7 @@ def relaxed_strides_equal(st1, st2, sh):
753753
X.shape = sh_f
754754
assert X.shape == sh_f
755755
assert relaxed_strides_equal(X.strides, cc_strides(sh_f), sh_f)
756-
assert X.flags & 1, "reshaped array expected to be C-contiguous"
756+
assert X.flags.c_contiguous, "reshaped array expected to be C-contiguous"
757757

758758
sh_s = (
759759
2,
@@ -919,7 +919,7 @@ def test_reshape():
919919

920920
X = dpt.usm_ndarray((1,))
921921
Y = dpt.reshape(X, X.shape)
922-
assert Y.flags == X.flags
922+
assert Y.flags.flags == X.flags.flags
923923

924924
A = dpt.usm_ndarray((0,), "i4")
925925
A1 = dpt.reshape(A, (0,))
@@ -1402,7 +1402,7 @@ def test_triu_order_k(order, k):
14021402
Xnp = np.arange(np.prod(shape), dtype="int").reshape(shape, order=order)
14031403
Ynp = np.triu(Xnp, k)
14041404
assert Y.dtype == Ynp.dtype
1405-
assert X.flags == Y.flags
1405+
assert X.flags.flags == Y.flags.flags
14061406
assert np.array_equal(Ynp, dpt.asnumpy(Y))
14071407

14081408

@@ -1423,7 +1423,7 @@ def test_tril_order_k(order, k):
14231423
Xnp = np.arange(np.prod(shape), dtype="int").reshape(shape, order=order)
14241424
Ynp = np.tril(Xnp, k)
14251425
assert Y.dtype == Ynp.dtype
1426-
assert X.flags == Y.flags
1426+
assert X.flags.flags == Y.flags.flags
14271427
assert np.array_equal(Ynp, dpt.asnumpy(Y))
14281428

14291429

0 commit comments

Comments
 (0)