Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit fb08a5d

Browse files
authored
Fix collapse loop (#30)
* Updates. * Redirect TARGET.TEAMS.DISTRIBUTE.PARALLEL.LOOP.SIMD to TARGET.TEAMS.DISTRIBUTE.PARALLEL.LOOP.
1 parent 1fa0b32 commit fb08a5d

File tree

2 files changed

+38
-10
lines changed

2 files changed

+38
-10
lines changed

numba/openmp.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2809,8 +2809,6 @@ def get_dotted_type(x, typemap, lowerer):
28092809

28102810

28112811
def is_target_arg(name):
2812-
#return name in ["QUAL.OMP.FIRSTPRIVATE", "QUAL.OMP.TARGET.IMPLICIT", "QUAL.OMP.THREAD_LIMIT", "QUAL.OMP.NUM_TEAMS"] or name.startswith("QUAL.OMP.MAP")
2813-
#or name.startswith("QUAL.OMP.NORMALIZED")
28142812
return name in ["QUAL.OMP.FIRSTPRIVATE", "QUAL.OMP.TARGET.IMPLICIT"] or name.startswith("QUAL.OMP.MAP") or name.startswith("QUAL.OMP.REDUCTION")
28152813

28162814

@@ -2824,7 +2822,6 @@ def is_pointer_target_arg(name, typ):
28242822
if name.startswith("QUAL.OMP.MAP"):
28252823
if isinstance(typ, types.npytypes.Array):
28262824
return True
2827-
#return False
28282825
else:
28292826
return True
28302827
if name in ["QUAL.OMP.FIRSTPRIVATE", "QUAL.OMP.PRIVATE"]:
@@ -3408,15 +3405,25 @@ def get_loops_in_region(all_loops):
34083405

34093406
# Copy all stmts from the loop entry block up to the ir.Global
34103407
# for range.
3408+
call_offset = None
34113409
for entry_block_index, stmt in enumerate(loop_entry_block.body):
3410+
found_range = False
34123411
if isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Global) and stmt.value.name == "range":
3412+
found_range = True
34133413
range_target = stmt.target
3414-
call_stmt = loop_entry_block.body[entry_block_index + 1]
3415-
assert isinstance(call_stmt, ir.Assign) and isinstance(call_stmt.value, ir.Expr) and call_stmt.value.op == 'call' and call_stmt.value.func == range_target
3416-
# Remove stmts that were retained.
3417-
loop_entry_block.body = loop_entry_block.body[entry_block_index:]
3414+
found_call = False
3415+
for call_index in range(entry_block_index + 1, len(loop_entry_block.body)):
3416+
call_stmt = loop_entry_block.body[call_index]
3417+
if isinstance(call_stmt, ir.Assign) and isinstance(call_stmt.value, ir.Expr) and call_stmt.value.op == 'call' and call_stmt.value.func == range_target:
3418+
found_call = True
3419+
# Remove stmts that were retained.
3420+
loop_entry_block.body = loop_entry_block.body[entry_block_index:]
3421+
call_offset = call_index - entry_block_index
3422+
break
3423+
assert found_call
34183424
break
34193425
stmts_to_retain.append(stmt)
3426+
assert found_range
34203427
for header_block_index, stmt in enumerate(loop_header_block.body):
34213428
if isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Expr) and stmt.value.op == "iternext":
34223429
iternext_inst = loop_header_block.body[header_block_index]
@@ -3475,7 +3482,7 @@ def get_loops_in_region(all_loops):
34753482
new_stmts_for_iterspace.append(ir.Assign(mul_op, new_iterspace_var, self.loc))
34763483
# Change iteration space of innermost loop to the product of all the
34773484
# loops' iteration spaces.
3478-
last_loop_entry_block.body[1].value.args[0] = new_iterspace_var
3485+
last_loop_entry_block.body[call_offset].value.args[0] = new_iterspace_var
34793486

34803487
last_eliminated_loop_header_block.body = new_stmts_for_iterspace + last_eliminated_loop_header_block.body
34813488

@@ -4324,14 +4331,16 @@ def target_teams_distribute_directive(self, args):
43244331
self.some_target_directive(args, "TARGET.TEAMS.DISTRIBUTE", 3, has_loop=True)
43254332

43264333
def target_teams_loop_directive(self, args):
4327-
self.some_target_directive(args, "TARGET.TEAMS.DISTRIBUTE.PARALLEL.LOOP.SIMD", 3, has_loop=True)
4334+
self.some_target_directive(args, "TARGET.TEAMS.DISTRIBUTE.PARALLEL.LOOP", 3, has_loop=True)
4335+
#self.some_target_directive(args, "TARGET.TEAMS.DISTRIBUTE.PARALLEL.LOOP.SIMD", 3, has_loop=True)
43284336
#self.some_target_directive(args, "TARGET.TEAMS.LOOP", 3, has_loop=True)
43294337

43304338
def target_teams_distribute_parallel_for_directive(self, args):
43314339
self.some_target_directive(args, "TARGET.TEAMS.DISTRIBUTE.PARALLEL.LOOP", 5, has_loop=True)
43324340

43334341
def target_teams_distribute_parallel_for_simd_directive(self, args):
4334-
self.some_target_directive(args, "TARGET.TEAMS.DISTRIBUTE.PARALLEL.LOOP.SIMD", 6, has_loop=True)
4342+
# Intentionally dropping "SIMD" from string as that typically isn't implemented on GPU.
4343+
self.some_target_directive(args, "TARGET.TEAMS.DISTRIBUTE.PARALLEL.LOOP", 6, has_loop=True)
43354344

43364345
def get_clauses_by_name(self, clauses, names, remove_from_orig=False):
43374346
if not isinstance(names, list):

numba/tests/test_openmp.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4390,6 +4390,25 @@ def test_impl():
43904390
else:
43914391
raise ValueError(f"Device {device} must be 0 or 1")
43924392

4393+
def target_teams_loop_collapse(self, device):
4394+
target_pragma = f"""target teams loop collapse(2)
4395+
device({device})
4396+
map(tofrom: a, b, c)"""
4397+
@njit
4398+
def test_impl(n):
4399+
a = np.ones((n,n))
4400+
b = np.ones((n,n))
4401+
c = np.zeros((n,n))
4402+
with openmp(target_pragma):
4403+
for i in range(n):
4404+
for j in range(n):
4405+
c[i,j] = a[i,j] + b[i,j]
4406+
return c
4407+
4408+
n = 10
4409+
c = test_impl(n)
4410+
np.testing.assert_array_equal(c, np.full((n,n), 2))
4411+
43934412

43944413
for memberName in dir(TestOpenmpTarget):
43954414
if memberName.startswith("target"):

0 commit comments

Comments
 (0)