Skip to content

Commit e6e35f4

Browse files
Implement np.cumsum and np.cumprod in kernel by dpnp (#258)
1 parent 6959fd9 commit e6e35f4

File tree

5 files changed

+125
-1
lines changed

5 files changed

+125
-1
lines changed

numba_dppy/dpnp_glue/dpnp_array_ops_impl.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,90 @@
1717
from numba.core.typing import signature
1818
from . import stubs
1919
import numba_dppy.dpnp_glue as dpnp_lowering
20-
from numba.core.extending import overload
20+
from numba.core.extending import overload, register_jitable
2121
import numpy as np
2222
from numba_dppy import dpctl_functions
2323
import numba_dppy
2424

2525

26+
@register_jitable
27+
def common_impl(a, out, dpnp_func, print_debug):
28+
if a.size == 0:
29+
raise ValueError("Passed Empty array")
30+
31+
sycl_queue = dpctl_functions.get_current_queue()
32+
a_usm = dpctl_functions.malloc_shared(a.size * a.itemsize, sycl_queue)
33+
dpctl_functions.queue_memcpy(sycl_queue, a_usm, a.ctypes, a.size * a.itemsize)
34+
35+
out_usm = dpctl_functions.malloc_shared(a.itemsize, sycl_queue)
36+
37+
dpnp_func(a_usm, out_usm, a.size)
38+
39+
dpctl_functions.queue_memcpy(
40+
sycl_queue, out.ctypes, out_usm, out.size * out.itemsize
41+
)
42+
43+
dpctl_functions.free_with_queue(a_usm, sycl_queue)
44+
dpctl_functions.free_with_queue(out_usm, sycl_queue)
45+
46+
dpnp_ext._dummy_liveness_func([a.size, out.size])
47+
48+
if print_debug:
49+
print("dpnp implementation")
50+
51+
52+
@overload(stubs.dpnp.cumsum)
53+
def dpnp_cumsum_impl(a):
54+
name = "cumsum"
55+
dpnp_lowering.ensure_dpnp(name)
56+
57+
res_type = types.void
58+
"""
59+
dpnp source:
60+
https://github.com/IntelPython/dpnp/blob/0.5.1/dpnp/backend/kernels/dpnp_krnl_mathematical.cpp#L135
61+
Function declaration:
62+
void dpnp_cumsum_c(void* array1_in, void* result1, size_t size)
63+
"""
64+
sig = signature(res_type, types.voidptr, types.voidptr, types.intp)
65+
dpnp_func = dpnp_ext.dpnp_func("dpnp_" + name, [a.dtype.name, "NONE"], sig)
66+
67+
PRINT_DEBUG = dpnp_lowering.DEBUG
68+
69+
def dpnp_impl(a):
70+
out = np.arange(a.size, dtype=a.dtype)
71+
common_impl(a, out, dpnp_func, PRINT_DEBUG)
72+
73+
return out
74+
75+
return dpnp_impl
76+
77+
78+
@overload(stubs.dpnp.cumprod)
79+
def dpnp_cumprod_impl(a):
80+
name = "cumprod"
81+
dpnp_lowering.ensure_dpnp(name)
82+
83+
res_type = types.void
84+
"""
85+
dpnp source:
86+
https://github.com/IntelPython/dpnp/blob/0.5.1/dpnp/backend/kernels/dpnp_krnl_mathematical.cpp#L110
87+
Function declaration:
88+
void dpnp_cumprod_c(void* array1_in, void* result1, size_t size)
89+
"""
90+
sig = signature(res_type, types.voidptr, types.voidptr, types.intp)
91+
dpnp_func = dpnp_ext.dpnp_func("dpnp_" + name, [a.dtype.name, "NONE"], sig)
92+
93+
PRINT_DEBUG = dpnp_lowering.DEBUG
94+
95+
def dpnp_impl(a):
96+
out = np.arange(a.size, dtype=a.dtype)
97+
common_impl(a, out, dpnp_func, PRINT_DEBUG)
98+
99+
return out
100+
101+
return dpnp_impl
102+
103+
26104
@overload(stubs.dpnp.sort)
27105
def dpnp_sort_impl(a):
28106
name = "sort"

numba_dppy/dpnp_glue/dpnp_fptr_interface.pyx

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
3030
DPNP_FN_COS
3131
DPNP_FN_COSH
3232
DPNP_FN_COV
33+
DPNP_FN_CUMPROD
34+
DPNP_FN_CUMSUM
3335
DPNP_FN_DEGREES
3436
DPNP_FN_DET
3537
DPNP_FN_DIVIDE
@@ -201,6 +203,10 @@ cdef DPNPFuncName get_DPNPFuncName_from_str(name):
201203
return DPNPFuncName.DPNP_FN_DET
202204
elif name == "dpnp_matrix_rank":
203205
return DPNPFuncName.DPNP_FN_MATRIX_RANK
206+
elif name == "dpnp_cumsum":
207+
return DPNPFuncName.DPNP_FN_CUMSUM
208+
elif name == "dpnp_cumprod":
209+
return DPNPFuncName.DPNP_FN_CUMPROD
204210
elif name == "dpnp_sort":
205211
return DPNPFuncName.DPNP_FN_SORT
206212
else:

numba_dppy/dpnp_glue/stubs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,5 +179,11 @@ class nansum(Stub):
179179
class nanprod(Stub):
180180
pass
181181

182+
class cumsum(Stub):
183+
pass
184+
185+
class cumprod(Stub):
186+
pass
187+
182188
class sort(Stub):
183189
pass

numba_dppy/rename_numpy_functions_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@
8383
"prod": (["numpy"], "prod"),
8484
"sum": (["numpy"], "sum"),
8585
# array ops
86+
"cumsum": (["numpy"], "cumsum"),
87+
"cumprod": (["numpy"], "cumprod"),
8688
"sort": (["numpy"], "sort"),
8789
}
8890

numba_dppy/tests/test_dpnp_functions.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1408,6 +1408,38 @@ def f(a, b):
14081408
class Testdpnp_array_ops_functions(unittest.TestCase):
14091409
tys = [np.int32, np.uint32, np.int64, np.uint64, np.float, np.double]
14101410

1411+
def test_cumsum(self):
1412+
@njit
1413+
def f(a):
1414+
c = np.cumsum(a)
1415+
return c
1416+
1417+
with assert_dpnp_implementaion():
1418+
self.assertTrue(
1419+
check_for_different_datatypes(f, np.cumsum, [10], 1, self.tys, True)
1420+
)
1421+
self.assertTrue(check_for_dimensions(f, np.cumsum, [10, 2], self.tys, True))
1422+
self.assertTrue(
1423+
check_for_dimensions(f, np.cumsum, [10, 2, 3], self.tys, True)
1424+
)
1425+
1426+
def test_cumprod(self):
1427+
@njit
1428+
def f(a):
1429+
c = np.cumprod(a)
1430+
return c
1431+
1432+
with assert_dpnp_implementaion():
1433+
self.assertTrue(
1434+
check_for_different_datatypes(f, np.cumprod, [10], 1, self.tys, True)
1435+
)
1436+
self.assertTrue(
1437+
check_for_dimensions(f, np.cumprod, [10, 2], self.tys, True)
1438+
)
1439+
self.assertTrue(
1440+
check_for_dimensions(f, np.cumprod, [10, 2, 3], self.tys, True)
1441+
)
1442+
14111443
def test_sort(self):
14121444
@njit
14131445
def f(a):

0 commit comments

Comments
 (0)