Skip to content

Commit 5b176d0

Browse files
chudur-budurkhaled
authored andcommitted
WIP for DpctlSyclQueue support
Adding proper typing for dpctl.SyclQueue Revert Tried to follow the interval example from numba Redo the implementation by following how the ArrayModel was done Fix arg names keep driver.py Added unboxing function for SyclQueueType Adding minimal test Testing with different pattern
1 parent 9d3d507 commit 5b176d0

File tree

9 files changed

+249
-20
lines changed

9 files changed

+249
-20
lines changed

driver.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from dpctl import SyclQueue
2+
from numba import njit
3+
4+
from numba_dpex import dpjit
5+
6+
if __name__ == "__main__":
7+
8+
@dpjit
9+
def test(q):
10+
pass
11+
12+
queue = SyclQueue()
13+
test(queue)

numba_dpex/core/datamodel/models.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from numba_dpex.utils import address_space
1111

12-
from ..types import Array, DpctlSyclQueue, DpnpNdArray, USMNdArray
12+
from ..types import Array, DpnpNdArray, SyclQueueType, USMNdArray
1313

1414

1515
class GenericPointerModel(PrimitiveModel):
@@ -54,6 +54,18 @@ def __init__(self, dmm, fe_type):
5454
super(ArrayModel, self).__init__(dmm, fe_type, members)
5555

5656

57+
class SyclQueueModel(StructModel):
58+
def __init__(self, dmm, fe_type):
59+
members = [
60+
("parent", types.CPointer),
61+
("queue_ref", types.PyObject),
62+
("context", types.PyObject),
63+
("device", types.PyObject),
64+
]
65+
# super(StructModel, self).__init__(dmm, fe_type, members)
66+
StructModel.__init__(self, dmm, fe_type, members)
67+
68+
5769
def _init_data_model_manager():
5870
dmm = datamodel.default_manager.copy()
5971
dmm.register(types.CPointer, GenericPointerModel)
@@ -84,5 +96,9 @@ def _init_data_model_manager():
8496
dpex_data_model_manager.register(DpnpNdArray, DpnpNdArrayModel)
8597

8698
# Register the DpctlSyclQueue type with Numba's OpaqueModel
87-
register_model(DpctlSyclQueue)(OpaqueModel)
88-
dpex_data_model_manager.register(DpctlSyclQueue, OpaqueModel)
99+
# register_model(DpctlSyclQueue)(OpaqueModel)
100+
# dpex_data_model_manager.register(DpctlSyclQueue, OpaqueModel)
101+
102+
# Register the DpctlSyclQueue type with Numba's OpaqueModel
103+
register_model(SyclQueueType)(SyclQueueModel)
104+
dpex_data_model_manager.register(SyclQueueType, SyclQueueModel)

numba_dpex/core/runtime/_dpexrt_python.c

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "_nrt_helper.h"
2020
#include "_nrt_python_helper.h"
2121

22+
#include "_queuestruct.h"
2223
#include "numba/_arraystruct.h"
2324

2425
/* Debugging facilities - enabled at compile-time */
@@ -66,6 +67,9 @@ DPEXRT_sycl_usm_ndarray_to_python_acqref(arystruct_t *arystruct,
6667
int ndim,
6768
int writeable,
6869
PyArray_Descr *descr);
70+
static struct PySyclQueueObject *to_py_syclqobject(PyObject *obj);
71+
static int DPEXRT_sycl_queue_from_python(PyObject *obj,
72+
queuestruct_t *queue_struct);
6973

7074
/*
7175
* Debugging printf function used internally
@@ -645,6 +649,21 @@ static struct PyUSMArrayObject *PyUSMNdArray_ARRAYOBJ(PyObject *obj)
645649
return pyusmarrayobj;
646650
}
647651

652+
static struct PySyclQueueObject *to_py_syclqobject(PyObject *obj)
653+
{
654+
if (!obj)
655+
return NULL;
656+
if (!PyObject_TypeCheck(obj, &PySyclQueueType))
657+
return NULL;
658+
659+
struct PySyclQueueObject *pysyclqobj = (struct PySyclQueueObject *)(obj);
660+
// struct Py_SyclQueueObject py_syclqobj = pysyclqobj->__pyx_base;
661+
662+
// return &py_syclqobj;
663+
664+
return pysyclqobj;
665+
}
666+
648667
/*!
649668
* @brief Returns the product of the elements in an array of a given
650669
* length.
@@ -785,6 +804,62 @@ static int DPEXRT_sycl_usm_ndarray_from_python(PyObject *obj,
785804
return -1;
786805
}
787806

807+
static int DPEXRT_sycl_queue_from_python(PyObject *obj,
808+
queuestruct_t *queue_struct)
809+
{
810+
811+
struct PySyclQueueObject *queue_obj = NULL;
812+
// DPCTLSyclQueueRef queue_ref = NULL;
813+
PyGILState_STATE gstate;
814+
815+
// Increment the ref count on obj to prevent CPython from garbage
816+
// collecting the array.
817+
Py_IncRef(obj);
818+
819+
DPEXRT_DEBUG(
820+
nrt_debug_print("DPEXRT-DEBUG: In DPEXRT_sycl_queue_from_python.\n"));
821+
822+
// Check if the PyObject obj has an _array_obj attribute that is of
823+
// dpctl.tensor.usm_ndarray type.
824+
if (!(queue_obj = to_py_syclqobject(obj))) {
825+
DPEXRT_DEBUG(nrt_debug_print(
826+
"DPEXRT-ERROR: to_py_syclqobject() check failed %d\n", __FILE__,
827+
__LINE__));
828+
goto error;
829+
}
830+
831+
// if (!(queue_ref = SyclQueue_GetQueueRef(queue_obj))) {
832+
// DPEXRT_DEBUG(nrt_debug_print(
833+
// "DPEXRT-ERROR: SyclQueue_GetQueueRef returned NULL at "
834+
// "%s, line %d.\n",
835+
// __FILE__, __LINE__));
836+
// goto error;
837+
// }
838+
839+
queue_struct->parent = obj;
840+
// queue_struct->queue_ref = queue_ref;
841+
queue_struct->queue_ref = (PyObject *)queue_obj->__pyx_base._queue_ref;
842+
queue_struct->cotext = (PyObject *)queue_obj->__pyx_base._context;
843+
queue_struct->device = (PyObject *)queue_obj->__pyx_base._device;
844+
845+
error:
846+
// If the check failed then decrement the refcount and return an error
847+
// code of -1.
848+
// Decref the Pyobject of the array
849+
// ensure the GIL
850+
DPEXRT_DEBUG(nrt_debug_print(
851+
"DPEXRT-ERROR: Failed to unbox dpctl SyclQueue into a Numba "
852+
"queuestruct at %s, line %d\n",
853+
__FILE__, __LINE__));
854+
gstate = PyGILState_Ensure();
855+
// decref the python object
856+
Py_DECREF(obj);
857+
// release the GIL
858+
PyGILState_Release(gstate);
859+
860+
return -1;
861+
}
862+
788863
/*!
789864
* @brief A helper function that boxes a Numba arystruct_t object into a
790865
* dpnp.ndarray PyObject using the arystruct_t's parent attribute.
@@ -1082,6 +1157,8 @@ static PyObject *build_c_helpers_dict(void)
10821157
_declpointer("DPEXRT_MemInfo_fill", &DPEXRT_MemInfo_fill);
10831158
_declpointer("NRT_ExternalAllocator_new_for_usm",
10841159
&NRT_ExternalAllocator_new_for_usm);
1160+
_declpointer("DPEXRT_sycl_queue_from_python",
1161+
&DPEXRT_sycl_queue_from_python);
10851162

10861163
#undef _declpointer
10871164
return dct;
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#ifndef NUMBA_DPEX_QUEUESTRUCT_H_
2+
#define NUMBA_DPEX_QUEUESTRUCT_H_
3+
/*
4+
* Fill in the *queuestruct* with information from the Numpy array *obj*.
5+
* *queuestruct*'s layout is defined in numba.targets.arrayobj (look
6+
* for the ArrayTemplate class).
7+
*/
8+
9+
#include "numpy/npy_common.h"
10+
#include <Python.h>
11+
12+
typedef struct
13+
{
14+
PyObject *parent;
15+
PyObject *queue_ref;
16+
PyObject *cotext;
17+
PyObject *device;
18+
} queuestruct_t;
19+
20+
#endif /* NUMBA_DPEX_QUEUESTRUCT_H_ */

numba_dpex/core/runtime/context.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,21 @@ def arraystruct_from_python(self, pyapi, obj, ptr):
128128

129129
return self.error
130130

131+
def queuestruct_from_python(self, pyapi, obj, ptr):
132+
# call the c function DPEXRT_sycl_queue_from_python
133+
134+
fnty = llvmir.FunctionType(
135+
llvmir.IntType(32), [pyapi.pyobj, pyapi.voidptr]
136+
)
137+
138+
fn = pyapi._get_function(fnty, "DPEXRT_sycl_queue_from_python")
139+
fn.args[0].add_attribute("nocapture")
140+
fn.args[1].add_attribute("nocapture")
141+
142+
self.error = pyapi.builder.call(fn, (obj, ptr))
143+
144+
return self.error
145+
131146
def usm_ndarray_to_python_acqref(self, pyapi, aryty, ary, dtypeptr):
132147
"""Boxes a DpnpNdArray native object into a Python dpnp.ndarray.
133148

numba_dpex/core/types/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
from .array_type import Array
6-
from .dpctl_types import DpctlSyclQueue
6+
from .dpctl_types import SyclQueueType
77
from .dpnp_ndarray_type import DpnpNdArray
88
from .numba_types_short_names import (
99
b1,
@@ -32,7 +32,7 @@
3232

3333
__all__ = [
3434
"Array",
35-
"DpctlSyclQueue",
35+
"SyclQueueType",
3636
"DpnpNdArray",
3737
"USMNdArray",
3838
"none",

numba_dpex/core/types/dpctl_types.py

Lines changed: 97 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,21 @@
44

55
from dpctl import SyclQueue
66
from numba import types
7-
from numba.extending import NativeValue, box, type_callable, unbox
7+
from numba.core import cgutils
8+
from numba.extending import (
9+
NativeValue,
10+
as_numba_type,
11+
box,
12+
type_callable,
13+
typeof_impl,
14+
unbox,
15+
)
816

17+
from numba_dpex.core.exceptions import UnreachableError
18+
from numba_dpex.core.runtime import context as dpxrtc
919

10-
class DpctlSyclQueue(types.Type):
20+
21+
class SyclQueueType(types.Type):
1122
"""A Numba type to represent a dpctl.SyclQueue PyObject.
1223
1324
For now, a dpctl.SyclQueue is represented as a Numba opaque type that allows
@@ -16,25 +27,99 @@ class DpctlSyclQueue(types.Type):
1627
"""
1728

1829
def __init__(self):
19-
super().__init__(name="DpctlSyclQueueType")
30+
super(SyclQueueType, self).__init__(name="SyclQueue")
31+
32+
33+
# sycl_queue_type = SyclQueueType()
2034

2135

22-
sycl_queue_ty = DpctlSyclQueue()
36+
# @typeof_impl.register(SyclQueue)
37+
# def typeof_index(val, c):
38+
# return sycl_queue_type
39+
40+
41+
# as_numba_type.register(SyclQueue, sycl_queue_type)
2342

2443

2544
@type_callable(SyclQueue)
26-
def type_interval(context):
27-
def typer():
28-
return sycl_queue_ty
45+
def type_sycl_queue(context):
46+
def typer(args):
47+
if isinstance(args, types.Tuple):
48+
if len(args) > 0:
49+
if (
50+
isinstance(args[0], types.PyObject)
51+
and isinstance(args[1], types.StringLiteral)
52+
and isinstance(args[2], types.PyObject)
53+
):
54+
return SyclQueueType()
55+
else:
56+
return SyclQueueType()
57+
elif isinstance(args, types.NoneType):
58+
return SyclQueueType()
59+
else:
60+
raise ValueError("Couldn't do type inference for 'SycleQueue'.")
2961

3062
return typer
3163

3264

33-
@unbox(DpctlSyclQueue)
65+
# @lower_builtin(SyclQueue, types.PyObject, types.StringLiteral, types.PyObject)
66+
# def impl_interval(context, builder, sig, args):
67+
# typ = sig.return_type
68+
# if len(args) > 0:
69+
# ctx, dev, property = args
70+
# sycl_queue = cgutils.create_struct_proxy(typ)(context, builder)
71+
# sycl_queue.ctx = ctx
72+
# sycl_queue.dev = dev
73+
# sycl_queue.property = property
74+
# else:
75+
# sycl_queue = cgutils.create_struct_proxy(typ)(context, builder)
76+
# return sycl_queue._getvalue()
77+
78+
79+
@unbox(SyclQueueType)
3480
def unbox_sycl_queue(typ, obj, c):
35-
return NativeValue(obj)
81+
"""
82+
Convert a SyclQueue object to a native structure.
83+
"""
84+
qstruct = cgutils.create_struct_proxy(typ)(c.context, c.builder)
85+
qptr = qstruct._getpointer()
86+
ptr = c.builder.bitcast(qptr, c.pyapi.voidptr)
87+
if c.context.enable_nrt:
88+
dpexrtCtx = dpxrtc.DpexRTContext(c.context)
89+
errcode = dpexrtCtx.queuestruct_from_python(c.pyapi, obj, ptr)
90+
else:
91+
raise UnreachableError
92+
93+
is_error = cgutils.is_not_null(c.builder, errcode)
94+
# Handle error
95+
with c.builder.if_then(is_error, likely=False):
96+
c.pyapi.err_set_string(
97+
"PyExc_TypeError",
98+
"can't unbox array from PyObject into "
99+
"native value. The object maybe of a "
100+
"different type",
101+
)
102+
103+
return NativeValue(c.builder.load(qptr), is_error=is_error)
36104

37105

38-
@box(DpctlSyclQueue)
39-
def box_pyobject(typ, val, c):
40-
return val
106+
# @box(SyclQueueType)
107+
# def box_sycl_queue_(typ, val, c):
108+
# """
109+
# Convert a native interval structure to an Interval object.
110+
# """
111+
# sycl_queue = cgutils.create_struct_proxy(typ)(
112+
# c.context, c.builder, value=val
113+
# )
114+
# ctx_obj = sycl_queue.ctx
115+
# dev_obj = sycl_queue.dev
116+
# property_obj = sycl_queue.property
117+
# class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(SyclQueue))
118+
# res = c.pyapi.call_function_objargs(
119+
# class_obj, (ctx_obj, dev_obj, property_obj)
120+
# )
121+
# c.pyapi.decref(ctx_obj)
122+
# c.pyapi.decref(dev_obj)
123+
# c.pyapi.decref(property_obj)
124+
# c.pyapi.decref(class_obj)
125+
# return res

numba_dpex/core/types/dpnp_ndarray_type.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,8 @@ def unbox_dpnp_nd_array(typ, obj, c):
264264
# potential memory corruption
265265
#
266266
# --------------- End of Numba comment from @ubox(types.Array)
267-
nativearycls = c.context.make_array(typ)
267+
268+
nativearycls = c.context.make_array(typ) # make_array is in numba.core.base
268269
nativeary = nativearycls(c.context, c.builder)
269270
aryptr = nativeary._getpointer()
270271

numba_dpex/core/typing/typeof.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from numba_dpex.utils import address_space
1313

14-
from ..types.dpctl_types import sycl_queue_ty
14+
from ..types.dpctl_types import SyclQueueType
1515
from ..types.dpnp_ndarray_type import DpnpNdArray
1616
from ..types.usm_ndarray_type import USMNdArray
1717

@@ -107,4 +107,6 @@ def typeof_dpctl_sycl_queue(val, c):
107107
108108
Returns: A numba_dpex.core.types.dpctl_types.DpctlSyclQueue instance.
109109
"""
110-
return sycl_queue_ty
110+
# return sycl_queue_type
111+
# return _typeof_helper(val, SyclQueueType)
112+
return SyclQueueType()

0 commit comments

Comments
 (0)