Skip to content

Implements dpctl.tensor._flags #921

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ per-file-ignores =
dpctl/program/_program.pyx: E999, E225, E226, E227
dpctl/tensor/_usmarray.pyx: E999, E225, E226, E227
dpctl/tensor/_dlpack.pyx: E999, E225, E226, E227
dpctl/tensor/_flags.pyx: E999, E225, E226, E227
dpctl/tensor/numpy_usm_shared.py: F821
dpctl/tests/_cython_api.pyx: E999, E225, E227, E402
dpctl/utils/_compute_follows_data.pyx: E999, E225, E227
Expand Down
16 changes: 8 additions & 8 deletions dpctl/tensor/_copy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,18 +261,18 @@ def copy(usm_ary, order="K"):
elif order == "F":
copy_order = order
elif order == "A":
if usm_ary.flags & 2:
if usm_ary.flags.f_contiguous:
copy_order = "F"
elif order == "K":
if usm_ary.flags & 2:
if usm_ary.flags.f_contiguous:
copy_order = "F"
else:
raise ValueError(
"Unrecognized value of the order keyword. "
"Recognized values are 'A', 'C', 'F', or 'K'"
)
c_contig = usm_ary.flags & 1
f_contig = usm_ary.flags & 2
c_contig = usm_ary.flags.c_contiguous
f_contig = usm_ary.flags.f_contiguous
R = dpt.usm_ndarray(
usm_ary.shape,
dtype=usm_ary.dtype,
Expand Down Expand Up @@ -325,8 +325,8 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
ary_dtype, newdtype, casting
)
)
c_contig = usm_ary.flags & 1
f_contig = usm_ary.flags & 2
c_contig = usm_ary.flags.c_contiguous
f_contig = usm_ary.flags.f_contiguous
needs_copy = copy or not (ary_dtype == target_dtype)
if not needs_copy and (order != "K"):
needs_copy = (c_contig and order not in ["A", "C"]) or (
Expand All @@ -339,10 +339,10 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
elif order == "F":
copy_order = order
elif order == "A":
if usm_ary.flags & 2:
if usm_ary.flags.f_contiguous:
copy_order = "F"
elif order == "K":
if usm_ary.flags & 2:
if usm_ary.flags.f_contiguous:
copy_order = "F"
else:
raise ValueError(
Expand Down
10 changes: 5 additions & 5 deletions dpctl/tensor/_ctors.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,9 @@ def _asarray_from_usm_ndarray(
# sycl_queue is unchanged
can_zero_copy = can_zero_copy and copy_q is usm_ndary.sycl_queue
# order is unchanged
c_contig = usm_ndary.flags & 1
f_contig = usm_ndary.flags & 2
fc_contig = usm_ndary.flags & 3
c_contig = usm_ndary.flags.c_contiguous
f_contig = usm_ndary.flags.f_contiguous
fc_contig = usm_ndary.flags.forc
if can_zero_copy:
if order == "C" and c_contig:
pass
Expand Down Expand Up @@ -1130,7 +1130,7 @@ def tril(X, k=0):
k = operator.index(k)

# F_CONTIGUOUS = 2
order = "F" if (X.flags & 2) else "C"
order = "F" if (X.flags.f_contiguous) else "C"

shape = X.shape
nd = X.ndim
Expand Down Expand Up @@ -1171,7 +1171,7 @@ def triu(X, k=0):
k = operator.index(k)

# F_CONTIGUOUS = 2
order = "F" if (X.flags & 2) else "C"
order = "F" if (X.flags.f_contiguous) else "C"

shape = X.shape
nd = X.ndim
Expand Down
111 changes: 111 additions & 0 deletions dpctl/tensor/_flags.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Data Parallel Control (dpctl)
#
# Copyright 2020-2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# distutils: language = c++
# cython: language_level=3
# cython: linetrace=True

from libcpp cimport bool as cpp_bool

from dpctl.tensor._usmarray cimport (
USM_ARRAY_C_CONTIGUOUS,
USM_ARRAY_F_CONTIGUOUS,
USM_ARRAY_WRITEABLE,
usm_ndarray,
)


cdef cpp_bool _check_bit(int flag, int mask):
return (flag & mask) == mask


cdef class Flags:
"""Helper class to represent flags of :class:`dpctl.tensor.usm_ndarray`."""
cdef int flags_
cdef usm_ndarray arr_

def __cinit__(self, usm_ndarray arr, int flags):
self.arr_ = arr
self.flags_ = flags

@property
def flags(self):
return self.flags_

@property
def c_contiguous(self):
return _check_bit(self.flags_, USM_ARRAY_C_CONTIGUOUS)

@property
def f_contiguous(self):
return _check_bit(self.flags_, USM_ARRAY_F_CONTIGUOUS)

@property
def writable(self):
return _check_bit(self.flags_, USM_ARRAY_WRITEABLE)

@property
def fc(self):
return (
_check_bit(self.flags_, USM_ARRAY_C_CONTIGUOUS)
and _check_bit(self.flags_, USM_ARRAY_F_CONTIGUOUS)
)

@property
def forc(self):
return (
_check_bit(self.flags_, USM_ARRAY_C_CONTIGUOUS)
or _check_bit(self.flags_, USM_ARRAY_F_CONTIGUOUS)
)

@property
def fnc(self):
return (
_check_bit(self.flags_, USM_ARRAY_C_CONTIGUOUS)
and not _check_bit(self.flags_, USM_ARRAY_F_CONTIGUOUS)
)

@property
def contiguous(self):
return self.forc

def __getitem__(self, name):
if name in ["C_CONTIGUOUS", "C"]:
return self.c_contiguous
elif name in ["F_CONTIGUOUS", "F"]:
return self.f_contiguous
elif name == "WRITABLE":
return self.writable
elif name == "FC":
return self.fc
elif name == "CONTIGUOUS":
return self.forc

def __repr__(self):
out = []
for name in "C_CONTIGUOUS", "F_CONTIGUOUS", "WRITABLE":
out.append(" {} : {}".format(name, self[name]))
return '\n'.join(out)

def __eq__(self, other):
cdef Flags other_
if isinstance(other, self.__class__):
other_ = <Flags>other
return self.flags_ == other_.flags_
elif isinstance(other, int):
return self.flags_ == <int>other
else:
return False
9 changes: 5 additions & 4 deletions dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ from cpython.tuple cimport PyTuple_New, PyTuple_SetItem
cimport dpctl as c_dpctl
cimport dpctl.memory as c_dpmem
cimport dpctl.tensor._dlpack as c_dlpack
import dpctl.tensor._flags as _flags

include "_stride_utils.pxi"
include "_types.pxi"
Expand Down Expand Up @@ -503,9 +504,9 @@ cdef class usm_ndarray:
@property
def flags(self):
"""
Currently returns integer whose bits correspond to the flags.
Returns dpctl.tensor._flags object.
"""
return self.flags_
return _flags.Flags(self, self.flags_)

@property
def usm_type(self):
Expand Down Expand Up @@ -663,7 +664,7 @@ cdef class usm_ndarray:
strides=self.strides,
offset=self.get_offset()
)
res.flags_ = self.flags
res.flags_ = self.flags.flags
return res
else:
nbytes = self.usm_data.nbytes
Expand All @@ -678,7 +679,7 @@ cdef class usm_ndarray:
strides=self.strides,
offset=self.get_offset()
)
res.flags_ = self.flags
res.flags_ = self.flags.flags
return res

def _set_namespace(self, mod):
Expand Down
33 changes: 24 additions & 9 deletions dpctl/tests/test_usm_ndarray_ctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,13 @@ def test_allocate_usm_ndarray(shape, usm_type):


def test_usm_ndarray_flags():
assert dpt.usm_ndarray((5,)).flags == 3
assert dpt.usm_ndarray((5, 2)).flags == 1
assert dpt.usm_ndarray((5, 2), order="F").flags == 2
assert dpt.usm_ndarray((5, 1, 2), order="F").flags == 2
assert dpt.usm_ndarray((5, 1, 2), strides=(2, 0, 1)).flags == 1
assert dpt.usm_ndarray((5, 1, 2), strides=(1, 0, 5)).flags == 2
assert dpt.usm_ndarray((5, 1, 1), strides=(1, 0, 1)).flags == 3
assert dpt.usm_ndarray((5,)).flags.fc
assert dpt.usm_ndarray((5, 2)).flags.c_contiguous
assert dpt.usm_ndarray((5, 2), order="F").flags.f_contiguous
assert dpt.usm_ndarray((5, 1, 2), order="F").flags.f_contiguous
assert dpt.usm_ndarray((5, 1, 2), strides=(2, 0, 1)).flags.c_contiguous
assert dpt.usm_ndarray((5, 1, 2), strides=(1, 0, 5)).flags.f_contiguous
assert dpt.usm_ndarray((5, 1, 1), strides=(1, 0, 1)).flags.fc


@pytest.mark.parametrize(
Expand Down Expand Up @@ -465,7 +465,7 @@ def test_pyx_capi_get_flags():
fn_restype=ctypes.c_int,
)
flags = get_flags_fn(X)
assert type(flags) is int and flags == X.flags
assert type(flags) is int and X.flags == flags


def test_pyx_capi_get_offset():
Expand Down Expand Up @@ -753,7 +753,7 @@ def relaxed_strides_equal(st1, st2, sh):
X.shape = sh_f
assert X.shape == sh_f
assert relaxed_strides_equal(X.strides, cc_strides(sh_f), sh_f)
assert X.flags & 1, "reshaped array expected to be C-contiguous"
assert X.flags.c_contiguous, "reshaped array expected to be C-contiguous"

sh_s = (
2,
Expand Down Expand Up @@ -1516,3 +1516,18 @@ def test_common_arg_validation():
dpt.triu(X)
with pytest.raises(TypeError):
dpt.meshgrid(X)


def test_flags():
x = dpt.empty(tuple(), "i4")
f = x.flags
f.__repr__()
f.c_contiguous
f.f_contiguous
f.contiguous
f.fc
f.fnc
f.forc
f.writable
# check comparison with generic types
f == Ellipsis