Skip to content

Commit 43f3b7b

Browse files
authored
Merge pull request #1239 from vlad-perevezentsev/comparison_funcs
Implementation of less_equal, greater, greater_equal for dpctl.tensor
2 parents 0185fa9 + 20cbd36 commit 43f3b7b

File tree

9 files changed

+2039
-9
lines changed

9 files changed

+2039
-9
lines changed

dpctl/tensor/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,14 @@
101101
exp,
102102
expm1,
103103
floor_divide,
104+
greater,
105+
greater_equal,
104106
imag,
105107
isfinite,
106108
isinf,
107109
isnan,
108110
less,
111+
less_equal,
109112
log,
110113
log1p,
111114
multiply,
@@ -199,11 +202,14 @@
199202
"cos",
200203
"exp",
201204
"expm1",
205+
"greater",
206+
"greater_equal",
202207
"imag",
203208
"isinf",
204209
"isnan",
205210
"isfinite",
206211
"less",
212+
"less_equal",
207213
"log",
208214
"log1p",
209215
"proj",

dpctl/tensor/_elementwise_funcs.py

Lines changed: 83 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -297,10 +297,62 @@
297297
)
298298

299299
# B11: ==== GREATER (x1, x2)
300-
# FIXME: implement B11
300+
_greater_docstring_ = """
301+
greater(x1, x2, out=None, order='K')
302+
Computes the greater-than test results for each element `x1_i` of
303+
the input array `x1` the respective element `x2_i` of the input array `x2`.
304+
Args:
305+
x1 (usm_ndarray):
306+
First input array, expected to have numeric data type.
307+
x2 (usm_ndarray):
308+
Second input array, also expected to have numeric data type.
309+
out ({None, usm_ndarray}, optional):
310+
Output array to populate.
311+
Array have the correct shape and the expected data type.
312+
order ("C","F","A","K", optional):
313+
Memory layout of the newly output array, if parameter `out` is `None`.
314+
Default: "K".
315+
Returns:
316+
usm_narray:
317+
An array containing the result of element-wise greater-than comparison.
318+
The data type of the returned array is determined by the
319+
Type Promotion Rules.
320+
"""
321+
322+
greater = BinaryElementwiseFunc(
323+
"greater", ti._greater_result_type, ti._greater, _greater_docstring_
324+
)
301325

302326
# B12: ==== GREATER_EQUAL (x1, x2)
303-
# FIXME: implement B12
327+
_greater_equal_docstring_ = """
328+
greater_equal(x1, x2, out=None, order='K')
329+
Computes the greater-than or equal-to test results for each element `x1_i` of
330+
the input array `x1` the respective element `x2_i` of the input array `x2`.
331+
Args:
332+
x1 (usm_ndarray):
333+
First input array, expected to have numeric data type.
334+
x2 (usm_ndarray):
335+
Second input array, also expected to have numeric data type.
336+
out ({None, usm_ndarray}, optional):
337+
Output array to populate.
338+
Array have the correct shape and the expected data type.
339+
order ("C","F","A","K", optional):
340+
Memory layout of the newly output array, if parameter `out` is `None`.
341+
Default: "K".
342+
Returns:
343+
usm_narray:
344+
An array containing the result of element-wise greater-than or equal-to
345+
comparison.
346+
The data type of the returned array is determined by the
347+
Type Promotion Rules.
348+
"""
349+
350+
greater_equal = BinaryElementwiseFunc(
351+
"greater_equal",
352+
ti._greater_equal_result_type,
353+
ti._greater_equal,
354+
_greater_equal_docstring_,
355+
)
304356

305357
# U16: ==== IMAG (x)
306358
_imag_docstring = """
@@ -434,7 +486,35 @@
434486
)
435487

436488
# B14: ==== LESS_EQUAL (x1, x2)
437-
# FIXME: implement B14
489+
_less_equal_docstring_ = """
490+
less_equal(x1, x2, out=None, order='K')
491+
Computes the less-than or equal-to test results for each element `x1_i` of
492+
the input array `x1` the respective element `x2_i` of the input array `x2`.
493+
Args:
494+
x1 (usm_ndarray):
495+
First input array, expected to have numeric data type.
496+
x2 (usm_ndarray):
497+
Second input array, also expected to have numeric data type.
498+
out ({None, usm_ndarray}, optional):
499+
Output array to populate.
500+
Array have the correct shape and the expected data type.
501+
order ("C","F","A","K", optional):
502+
Memory layout of the newly output array, if parameter `out` is `None`.
503+
Default: "K".
504+
Returns:
505+
usm_narray:
506+
An array containing the result of element-wise less-than or equal-to
507+
comparison.
508+
The data type of the returned array is determined by the
509+
Type Promotion Rules.
510+
"""
511+
512+
less_equal = BinaryElementwiseFunc(
513+
"less_equal",
514+
ti._less_equal_result_type,
515+
ti._less_equal,
516+
_less_equal_docstring_,
517+
)
438518

439519
# U20: ==== LOG (x)
440520
_log_docstring = """

0 commit comments

Comments
 (0)