Skip to content

Commit f16d6d1

Browse files
Merge pull request #827 from IntelPython/add-default-context
Add default context
2 parents 7a86fbf + 534fb61 commit f16d6d1

9 files changed

+110
-3
lines changed

dpctl/_backend.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,8 @@ cdef extern from "syclinterface/dpctl_sycl_platform_interface.h":
278278
cdef const char *DPCTLPlatform_GetVendor(const DPCTLSyclPlatformRef)
279279
cdef const char *DPCTLPlatform_GetVersion(const DPCTLSyclPlatformRef)
280280
cdef DPCTLPlatformVectorRef DPCTLPlatform_GetPlatforms()
281+
cdef DPCTLSyclContextRef DPCTLPlatform_GetDefaultContext(
282+
const DPCTLSyclPlatformRef)
281283

282284

283285
cdef extern from "syclinterface/dpctl_sycl_context_interface.h":

dpctl/_sycl_device.pyx

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ from ._backend cimport ( # noqa: E211
5050
DPCTLDevice_GetMaxWriteImageArgs,
5151
DPCTLDevice_GetName,
5252
DPCTLDevice_GetParentDevice,
53+
DPCTLDevice_GetPlatform,
5354
DPCTLDevice_GetPreferredVectorWidthChar,
5455
DPCTLDevice_GetPreferredVectorWidthDouble,
5556
DPCTLDevice_GetPreferredVectorWidthFloat,
@@ -80,6 +81,7 @@ from ._backend cimport ( # noqa: E211
8081
DPCTLSize_t_Array_Delete,
8182
DPCTLSyclDeviceRef,
8283
DPCTLSyclDeviceSelectorRef,
84+
DPCTLSyclPlatformRef,
8385
_aspect_type,
8486
_backend_type,
8587
_device_type,
@@ -91,6 +93,8 @@ from .enum_types import backend_type, device_type
9193
from libc.stdint cimport int64_t, uint32_t
9294
from libc.stdlib cimport free, malloc
9395

96+
from ._sycl_platform cimport SyclPlatform
97+
9498
import collections
9599
import warnings
96100

@@ -639,6 +643,22 @@ cdef class SyclDevice(_SyclDevice):
639643
self._device_ref
640644
)
641645

646+
@property
647+
def sycl_platform(self):
648+
""" Returns the platform associated with this device.
649+
650+
Returns:
651+
:class:`dpctl.SyclPlatform`: The platform associated with this
652+
device.
653+
"""
654+
cdef DPCTLSyclPlatformRef PRef = (
655+
DPCTLDevice_GetPlatform(self._device_ref)
656+
)
657+
if (PRef == NULL):
658+
raise RuntimeError("Could not get platform for device.")
659+
else:
660+
return SyclPlatform._create(PRef)
661+
642662
@property
643663
def preferred_vector_width_char(self):
644664
""" Returns the preferred native vector width size for built-in scalar

dpctl/_sycl_platform.pyx

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ from ._backend cimport ( # noqa: E211
3030
DPCTLPlatform_CreateFromSelector,
3131
DPCTLPlatform_Delete,
3232
DPCTLPlatform_GetBackend,
33+
DPCTLPlatform_GetDefaultContext,
3334
DPCTLPlatform_GetName,
3435
DPCTLPlatform_GetPlatforms,
3536
DPCTLPlatform_GetVendor,
@@ -40,15 +41,19 @@ from ._backend cimport ( # noqa: E211
4041
DPCTLPlatformVector_GetAt,
4142
DPCTLPlatformVector_Size,
4243
DPCTLPlatformVectorRef,
44+
DPCTLSyclContextRef,
4345
DPCTLSyclDeviceSelectorRef,
4446
DPCTLSyclPlatformRef,
4547
_backend_type,
4648
)
4749

4850
import warnings
4951

52+
from ._sycl_context import SyclContextCreationError
5053
from .enum_types import backend_type
5154

55+
from ._sycl_context cimport SyclContext
56+
5257
__all__ = [
5358
"get_platforms",
5459
"lsplatform",
@@ -236,10 +241,10 @@ cdef class SyclPlatform(_SyclPlatform):
236241

237242
@property
238243
def backend(self):
239-
"""Returns the backend_type enum value for this device
244+
"""Returns the backend_type enum value for this platform
240245
241246
Returns:
242-
backend_type: The backend for the device.
247+
backend_type: The backend for the platform.
243248
"""
244249
cdef _backend_type BTy = (
245250
DPCTLPlatform_GetBackend(self._platform_ref)
@@ -255,6 +260,22 @@ cdef class SyclPlatform(_SyclPlatform):
255260
else:
256261
raise ValueError("Unknown backend type.")
257262

263+
@property
264+
def default_context(self):
265+
"""Returns the default platform context for this platform
266+
267+
Returns:
268+
SyclContext: The default context for the platform.
269+
"""
270+
cdef DPCTLSyclContextRef CRef = (
271+
DPCTLPlatform_GetDefaultContext(self._platform_ref)
272+
)
273+
274+
if (CRef == NULL):
275+
raise
276+
else:
277+
return SyclContext._create(CRef)
278+
258279

259280
def lsplatform(verbosity=0):
260281
"""

dpctl/tests/test_sycl_device.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,11 @@ def check_profiling_timer_resolution(device):
496496
assert isinstance(resol, int) and resol > 0
497497

498498

499+
def check_platform(device):
500+
p = device.sycl_platform
501+
assert isinstance(p, dpctl.SyclPlatform)
502+
503+
499504
list_of_checks = [
500505
check_get_max_compute_units,
501506
check_get_max_work_item_dims,
@@ -552,6 +557,8 @@ def check_profiling_timer_resolution(device):
552557
check_repr,
553558
check_get_global_mem_size,
554559
check_get_local_mem_size,
560+
check_profiling_timer_resolution,
561+
check_platform,
555562
]
556563

557564

dpctl/tests/test_sycl_platform.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,11 @@ def check_repr(platform):
8787
assert r != ""
8888

8989

90+
def check_default_context(platform):
91+
r = platform.default_context
92+
assert type(r) is dpctl.SyclContext
93+
94+
9095
list_of_checks = [
9196
check_name,
9297
check_vendor,

libsyclinterface/include/dpctl_sycl_platform_interface.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,4 +142,16 @@ DPCTLPlatform_GetVersion(__dpctl_keep const DPCTLSyclPlatformRef PRef);
142142
DPCTL_API
143143
__dpctl_give DPCTLPlatformVectorRef DPCTLPlatform_GetPlatforms(void);
144144

145+
/*!
146+
* @brief Returns a DPCTLSyclContextRef for default platform context.
147+
*
148+
* @param PRef Opaque pointer to a sycl::platform
149+
* @return A DPCTLSyclContextRef value for the default platform associated
150+
* with this platform.
151+
* @ingroup PlatformInterface
152+
*/
153+
DPCTL_API
154+
__dpctl_give DPCTLSyclContextRef
155+
DPCTLPlatform_GetDefaultContext(__dpctl_keep const DPCTLSyclPlatformRef PRef);
156+
145157
DPCTL_C_EXTERN_C_END

libsyclinterface/source/dpctl_sycl_platform_interface.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ using namespace cl::sycl;
4141
namespace
4242
{
4343
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(platform, DPCTLSyclPlatformRef);
44+
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(context, DPCTLSyclContextRef);
4445
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(device_selector, DPCTLSyclDeviceSelectorRef);
4546
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(std::vector<DPCTLSyclPlatformRef>,
4647
DPCTLPlatformVectorRef);
@@ -202,3 +203,19 @@ __dpctl_give DPCTLPlatformVectorRef DPCTLPlatform_GetPlatforms()
202203
// the wrap function is defined inside dpctl_vector_templ.cpp
203204
return wrap(Platforms);
204205
}
206+
207+
__dpctl_give DPCTLSyclContextRef
208+
DPCTLPlatform_GetDefaultContext(__dpctl_keep const DPCTLSyclPlatformRef PRef)
209+
{
210+
auto P = unwrap(PRef);
211+
if (P) {
212+
auto default_ctx = P->ext_oneapi_get_default_context();
213+
return wrap(new context(default_ctx));
214+
}
215+
else {
216+
error_handler(
217+
"Default platform cannot be obtained up for a NULL platform.",
218+
__FILE__, __func__, __LINE__);
219+
return nullptr;
220+
}
221+
}

libsyclinterface/tests/test_sycl_platform_interface.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
//===----------------------------------------------------------------------===//
2626

2727
#include "Support/CBindingWrapping.h"
28+
#include "dpctl_sycl_context_interface.h"
2829
#include "dpctl_sycl_device_selector_interface.h"
2930
#include "dpctl_sycl_platform_interface.h"
3031
#include "dpctl_sycl_platform_manager.h"
@@ -82,6 +83,16 @@ void check_platform_backend(__dpctl_keep const DPCTLSyclPlatformRef PRef)
8283
}());
8384
}
8485

86+
void check_platform_default_context(
87+
__dpctl_keep const DPCTLSyclPlatformRef PRef)
88+
{
89+
DPCTLSyclContextRef CRef = nullptr;
90+
EXPECT_NO_FATAL_FAILURE(CRef = DPCTLPlatform_GetDefaultContext(PRef));
91+
EXPECT_TRUE(CRef != nullptr);
92+
93+
EXPECT_NO_FATAL_FAILURE(DPCTLContext_Delete(CRef));
94+
}
95+
8596
} // namespace
8697

8798
struct TestDPCTLSyclPlatformInterface
@@ -167,6 +178,14 @@ TEST_F(TestDPCTLSyclPlatformNull, ChkGetVersion)
167178
ASSERT_TRUE(version == nullptr);
168179
}
169180

181+
TEST_F(TestDPCTLSyclPlatformNull, ChkGetDefaultConext)
182+
{
183+
DPCTLSyclContextRef CRef = nullptr;
184+
185+
EXPECT_NO_FATAL_FAILURE(CRef = DPCTLPlatform_GetDefaultContext(NullPRef));
186+
EXPECT_TRUE(CRef == nullptr);
187+
}
188+
170189
struct TestDPCTLSyclDefaultPlatform : public ::testing::Test
171190
{
172191
DPCTLSyclPlatformRef PRef = nullptr;
@@ -207,6 +226,11 @@ TEST_P(TestDPCTLSyclPlatformInterface, ChkGetBackend)
207226
check_platform_backend(PRef);
208227
}
209228

229+
TEST_P(TestDPCTLSyclPlatformInterface, ChkGetDefaultContext)
230+
{
231+
check_platform_default_context(PRef);
232+
}
233+
210234
TEST_P(TestDPCTLSyclPlatformInterface, ChkCopy)
211235
{
212236
DPCTLSyclPlatformRef Copied_PRef = nullptr;

libsyclinterface/tests/test_sycl_queue_interface.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,6 @@ TEST_P(TestDPCTLQueueMemberFunctions, CheckMemset)
446446

447447
ASSERT_NO_FATAL_FAILURE(DPCTLfree_with_queue(p, QRef));
448448

449-
bool equal = true;
450449
for (size_t i = 0; i < nbytes; ++i) {
451450
ASSERT_TRUE(host_arr[i] == val);
452451
}

0 commit comments

Comments
 (0)