Skip to content

Enable support for Python operators in usm_ndarray class #1324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
c5fe314
Changed behavior of __array_namespace__
oleksandr-pavlyk Aug 2, 2023
4555bf1
Corrected text of exception message
oleksandr-pavlyk Aug 2, 2023
964db08
Corrected operator true_divide with divide
oleksandr-pavlyk Aug 2, 2023
78aa99b
Fixed test per change in dpctl implementation
oleksandr-pavlyk Aug 3, 2023
51bf9e1
Fixed _slice_len for vacuous slices
oleksandr-pavlyk Aug 3, 2023
f4a31bb
Fixed issue discovered by array API tests
oleksandr-pavlyk Aug 3, 2023
05aa952
Corrected remaining operators in _usmarray.pyx
ndgrigorian Aug 3, 2023
25547e7
Fixed bug in contiguity flag computation found by array-api-tests
oleksandr-pavlyk Aug 3, 2023
22c95b6
Fixed flags test case for changes to contiguity flag computation
ndgrigorian Aug 3, 2023
4803f13
unpacked chained method calls
oleksandr-pavlyk Aug 3, 2023
cc08b5d
Fixed Cython warning
oleksandr-pavlyk Aug 6, 2023
5c1a961
Fixed array API test failure by adding validation
oleksandr-pavlyk Aug 6, 2023
07faf2b
Use bitwise_invert for __invert__
oleksandr-pavlyk Aug 6, 2023
3ddf51c
Corrected order='K' support in astype
oleksandr-pavlyk Aug 6, 2023
440872d
Merge remote-tracking branch 'origin/master' into enable-operators
oleksandr-pavlyk Aug 6, 2023
51e3f15
Moved 2 tests from test_type_utils to elementwise/test_type_utils
oleksandr-pavlyk Aug 7, 2023
e5785ca
Fixed bug in concat uncovered by array API tests
oleksandr-pavlyk Aug 7, 2023
a3c00bc
Closes gh-1325
oleksandr-pavlyk Aug 7, 2023
80eae6e
Corrected order="K" support in copy
ndgrigorian Aug 7, 2023
ff1081a
Fixed logaddexp for mixed nan and number operands
ndgrigorian Aug 7, 2023
1b5419f
logaddexp now handles both NaNs and infinities correctly per array API
ndgrigorian Aug 7, 2023
3c87433
Broke up 'or' conditional in logaddexp logic for inf and NaN
ndgrigorian Aug 7, 2023
3c0aeed
Modularized logic implementing logaddexp
oleksandr-pavlyk Aug 7, 2023
cf7d9bf
Change to test_complex_special_cases
oleksandr-pavlyk Aug 7, 2023
8343edc
Array-API conformance testing can start as soon as build_linux jobs f…
oleksandr-pavlyk Aug 7, 2023
ebd1faf
Fixed log-add-exp per review feedback
oleksandr-pavlyk Aug 8, 2023
df8eb5f
Merge pull request #1328 from IntelPython/fix-some-array-api-test-cases
oleksandr-pavlyk Aug 8, 2023
7c94a33
Simplified flags_ computation in to_device method
oleksandr-pavlyk Aug 7, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/conda-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ jobs:
done

array-api-conformity:
needs: test_linux
needs: build_linux
runs-on: ${{ matrix.runner }}

strategy:
Expand Down
2 changes: 1 addition & 1 deletion dpctl/_sycl_queue.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ from ._sycl_event cimport SyclEvent
from .program._program cimport SyclKernel


cdef void default_async_error_handler(int) nogil except *
cdef void default_async_error_handler(int) except * nogil

cdef public api class _SyclQueue [
object Py_SyclQueueObject, type Py_SyclQueueType
Expand Down
141 changes: 105 additions & 36 deletions dpctl/tensor/_copy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import builtins
import operator

import numpy as np
Expand Down Expand Up @@ -289,6 +290,96 @@ def _copy_from_usm_ndarray_to_usm_ndarray(dst, src):
_copy_same_shape(dst, src_same_shape)


def _empty_like_orderK(X, dt, usm_type=None, dev=None):
"""Returns empty array like `x`, using order='K'

For an array `x` that was obtained by permutation of a contiguous
array the returned array will have the same shape and the same
strides as `x`.
"""
if not isinstance(X, dpt.usm_ndarray):
raise TypeError(f"Expected usm_ndarray, got {type(X)}")
if usm_type is None:
usm_type = X.usm_type
if dev is None:
dev = X.device
fl = X.flags
if fl["C"] or X.size <= 1:
return dpt.empty_like(
X, dtype=dt, usm_type=usm_type, device=dev, order="C"
)
elif fl["F"]:
return dpt.empty_like(
X, dtype=dt, usm_type=usm_type, device=dev, order="F"
)
st = list(X.strides)
perm = sorted(
range(X.ndim), key=lambda d: builtins.abs(st[d]), reverse=True
)
inv_perm = sorted(range(X.ndim), key=lambda i: perm[i])
st_sorted = [st[i] for i in perm]
sh = X.shape
sh_sorted = tuple(sh[i] for i in perm)
R = dpt.empty(sh_sorted, dtype=dt, usm_type=usm_type, device=dev, order="C")
if min(st_sorted) < 0:
sl = tuple(
slice(None, None, -1)
if st_sorted[i] < 0
else slice(None, None, None)
for i in range(X.ndim)
)
R = R[sl]
return dpt.permute_dims(R, inv_perm)


def _empty_like_pair_orderK(X1, X2, dt, res_shape, usm_type, dev):
if not isinstance(X1, dpt.usm_ndarray):
raise TypeError(f"Expected usm_ndarray, got {type(X1)}")
if not isinstance(X2, dpt.usm_ndarray):
raise TypeError(f"Expected usm_ndarray, got {type(X2)}")
nd1 = X1.ndim
nd2 = X2.ndim
if nd1 > nd2 and X1.shape == res_shape:
return _empty_like_orderK(X1, dt, usm_type, dev)
elif nd1 < nd2 and X2.shape == res_shape:
return _empty_like_orderK(X2, dt, usm_type, dev)
fl1 = X1.flags
fl2 = X2.flags
if fl1["C"] or fl2["C"]:
return dpt.empty(
res_shape, dtype=dt, usm_type=usm_type, device=dev, order="C"
)
if fl1["F"] and fl2["F"]:
return dpt.empty(
res_shape, dtype=dt, usm_type=usm_type, device=dev, order="F"
)
st1 = list(X1.strides)
st2 = list(X2.strides)
max_ndim = max(nd1, nd2)
st1 += [0] * (max_ndim - len(st1))
st2 += [0] * (max_ndim - len(st2))
perm = sorted(
range(max_ndim),
key=lambda d: (builtins.abs(st1[d]), builtins.abs(st2[d])),
reverse=True,
)
inv_perm = sorted(range(max_ndim), key=lambda i: perm[i])
st1_sorted = [st1[i] for i in perm]
st2_sorted = [st2[i] for i in perm]
sh = res_shape
sh_sorted = tuple(sh[i] for i in perm)
R = dpt.empty(sh_sorted, dtype=dt, usm_type=usm_type, device=dev, order="C")
if max(min(st1_sorted), min(st2_sorted)) < 0:
sl = tuple(
slice(None, None, -1)
if (st1_sorted[i] < 0 and st2_sorted[i] < 0)
else slice(None, None, None)
for i in range(nd1)
)
R = R[sl]
return dpt.permute_dims(R, inv_perm)


def copy(usm_ary, order="K"):
"""copy(ary, order="K")

Expand Down Expand Up @@ -334,28 +425,15 @@ def copy(usm_ary, order="K"):
"Unrecognized value of the order keyword. "
"Recognized values are 'A', 'C', 'F', or 'K'"
)
c_contig = usm_ary.flags.c_contiguous
f_contig = usm_ary.flags.f_contiguous
R = dpt.usm_ndarray(
usm_ary.shape,
dtype=usm_ary.dtype,
buffer=usm_ary.usm_type,
order=copy_order,
buffer_ctor_kwargs={"queue": usm_ary.sycl_queue},
)
if order == "K" and (not c_contig and not f_contig):
original_strides = usm_ary.strides
ind = sorted(
range(usm_ary.ndim),
key=lambda i: abs(original_strides[i]),
reverse=True,
)
new_strides = tuple(R.strides[ind[i]] for i in ind)
if order == "K":
R = _empty_like_orderK(usm_ary, usm_ary.dtype)
else:
R = dpt.usm_ndarray(
usm_ary.shape,
dtype=usm_ary.dtype,
buffer=R.usm_data,
strides=new_strides,
buffer=usm_ary.usm_type,
order=copy_order,
buffer_ctor_kwargs={"queue": usm_ary.sycl_queue},
)
_copy_same_shape(R, usm_ary)
return R
Expand Down Expand Up @@ -432,26 +510,15 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
"Unrecognized value of the order keyword. "
"Recognized values are 'A', 'C', 'F', or 'K'"
)
R = dpt.usm_ndarray(
usm_ary.shape,
dtype=target_dtype,
buffer=usm_ary.usm_type,
order=copy_order,
buffer_ctor_kwargs={"queue": usm_ary.sycl_queue},
)
if order == "K" and (not c_contig and not f_contig):
original_strides = usm_ary.strides
ind = sorted(
range(usm_ary.ndim),
key=lambda i: abs(original_strides[i]),
reverse=True,
)
new_strides = tuple(R.strides[ind[i]] for i in ind)
if order == "K":
R = _empty_like_orderK(usm_ary, target_dtype)
else:
R = dpt.usm_ndarray(
usm_ary.shape,
dtype=target_dtype,
buffer=R.usm_data,
strides=new_strides,
buffer=usm_ary.usm_type,
order=copy_order,
buffer_ctor_kwargs={"queue": usm_ary.sycl_queue},
)
_copy_from_usm_ndarray_to_usm_ndarray(R, usm_ary)
return R
Expand Down Expand Up @@ -492,6 +559,8 @@ def _extract_impl(ary, ary_mask, axis=0):
dst = dpt.empty(
dst_shape, dtype=ary.dtype, usm_type=ary.usm_type, device=ary.device
)
if dst.size == 0:
return dst
hev, _ = ti._extract(
src=ary,
cumsum=cumsum,
Expand Down
3 changes: 1 addition & 2 deletions dpctl/tensor/_elementwise_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@
from dpctl.tensor._usmarray import _is_object_with_buffer_protocol as _is_buffer
from dpctl.utils import ExecutionPlacementError

from ._copy_utils import _empty_like_orderK, _empty_like_pair_orderK
from ._type_utils import (
_acceptance_fn_default,
_empty_like_orderK,
_empty_like_pair_orderK,
_find_buf_dtype,
_find_buf_dtype2,
_find_inplace_dtype,
Expand Down
11 changes: 10 additions & 1 deletion dpctl/tensor/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import dpctl.tensor._tensor_impl as ti
import dpctl.utils as dputils

from ._type_utils import _to_device_supported_dtype

__doc__ = (
"Implementation module for array manipulation "
"functions in :module:`dpctl.tensor`"
Expand Down Expand Up @@ -504,8 +506,10 @@ def _arrays_validation(arrays, check_ndim=True):
_supported_dtype(Xi.dtype for Xi in arrays)

res_dtype = X0.dtype
dev = exec_q.sycl_device
for i in range(1, n):
res_dtype = np.promote_types(res_dtype, arrays[i])
res_dtype = _to_device_supported_dtype(res_dtype, dev)

if check_ndim:
for i in range(1, n):
Expand Down Expand Up @@ -554,8 +558,13 @@ def _concat_axis_None(arrays):
sycl_queue=exec_q,
)
else:
src_ = array
# _copy_usm_ndarray_for_reshape requires src and dst to have
# the same data type
if not array.dtype == res_dtype:
src_ = dpt.astype(src_, res_dtype)
hev, _ = ti._copy_usm_ndarray_for_reshape(
src=array,
src=src_,
dst=res[fill_start:fill_end],
shift=0,
sycl_queue=exec_q,
Expand Down
8 changes: 8 additions & 0 deletions dpctl/tensor/_slicing.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,13 @@ cdef Py_ssize_t _slice_len(
if sl_start == sl_stop:
return 0
if sl_step > 0:
if sl_start > sl_stop:
return 0
# 1 + argmax k such htat sl_start + sl_step*k < sl_stop
return 1 + ((sl_stop - sl_start - 1) // sl_step)
else:
if sl_start < sl_stop:
return 0
return 1 + ((sl_stop - sl_start + 1) // sl_step)


Expand Down Expand Up @@ -221,6 +225,9 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
k_new = k + ellipses_count
new_shape.extend(shape[k:k_new])
new_strides.extend(strides[k:k_new])
if any(dim == 0 for dim in shape[k:k_new]):
is_empty = True
new_offset = offset
k = k_new
elif ind_i is None:
new_shape.append(1)
Expand All @@ -236,6 +243,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
new_offset = new_offset + sl_start * strides[k]
if sh_i == 0:
is_empty = True
new_offset = offset
k = k_new
elif _is_boolean(ind_i):
new_shape.append(1 if ind_i else 0)
Expand Down
8 changes: 8 additions & 0 deletions dpctl/tensor/_stride_utils.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ cdef int _from_input_shape_strides(
cdef int j
cdef bint all_incr = 1
cdef bint all_decr = 1
cdef bint all_incr_modified = 0
cdef bint all_decr_modified = 0
cdef Py_ssize_t elem_count = 1
cdef Py_ssize_t min_shift = 0
cdef Py_ssize_t max_shift = 0
Expand Down Expand Up @@ -166,12 +168,14 @@ cdef int _from_input_shape_strides(
j = j + 1
if j < nd:
if all_incr:
all_incr_modified = 1
all_incr = (
(strides_arr[i] > 0) and
(strides_arr[j] > 0) and
(strides_arr[i] <= strides_arr[j])
)
if all_decr:
all_decr_modified = 1
all_decr = (
(strides_arr[i] > 0) and
(strides_arr[j] > 0) and
Expand All @@ -180,6 +184,10 @@ cdef int _from_input_shape_strides(
i = j
else:
break
# should only set contig flags on actually obtained
# values, rather than default values
all_incr = all_incr and all_incr_modified
all_decr = all_decr and all_decr_modified
if all_incr and all_decr:
contig[0] = (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS)
elif all_incr:
Expand Down
Loading