Skip to content

Commit 1ecd8a8

Browse files
authored
Merge pull request #1732 from IntelPython/diff-count-nonzero-array-api
Implements `dpctl.tensor.count_nonzero` and `dpctl.tensor.diff`
2 parents f98fe15 + 6520cc0 commit 1ecd8a8

File tree

10 files changed

+835
-114
lines changed

10 files changed

+835
-114
lines changed

docs/doc_sources/api_reference/dpctl/tensor.elementwise_functions.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ function values computed for every element of input array(s).
6464
minimum
6565
multiply
6666
negative
67+
nextafter
6768
not_equal
6869
positive
6970
pow

docs/doc_sources/api_reference/dpctl/tensor.searching_functions.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Searching functions
1010

1111
argmax
1212
argmin
13+
count_nonzero
1314
nonzero
1415
searchsorted
1516
where

docs/doc_sources/api_reference/dpctl/tensor.utility_functions.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Utility functions
1111
all
1212
any
1313
allclose
14+
diff
1415

1516
Device object
1617
-------------

dpctl/tensor/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@
9494
from dpctl.tensor._search_functions import where
9595
from dpctl.tensor._statistical_functions import mean, std, var
9696
from dpctl.tensor._usmarray import usm_ndarray
97-
from dpctl.tensor._utility_functions import all, any
97+
from dpctl.tensor._utility_functions import all, any, diff
9898

9999
from ._accumulation import cumulative_logsumexp, cumulative_prod, cumulative_sum
100100
from ._array_api import __array_api_version__, __array_namespace_info__
@@ -176,6 +176,7 @@
176176
from ._reduction import (
177177
argmax,
178178
argmin,
179+
count_nonzero,
179180
logsumexp,
180181
max,
181182
min,
@@ -373,4 +374,6 @@
373374
"cumulative_prod",
374375
"cumulative_sum",
375376
"nextafter",
377+
"diff",
378+
"count_nonzero",
376379
]

dpctl/tensor/_clip.py

Lines changed: 3 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -30,124 +30,15 @@
3030
_validate_dtype,
3131
)
3232
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
33-
from dpctl.tensor._type_utils import _can_cast, _to_device_supported_dtype
33+
from dpctl.tensor._type_utils import _can_cast
3434
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager
3535

3636
from ._type_utils import (
37-
WeakComplexType,
38-
WeakIntegralType,
39-
_is_weak_dtype,
40-
_strong_dtype_num_kind,
41-
_weak_type_num_kind,
37+
_resolve_one_strong_one_weak_types,
38+
_resolve_one_strong_two_weak_types,
4239
)
4340

4441

45-
def _resolve_one_strong_two_weak_types(st_dtype, dtype1, dtype2, dev):
46-
"Resolves weak data types per NEP-0050,"
47-
"where the second and third arguments are"
48-
"permitted to be weak types"
49-
if _is_weak_dtype(st_dtype):
50-
raise ValueError
51-
if _is_weak_dtype(dtype1):
52-
if _is_weak_dtype(dtype2):
53-
kind_num1 = _weak_type_num_kind(dtype1)
54-
kind_num2 = _weak_type_num_kind(dtype2)
55-
st_kind_num = _strong_dtype_num_kind(st_dtype)
56-
57-
if kind_num1 > st_kind_num:
58-
if isinstance(dtype1, WeakIntegralType):
59-
ret_dtype1 = dpt.dtype(ti.default_device_int_type(dev))
60-
elif isinstance(dtype1, WeakComplexType):
61-
if st_dtype is dpt.float16 or st_dtype is dpt.float32:
62-
ret_dtype1 = dpt.complex64
63-
ret_dtype1 = _to_device_supported_dtype(dpt.complex128, dev)
64-
else:
65-
ret_dtype1 = _to_device_supported_dtype(dpt.float64, dev)
66-
else:
67-
ret_dtype1 = st_dtype
68-
69-
if kind_num2 > st_kind_num:
70-
if isinstance(dtype2, WeakIntegralType):
71-
ret_dtype2 = dpt.dtype(ti.default_device_int_type(dev))
72-
elif isinstance(dtype2, WeakComplexType):
73-
if st_dtype is dpt.float16 or st_dtype is dpt.float32:
74-
ret_dtype2 = dpt.complex64
75-
ret_dtype2 = _to_device_supported_dtype(dpt.complex128, dev)
76-
else:
77-
ret_dtype2 = _to_device_supported_dtype(dpt.float64, dev)
78-
else:
79-
ret_dtype2 = st_dtype
80-
81-
return ret_dtype1, ret_dtype2
82-
83-
max_dt_num_kind, max_dtype = max(
84-
[
85-
(_strong_dtype_num_kind(st_dtype), st_dtype),
86-
(_strong_dtype_num_kind(dtype2), dtype2),
87-
]
88-
)
89-
dt1_kind_num = _weak_type_num_kind(dtype1)
90-
if dt1_kind_num > max_dt_num_kind:
91-
if isinstance(dtype1, WeakIntegralType):
92-
return dpt.dtype(ti.default_device_int_type(dev)), dtype2
93-
if isinstance(dtype1, WeakComplexType):
94-
if max_dtype is dpt.float16 or max_dtype is dpt.float32:
95-
return dpt.complex64, dtype2
96-
return (
97-
_to_device_supported_dtype(dpt.complex128, dev),
98-
dtype2,
99-
)
100-
return _to_device_supported_dtype(dpt.float64, dev), dtype2
101-
else:
102-
return max_dtype, dtype2
103-
elif _is_weak_dtype(dtype2):
104-
max_dt_num_kind, max_dtype = max(
105-
[
106-
(_strong_dtype_num_kind(st_dtype), st_dtype),
107-
(_strong_dtype_num_kind(dtype1), dtype1),
108-
]
109-
)
110-
dt2_kind_num = _weak_type_num_kind(dtype2)
111-
if dt2_kind_num > max_dt_num_kind:
112-
if isinstance(dtype2, WeakIntegralType):
113-
return dtype1, dpt.dtype(ti.default_device_int_type(dev))
114-
if isinstance(dtype2, WeakComplexType):
115-
if max_dtype is dpt.float16 or max_dtype is dpt.float32:
116-
return dtype1, dpt.complex64
117-
return (
118-
dtype1,
119-
_to_device_supported_dtype(dpt.complex128, dev),
120-
)
121-
return dtype1, _to_device_supported_dtype(dpt.float64, dev)
122-
else:
123-
return dtype1, max_dtype
124-
else:
125-
# both are strong dtypes
126-
# return unmodified
127-
return dtype1, dtype2
128-
129-
130-
def _resolve_one_strong_one_weak_types(st_dtype, dtype, dev):
131-
"Resolves one weak data type with one strong data type per NEP-0050"
132-
if _is_weak_dtype(st_dtype):
133-
raise ValueError
134-
if _is_weak_dtype(dtype):
135-
st_kind_num = _strong_dtype_num_kind(st_dtype)
136-
kind_num = _weak_type_num_kind(dtype)
137-
if kind_num > st_kind_num:
138-
if isinstance(dtype, WeakIntegralType):
139-
return dpt.dtype(ti.default_device_int_type(dev))
140-
if isinstance(dtype, WeakComplexType):
141-
if st_dtype is dpt.float16 or st_dtype is dpt.float32:
142-
return dpt.complex64
143-
return _to_device_supported_dtype(dpt.complex128, dev)
144-
return _to_device_supported_dtype(dpt.float64, dev)
145-
else:
146-
return st_dtype
147-
else:
148-
return dtype
149-
150-
15142
def _check_clip_dtypes(res_dtype, arg1_dtype, arg2_dtype, sycl_dev):
15243
"Checks if both types `arg1_dtype` and `arg2_dtype` can be"
15344
"cast to `res_dtype` according to the rule `safe`"

dpctl/tensor/_reduction.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -773,3 +773,46 @@ def argmin(x, /, *, axis=None, keepdims=False, out=None):
773773
default array index data type for the device of ``x``.
774774
"""
775775
return _search_over_axis(x, axis, keepdims, out, tri._argmin_over_axis)
776+
777+
778+
def count_nonzero(x, /, *, axis=None, keepdims=False, out=None):
779+
"""
780+
Counts the number of elements in the input array ``x`` which are non-zero.
781+
782+
Args:
783+
x (usm_ndarray):
784+
input array.
785+
axis (Optional[int, Tuple[int, ...]]):
786+
axis or axes along which to count. If a tuple of unique integers,
787+
the number of non-zero values are computed over multiple axes.
788+
If ``None``, the number of non-zero values is computed over the
789+
entire array.
790+
Default: ``None``.
791+
keepdims (Optional[bool]):
792+
if ``True``, the reduced axes (dimensions) are included in the
793+
result as singleton dimensions, so that the returned array remains
794+
compatible with the input arrays according to Array Broadcasting
795+
rules. Otherwise, if ``False``, the reduced axes are not included
796+
in the returned array. Default: ``False``.
797+
out (Optional[usm_ndarray]):
798+
the array into which the result is written.
799+
The data type of ``out`` must match the expected shape and data
800+
type.
801+
If ``None`` then a new array is returned. Default: ``None``.
802+
803+
Returns:
804+
usm_ndarray:
805+
an array containing the count of non-zero values. If the sum was
806+
computed over the entire array, a zero-dimensional array is
807+
returned. The returned array will have the default array index data
808+
type.
809+
"""
810+
if x.dtype != dpt.bool:
811+
x = dpt.astype(x, dpt.bool, copy=False)
812+
return sum(
813+
x,
814+
axis=axis,
815+
dtype=ti.default_device_index_type(x.sycl_device),
816+
keepdims=keepdims,
817+
out=out,
818+
)

dpctl/tensor/_type_utils.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,112 @@ def _resolve_weak_types_all_py_ints(o1_dtype, o2_dtype, dev):
450450
return o1_dtype, o2_dtype
451451

452452

453+
def _resolve_one_strong_two_weak_types(st_dtype, dtype1, dtype2, dev):
454+
"Resolves weak data types per NEP-0050,"
455+
"where the second and third arguments are"
456+
"permitted to be weak types"
457+
if _is_weak_dtype(st_dtype):
458+
raise ValueError
459+
if _is_weak_dtype(dtype1):
460+
if _is_weak_dtype(dtype2):
461+
kind_num1 = _weak_type_num_kind(dtype1)
462+
kind_num2 = _weak_type_num_kind(dtype2)
463+
st_kind_num = _strong_dtype_num_kind(st_dtype)
464+
465+
if kind_num1 > st_kind_num:
466+
if isinstance(dtype1, WeakIntegralType):
467+
ret_dtype1 = dpt.dtype(ti.default_device_int_type(dev))
468+
elif isinstance(dtype1, WeakComplexType):
469+
if st_dtype is dpt.float16 or st_dtype is dpt.float32:
470+
ret_dtype1 = dpt.complex64
471+
ret_dtype1 = _to_device_supported_dtype(dpt.complex128, dev)
472+
else:
473+
ret_dtype1 = _to_device_supported_dtype(dpt.float64, dev)
474+
else:
475+
ret_dtype1 = st_dtype
476+
477+
if kind_num2 > st_kind_num:
478+
if isinstance(dtype2, WeakIntegralType):
479+
ret_dtype2 = dpt.dtype(ti.default_device_int_type(dev))
480+
elif isinstance(dtype2, WeakComplexType):
481+
if st_dtype is dpt.float16 or st_dtype is dpt.float32:
482+
ret_dtype2 = dpt.complex64
483+
ret_dtype2 = _to_device_supported_dtype(dpt.complex128, dev)
484+
else:
485+
ret_dtype2 = _to_device_supported_dtype(dpt.float64, dev)
486+
else:
487+
ret_dtype2 = st_dtype
488+
489+
return ret_dtype1, ret_dtype2
490+
491+
max_dt_num_kind, max_dtype = max(
492+
[
493+
(_strong_dtype_num_kind(st_dtype), st_dtype),
494+
(_strong_dtype_num_kind(dtype2), dtype2),
495+
]
496+
)
497+
dt1_kind_num = _weak_type_num_kind(dtype1)
498+
if dt1_kind_num > max_dt_num_kind:
499+
if isinstance(dtype1, WeakIntegralType):
500+
return dpt.dtype(ti.default_device_int_type(dev)), dtype2
501+
if isinstance(dtype1, WeakComplexType):
502+
if max_dtype is dpt.float16 or max_dtype is dpt.float32:
503+
return dpt.complex64, dtype2
504+
return (
505+
_to_device_supported_dtype(dpt.complex128, dev),
506+
dtype2,
507+
)
508+
return _to_device_supported_dtype(dpt.float64, dev), dtype2
509+
else:
510+
return max_dtype, dtype2
511+
elif _is_weak_dtype(dtype2):
512+
max_dt_num_kind, max_dtype = max(
513+
[
514+
(_strong_dtype_num_kind(st_dtype), st_dtype),
515+
(_strong_dtype_num_kind(dtype1), dtype1),
516+
]
517+
)
518+
dt2_kind_num = _weak_type_num_kind(dtype2)
519+
if dt2_kind_num > max_dt_num_kind:
520+
if isinstance(dtype2, WeakIntegralType):
521+
return dtype1, dpt.dtype(ti.default_device_int_type(dev))
522+
if isinstance(dtype2, WeakComplexType):
523+
if max_dtype is dpt.float16 or max_dtype is dpt.float32:
524+
return dtype1, dpt.complex64
525+
return (
526+
dtype1,
527+
_to_device_supported_dtype(dpt.complex128, dev),
528+
)
529+
return dtype1, _to_device_supported_dtype(dpt.float64, dev)
530+
else:
531+
return dtype1, max_dtype
532+
else:
533+
# both are strong dtypes
534+
# return unmodified
535+
return dtype1, dtype2
536+
537+
538+
def _resolve_one_strong_one_weak_types(st_dtype, dtype, dev):
539+
"Resolves one weak data type with one strong data type per NEP-0050"
540+
if _is_weak_dtype(st_dtype):
541+
raise ValueError
542+
if _is_weak_dtype(dtype):
543+
st_kind_num = _strong_dtype_num_kind(st_dtype)
544+
kind_num = _weak_type_num_kind(dtype)
545+
if kind_num > st_kind_num:
546+
if isinstance(dtype, WeakIntegralType):
547+
return dpt.dtype(ti.default_device_int_type(dev))
548+
if isinstance(dtype, WeakComplexType):
549+
if st_dtype is dpt.float16 or st_dtype is dpt.float32:
550+
return dpt.complex64
551+
return _to_device_supported_dtype(dpt.complex128, dev)
552+
return _to_device_supported_dtype(dpt.float64, dev)
553+
else:
554+
return st_dtype
555+
else:
556+
return dtype
557+
558+
453559
class finfo_object:
454560
"""
455561
`numpy.finfo` subclass which returns Python floating-point scalars for
@@ -838,6 +944,8 @@ def _default_accumulation_dtype_fp_types(inp_dt, q):
838944
"_acceptance_fn_divide",
839945
"_acceptance_fn_negative",
840946
"_acceptance_fn_subtract",
947+
"_resolve_one_strong_one_weak_types",
948+
"_resolve_one_strong_two_weak_types",
841949
"_resolve_weak_types",
842950
"_resolve_weak_types_all_py_ints",
843951
"_weak_type_num_kind",

0 commit comments

Comments
 (0)