Skip to content

Commit 7bcc426

Browse files
committed
address reviewer's comments
1 parent bbafc89 commit 7bcc426

File tree

2 files changed

+54
-49
lines changed

2 files changed

+54
-49
lines changed

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 49 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -186,55 +186,6 @@ def dpnp_add(x1, x2, out=None, order="K"):
186186
return dpnp_array._create_from_usm_ndarray(res_usm)
187187

188188

189-
_cos_docstring = """
190-
cos(x, out=None, order='K')
191-
Computes cosine for each element `x_i` for input array `x`.
192-
Args:
193-
x (dpnp.ndarray):
194-
Input array, expected to have numeric data type.
195-
out ({None, dpnp.ndarray}, optional):
196-
Output array to populate. Array must have the correct
197-
shape and the expected data type.
198-
order ("C","F","A","K", optional): memory layout of the new
199-
output array, if parameter `out` is `None`.
200-
Default: "K".
201-
Return:
202-
dpnp.ndarray:
203-
An array containing the element-wise cosine. The data type
204-
of the returned array is determined by the Type Promotion Rules.
205-
"""
206-
207-
208-
def dpnp_cos(x, out=None, order="K"):
209-
"""
210-
Invokes cos() function from pybind11 extension of OneMKL VM if possible.
211-
212-
Otherwise fully relies on dpctl.tensor implementation for cos() function.
213-
214-
"""
215-
216-
def _call_cos(src, dst, sycl_queue, depends=None):
217-
"""A callback to register in UnaryElementwiseFunc class of dpctl.tensor"""
218-
219-
if depends is None:
220-
depends = []
221-
222-
if vmi._mkl_cos_to_call(sycl_queue, src, dst):
223-
# call pybind11 extension for cos() function from OneMKL VM
224-
return vmi._cos(sycl_queue, src, dst, depends)
225-
return ti._cos(src, dst, sycl_queue, depends)
226-
227-
# dpctl.tensor only works with usm_ndarray
228-
x1_usm = dpnp.get_usm_ndarray(x)
229-
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
230-
231-
func = UnaryElementwiseFunc(
232-
"cos", ti._cos_result_type, _call_cos, _cos_docstring
233-
)
234-
res_usm = func(x1_usm, out=out_usm, order=order)
235-
return dpnp_array._create_from_usm_ndarray(res_usm)
236-
237-
238189
_bitwise_and_docstring_ = """
239190
bitwise_and(x1, x2, out=None, order='K')
240191
@@ -367,6 +318,55 @@ def dpnp_bitwise_xor(x1, x2, out=None, order="K"):
367318
return dpnp_array._create_from_usm_ndarray(res_usm)
368319

369320

321+
_cos_docstring = """
322+
cos(x, out=None, order='K')
323+
Computes cosine for each element `x_i` for input array `x`.
324+
Args:
325+
x (dpnp.ndarray):
326+
Input array, expected to have numeric data type.
327+
out ({None, dpnp.ndarray}, optional):
328+
Output array to populate. Array must have the correct
329+
shape and the expected data type.
330+
order ("C","F","A","K", optional): memory layout of the new
331+
output array, if parameter `out` is `None`.
332+
Default: "K".
333+
Return:
334+
dpnp.ndarray:
335+
An array containing the element-wise cosine. The data type
336+
of the returned array is determined by the Type Promotion Rules.
337+
"""
338+
339+
340+
def dpnp_cos(x, out=None, order="K"):
341+
"""
342+
Invokes cos() function from pybind11 extension of OneMKL VM if possible.
343+
344+
Otherwise fully relies on dpctl.tensor implementation for cos() function.
345+
346+
"""
347+
348+
def _call_cos(src, dst, sycl_queue, depends=None):
349+
"""A callback to register in UnaryElementwiseFunc class of dpctl.tensor"""
350+
351+
if depends is None:
352+
depends = []
353+
354+
if vmi._mkl_cos_to_call(sycl_queue, src, dst):
355+
# call pybind11 extension for cos() function from OneMKL VM
356+
return vmi._cos(sycl_queue, src, dst, depends)
357+
return ti._cos(src, dst, sycl_queue, depends)
358+
359+
# dpctl.tensor only works with usm_ndarray
360+
x1_usm = dpnp.get_usm_ndarray(x)
361+
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
362+
363+
func = UnaryElementwiseFunc(
364+
"cos", ti._cos_result_type, _call_cos, _cos_docstring
365+
)
366+
res_usm = func(x1_usm, out=out_usm, order=order)
367+
return dpnp_array._create_from_usm_ndarray(res_usm)
368+
369+
370370
_divide_docstring_ = """
371371
divide(x1, x2, out=None, order="K")
372372

tests/test_bitwise.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def test_bitwise_and(self, lhs, rhs, dtype):
6868
assert_array_equal(dp_a & dp_b, np_a & np_b)
6969

7070
"""
71+
TODO: unmute once dpctl support that
7172
if (
7273
not (inp.isscalar(dp_a) or inp.isscalar(dp_b))
7374
and dp_a.shape == dp_b.shape
@@ -84,6 +85,7 @@ def test_bitwise_or(self, lhs, rhs, dtype):
8485
assert_array_equal(dp_a | dp_b, np_a | np_b)
8586

8687
"""
88+
TODO: unmute once dpctl support that
8789
if (
8890
not (inp.isscalar(dp_a) or inp.isscalar(dp_b))
8991
and dp_a.shape == dp_b.shape
@@ -100,6 +102,7 @@ def test_bitwise_xor(self, lhs, rhs, dtype):
100102
assert_array_equal(dp_a ^ dp_b, np_a ^ np_b)
101103

102104
"""
105+
TODO: unmute once dpctl support that
103106
if (
104107
not (inp.isscalar(dp_a) or inp.isscalar(dp_b))
105108
and dp_a.shape == dp_b.shape
@@ -120,6 +123,7 @@ def test_left_shift(self, lhs, rhs, dtype):
120123
assert_array_equal(dp_a << dp_b, np_a << np_b)
121124

122125
"""
126+
TODO: unmute once dpctl support that
123127
if (
124128
not (inp.isscalar(dp_a) or inp.isscalar(dp_b))
125129
and dp_a.shape == dp_b.shape
@@ -136,6 +140,7 @@ def test_right_shift(self, lhs, rhs, dtype):
136140
assert_array_equal(dp_a >> dp_b, np_a >> np_b)
137141

138142
"""
143+
TODO: unmute once dpctl support that
139144
if (
140145
not (inp.isscalar(dp_a) or inp.isscalar(dp_b))
141146
and dp_a.shape == dp_b.shape

0 commit comments

Comments
 (0)