Skip to content

Commit 8cbed99

Browse files
Merge pull request #921 from IntelPython/flags-helper-class
Implements dpctl.tensor._flags.Flags
2 parents 72daccc + 74043c3 commit 8cbed99

File tree

6 files changed

+154
-26
lines changed

6 files changed

+154
-26
lines changed

.flake8

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ per-file-ignores =
2525
dpctl/program/_program.pyx: E999, E225, E226, E227
2626
dpctl/tensor/_usmarray.pyx: E999, E225, E226, E227
2727
dpctl/tensor/_dlpack.pyx: E999, E225, E226, E227
28+
dpctl/tensor/_flags.pyx: E999, E225, E226, E227
2829
dpctl/tensor/numpy_usm_shared.py: F821
2930
dpctl/tests/_cython_api.pyx: E999, E225, E227, E402
3031
dpctl/utils/_compute_follows_data.pyx: E999, E225, E227

dpctl/tensor/_copy_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -261,18 +261,18 @@ def copy(usm_ary, order="K"):
261261
elif order == "F":
262262
copy_order = order
263263
elif order == "A":
264-
if usm_ary.flags & 2:
264+
if usm_ary.flags.f_contiguous:
265265
copy_order = "F"
266266
elif order == "K":
267-
if usm_ary.flags & 2:
267+
if usm_ary.flags.f_contiguous:
268268
copy_order = "F"
269269
else:
270270
raise ValueError(
271271
"Unrecognized value of the order keyword. "
272272
"Recognized values are 'A', 'C', 'F', or 'K'"
273273
)
274-
c_contig = usm_ary.flags & 1
275-
f_contig = usm_ary.flags & 2
274+
c_contig = usm_ary.flags.c_contiguous
275+
f_contig = usm_ary.flags.f_contiguous
276276
R = dpt.usm_ndarray(
277277
usm_ary.shape,
278278
dtype=usm_ary.dtype,
@@ -325,8 +325,8 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
325325
ary_dtype, newdtype, casting
326326
)
327327
)
328-
c_contig = usm_ary.flags & 1
329-
f_contig = usm_ary.flags & 2
328+
c_contig = usm_ary.flags.c_contiguous
329+
f_contig = usm_ary.flags.f_contiguous
330330
needs_copy = copy or not (ary_dtype == target_dtype)
331331
if not needs_copy and (order != "K"):
332332
needs_copy = (c_contig and order not in ["A", "C"]) or (
@@ -339,10 +339,10 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
339339
elif order == "F":
340340
copy_order = order
341341
elif order == "A":
342-
if usm_ary.flags & 2:
342+
if usm_ary.flags.f_contiguous:
343343
copy_order = "F"
344344
elif order == "K":
345-
if usm_ary.flags & 2:
345+
if usm_ary.flags.f_contiguous:
346346
copy_order = "F"
347347
else:
348348
raise ValueError(

dpctl/tensor/_ctors.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,9 @@ def _asarray_from_usm_ndarray(
133133
# sycl_queue is unchanged
134134
can_zero_copy = can_zero_copy and copy_q is usm_ndary.sycl_queue
135135
# order is unchanged
136-
c_contig = usm_ndary.flags & 1
137-
f_contig = usm_ndary.flags & 2
138-
fc_contig = usm_ndary.flags & 3
136+
c_contig = usm_ndary.flags.c_contiguous
137+
f_contig = usm_ndary.flags.f_contiguous
138+
fc_contig = usm_ndary.flags.forc
139139
if can_zero_copy:
140140
if order == "C" and c_contig:
141141
pass
@@ -1130,7 +1130,7 @@ def tril(X, k=0):
11301130
k = operator.index(k)
11311131

11321132
# F_CONTIGUOUS = 2
1133-
order = "F" if (X.flags & 2) else "C"
1133+
order = "F" if (X.flags.f_contiguous) else "C"
11341134

11351135
shape = X.shape
11361136
nd = X.ndim
@@ -1171,7 +1171,7 @@ def triu(X, k=0):
11711171
k = operator.index(k)
11721172

11731173
# F_CONTIGUOUS = 2
1174-
order = "F" if (X.flags & 2) else "C"
1174+
order = "F" if (X.flags.f_contiguous) else "C"
11751175

11761176
shape = X.shape
11771177
nd = X.ndim

dpctl/tensor/_flags.pyx

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2022 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
# distutils: language = c++
18+
# cython: language_level=3
19+
# cython: linetrace=True
20+
21+
from libcpp cimport bool as cpp_bool
22+
23+
from dpctl.tensor._usmarray cimport (
24+
USM_ARRAY_C_CONTIGUOUS,
25+
USM_ARRAY_F_CONTIGUOUS,
26+
USM_ARRAY_WRITEABLE,
27+
usm_ndarray,
28+
)
29+
30+
31+
cdef cpp_bool _check_bit(int flag, int mask):
32+
return (flag & mask) == mask
33+
34+
35+
cdef class Flags:
36+
"""Helper class to represent flags of :class:`dpctl.tensor.usm_ndarray`."""
37+
cdef int flags_
38+
cdef usm_ndarray arr_
39+
40+
def __cinit__(self, usm_ndarray arr, int flags):
41+
self.arr_ = arr
42+
self.flags_ = flags
43+
44+
@property
45+
def flags(self):
46+
return self.flags_
47+
48+
@property
49+
def c_contiguous(self):
50+
return _check_bit(self.flags_, USM_ARRAY_C_CONTIGUOUS)
51+
52+
@property
53+
def f_contiguous(self):
54+
return _check_bit(self.flags_, USM_ARRAY_F_CONTIGUOUS)
55+
56+
@property
57+
def writable(self):
58+
return _check_bit(self.flags_, USM_ARRAY_WRITEABLE)
59+
60+
@property
61+
def fc(self):
62+
return (
63+
_check_bit(self.flags_, USM_ARRAY_C_CONTIGUOUS)
64+
and _check_bit(self.flags_, USM_ARRAY_F_CONTIGUOUS)
65+
)
66+
67+
@property
68+
def forc(self):
69+
return (
70+
_check_bit(self.flags_, USM_ARRAY_C_CONTIGUOUS)
71+
or _check_bit(self.flags_, USM_ARRAY_F_CONTIGUOUS)
72+
)
73+
74+
@property
75+
def fnc(self):
76+
return (
77+
_check_bit(self.flags_, USM_ARRAY_C_CONTIGUOUS)
78+
and not _check_bit(self.flags_, USM_ARRAY_F_CONTIGUOUS)
79+
)
80+
81+
@property
82+
def contiguous(self):
83+
return self.forc
84+
85+
def __getitem__(self, name):
86+
if name in ["C_CONTIGUOUS", "C"]:
87+
return self.c_contiguous
88+
elif name in ["F_CONTIGUOUS", "F"]:
89+
return self.f_contiguous
90+
elif name == "WRITABLE":
91+
return self.writable
92+
elif name == "FC":
93+
return self.fc
94+
elif name == "CONTIGUOUS":
95+
return self.forc
96+
97+
def __repr__(self):
98+
out = []
99+
for name in "C_CONTIGUOUS", "F_CONTIGUOUS", "WRITABLE":
100+
out.append(" {} : {}".format(name, self[name]))
101+
return '\n'.join(out)
102+
103+
def __eq__(self, other):
104+
cdef Flags other_
105+
if isinstance(other, self.__class__):
106+
other_ = <Flags>other
107+
return self.flags_ == other_.flags_
108+
elif isinstance(other, int):
109+
return self.flags_ == <int>other
110+
else:
111+
return False

dpctl/tensor/_usmarray.pyx

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ from cpython.tuple cimport PyTuple_New, PyTuple_SetItem
3333
cimport dpctl as c_dpctl
3434
cimport dpctl.memory as c_dpmem
3535
cimport dpctl.tensor._dlpack as c_dlpack
36+
import dpctl.tensor._flags as _flags
3637

3738
include "_stride_utils.pxi"
3839
include "_types.pxi"
@@ -503,9 +504,9 @@ cdef class usm_ndarray:
503504
@property
504505
def flags(self):
505506
"""
506-
Currently returns integer whose bits correspond to the flags.
507+
Returns dpctl.tensor._flags object.
507508
"""
508-
return self.flags_
509+
return _flags.Flags(self, self.flags_)
509510

510511
@property
511512
def usm_type(self):
@@ -663,7 +664,7 @@ cdef class usm_ndarray:
663664
strides=self.strides,
664665
offset=self.get_offset()
665666
)
666-
res.flags_ = self.flags
667+
res.flags_ = self.flags.flags
667668
return res
668669
else:
669670
nbytes = self.usm_data.nbytes
@@ -678,7 +679,7 @@ cdef class usm_ndarray:
678679
strides=self.strides,
679680
offset=self.get_offset()
680681
)
681-
res.flags_ = self.flags
682+
res.flags_ = self.flags.flags
682683
return res
683684

684685
def _set_namespace(self, mod):

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,13 @@ def test_allocate_usm_ndarray(shape, usm_type):
5959

6060

6161
def test_usm_ndarray_flags():
62-
assert dpt.usm_ndarray((5,)).flags == 3
63-
assert dpt.usm_ndarray((5, 2)).flags == 1
64-
assert dpt.usm_ndarray((5, 2), order="F").flags == 2
65-
assert dpt.usm_ndarray((5, 1, 2), order="F").flags == 2
66-
assert dpt.usm_ndarray((5, 1, 2), strides=(2, 0, 1)).flags == 1
67-
assert dpt.usm_ndarray((5, 1, 2), strides=(1, 0, 5)).flags == 2
68-
assert dpt.usm_ndarray((5, 1, 1), strides=(1, 0, 1)).flags == 3
62+
assert dpt.usm_ndarray((5,)).flags.fc
63+
assert dpt.usm_ndarray((5, 2)).flags.c_contiguous
64+
assert dpt.usm_ndarray((5, 2), order="F").flags.f_contiguous
65+
assert dpt.usm_ndarray((5, 1, 2), order="F").flags.f_contiguous
66+
assert dpt.usm_ndarray((5, 1, 2), strides=(2, 0, 1)).flags.c_contiguous
67+
assert dpt.usm_ndarray((5, 1, 2), strides=(1, 0, 5)).flags.f_contiguous
68+
assert dpt.usm_ndarray((5, 1, 1), strides=(1, 0, 1)).flags.fc
6969

7070

7171
@pytest.mark.parametrize(
@@ -465,7 +465,7 @@ def test_pyx_capi_get_flags():
465465
fn_restype=ctypes.c_int,
466466
)
467467
flags = get_flags_fn(X)
468-
assert type(flags) is int and flags == X.flags
468+
assert type(flags) is int and X.flags == flags
469469

470470

471471
def test_pyx_capi_get_offset():
@@ -753,7 +753,7 @@ def relaxed_strides_equal(st1, st2, sh):
753753
X.shape = sh_f
754754
assert X.shape == sh_f
755755
assert relaxed_strides_equal(X.strides, cc_strides(sh_f), sh_f)
756-
assert X.flags & 1, "reshaped array expected to be C-contiguous"
756+
assert X.flags.c_contiguous, "reshaped array expected to be C-contiguous"
757757

758758
sh_s = (
759759
2,
@@ -1516,3 +1516,18 @@ def test_common_arg_validation():
15161516
dpt.triu(X)
15171517
with pytest.raises(TypeError):
15181518
dpt.meshgrid(X)
1519+
1520+
1521+
def test_flags():
1522+
x = dpt.empty(tuple(), "i4")
1523+
f = x.flags
1524+
f.__repr__()
1525+
f.c_contiguous
1526+
f.f_contiguous
1527+
f.contiguous
1528+
f.fc
1529+
f.fnc
1530+
f.forc
1531+
f.writable
1532+
# check comparison with generic types
1533+
f == Ellipsis

0 commit comments

Comments
 (0)