Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 2982405

Browse files
committed
Add tests for teams and parallel shared arrays.
1 parent e975e83 commit 2982405

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

numba/tests/test_openmp.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3775,6 +3775,63 @@ def test_impl():
37753775
test_impl()
37763776
input("ok?")
37773777

3778+
def target_teams_shared_array(self, device):
3779+
target_pragma = f"target teams num_teams(10) map(tofrom: outside) device({device})"
3780+
@njit
3781+
def test_impl():
3782+
outside = np.zeros(10, dtype=np.int32)
3783+
3784+
with openmp (target_pragma):
3785+
team_shared_array = np.empty(10, dtype=np.int32)
3786+
for i in range(10):
3787+
team_shared_array[i] = omp_get_team_num()
3788+
3789+
lasum = 0
3790+
for i in range(10):
3791+
lasum += team_shared_array[i]
3792+
outside[omp_get_team_num()] = lasum
3793+
3794+
return outside
3795+
3796+
r = test_impl()
3797+
np.testing.assert_array_equal(r, np.arange(10) * 10)
3798+
3799+
def target_teams_parallel_shared_array(self, device):
3800+
target_pragma = f"target teams num_teams(10) map(tofrom: outside) device({device})"
3801+
@njit
3802+
def test_impl():
3803+
outside = np.zeros(10, dtype=np.int32)
3804+
3805+
with openmp (target_pragma):
3806+
team_shared_array = np.empty(10, dtype=np.int32)
3807+
outside_parallel = np.empty(10, dtype=np.int32)
3808+
with openmp ("parallel num_threads(32)"):
3809+
thread_shared_array = np.empty(32, dtype=np.int32)
3810+
for i in range(32):
3811+
thread_shared_array[i] = omp_get_thread_num()
3812+
3813+
lasum = 0
3814+
for i in range(32):
3815+
lasum += thread_shared_array[i]
3816+
outside_parallel[omp_get_thread_num()] = lasum
3817+
3818+
for i in range(10):
3819+
if outside_parallel[i] == i * 32:
3820+
team_shared_array[i] = omp_get_team_num()
3821+
else:
3822+
team_shared_array[i] = 0
3823+
3824+
lasum = 0
3825+
for i in range(10):
3826+
lasum += team_shared_array[i]
3827+
outside[omp_get_team_num()] = lasum
3828+
3829+
return outside
3830+
3831+
r = test_impl()
3832+
np.testing.assert_array_equal(r, np.arange(10) * 10)
3833+
3834+
37783835
for memberName in dir(TestOpenmpTarget):
37793836
if memberName.startswith("target"):
37803837
test_func = getattr(TestOpenmpTarget, memberName)

0 commit comments

Comments
 (0)