Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit f7a481e

Browse files
authored
Allow parallel for range step to be a non-constant. (#41)
* Allow step to be a non-constant. * Move previously asserting test to the working section. That tests was flawed and had a 0 step in it so fixed that as well. * Add range step var as firstprivate. * Add arbitrary step support tests for target.
1 parent 296fb85 commit f7a481e

File tree

2 files changed

+79
-35
lines changed

2 files changed

+79
-35
lines changed

numba/openmp.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def arg_to_str(self, x, lowerer, struct_lower=False, var_table=None, gen_copy=Fa
330330
elif isinstance(arg_str, lir.instructions.AllocaInstr):
331331
decl = arg_str.get_decl()
332332
else:
333-
breakpoint()
333+
assert False, f"Don't know how to get decl string for variable {arg_str} of type {type(arg_str)}"
334334

335335
if struct_lower and isinstance(xtyp, types.npytypes.Array):
336336
dm = lowerer.context.data_model_manager.lookup(xtyp)
@@ -3550,9 +3550,7 @@ def _get_loop_kind(func_var, call_table):
35503550
if len(call) == 0:
35513551
return False
35523552

3553-
return call[0] # or call[0] == prange
3554-
#or call[0] == 'internal_prange' or call[0] == internal_prange
3555-
#$or call[0] == 'pndindex' or call[0] == pndindex)
3553+
return call[0]
35563554

35573555
loop = loops[0]
35583556
entry = list(loop.entries)[0]
@@ -3744,19 +3742,18 @@ def _get_loop_kind(func_var, call_table):
37443742
size_var = range_args[1]
37453743
try:
37463744
step = self.func_ir.get_definition(range_args[2])
3745+
# Only use get_definition to get a const if
3746+
# available. Otherwise use the variable.
3747+
if not isinstance(step, (int, ir.Const)):
3748+
step = range_args[2]
37473749
except KeyError:
3748-
raise NotImplementedError(
3749-
"Only known step size is supported for prange")
3750-
if not isinstance(step, ir.Const):
3751-
raise NotImplementedError(
3752-
"Only constant step size is supported for prange")
3753-
step = step.value
3754-
# if step != 1:
3755-
# print("unsupported step:", step, type(step))
3756-
# raise NotImplementedError(
3757-
# "Only constant step size of 1 is supported for prange")
3758-
3759-
#assert(start == 0 or (isinstance(start, ir.Const) and start.value == 0))
3750+
# If there is more than one definition possible for the
3751+
# step variable then just use the variable and don't try
3752+
# to convert to a const.
3753+
step = range_args[2]
3754+
if isinstance(step, ir.Const):
3755+
step = step.value
3756+
37603757
if config.DEBUG_OPENMP >= 1:
37613758
print("size_var:", size_var, type(size_var))
37623759

@@ -3848,7 +3845,15 @@ def _get_loop_kind(func_var, call_table):
38483845
detect_step_assign = ir.Assign(ir.Const(0, inst.loc), step_var, inst.loc)
38493846
after_start.append(detect_step_assign)
38503847

3851-
step_assign = ir.Assign(ir.Const(step, inst.loc), step_var, inst.loc)
3848+
if isinstance(step, int):
3849+
step_assign = ir.Assign(ir.Const(step, inst.loc), step_var, inst.loc)
3850+
elif isinstance(step, ir.Var):
3851+
step_assign = ir.Assign(step, step_var, inst.loc)
3852+
start_tags.append(openmp_tag("QUAL.OMP.FIRSTPRIVATE", step.name))
3853+
else:
3854+
print("Unsupported step:", step, type(step))
3855+
raise NotImplementedError(
3856+
f"Unknown step type that isn't a constant or variable but {type(step)} instead.")
38523857
scale_var = loop_index.scope.redefine("$scale", inst.loc)
38533858
fake_iternext = ir.Assign(ir.Const(0, inst.loc), iternext_inst.target, inst.loc)
38543859
fake_second = ir.Assign(ir.Const(0, inst.loc), pair_second_inst.target, inst.loc)
@@ -4606,9 +4611,7 @@ def some_data_clause_directive(self, args, start_tags, end_tags, lexer_count, ha
46064611
end_tags,
46074612
scope)
46084613
vars_in_explicit_clauses, explicit_privates, non_user_explicits = self.get_explicit_vars(clauses)
4609-
46104614
found_loop, blocks_for_io, blocks_in_region, entry_pred, exit_block, inst, size_var, step_var, latest_index, loop_index = prepare_out
4611-
46124615
assert(found_loop)
46134616
else:
46144617
blocks_for_io = self.body_blocks
@@ -6363,15 +6366,13 @@ def omp_shared_array(size, dtype):
63636366

63646367
@overload(omp_shared_array, target='cpu', inline='always', prefer_literal=True)
63656368
def omp_shared_array_overload(size, dtype):
6366-
breakpoint()
63676369
assert isinstance(size, types.IntegerLiteral)
63686370
def impl(size, dtype):
63696371
return np.empty(size, dtype=dtype)
63706372
return impl
63716373

63726374
@overload(omp_shared_array, target='cuda', inline='always', prefer_literal=True)
63736375
def omp_shared_array_overload(size, dtype):
6374-
breakpoint()
63756376
assert isinstance(size, types.IntegerLiteral)
63766377
def impl(size, dtype):
63776378
return numba_cuda.shared.array(size, dtype)

numba/tests/test_openmp.py

Lines changed: 57 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -606,12 +606,33 @@ def test_parallel_for_range_step_2(self):
606606
def test_impl(N):
607607
a = np.zeros(N, dtype=np.int32)
608608
with openmp("parallel for"):
609-
for i in range(0, 10, 2):
609+
for i in range(0, len(a), 2):
610610
a[i] = i + 1
611611

612612
return a
613613
self.check(test_impl, 12)
614614

615+
def test_parallel_for_range_step_arg(self):
616+
def test_impl(N, step):
617+
a = np.zeros(N, dtype=np.int32)
618+
with openmp("parallel for"):
619+
for i in range(0, len(a), step):
620+
a[i] = i + 1
621+
622+
return a
623+
self.check(test_impl, 12, 2)
624+
625+
def test_parallel_for_incremented_step(self):
626+
@njit
627+
def test_impl(v, n):
628+
for i in range(n):
629+
with openmp("parallel for"):
630+
for j in range(0, len(v), i + 1):
631+
v[j] = i + 1
632+
return v
633+
634+
self.check(test_impl, np.zeros(100), 3)
635+
615636
def test_parallel_for_range_backward_step(self):
616637
def test_impl(N):
617638
a = np.zeros(N, dtype=np.int32)
@@ -1844,19 +1865,6 @@ def test_impl():
18441865
test_impl()
18451866
self.assertIn("Extra code near line", str(raises.exception))
18461867

1847-
def test_parallel_for_incremented_step(self):
1848-
@njit
1849-
def test_impl(v, n):
1850-
for i in range(n):
1851-
with openmp("parallel for"):
1852-
for j in range(0, len(v), i):
1853-
v[j] = i
1854-
return v
1855-
1856-
with self.assertRaises(NotImplementedError) as raises:
1857-
test_impl(np.zeros(100), 3)
1858-
self.assertIn("Only constant step", str(raises.exception))
1859-
18601868
def test_nonstring_var_omp_statement(self):
18611869
@njit
18621870
def test_impl(v):
@@ -3350,6 +3358,41 @@ def test_impl():
33503358
r = test_impl()
33513359
np.testing.assert_equal(r, np.full(32, 1))
33523360

3361+
def target_parallel_for_range_step_arg(self, device):
3362+
target_pragma = f"target device({device}) map(tofrom: a)"
3363+
parallel_pragma = "parallel for"
3364+
N = 10
3365+
step = 2
3366+
@njit
3367+
def test_impl():
3368+
a = np.zeros(N, dtype=np.int32)
3369+
with openmp(target_pragma):
3370+
with openmp(parallel_pragma):
3371+
for i in range(0, len(a), step):
3372+
a[i] = i + 1
3373+
3374+
return a
3375+
r = test_impl()
3376+
np.testing.assert_equal(r, np.array([1,0,3,0,5,0,7,0,9,0]))
3377+
3378+
def target_parallel_for_incremented_step(self, device):
3379+
target_pragma = f"target device({device}) map(tofrom: a)"
3380+
parallel_pragma = "parallel for"
3381+
N = 10
3382+
step_range = 3
3383+
@njit
3384+
def test_impl():
3385+
a = np.zeros(N, dtype=np.int32)
3386+
for i in range(step_range):
3387+
with openmp(target_pragma):
3388+
with openmp(parallel_pragma):
3389+
for j in range(0, len(a), i + 1):
3390+
a[j] = i + 1
3391+
return a
3392+
3393+
r = test_impl()
3394+
np.testing.assert_equal(r, np.array([3,1,2,3,2,1,3,1,2,3]))
3395+
33533396
def target_teams(self, device):
33543397
target_pragma = f"target teams num_teams(100) device({device}) map(from: a, nteams)"
33553398
@njit

0 commit comments

Comments
 (0)