Skip to content

Feature/add linear indexing overload #1368

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ItemType,
NdItemType,
)
from numba_dpex.kernel_api import Group, Item, NdItem
from numba_dpex.kernel_api_impl.spirv.target import SPIRVTargetContext

from ..target import DPEX_KERNEL_EXP_TARGET_NAME
Expand Down Expand Up @@ -180,26 +181,31 @@ def ol_item_get_index_impl(item, dim):
return ol_item_gen_index


_index_const_overload_methods = [
(ItemType, "get_id", _intrinsic_spirv_global_invocation_id),
(ItemType, "get_range", _intrinsic_spirv_global_size),
(NdItemType, "get_global_id", _intrinsic_spirv_global_invocation_id),
(NdItemType, "get_local_id", _intrinsic_spirv_local_invocation_id),
(NdItemType, "get_global_range", _intrinsic_spirv_global_size),
(NdItemType, "get_local_range", _intrinsic_spirv_workgroup_size),
(GroupType, "get_group_id", _intrinsic_spirv_workgroup_id),
(GroupType, "get_group_range", _intrinsic_spirv_numworkgroups),
(GroupType, "get_local_range", _intrinsic_spirv_workgroup_size),
]
def register_index_const_methods():
"""Register indexing related methods that can be defined as spirv const."""
_index_const_overload_methods = [
(ItemType, "get_id", _intrinsic_spirv_global_invocation_id),
(ItemType, "get_range", _intrinsic_spirv_global_size),
(NdItemType, "get_global_id", _intrinsic_spirv_global_invocation_id),
(NdItemType, "get_local_id", _intrinsic_spirv_local_invocation_id),
(NdItemType, "get_global_range", _intrinsic_spirv_global_size),
(NdItemType, "get_local_range", _intrinsic_spirv_workgroup_size),
(GroupType, "get_group_id", _intrinsic_spirv_workgroup_id),
(GroupType, "get_group_range", _intrinsic_spirv_numworkgroups),
(GroupType, "get_local_range", _intrinsic_spirv_workgroup_size),
]

for index_overload in _index_const_overload_methods:
_type, method, _intrinsic = index_overload
for index_overload in _index_const_overload_methods:
_type, method, _intrinsic = index_overload

ol_index_func = generate_index_overload(_type, _intrinsic)
ol_index_func = generate_index_overload(_type, _intrinsic)

overload_method(_type, method, target=DPEX_KERNEL_EXP_TARGET_NAME)(
ol_index_func
)
overload_method(_type, method, target=DPEX_KERNEL_EXP_TARGET_NAME)(
ol_index_func
)


register_index_const_methods()


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
Expand Down Expand Up @@ -269,3 +275,37 @@ def ol_nd_item_get_group_impl(item):
return dimensions

return ol_nd_item_get_group_impl


def _generate_method_overload(method):
"""Generates naive method overload with no argument, except self."""

def ol_method(self): # pylint: disable=unused-argument
return method

return ol_method


def register_jitable_method(type_, method):
"""
Register a regular python method that can be executed by the
python interpreter and can be compiled into a nopython
function when referenced by other jit'ed functions.

Same as register_jitable, but for methods with no arguments.
"""
overloaded_method = _generate_method_overload(method)
overload_method(type_, method.__name__, target=DPEX_KERNEL_EXP_TARGET_NAME)(
overloaded_method
)


register_jitable_method(ItemType, Item.get_linear_id)
register_jitable_method(ItemType, Item.get_linear_range)
register_jitable_method(NdItemType, NdItem.get_global_linear_id)
register_jitable_method(NdItemType, NdItem.get_global_linear_range)
register_jitable_method(NdItemType, NdItem.get_local_linear_range)
register_jitable_method(NdItemType, NdItem.get_local_linear_id)
register_jitable_method(GroupType, Group.get_group_linear_id)
register_jitable_method(GroupType, Group.get_group_linear_range)
register_jitable_method(GroupType, Group.get_local_linear_range)
104 changes: 84 additions & 20 deletions numba_dpex/kernel_api/index_space_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,20 @@ def get_group_id(self, dim):

def get_group_linear_id(self):
"""Returns a linearized version of the work-group index."""
if len(self._index) == 1:
return self._index[0]
if len(self._index) == 2:
return self._index[0] * self._group_range[1] + self._index[1]
if self.dimensions == 1:
return self.get_group_id(0)
if self.dimensions == 2:
return self.get_group_id(0) * self.get_group_range(
1
) + self.get_group_id(1)
return (
(self._index[0] * self._group_range[1] * self._group_range[2])
+ (self._index[1] * self._group_range[2])
+ (self._index[2])
(
self.get_group_id(0)
* self.get_group_range(1)
* self.get_group_range(2)
)
+ (self.get_group_id(1) * self.get_group_range(2))
+ (self.get_group_id(2))
)

def get_group_range(self, dim):
Expand All @@ -61,8 +67,8 @@ def get_group_range(self, dim):
def get_group_linear_range(self):
"""Return the total number of work-groups in the nd_range."""
num_wg = 1
for ext in self._group_range:
num_wg *= ext
for i in range(self.dimensions):
num_wg *= self.get_group_range(i)

return num_wg

Expand All @@ -76,8 +82,8 @@ def get_local_range(self, dim):
def get_local_linear_range(self):
"""Return the total number of work-items in the work-group."""
num_wi = 1
for ext in self._local_range:
num_wi *= ext
for i in range(self.dimensions):
num_wi *= self.get_local_range(i)

return num_wi

Expand Down Expand Up @@ -128,14 +134,14 @@ def get_linear_id(self):
Returns:
int: The linear id.
"""
if len(self._extent) == 1:
return self._index[0]
if len(self._extent) == 2:
return self._index[0] * self._extent[1] + self._index[1]
if self.dimensions == 1:
return self.get_id(0)
if self.dimensions == 2:
return self.get_id(0) * self.get_range(1) + self.get_id(1)
return (
(self._index[0] * self._extent[1] * self._extent[2])
+ (self._index[1] * self._extent[2])
+ (self._index[2])
(self.get_id(0) * self.get_range(1) * self.get_range(2))
+ (self.get_id(1) * self.get_range(2))
+ (self.get_id(2))
)

def get_id(self, idx):
Expand All @@ -146,6 +152,14 @@ def get_id(self, idx):
"""
return self._index[idx]

def get_linear_range(self):
"""Return the total number of work-items in the work-group."""
num_wi = 1
for i in range(self.dimensions):
num_wi *= self.get_range(i)

return num_wi

def get_range(self, idx):
"""Get the range size for a specific dimension.

Expand Down Expand Up @@ -193,7 +207,24 @@ def get_global_linear_id(self):
Returns:
int: The global linear id.
"""
return self._global_item.get_linear_id()
# Instead of calling self._global_item.get_linear_id(), the linearization
# logic is duplicated here so that the method can be JIT compiled by
# numba-dpex and works in both Python and Numba nopython modes.
if self.dimensions == 1:
return self.get_global_id(0)
if self.dimensions == 2:
return self.get_global_id(0) * self.get_global_range(
1
) + self.get_global_id(1)
return (
(
self.get_global_id(0)
* self.get_global_range(1)
* self.get_global_range(2)
)
+ (self.get_global_id(1) * self.get_global_range(2))
+ (self.get_global_id(2))
)

def get_local_id(self, idx):
"""Get the local id for a specific dimension.
Expand All @@ -210,7 +241,24 @@ def get_local_linear_id(self):
Returns:
int: The local linear id.
"""
return self._local_item.get_linear_id()
# Instead of calling self._local_item.get_linear_id(), the linearization
# logic is duplicated here so that the method can be JIT compiled by
# numba-dpex and works in both Python and Numba nopython modes.
if self.dimensions == 1:
return self.get_local_id(0)
if self.dimensions == 2:
return self.get_local_id(0) * self.get_local_range(
1
) + self.get_local_id(1)
return (
(
self.get_local_id(0)
* self.get_local_range(1)
* self.get_local_range(2)
)
+ (self.get_local_id(1) * self.get_local_range(2))
+ (self.get_local_id(2))
)

def get_global_range(self, idx):
"""Get the global range size for a specific dimension.
Expand All @@ -228,6 +276,22 @@ def get_local_range(self, idx):
"""
return self._local_item.get_range(idx)

def get_local_linear_range(self):
"""Return the total number of work-items in the work-group."""
num_wi = 1
for i in range(self.dimensions):
num_wi *= self.get_local_range(i)

return num_wi

def get_global_linear_range(self):
"""Return the total number of work-items in the work-group."""
num_wi = 1
for i in range(self.dimensions):
num_wi *= self.get_global_range(i)

return num_wi

def get_group(self):
"""Returns the group.

Expand Down
Loading