Skip to content

Commit f7b335d

Browse files
Merge pull request #589 from IntelPython/device_memory_sizes
Device memory sizes
2 parents 002d24b + 08da9e2 commit f7b335d

File tree

7 files changed

+105
-6
lines changed

7 files changed

+105
-6
lines changed

dpctl-capi/include/dpctl_sycl_device_interface.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,28 @@ DPCTL_API
177177
uint32_t
178178
DPCTLDevice_GetMaxComputeUnits(__dpctl_keep const DPCTLSyclDeviceRef DRef);
179179

180+
/*!
181+
* @brief Wrapper over device.get_info<info::device::global_mem_size>().
182+
*
183+
* @param DRef Opaque pointer to a ``sycl::device``
184+
* @return Returns the valid result if device exists else returns 0.
185+
* @ingroup DeviceInterface
186+
*/
187+
DPCTL_API
188+
uint64_t
189+
DPCTLDevice_GetGlobalMemSize(__dpctl_keep const DPCTLSyclDeviceRef DRef);
190+
191+
/*!
192+
* @brief Wrapper over device.get_info<info::device::local_mem_size>().
193+
*
194+
* @param DRef Opaque pointer to a ``sycl::device``
195+
* @return Returns the valid result if device exists else returns 0.
196+
* @ingroup DeviceInterface
197+
*/
198+
DPCTL_API
199+
uint64_t
200+
DPCTLDevice_GetLocalMemSize(__dpctl_keep const DPCTLSyclDeviceRef DRef);
201+
180202
/*!
181203
* @brief Wrapper for get_info<info::device::max_work_item_dimensions>().
182204
*

dpctl-capi/source/dpctl_sycl_device_interface.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,37 @@ DPCTLDevice_GetMaxComputeUnits(__dpctl_keep const DPCTLSyclDeviceRef DRef)
186186
return nComputeUnits;
187187
}
188188

189+
uint64_t
190+
DPCTLDevice_GetGlobalMemSize(__dpctl_keep const DPCTLSyclDeviceRef DRef)
191+
{
192+
uint64_t GlobalMemSize = 0;
193+
auto D = unwrap(DRef);
194+
if (D) {
195+
try {
196+
GlobalMemSize = D->get_info<info::device::global_mem_size>();
197+
} catch (runtime_error const &re) {
198+
// \todo log error
199+
std::cerr << re.what() << '\n';
200+
}
201+
}
202+
return GlobalMemSize;
203+
}
204+
205+
uint64_t DPCTLDevice_GetLocalMemSize(__dpctl_keep const DPCTLSyclDeviceRef DRef)
206+
{
207+
uint64_t LocalMemSize = 0;
208+
auto D = unwrap(DRef);
209+
if (D) {
210+
try {
211+
LocalMemSize = D->get_info<info::device::local_mem_size>();
212+
} catch (runtime_error const &re) {
213+
// \todo log error
214+
std::cerr << re.what() << '\n';
215+
}
216+
}
217+
return LocalMemSize;
218+
}
219+
189220
uint32_t
190221
DPCTLDevice_GetMaxWorkItemDims(__dpctl_keep const DPCTLSyclDeviceRef DRef)
191222
{

dpctl-capi/tests/test_sycl_device_interface.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,20 @@ TEST_P(TestDPCTLSyclDeviceInterface, ChkGetMaxComputeUnits)
132132
EXPECT_TRUE(n > 0);
133133
}
134134

135+
TEST_P(TestDPCTLSyclDeviceInterface, ChkGetGlobalMemSize)
136+
{
137+
size_t gm_sz = 0;
138+
EXPECT_NO_FATAL_FAILURE(gm_sz = DPCTLDevice_GetGlobalMemSize(DRef));
139+
EXPECT_TRUE(gm_sz > 0);
140+
}
141+
142+
TEST_P(TestDPCTLSyclDeviceInterface, ChkGetLocalMemSize)
143+
{
144+
size_t lm_sz = 0;
145+
EXPECT_NO_FATAL_FAILURE(lm_sz = DPCTLDevice_GetLocalMemSize(DRef));
146+
EXPECT_TRUE(lm_sz > 0);
147+
}
148+
135149
TEST_P(TestDPCTLSyclDeviceInterface, ChkGetMaxWorkItemDims)
136150
{
137151
size_t n = 0;

dpctl/_backend.pxd

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
types defined by dpctl's C API.
2222
"""
2323

24-
from libc.stdint cimport int64_t, uint32_t, uint64_t
24+
from libc.stdint cimport int64_t, uint32_t
2525
from libcpp cimport bool
2626

2727

@@ -149,6 +149,8 @@ cdef extern from "dpctl_sycl_device_interface.h":
149149
cdef _backend_type DPCTLDevice_GetBackend(const DPCTLSyclDeviceRef)
150150
cdef _device_type DPCTLDevice_GetDeviceType(const DPCTLSyclDeviceRef)
151151
cdef const char *DPCTLDevice_GetDriverVersion(const DPCTLSyclDeviceRef DRef)
152+
cdef size_t DPCTLDevice_GetGlobalMemSize(const DPCTLSyclDeviceRef DRef)
153+
cdef size_t DPCTLDevice_GetLocalMemSize(const DPCTLSyclDeviceRef DRef)
152154
cdef uint32_t DPCTLDevice_GetMaxComputeUnits(const DPCTLSyclDeviceRef DRef)
153155
cdef uint32_t DPCTLDevice_GetMaxNumSubGroups(const DPCTLSyclDeviceRef DRef)
154156
cdef size_t DPCTLDevice_GetMaxWorkGroupSize(const DPCTLSyclDeviceRef DRef)
@@ -239,9 +241,9 @@ cdef extern from "dpctl_sycl_event_interface.h":
239241
size_t index)
240242
cdef DPCTLEventVectorRef DPCTLEvent_GetWaitList(
241243
DPCTLSyclEventRef ERef)
242-
cdef uint64_t DPCTLEvent_GetProfilingInfoSubmit(DPCTLSyclEventRef ERef)
243-
cdef uint64_t DPCTLEvent_GetProfilingInfoStart(DPCTLSyclEventRef ERef)
244-
cdef uint64_t DPCTLEvent_GetProfilingInfoEnd(DPCTLSyclEventRef ERef)
244+
cdef size_t DPCTLEvent_GetProfilingInfoSubmit(DPCTLSyclEventRef ERef)
245+
cdef size_t DPCTLEvent_GetProfilingInfoStart(DPCTLSyclEventRef ERef)
246+
cdef size_t DPCTLEvent_GetProfilingInfoEnd(DPCTLSyclEventRef ERef)
245247

246248

247249
cdef extern from "dpctl_sycl_kernel_interface.h":

dpctl/_sycl_device.pyx

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,13 @@ from ._backend cimport ( # noqa: E211
3434
DPCTLDevice_GetBackend,
3535
DPCTLDevice_GetDeviceType,
3636
DPCTLDevice_GetDriverVersion,
37+
DPCTLDevice_GetGlobalMemSize,
3738
DPCTLDevice_GetImage2dMaxHeight,
3839
DPCTLDevice_GetImage2dMaxWidth,
3940
DPCTLDevice_GetImage3dMaxDepth,
4041
DPCTLDevice_GetImage3dMaxHeight,
4142
DPCTLDevice_GetImage3dMaxWidth,
43+
DPCTLDevice_GetLocalMemSize,
4244
DPCTLDevice_GetMaxComputeUnits,
4345
DPCTLDevice_GetMaxNumSubGroups,
4446
DPCTLDevice_GetMaxReadImageArgs,
@@ -656,6 +658,22 @@ cdef class SyclDevice(_SyclDevice):
656658
"""
657659
return DPCTLDevice_GetPreferredVectorWidthHalf(self._device_ref)
658660

661+
@property
662+
def global_mem_size(self):
663+
""" Returns the size of global memory on this device in bytes.
664+
"""
665+
cdef size_t global_mem_size = 0
666+
global_mem_size = DPCTLDevice_GetGlobalMemSize(self._device_ref)
667+
return global_mem_size
668+
669+
@property
670+
def local_mem_size(self):
671+
""" Returns the size of local memory on this device in bytes.
672+
"""
673+
cdef size_t local_mem_size = 0
674+
local_mem_size = DPCTLDevice_GetLocalMemSize(self._device_ref)
675+
return local_mem_size
676+
659677
@property
660678
def vendor(self):
661679
""" Returns the device vendor name as a string.

dpctl/tests/test_sycl_device.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,16 @@ def check_get_max_compute_units(device):
6161
assert max_compute_units > 0
6262

6363

64+
def check_get_global_mem_size(device):
65+
global_mem_size = device.global_mem_size
66+
assert global_mem_size > 0
67+
68+
69+
def check_get_local_mem_size(device):
70+
local_mem_size = device.local_mem_size
71+
assert local_mem_size > 0
72+
73+
6474
def check_get_max_work_item_dims(device):
6575
max_work_item_dims = device.max_work_item_dims
6676
assert max_work_item_dims > 0
@@ -529,6 +539,8 @@ def check_repr(device):
529539
check_create_sub_devices_by_affinity_next_partitionable,
530540
check_print_device_info,
531541
check_repr,
542+
check_get_global_mem_size,
543+
check_get_local_mem_size,
532544
]
533545

534546

dpctl/tests/test_sycl_kernel_submit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def test_create_program_from_source(ctype_str, dtype, ctypes_ctor):
8787
q.submit(axpyKernel, args, r).wait()
8888
ref_c = a * np.array(d, dtype=dtype) + b
8989
host_dt, device_dt = timer.dt
90-
assert host_dt > device_dt
90+
assert type(host_dt) is float and type(device_dt) is float
9191
assert np.allclose(c, ref_c), "Failed for {}".format(r)
9292

9393
for gr, lr in (
@@ -106,5 +106,5 @@ def test_create_program_from_source(ctype_str, dtype, ctypes_ctor):
106106
q.submit(axpyKernel, args, gr, lr, [dpctl.SyclEvent()]).wait()
107107
ref_c = a * np.array(d, dtype=dtype) + b
108108
host_dt, device_dt = timer.dt
109-
assert host_dt > device_dt
109+
assert type(host_dt) is float and type(device_dt) is float
110110
assert np.allclose(c, ref_c), "Faled for {}, {}".formatg(r, lr)

0 commit comments

Comments
 (0)