Skip to content

Commit c73190c

Browse files
Implemented SyclKernel.max_sub_group_size
1 parent 9259083 commit c73190c

File tree

3 files changed

+5
-7
lines changed

3 files changed

+5
-7
lines changed

dpctl/_backend.pxd

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,7 @@ cdef extern from "syclinterface/dpctl_sycl_kernel_interface.h":
271271
cdef size_t DPCTLKernel_GetPreferredWorkGroupSizeMultiple(const DPCTLSyclKernelRef KRef)
272272
cdef size_t DPCTLKernel_GetPrivateMemSize(const DPCTLSyclKernelRef KRef)
273273
cdef uint32_t DPCTLKernel_GetMaxNumSubGroups(const DPCTLSyclKernelRef KRef)
274-
## Next line is commented out due to issue in DPC++ runtime
275-
# cdef uint32_t DPCTLKernel_GetMaxSubGroupSize(const DPCTLSyclKernelRef KRef)
274+
cdef uint32_t DPCTLKernel_GetMaxSubGroupSize(const DPCTLSyclKernelRef KRef)
276275
cdef uint32_t DPCTLKernel_GetCompileNumSubGroups(const DPCTLSyclKernelRef KRef)
277276
cdef uint32_t DPCTLKernel_GetCompileSubGroupSize(const DPCTLSyclKernelRef KRef)
278277

dpctl/program/_program.pyx

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ from dpctl._backend cimport ( # noqa: E211, E402;
3434
DPCTLKernel_GetCompileNumSubGroups,
3535
DPCTLKernel_GetCompileSubGroupSize,
3636
DPCTLKernel_GetMaxNumSubGroups,
37+
DPCTLKernel_GetMaxSubGroupSize,
3738
DPCTLKernel_GetNumArgs,
3839
DPCTLKernel_GetPreferredWorkGroupSizeMultiple,
3940
DPCTLKernel_GetPrivateMemSize,
@@ -146,8 +147,8 @@ cdef class SyclKernel:
146147
def max_sub_group_size(self):
147148
""" Returns the maximum sub-groups size for this kernel.
148149
"""
149-
cdef uint32_t sz = 0
150-
return NotImplemented
150+
cdef uint32_t sz = DPCTLKernel_GetMaxSubGroupSize(self._kernel_ref)
151+
return sz
151152

152153
@property
153154
def compile_num_sub_groups(self):

dpctl/tests/test_sycl_program.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,7 @@ def _check_multi_kernel_program(prog):
180180
vmnsg = krn.max_num_sub_groups
181181
assert type(vmnsg) is int
182182
v = krn.max_sub_group_size
183-
assert (
184-
v == NotImplemented
185-
), "SyclKernel.max_sub_group_size acquired implementation, fix the test"
183+
assert type(v) is int
186184
cmnsg = krn.compile_num_sub_groups
187185
assert type(cmnsg) is int
188186
cmsgsz = krn.compile_sub_group_size

0 commit comments

Comments
 (0)