Skip to content

Commit 09a387a

Browse files
authored
Add support of comparison operations (#1278)
1 parent 67e7f87 commit 09a387a

File tree

12 files changed

+259
-153
lines changed

12 files changed

+259
-153
lines changed

dpnp/backend/include/dpnp_gen_2arg_2type_tbl.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@
8686

8787
#endif
8888

89+
MACRO_2ARG_2TYPES_LOGIC_OP(dpnp_equal_c, input1_elem == input2_elem)
90+
MACRO_2ARG_2TYPES_LOGIC_OP(dpnp_greater_c, input1_elem > input2_elem)
91+
MACRO_2ARG_2TYPES_LOGIC_OP(dpnp_greater_equal_c, input1_elem >= input2_elem)
92+
MACRO_2ARG_2TYPES_LOGIC_OP(dpnp_less_c, input1_elem < input2_elem)
8993
MACRO_2ARG_2TYPES_LOGIC_OP(dpnp_less_equal_c, input1_elem <= input2_elem)
94+
MACRO_2ARG_2TYPES_LOGIC_OP(dpnp_not_equal_c, input1_elem != input2_elem)
9095

9196
#undef MACRO_2ARG_2TYPES_LOGIC_OP

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ enum class DPNPFuncName : size_t
151151
DPNP_FN_EIG_EXT, /**< Used in numpy.linalg.eig() impl, requires extra parameters */
152152
DPNP_FN_EIGVALS, /**< Used in numpy.linalg.eigvals() impl */
153153
DPNP_FN_EIGVALS_EXT, /**< Used in numpy.linalg.eigvals() impl, requires extra parameters */
154+
DPNP_FN_EQUAL_EXT, /**< Used in numpy.equal() impl, requires extra parameters */
154155
DPNP_FN_ERF, /**< Used in scipy.special.erf impl */
155156
DPNP_FN_ERF_EXT, /**< Used in scipy.special.erf impl, requires extra parameters */
156157
DPNP_FN_EYE, /**< Used in numpy.eye() impl */
@@ -179,6 +180,8 @@ enum class DPNPFuncName : size_t
179180
DPNP_FN_FMOD_EXT, /**< Used in numpy.fmod() impl, requires extra parameters */
180181
DPNP_FN_FULL, /**< Used in numpy.full() impl */
181182
DPNP_FN_FULL_LIKE, /**< Used in numpy.full_like() impl */
183+
DPNP_FN_GREATER_EXT, /**< Used in numpy.greater() impl, requires extra parameters */
184+
DPNP_FN_GREATER_EQUAL_EXT, /**< Used in numpy.greater_equal() impl, requires extra parameters */
182185
DPNP_FN_HYPOT, /**< Used in numpy.hypot() impl */
183186
DPNP_FN_HYPOT_EXT, /**< Used in numpy.hypot() impl, requires extra parameters */
184187
DPNP_FN_IDENTITY, /**< Used in numpy.identity() impl */
@@ -193,6 +196,7 @@ enum class DPNPFuncName : size_t
193196
DPNP_FN_KRON_EXT, /**< Used in numpy.kron() impl, requires extra parameters */
194197
DPNP_FN_LEFT_SHIFT, /**< Used in numpy.left_shift() impl */
195198
DPNP_FN_LEFT_SHIFT_EXT, /**< Used in numpy.left_shift() impl, requires extra parameters */
199+
DPNP_FN_LESS_EXT, /**< Used in numpy.less() impl, requires extra parameters */
196200
DPNP_FN_LESS_EQUAL_EXT, /**< Used in numpy.less_equal() impl, requires extra parameters */
197201
DPNP_FN_LOG, /**< Used in numpy.log() impl */
198202
DPNP_FN_LOG_EXT, /**< Used in numpy.log() impl, requires extra parameters */
@@ -228,6 +232,7 @@ enum class DPNPFuncName : size_t
228232
DPNP_FN_NEGATIVE_EXT, /**< Used in numpy.negative() impl, requires extra parameters */
229233
DPNP_FN_NONZERO, /**< Used in numpy.nonzero() impl */
230234
DPNP_FN_NONZERO_EXT, /**< Used in numpy.nonzero() impl, requires extra parameters */
235+
DPNP_FN_NOT_EQUAL_EXT, /**< Used in numpy.not_equal() impl, requires extra parameters */
231236
DPNP_FN_ONES, /**< Used in numpy.ones() impl */
232237
DPNP_FN_ONES_LIKE, /**< Used in numpy.ones_like() impl */
233238
DPNP_FN_PARTITION, /**< Used in numpy.partition() impl */

dpnp/backend/kernels/dpnp_krnl_logic.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,8 +536,18 @@ DPCTLSyclEventRef (*dpnp_any_ext_c)(DPCTLSyclQueueRef,
536536
template <DPNPFuncType FT1, DPNPFuncType ... FTs>
537537
static void func_map_logic_2arg_2type_core(func_map_t& fmap)
538538
{
539+
((fmap[DPNPFuncName::DPNP_FN_EQUAL_EXT][FT1][FTs] =
540+
{eft_BLN, (void*)dpnp_equal_c_ext<func_type_map_t::find_type<FT1>, func_type_map_t::find_type<FTs>>}), ...);
541+
((fmap[DPNPFuncName::DPNP_FN_GREATER_EXT][FT1][FTs] =
542+
{eft_BLN, (void*)dpnp_greater_c_ext<func_type_map_t::find_type<FT1>, func_type_map_t::find_type<FTs>>}), ...);
543+
((fmap[DPNPFuncName::DPNP_FN_GREATER_EQUAL_EXT][FT1][FTs] =
544+
{eft_BLN, (void*)dpnp_greater_equal_c_ext<func_type_map_t::find_type<FT1>, func_type_map_t::find_type<FTs>>}), ...);
545+
((fmap[DPNPFuncName::DPNP_FN_LESS_EXT][FT1][FTs] =
546+
{eft_BLN, (void*)dpnp_less_c_ext<func_type_map_t::find_type<FT1>, func_type_map_t::find_type<FTs>>}), ...);
539547
((fmap[DPNPFuncName::DPNP_FN_LESS_EQUAL_EXT][FT1][FTs] =
540548
{eft_BLN, (void*)dpnp_less_equal_c_ext<func_type_map_t::find_type<FT1>, func_type_map_t::find_type<FTs>>}), ...);
549+
((fmap[DPNPFuncName::DPNP_FN_NOT_EQUAL_EXT][FT1][FTs] =
550+
{eft_BLN, (void*)dpnp_not_equal_c_ext<func_type_map_t::find_type<FT1>, func_type_map_t::find_type<FTs>>}), ...);
541551
}
542552

543553
template <DPNPFuncType ... FTs>

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
127127
DPNP_FN_EIG_EXT
128128
DPNP_FN_EIGVALS
129129
DPNP_FN_EIGVALS_EXT
130+
DPNP_FN_EQUAL_EXT
130131
DPNP_FN_ERF
131132
DPNP_FN_ERF_EXT
132133
DPNP_FN_EYE
@@ -155,6 +156,8 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
155156
DPNP_FN_FMOD_EXT
156157
DPNP_FN_FULL
157158
DPNP_FN_FULL_LIKE
159+
DPNP_FN_GREATER_EXT
160+
DPNP_FN_GREATER_EQUAL_EXT
158161
DPNP_FN_HYPOT
159162
DPNP_FN_HYPOT_EXT
160163
DPNP_FN_IDENTITY
@@ -169,6 +172,7 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
169172
DPNP_FN_KRON_EXT
170173
DPNP_FN_LEFT_SHIFT
171174
DPNP_FN_LEFT_SHIFT_EXT
175+
DPNP_FN_LESS_EXT
172176
DPNP_FN_LESS_EQUAL_EXT
173177
DPNP_FN_LOG
174178
DPNP_FN_LOG_EXT
@@ -204,6 +208,7 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
204208
DPNP_FN_NEGATIVE_EXT
205209
DPNP_FN_NONZERO
206210
DPNP_FN_NONZERO_EXT
211+
DPNP_FN_NOT_EQUAL_EXT
207212
DPNP_FN_ONES
208213
DPNP_FN_ONES_LIKE
209214
DPNP_FN_PARTITION

dpnp/dpnp_algo/dpnp_algo_logic.pyx

Lines changed: 30 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -166,46 +166,28 @@ cpdef utils.dpnp_descriptor dpnp_any(utils.dpnp_descriptor array1):
166166
return result
167167

168168

169-
cpdef utils.dpnp_descriptor dpnp_equal(utils.dpnp_descriptor input1, utils.dpnp_descriptor input2):
170-
result_sycl_device, result_usm_type, result_sycl_queue = utils.get_common_usm_allocation(input1, input2)
171-
cdef utils.dpnp_descriptor result = utils_py.create_output_descriptor_py(input1.shape,
172-
dpnp.bool,
173-
None,
174-
device=result_sycl_device,
175-
usm_type=result_usm_type,
176-
sycl_queue=result_sycl_queue)
177-
for i in range(result.size):
178-
result.get_pyobj()[i] = dpnp.bool(input1.get_pyobj()[i] == input2.get_pyobj()[i])
169+
cpdef utils.dpnp_descriptor dpnp_equal(utils.dpnp_descriptor x1_obj,
170+
utils.dpnp_descriptor x2_obj,
171+
object dtype=None,
172+
utils.dpnp_descriptor out=None,
173+
object where=True):
174+
return call_fptr_2in_1out_strides(DPNP_FN_EQUAL_EXT, x1_obj, x2_obj, dtype, out, where, func_name="equal")
179175

180-
return result
181176

177+
cpdef utils.dpnp_descriptor dpnp_greater(utils.dpnp_descriptor x1_obj,
178+
utils.dpnp_descriptor x2_obj,
179+
object dtype=None,
180+
utils.dpnp_descriptor out=None,
181+
object where=True):
182+
return call_fptr_2in_1out_strides(DPNP_FN_GREATER_EXT, x1_obj, x2_obj, dtype, out, where, func_name="greater")
182183

183-
cpdef utils.dpnp_descriptor dpnp_greater(utils.dpnp_descriptor input1, utils.dpnp_descriptor input2):
184-
result_sycl_device, result_usm_type, result_sycl_queue = utils.get_common_usm_allocation(input1, input2)
185-
cdef utils.dpnp_descriptor result = utils_py.create_output_descriptor_py(input1.shape,
186-
dpnp.bool,
187-
None,
188-
device=result_sycl_device,
189-
usm_type=result_usm_type,
190-
sycl_queue=result_sycl_queue)
191-
for i in range(result.size):
192-
result.get_pyobj()[i] = dpnp.bool(input1.get_pyobj()[i] > input2.get_pyobj()[i])
193-
194-
return result
195-
196-
197-
cpdef utils.dpnp_descriptor dpnp_greater_equal(utils.dpnp_descriptor input1, utils.dpnp_descriptor input2):
198-
result_sycl_device, result_usm_type, result_sycl_queue = utils.get_common_usm_allocation(input1, input2)
199-
cdef utils.dpnp_descriptor result = utils_py.create_output_descriptor_py(input1.shape,
200-
dpnp.bool,
201-
None,
202-
device=result_sycl_device,
203-
usm_type=result_usm_type,
204-
sycl_queue=result_sycl_queue)
205-
for i in range(result.size):
206-
result.get_pyobj()[i] = dpnp.bool(input1.get_pyobj()[i] >= input2.get_pyobj()[i])
207184

208-
return result
185+
cpdef utils.dpnp_descriptor dpnp_greater_equal(utils.dpnp_descriptor x1_obj,
186+
utils.dpnp_descriptor x2_obj,
187+
object dtype=None,
188+
utils.dpnp_descriptor out=None,
189+
object where=True):
190+
return call_fptr_2in_1out_strides(DPNP_FN_GREATER_EQUAL_EXT, x1_obj, x2_obj, dtype, out, where, func_name="greater_equal")
209191

210192

211193
cpdef utils.dpnp_descriptor dpnp_isclose(utils.dpnp_descriptor input1,
@@ -272,18 +254,12 @@ cpdef utils.dpnp_descriptor dpnp_isnan(utils.dpnp_descriptor input1):
272254
return result
273255

274256

275-
cpdef utils.dpnp_descriptor dpnp_less(utils.dpnp_descriptor input1, utils.dpnp_descriptor input2):
276-
result_sycl_device, result_usm_type, result_sycl_queue = utils.get_common_usm_allocation(input1, input2)
277-
cdef utils.dpnp_descriptor result = utils_py.create_output_descriptor_py(input1.shape,
278-
dpnp.bool,
279-
None,
280-
device=result_sycl_device,
281-
usm_type=result_usm_type,
282-
sycl_queue=result_sycl_queue)
283-
for i in range(result.size):
284-
result.get_pyobj()[i] = dpnp.bool(input1.get_pyobj()[i] < input2.get_pyobj()[i])
285-
286-
return result
257+
cpdef utils.dpnp_descriptor dpnp_less(utils.dpnp_descriptor x1_obj,
258+
utils.dpnp_descriptor x2_obj,
259+
object dtype=None,
260+
utils.dpnp_descriptor out=None,
261+
object where=True):
262+
return call_fptr_2in_1out_strides(DPNP_FN_LESS_EXT, x1_obj, x2_obj, dtype, out, where, func_name="less")
287263

288264

289265
cpdef utils.dpnp_descriptor dpnp_less_equal(utils.dpnp_descriptor x1_obj,
@@ -355,15 +331,9 @@ cpdef utils.dpnp_descriptor dpnp_logical_xor(utils.dpnp_descriptor input1, utils
355331
return result
356332

357333

358-
cpdef utils.dpnp_descriptor dpnp_not_equal(utils.dpnp_descriptor input1, utils.dpnp_descriptor input2):
359-
result_sycl_device, result_usm_type, result_sycl_queue = utils.get_common_usm_allocation(input1, input2)
360-
cdef utils.dpnp_descriptor result = utils_py.create_output_descriptor_py(input1.shape,
361-
dpnp.bool,
362-
None,
363-
device=result_sycl_device,
364-
usm_type=result_usm_type,
365-
sycl_queue=result_sycl_queue)
366-
for i in range(result.size):
367-
result.get_pyobj()[i] = dpnp.bool(input1.get_pyobj()[i] != input2.get_pyobj()[i])
368-
369-
return result
334+
cpdef utils.dpnp_descriptor dpnp_not_equal(utils.dpnp_descriptor x1_obj,
335+
utils.dpnp_descriptor x2_obj,
336+
object dtype=None,
337+
utils.dpnp_descriptor out=None,
338+
object where=True):
339+
return call_fptr_2in_1out_strides(DPNP_FN_NOT_EQUAL_EXT, x1_obj, x2_obj, dtype, out, where, func_name="not_equal")

0 commit comments

Comments
 (0)