Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit e86d157

Browse files
authored
More loop variants. (#33)
* Add loop directive variants * Add tests * Adding way to skip test only for device=1 * Remove old comments * Process schedule_clause and dist_schedule_clause from the grammar
1 parent fb08a5d commit e86d157

File tree

2 files changed

+91
-54
lines changed

2 files changed

+91
-54
lines changed

numba/openmp.py

Lines changed: 51 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,7 @@ def replace_vars_inner(self, var_dict):
515515

516516
def add_to_usedef_set(self, use_set, def_set, start):
517517
assert start==True or start==False
518-
if config.DEBUG_OPENMP >= 1:
518+
if config.DEBUG_OPENMP >= 3:
519519
print("add_to_usedef_set", start, self.name, "is_dsa=", is_dsa(self.name))
520520

521521
def add_arg(arg, the_set):
@@ -3360,6 +3360,7 @@ def get_loops_in_region(all_loops):
33603360
collapse_tags = get_tags_of_type(clauses, "QUAL.OMP.COLLAPSE")
33613361
new_stmts_for_iterspace = []
33623362
collapse_iterspace_block = set()
3363+
iterspace_vars = []
33633364
if len(collapse_tags) > 0:
33643365
# Limit all_loops to just loops within the openmp region.
33653366
all_loops = get_loops_in_region(all_loops)
@@ -3469,7 +3470,6 @@ def get_loops_in_region(all_loops):
34693470
new_var_scope = last_loop_entry_block.body[0].target.scope
34703471

34713472
# -------- Add vars to remember cumulative product of iteration space sizes.
3472-
iterspace_vars = []
34733473
new_iterspace_var = new_var_scope.redefine("new_iterspace0", self.loc)
34743474
start_tags.append(openmp_tag("QUAL.OMP.FIRSTPRIVATE", new_iterspace_var.name))
34753475
iterspace_vars.append(new_iterspace_var)
@@ -3873,8 +3873,7 @@ def _get_loop_kind(func_var, call_table):
38733873
start_tags.append(openmp_tag("QUAL.OMP.FIRSTPRIVATE", omp_start_var.name))
38743874
start_tags.append(openmp_tag("QUAL.OMP.FIRSTPRIVATE", omp_lb_var.name))
38753875
start_tags.append(openmp_tag("QUAL.OMP.FIRSTPRIVATE", omp_ub_var.name))
3876-
tags_for_enclosing = [cmp_var.name, omp_lb_var.name, omp_start_var.name, omp_iv_var.name, types_mod_var.name, int64_var.name, itercount_var.name, omp_ub_var.name, const1_var.name, const1_latch_var.name]
3877-
#tags_for_enclosing = [omp_lb_var.name, omp_start_var.name, omp_iv_var.name, types_mod_var.name, int64_var.name, itercount_var.name, omp_ub_var.name, const1_var.name, const1_latch_var.name]
3876+
tags_for_enclosing = [cmp_var.name, omp_lb_var.name, omp_start_var.name, omp_iv_var.name, types_mod_var.name, int64_var.name, itercount_var.name, omp_ub_var.name, const1_var.name, const1_latch_var.name, get_itercount_var.name] + [x.name for x in iterspace_vars]
38783877
tags_for_enclosing = [openmp_tag("QUAL.OMP.PRIVATE", x) for x in tags_for_enclosing]
38793878
# Don't blindly copy code here...this isn't doing what the other spots are doing with privatization.
38803879
#self.add_private_to_enclosing(replace_vardict, tags_for_enclosing)
@@ -3891,15 +3890,6 @@ def some_for_directive(self, args, main_start_tag, main_end_tag, first_clause, g
38913890
start_tags = [openmp_tag(main_start_tag)]
38923891
end_tags = [openmp_tag(main_end_tag)]
38933892
clauses = self.some_data_clause_directive(args, start_tags, end_tags, first_clause, has_loop=True)
3894-
#sblk = self.blocks[self.blk_start]
3895-
#scope = sblk.scope
3896-
#eblk = self.blocks[self.blk_end]
3897-
#clauses, default_shared = self.flatten(args[first_clause:], sblk)
3898-
3899-
#if config.DEBUG_OPENMP >= 1:
3900-
# print("visit", main_start_tag, args, type(args), default_shared)
3901-
# for clause in clauses:
3902-
# print("post-process clauses:", clause)
39033893

39043894
if "PARALLEL" in main_start_tag:
39053895
# ---- Back propagate THREAD_LIMIT to enclosed target region. ----
@@ -3969,6 +3959,18 @@ def for_simd_clause(self, args):
39693959
args, type(args), args[0])
39703960
return args[0]
39713961

3962+
def schedule_clause(self, args):
3963+
if config.DEBUG_OPENMP >= 1:
3964+
print("visit schedule_clause",
3965+
args, type(args), args[0])
3966+
return args[0]
3967+
3968+
def dist_schedule_clause(self, args):
3969+
if config.DEBUG_OPENMP >= 1:
3970+
print("visit dist_schedule_clause",
3971+
args, type(args), args[0])
3972+
return args[0]
3973+
39723974
# Don't need a rule for parallel_for_simd_construct.
39733975

39743976
def parallel_for_simd_directive(self, args):
@@ -4071,7 +4073,7 @@ def map_clause(self, args):
40714073
assert(len(args) == 2)
40724074
else:
40734075
map_type = "TOFROM" # is this default right? FIX ME
4074-
var_list = args
4076+
var_list = args[0]
40754077
ret = []
40764078
for var in var_list:
40774079
ret.append(openmp_tag("QUAL.OMP.MAP." + map_type, var))
@@ -4267,7 +4269,7 @@ def teams_back_prop(self, clauses):
42674269
def check_distribute_nesting(self, dir_tag):
42684270
if "DISTRIBUTE" in dir_tag and "TEAMS" not in dir_tag:
42694271
enclosing_regions = get_enclosing_region(self.func_ir, self.blk_start)
4270-
if len(enclosing_regions) < 1 or "TEAMS" not in enclosing_regions[0].tags[0].name:
4272+
if len(enclosing_regions) < 1 or "TEAMS" not in enclosing_regions[-1].tags[0].name:
42714273
raise NotImplementedError("DISTRIBUTE must be nested under or combined with TEAMS.")
42724274

42734275
def teams_directive(self, args):
@@ -4330,10 +4332,11 @@ def target_teams_directive(self, args):
43304332
def target_teams_distribute_directive(self, args):
43314333
self.some_target_directive(args, "TARGET.TEAMS.DISTRIBUTE", 3, has_loop=True)
43324334

4335+
def target_loop_directive(self, args):
4336+
self.some_target_directive(args, "TARGET.TEAMS.DISTRIBUTE.PARALLEL.LOOP", 2, has_loop=True)
4337+
43334338
def target_teams_loop_directive(self, args):
43344339
self.some_target_directive(args, "TARGET.TEAMS.DISTRIBUTE.PARALLEL.LOOP", 3, has_loop=True)
4335-
#self.some_target_directive(args, "TARGET.TEAMS.DISTRIBUTE.PARALLEL.LOOP.SIMD", 3, has_loop=True)
4336-
#self.some_target_directive(args, "TARGET.TEAMS.LOOP", 3, has_loop=True)
43374340

43384341
def target_teams_distribute_parallel_for_directive(self, args):
43394342
self.some_target_directive(args, "TARGET.TEAMS.DISTRIBUTE.PARALLEL.LOOP", 5, has_loop=True)
@@ -4415,6 +4418,26 @@ def teams_distribute_directive(self, args):
44154418
def teams_distribute_simd_directive(self, args):
44164419
self.some_distribute_directive(args, "TEAMS.DISTRIBUTE.SIMD", 3, has_loop=True)
44174420

4421+
def teams_loop_directive(self, args):
4422+
self.some_distribute_directive(args, "TEAMS.DISTRIBUTE.PARALLEL.LOOP", 2, has_loop=True)
4423+
4424+
def loop_directive(self, args):
4425+
# TODO Add error checking that a clause that the parser accepts if we find that
4426+
# loop can even take clauses, which we're not sure that it can.
4427+
enclosing_regions = get_enclosing_region(self.func_ir, self.blk_start)
4428+
if not enclosing_regions or len(enclosing_regions) < 1:
4429+
self.some_for_directive(args, "DIR.OMP.PARALLEL.LOOP", "DIR.OMP.END.PARALLEL.LOOP", 1, True)
4430+
else:
4431+
if "DISTRIBUTE" in enclosing_regions[-1].tags[0].name:
4432+
self.some_distribute_directive(args, "PARALLEL.LOOP", 1, has_loop=True)
4433+
elif "TEAMS" in enclosing_regions[-1].tags[0].name:
4434+
self.some_distribute_directive(args, "DISTRIBUTE.PARALLEL.LOOP", 1, has_loop=True)
4435+
else:
4436+
if "TARGET" in enclosing_regions[-1].tags[0].name:
4437+
self.some_distribute_directive(args, "TEAMS.DISTRIBUTE.PARALLEL.LOOP", 1, has_loop=True)
4438+
else:
4439+
self.some_for_directive(args, "DIR.OMP.PARALLEL.LOOP", "DIR.OMP.END.PARALLEL.LOOP", 1, True)
4440+
44184441
def distribute_directive(self, args):
44194442
self.some_distribute_directive(args, "DISTRIBUTE", 1, has_loop=True)
44204443

@@ -4453,8 +4476,6 @@ def some_distribute_directive(self, args, dir_tag, lexer_count, has_loop=False):
44534476
start_tags.append(openmp_tag("QUAL.OMP.THREAD_LIMIT", 0))
44544477
self.teams_back_prop(clauses)
44554478
elif "PARALLEL" in dir_tag:
4456-
if len(self.get_clauses_by_name(clauses, "QUAL.OMP.THREAD_LIMIT")) == 0:
4457-
start_tags.append(openmp_tag("QUAL.OMP.THREAD_LIMIT", 0))
44584479
self.parallel_back_prop(clauses)
44594480

44604481
if config.DEBUG_OPENMP >= 1:
@@ -4796,13 +4817,6 @@ def target_teams_distribute_parallel_for_clause(self, args):
47964817
print(args[0][0])
47974818
return args[0]
47984819

4799-
def target_teams_loop_clause(self, args):
4800-
if config.DEBUG_OPENMP >= 1:
4801-
print("visit target_teams_loop_clause", args, type(args), args[0])
4802-
if isinstance(args[0], list):
4803-
print(args[0][0])
4804-
return args[0]
4805-
48064820
# Don't need a rule for target_update_construct.
48074821

48084822
def target_update_directive(self, args):
@@ -5514,12 +5528,15 @@ def NUMBER(self, args):
55145528
| teams_distribute_simd_construct
55155529
| teams_distribute_parallel_for_construct
55165530
| teams_distribute_parallel_for_simd_construct
5531+
| loop_construct
5532+
| teams_loop_construct
55175533
| target_construct
55185534
| target_teams_construct
55195535
| target_teams_distribute_construct
55205536
| target_teams_distribute_simd_construct
55215537
| target_teams_distribute_parallel_for_simd_construct
55225538
| target_teams_distribute_parallel_for_construct
5539+
| target_loop_construct
55235540
| target_teams_loop_construct
55245541
| target_enter_data_construct
55255542
| target_exit_data_construct
@@ -5539,8 +5556,6 @@ def NUMBER(self, args):
55395556
| parallel_sections_construct
55405557
| master_construct
55415558
| ordered_construct
5542-
//teams_distribute_parallel_for_simd_clause: target_clause
5543-
// | teams_distribute_parallel_for_simd_clause
55445559
for_simd_construct: for_simd_directive
55455560
for_simd_directive: FOR SIMD [for_simd_clause*]
55465561
for_simd_clause: for_clause
@@ -5735,6 +5750,9 @@ def NUMBER(self, args):
57355750
target_teams_distribute_parallel_for_construct: target_teams_distribute_parallel_for_directive
57365751
teams_distribute_parallel_for_construct: teams_distribute_parallel_for_directive
57375752
teams_distribute_parallel_for_simd_construct: teams_distribute_parallel_for_simd_directive
5753+
loop_construct: loop_directive
5754+
teams_loop_construct: teams_loop_directive
5755+
target_loop_construct: target_loop_directive
57385756
target_teams_loop_construct: target_teams_loop_directive
57395757
target_teams_construct: target_teams_directive
57405758
target_teams_distribute_construct: target_teams_distribute_directive
@@ -5903,30 +5921,10 @@ def NUMBER(self, args):
59035921
59045922
ompx_attribute: OMPX_ATTRIBUTE "(" PYTHON_NAME "(" number_list ")" ")"
59055923
OMPX_ATTRIBUTE: "ompx_attribute"
5906-
//target_teams_loop_directive: TARGET TEAMS LOOP [target_teams_loop_clause*]
5907-
target_teams_loop_directive: TARGET TEAMS LOOP [target_teams_distribute_parallel_for_simd_clause*]
5908-
target_teams_loop_clause: if_clause
5909-
| device_clause
5910-
| private_clause
5911-
| firstprivate_clause
5912-
// | in_reduction_clause
5913-
| map_clause
5914-
| is_device_ptr_clause
5915-
// | defaultmap_clause
5916-
| NOWAIT
5917-
| allocate_clause
5918-
| depend_with_modifier_clause
5919-
// | uses_allocators_clause
5920-
| num_teams_clause
5921-
| thread_limit_clause
5922-
| data_default_clause
5923-
| data_sharing_clause
5924-
// | reduction_default_only_clause
5925-
// | bind_clause
5926-
| collapse_clause
5927-
| ORDERED
5928-
| lastprivate_clause
5929-
| ompx_attribute
5924+
loop_directive: LOOP [teams_distribute_parallel_for_clause*]
5925+
teams_loop_directive: TEAMS LOOP [teams_distribute_parallel_for_clause*]
5926+
target_loop_directive: TARGET LOOP [target_teams_distribute_parallel_for_clause*]
5927+
target_teams_loop_directive: TARGET TEAMS LOOP [target_teams_distribute_parallel_for_clause*]
59305928
59315929
target_teams_directive: TARGET TEAMS [target_teams_clause*]
59325930
target_teams_clause: if_clause
@@ -6149,8 +6147,7 @@ def NUMBER(self, args):
61496147
for_directive: FOR [for_clause*]
61506148
for_clause: unique_for_clause | data_clause | NOWAIT
61516149
unique_for_clause: ORDERED
6152-
| sched_no_expr
6153-
| sched_expr
6150+
| schedule_clause
61546151
| collapse_clause
61556152
LINEAR: "linear"
61566153
linear_clause: LINEAR "(" var_list ":" const_num_or_var ")"

numba/tests/test_openmp.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2953,6 +2953,10 @@ class TestOpenmpTarget(TestOpenmpBase):
29532953
def __init__(self, *args):
29542954
TestOpenmpBase.__init__(self, *args)
29552955

2956+
@classmethod
2957+
def is_testing_cpu(cls):
2958+
return 1 in cls.devices
2959+
29562960
# How to check for nowait?
29572961
# Currently checks only compilation.
29582962
# Numba optimizes the whole target away? This runs too fast.
@@ -4409,6 +4413,25 @@ def test_impl(n):
44094413
c = test_impl(n)
44104414
np.testing.assert_array_equal(c, np.full((n,n), 2))
44114415

4416+
def target_nest_teams_nest_loop_collapse(self, device):
4417+
target_pragma = f"""target device({device}) map(tofrom: a, b, c)"""
4418+
@njit
4419+
def test_impl(n):
4420+
a = np.ones((n,n))
4421+
b = np.ones((n,n))
4422+
c = np.zeros((n,n))
4423+
with openmp(target_pragma):
4424+
with openmp("teams"):
4425+
with openmp("loop collapse(2)"):
4426+
for i in range(n):
4427+
for j in range(n):
4428+
c[i,j] = a[i,j] + b[i,j]
4429+
return c
4430+
4431+
n = 10
4432+
c = test_impl(n)
4433+
np.testing.assert_array_equal(c, np.full((n,n), 2))
4434+
44124435

44134436
for memberName in dir(TestOpenmpTarget):
44144437
if memberName.startswith("target"):
@@ -4462,6 +4485,23 @@ def test_impl(num_steps):
44624485

44634486
self.check(test_impl, 100000)
44644487

4488+
def test_pi_loop_directive(self):
4489+
def test_impl(num_steps):
4490+
step = 1.0 / num_steps
4491+
4492+
the_sum = 0.0
4493+
omp_set_num_threads(4)
4494+
4495+
with openmp("loop reduction(+:the_sum) schedule(static)"):
4496+
for j in range(num_steps):
4497+
x = ((j-1) - 0.5) * step
4498+
the_sum += 4.0 / (1.0 + x * x)
4499+
4500+
pi = step * the_sum
4501+
return pi
4502+
4503+
self.check(test_impl, 100000)
4504+
44654505
def test_pi_spmd(self):
44664506
def test_impl(num_steps):
44674507
step = 1.0 / num_steps

0 commit comments

Comments
 (0)