Skip to content

Commit d84fcb4

Browse files
numba-dppy tests and examples use with device_context (#46)
* change the test_prange * fix all tests and examples * Replace DPPLTestCase with DPPYTestCase * Fix typos * Use explicit device selection in the examples * Replace dppl with dppy * Fix int64 to long long conversion on windows * Fixed test_with_dppy_context_cpu Co-authored-by: Pokhodenko <sergey.pokhodenko@intel.com>
1 parent 0105e83 commit d84fcb4

17 files changed

+522
-325
lines changed

numba_dppy/dppy_host_fn_call_gen.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def _init_llvm_types_and_constants(self):
5252
self.byte_ptr_t = lc.Type.pointer(self.byte_t)
5353
self.byte_ptr_ptr_t = lc.Type.pointer(self.byte_ptr_t)
5454
self.intp_t = self.context.get_value_type(types.intp)
55-
self.long_t = self.context.get_value_type(types.int64)
55+
self.int64_t = self.context.get_value_type(types.int64)
5656
self.int32_t = self.context.get_value_type(types.int32)
5757
self.int32_ptr_t = lc.Type.pointer(self.int32_t)
5858
self.uintp_t = self.context.get_value_type(types.uintp)
@@ -113,23 +113,26 @@ def allocate_kenrel_arg_array(self, num_kernel_args):
113113

114114

115115
def resolve_and_return_dpctl_type(self, ty):
116+
"""This function looks up the dpctl defined enum values from DPCTLKernelArgType.
117+
"""
118+
116119
val = None
117120
if ty == types.int32 or isinstance(ty, types.scalars.IntegerLiteral):
118-
val = self.context.get_constant(types.int32, 4)
121+
val = self.context.get_constant(types.int32, 9) # DPCTL_LONG_LONG
119122
elif ty == types.uint32:
120-
val = self.context.get_constant(types.int32, 5)
123+
val = self.context.get_constant(types.int32, 10) # DPCTL_UNSIGNED_LONG_LONG
121124
elif ty == types.boolean:
122-
val = self.context.get_constant(types.int32, 5)
125+
val = self.context.get_constant(types.int32, 5) # DPCTL_UNSIGNED_INT
123126
elif ty == types.int64:
124-
val = self.context.get_constant(types.int32, 7)
127+
val = self.context.get_constant(types.int32, 9) # DPCTL_LONG_LONG
125128
elif ty == types.uint64:
126-
val = self.context.get_constant(types.int32, 8)
129+
val = self.context.get_constant(types.int32, 11) # DPCTL_SIZE_T
127130
elif ty == types.float32:
128-
val = self.context.get_constant(types.int32, 12)
131+
val = self.context.get_constant(types.int32, 12) # DPCTL_FLOAT
129132
elif ty == types.float64:
130-
val = self.context.get_constant(types.int32, 13)
133+
val = self.context.get_constant(types.int32, 13) # DPCTL_DOUBLE
131134
elif ty == types.voidptr:
132-
val = self.context.get_constant(types.int32, 15)
135+
val = self.context.get_constant(types.int32, 15) # DPCTL_VOID_PTR
133136
else:
134137
raise NotImplementedError
135138

@@ -151,12 +154,12 @@ def process_kernel_arg(self, var, llvm_arg, arg_type, gu_sig, val_type, index, m
151154
if llvm_arg is None:
152155
raise NotImplementedError(arg_type, var)
153156

154-
storage = cgutils.alloca_once(self.builder, self.long_t)
157+
storage = cgutils.alloca_once(self.builder, self.int64_t)
155158
self.builder.store(self.context.get_constant(types.int64, 0), storage)
156159
ty = self.resolve_and_return_dpctl_type(types.int64)
157160
self.form_kernel_arg_and_arg_ty(self.builder.bitcast(storage, self.void_ptr_t), ty)
158161

159-
storage = cgutils.alloca_once(self.builder, self.long_t)
162+
storage = cgutils.alloca_once(self.builder, self.int64_t)
160163
self.builder.store(self.context.get_constant(types.int64, 0), storage)
161164
ty = self.resolve_and_return_dpctl_type(types.int64)
162165
self.form_kernel_arg_and_arg_ty(self.builder.bitcast(storage, self.void_ptr_t), ty)
Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,29 @@
11
from numba import njit, gdb
22
import numpy as np
3+
import dpctl
34

4-
@njit(parallel={'offload':True})
5+
6+
@njit
57
def f1(a, b):
68
c = a + b
79
return c
810

11+
912
N = 1000
1013
print("N", N)
1114

12-
a = np.ones((N,N), dtype=np.float32)
13-
b = np.ones((N,N), dtype=np.float32)
15+
a = np.ones((N, N), dtype=np.float32)
16+
b = np.ones((N, N), dtype=np.float32)
1417

1518
print("a:", a, hex(a.ctypes.data))
1619
print("b:", b, hex(b.ctypes.data))
17-
c = f1(a,b)
20+
21+
with dpctl.device_context("opencl:gpu:0"):
22+
c = f1(a, b)
23+
1824
print("BIG RESULT c:", c, hex(c.ctypes.data))
1925
for i in range(N):
2026
for j in range(N):
21-
if c[i,j] != 2.0:
27+
if c[i, j] != 2.0:
2228
print("First index not equal to 2.0 was", i, j)
2329
break
Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,30 @@
11
from numba import njit, gdb
22
import numpy as np
3+
import dpctl
34

4-
@njit(parallel={'offload':True})
5+
6+
@njit
57
def f1(a, b):
68
c = a + b
79
return c
810

11+
912
N = 10
1013
print("N", N)
1114

12-
a = np.ones((N,N,N), dtype=np.float32)
13-
b = np.ones((N,N,N), dtype=np.float32)
15+
a = np.ones((N, N, N), dtype=np.float32)
16+
b = np.ones((N, N, N), dtype=np.float32)
1417

1518
print("a:", a, hex(a.ctypes.data))
1619
print("b:", b, hex(b.ctypes.data))
17-
c = f1(a,b)
20+
21+
with dpctl.device_context("opencl:gpu:0"):
22+
c = f1(a, b)
23+
1824
print("BIG RESULT c:", c, hex(c.ctypes.data))
1925
for i in range(N):
2026
for j in range(N):
2127
for k in range(N):
22-
if c[i,j,k] != 2.0:
28+
if c[i, j, k] != 2.0:
2329
print("First index not equal to 2.0 was", i, j, k)
2430
break
Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,31 @@
11
from numba import njit, gdb
22
import numpy as np
3+
import dpctl
34

4-
@njit(parallel={'offload':True})
5+
6+
@njit
57
def f1(a, b):
68
c = a + b
79
return c
810

11+
912
N = 10
1013
print("N", N)
1114

12-
a = np.ones((N,N,N,N), dtype=np.float32)
13-
b = np.ones((N,N,N,N), dtype=np.float32)
15+
a = np.ones((N, N, N, N), dtype=np.float32)
16+
b = np.ones((N, N, N, N), dtype=np.float32)
1417

1518
print("a:", a, hex(a.ctypes.data))
1619
print("b:", b, hex(b.ctypes.data))
17-
c = f1(a,b)
20+
21+
with dpctl.device_context("opencl:gpu:0"):
22+
c = f1(a, b)
23+
1824
print("BIG RESULT c:", c, hex(c.ctypes.data))
1925
for i in range(N):
2026
for j in range(N):
2127
for k in range(N):
2228
for l in range(N):
23-
if c[i,j,k,l] != 2.0:
29+
if c[i, j, k, l] != 2.0:
2430
print("First index not equal to 2.0 was", i, j, k, l)
2531
break
Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,32 @@
11
from numba import njit, gdb
22
import numpy as np
3+
import dpctl
34

4-
@njit(parallel={'offload':True})
5+
6+
@njit
57
def f1(a, b):
68
c = a + b
79
return c
810

11+
912
N = 5
1013
print("N", N)
1114

12-
a = np.ones((N,N,N,N,N), dtype=np.float32)
13-
b = np.ones((N,N,N,N,N), dtype=np.float32)
15+
a = np.ones((N, N, N, N, N), dtype=np.float32)
16+
b = np.ones((N, N, N, N, N), dtype=np.float32)
1417

1518
print("a:", a, hex(a.ctypes.data))
1619
print("b:", b, hex(b.ctypes.data))
17-
c = f1(a,b)
20+
21+
with dpctl.device_context("opencl:gpu:0"):
22+
c = f1(a, b)
23+
1824
print("BIG RESULT c:", c, hex(c.ctypes.data))
1925
for i in range(N):
2026
for j in range(N):
2127
for k in range(N):
2228
for l in range(N):
2329
for m in range(N):
24-
if c[i,j,k,l,m] != 2.0:
30+
if c[i, j, k, l, m] != 2.0:
2531
print("First index not equal to 2.0 was", i, j, k, l, m)
2632
break

numba_dppy/examples/pa_examples/test1.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from numba import njit
22
import numpy as np
3+
import dpctl
34

45

5-
@njit(parallel={'offload':True})
6+
@njit
67
def f1(a, b):
78
c = a + b
89
return c
@@ -19,7 +20,10 @@ def main():
1920

2021
print("a:", a, hex(a.ctypes.data))
2122
print("b:", b, hex(b.ctypes.data))
22-
c = f1(a,b)
23+
24+
with dpctl.device_context("opencl:gpu:0"):
25+
c = f1(a, b)
26+
2327
print("RESULT c:", c, hex(c.ctypes.data))
2428
for i in range(N):
2529
if c[i] != 2.0:

0 commit comments

Comments
 (0)