Skip to content

Commit 4a19aa2

Browse files
committed
rewrite C extension test
1 parent 3032038 commit 4a19aa2

File tree

3 files changed

+21
-21
lines changed

3 files changed

+21
-21
lines changed

dpctl/tests/_c_ext.c

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
#include "dpctl_capi.h"
3030
// clang-format on
3131

32-
PyObject *py_is_usm_ndarray(PyObject *self_unused, PyObject *args)
32+
PyObject *py_is_sycl_queue(PyObject *self_unused, PyObject *args)
3333
{
3434
PyObject *arg = NULL;
3535
PyObject *res = NULL;
@@ -43,7 +43,7 @@ PyObject *py_is_usm_ndarray(PyObject *self_unused, PyObject *args)
4343
return NULL;
4444
}
4545

46-
check = PyObject_TypeCheck(arg, &PyUSMArrayType);
46+
check = PyObject_TypeCheck(arg, &PySyclQueueType);
4747
if (check == -1) {
4848
PyErr_SetString(PyExc_RuntimeError, "Type check failed");
4949
return NULL;
@@ -55,35 +55,36 @@ PyObject *py_is_usm_ndarray(PyObject *self_unused, PyObject *args)
5555
return res;
5656
}
5757

58-
PyObject *py_usm_ndarray_ndim(PyObject *self_unused, PyObject *args)
58+
PyObject *py_check_queue_ref(PyObject *self_unused, PyObject *args)
5959
{
6060
PyObject *arg = NULL;
61-
struct PyUSMArrayObject *array = NULL;
6261
PyObject *res = NULL;
6362
int status = -1;
64-
int ndim = -1;
63+
struct PySyclQueueObject *q_obj = NULL;
64+
DPCTLSyclQueueRef qref = NULL;
6565

6666
(void)(self_unused); // avoid unused arguments warning
67-
status = PyArg_ParseTuple(args, "O!", &PyUSMArrayType, &arg);
67+
status = PyArg_ParseTuple(args, "O!", &PySyclQueueType, &arg);
6868
if (!status) {
69-
PyErr_SetString(
70-
PyExc_TypeError,
71-
"Expecting single argument of type dpctl.tensor.usm_ndarray");
69+
PyErr_SetString(PyExc_TypeError,
70+
"Expecting single argument of type dpctl.SyclQueue");
7271
return NULL;
7372
}
7473

75-
array = (struct PyUSMArrayObject *)arg;
76-
ndim = UsmNDArray_GetNDim(array);
74+
q_obj = (struct PySyclQueueObject *)arg;
75+
qref = SyclQueue_GetQueueRef((struct PySyclQueueObject *)arg);
76+
77+
res = (qref != NULL) ? Py_True : Py_False;
78+
Py_INCREF(res);
7779

78-
res = PyLong_FromLong(ndim);
7980
return res;
8081
}
8182

8283
static PyMethodDef CExtMethods[] = {
83-
{"is_usm_ndarray", py_is_usm_ndarray, METH_VARARGS,
84-
"Checks if input object is an usm_ndarray instance"},
85-
{"usm_ndarray_ndim", py_usm_ndarray_ndim, METH_VARARGS,
86-
"Get ndim property of an usm_ndarray instance"},
84+
{"is_sycl_queue", py_is_sycl_queue, METH_VARARGS,
85+
"Checks if input object is a dpctl.SyclQueue instance"},
86+
{"check_queue_ref", py_check_queue_ref, METH_VARARGS,
87+
"Checks that queue ref obtained via C-API is not NULL"},
8788
{NULL, NULL, 0, NULL} /* Sentinel */
8889
};
8990

dpctl/tests/test_headers.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import pytest
22

33
import dpctl
4-
import dpctl.tensor as dpt
54

65

76
@pytest.fixture(scope="session")
@@ -44,9 +43,9 @@ def dpctl_c_extension(tmp_path_factory):
4443

4544
def test_c_headers(dpctl_c_extension):
4645
try:
47-
x = dpt.empty(10)
46+
q = dpctl.SyclQueue()
4847
except (dpctl.SyclDeviceCreationError, dpctl.SyclQueueCreationError):
4948
pytest.skip()
5049

51-
assert dpctl_c_extension.is_usm_ndarray(x)
52-
assert dpctl_c_extension.usm_ndarray_ndim(x) == x.ndim
50+
assert dpctl_c_extension.is_sycl_queue(q)
51+
assert dpctl_c_extension.check_queue_ref(q)

dpctl/tests/test_sycl_queue.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def test_cpython_api_SyclQueue_GetQueueRef():
244244
try:
245245
q = dpctl.SyclQueue()
246246
except dpctl.SyclQueueCreationError:
247-
pytest.skip("Can not defaul-construct SyclQueue")
247+
pytest.skip("Can not default-construct SyclQueue")
248248
mod = sys.modules[q.__class__.__module__]
249249
# get capsule storign SyclQueue_GetQueueRef function ptr
250250
q_ref_fn_cap = mod.__pyx_capi__["SyclQueue_GetQueueRef"]

0 commit comments

Comments
 (0)