@@ -212,6 +212,14 @@ def dpnp_add(x1, x2, out=None, order="K"):
212212"""
213213
214214
215+ bitwise_and_func = BinaryElementwiseFunc (
216+ "bitwise_and" ,
217+ ti ._bitwise_and_result_type ,
218+ ti ._bitwise_and ,
219+ _bitwise_and_docstring_ ,
220+ )
221+
222+
215223def dpnp_bitwise_and (x1 , x2 , out = None , order = "K" ):
216224 """Invokes bitwise_and() from dpctl.tensor implementation for bitwise_and() function."""
217225
@@ -220,13 +228,9 @@ def dpnp_bitwise_and(x1, x2, out=None, order="K"):
220228 x2_usm_or_scalar = dpnp .get_usm_ndarray_or_scalar (x2 )
221229 out_usm = None if out is None else dpnp .get_usm_ndarray (out )
222230
223- func = BinaryElementwiseFunc (
224- "bitwise_and" ,
225- ti ._bitwise_and_result_type ,
226- ti ._bitwise_and ,
227- _bitwise_and_docstring_ ,
231+ res_usm = bitwise_and_func (
232+ x1_usm_or_scalar , x2_usm_or_scalar , out = out_usm , order = order
228233 )
229- res_usm = func (x1_usm_or_scalar , x2_usm_or_scalar , out = out_usm , order = order )
230234 return dpnp_array ._create_from_usm_ndarray (res_usm )
231235
232236
@@ -256,6 +260,14 @@ def dpnp_bitwise_and(x1, x2, out=None, order="K"):
256260"""
257261
258262
263+ bitwise_or_func = BinaryElementwiseFunc (
264+ "bitwise_or" ,
265+ ti ._bitwise_or_result_type ,
266+ ti ._bitwise_or ,
267+ _bitwise_or_docstring_ ,
268+ )
269+
270+
259271def dpnp_bitwise_or (x1 , x2 , out = None , order = "K" ):
260272 """Invokes bitwise_or() from dpctl.tensor implementation for bitwise_or() function."""
261273
@@ -264,13 +276,9 @@ def dpnp_bitwise_or(x1, x2, out=None, order="K"):
264276 x2_usm_or_scalar = dpnp .get_usm_ndarray_or_scalar (x2 )
265277 out_usm = None if out is None else dpnp .get_usm_ndarray (out )
266278
267- func = BinaryElementwiseFunc (
268- "bitwise_or" ,
269- ti ._bitwise_or_result_type ,
270- ti ._bitwise_or ,
271- _bitwise_or_docstring_ ,
279+ res_usm = bitwise_or_func (
280+ x1_usm_or_scalar , x2_usm_or_scalar , out = out_usm , order = order
272281 )
273- res_usm = func (x1_usm_or_scalar , x2_usm_or_scalar , out = out_usm , order = order )
274282 return dpnp_array ._create_from_usm_ndarray (res_usm )
275283
276284
@@ -300,6 +308,14 @@ def dpnp_bitwise_or(x1, x2, out=None, order="K"):
300308"""
301309
302310
311+ bitwise_xor_func = BinaryElementwiseFunc (
312+ "bitwise_xor" ,
313+ ti ._bitwise_xor_result_type ,
314+ ti ._bitwise_xor ,
315+ _bitwise_xor_docstring_ ,
316+ )
317+
318+
303319def dpnp_bitwise_xor (x1 , x2 , out = None , order = "K" ):
304320 """Invokes bitwise_xor() from dpctl.tensor implementation for bitwise_xor() function."""
305321
@@ -308,13 +324,9 @@ def dpnp_bitwise_xor(x1, x2, out=None, order="K"):
308324 x2_usm_or_scalar = dpnp .get_usm_ndarray_or_scalar (x2 )
309325 out_usm = None if out is None else dpnp .get_usm_ndarray (out )
310326
311- func = BinaryElementwiseFunc (
312- "bitwise_xor" ,
313- ti ._bitwise_xor_result_type ,
314- ti ._bitwise_xor ,
315- _bitwise_xor_docstring_ ,
327+ res_usm = bitwise_xor_func (
328+ x1_usm_or_scalar , x2_usm_or_scalar , out = out_usm , order = order
316329 )
317- res_usm = func (x1_usm_or_scalar , x2_usm_or_scalar , out = out_usm , order = order )
318330 return dpnp_array ._create_from_usm_ndarray (res_usm )
319331
320332
@@ -629,20 +641,22 @@ def dpnp_greater_equal(x1, x2, out=None, order="K"):
629641"""
630642
631643
644+ invert_func = UnaryElementwiseFunc (
645+ "invert" ,
646+ ti ._bitwise_invert_result_type ,
647+ ti ._bitwise_invert ,
648+ _invert_docstring ,
649+ )
650+
651+
632652def dpnp_invert (x , out = None , order = "K" ):
633653 """Invokes bitwise_invert() from dpctl.tensor implementation for invert() function."""
634654
635655 # dpctl.tensor only works with usm_ndarray or scalar
636656 x_usm = dpnp .get_usm_ndarray (x )
637657 out_usm = None if out is None else dpnp .get_usm_ndarray (out )
638658
639- func = UnaryElementwiseFunc (
640- "invert" ,
641- ti ._bitwise_invert_result_type ,
642- ti ._bitwise_invert ,
643- _invert_docstring ,
644- )
645- res_usm = func (x_usm , out = out_usm , order = order )
659+ res_usm = invert_func (x_usm , out = out_usm , order = order )
646660 return dpnp_array ._create_from_usm_ndarray (res_usm )
647661
648662
@@ -778,6 +792,14 @@ def dpnp_isnan(x, out=None, order="K"):
778792"""
779793
780794
795+ left_shift_func = BinaryElementwiseFunc (
796+ "bitwise_leftt_shift" ,
797+ ti ._bitwise_left_shift_result_type ,
798+ ti ._bitwise_left_shift ,
799+ _left_shift_docstring_ ,
800+ )
801+
802+
781803def dpnp_left_shift (x1 , x2 , out = None , order = "K" ):
782804 """Invokes bitwise_left_shift() from dpctl.tensor implementation for left_shift() function."""
783805
@@ -786,13 +808,9 @@ def dpnp_left_shift(x1, x2, out=None, order="K"):
786808 x2_usm_or_scalar = dpnp .get_usm_ndarray_or_scalar (x2 )
787809 out_usm = None if out is None else dpnp .get_usm_ndarray (out )
788810
789- func = BinaryElementwiseFunc (
790- "bitwise_leftt_shift" ,
791- ti ._bitwise_left_shift_result_type ,
792- ti ._bitwise_left_shift ,
793- _left_shift_docstring_ ,
811+ res_usm = left_shift_func (
812+ x1_usm_or_scalar , x2_usm_or_scalar , out = out_usm , order = order
794813 )
795- res_usm = func (x1_usm_or_scalar , x2_usm_or_scalar , out = out_usm , order = order )
796814 return dpnp_array ._create_from_usm_ndarray (res_usm )
797815
798816
@@ -1199,6 +1217,14 @@ def dpnp_not_equal(x1, x2, out=None, order="K"):
11991217"""
12001218
12011219
1220+ right_shift_func = BinaryElementwiseFunc (
1221+ "bitwise_right_shift" ,
1222+ ti ._bitwise_right_shift_result_type ,
1223+ ti ._bitwise_right_shift ,
1224+ _right_shift_docstring_ ,
1225+ )
1226+
1227+
12021228def dpnp_right_shift (x1 , x2 , out = None , order = "K" ):
12031229 """Invokes bitwise_right_shift() from dpctl.tensor implementation for right_shift() function."""
12041230
@@ -1207,13 +1233,9 @@ def dpnp_right_shift(x1, x2, out=None, order="K"):
12071233 x2_usm_or_scalar = dpnp .get_usm_ndarray_or_scalar (x2 )
12081234 out_usm = None if out is None else dpnp .get_usm_ndarray (out )
12091235
1210- func = BinaryElementwiseFunc (
1211- "bitwise_right_shift" ,
1212- ti ._bitwise_right_shift_result_type ,
1213- ti ._bitwise_right_shift ,
1214- _right_shift_docstring_ ,
1236+ res_usm = right_shift_func (
1237+ x1_usm_or_scalar , x2_usm_or_scalar , out = out_usm , order = order
12151238 )
1216- res_usm = func (x1_usm_or_scalar , x2_usm_or_scalar , out = out_usm , order = order )
12171239 return dpnp_array ._create_from_usm_ndarray (res_usm )
12181240
12191241
0 commit comments