Skip to content

Commit 37f6ba4

Browse files
authored
[flang][OpenMP] Fix construct privatization in default clause (#72510)
Current implementation of default clause privatization incorrectly fails to privatize in presence of non-OpenMP constructs (i.e. nested constructs with regions whose symbols need to be privatized in the scope of the parent OpenMP construct). This patch fixes the same by considering non-OpenMP constructs separately by collecting symbols of a nested region if it is a non-OpenMP construct with a region, and privatizing it in the scope of the parent OpenMP construct. Fixes #71914 and #71915
1 parent 6b94870 commit 37f6ba4

File tree

6 files changed

+92
-22
lines changed

6 files changed

+92
-22
lines changed

flang/include/flang/Lower/AbstractConverter.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,12 @@ class AbstractConverter {
134134
virtual bool isPresentShallowLookup(Fortran::semantics::Symbol &sym) = 0;
135135

136136
/// Collect the set of symbols with \p flag in \p eval
137-
/// region if \p collectSymbols is true. Likewise, collect the
137+
/// region if \p collectSymbols is true. Otherwise, collect the
138138
/// set of the host symbols with \p flag of the associated symbols in \p eval
139-
/// region if collectHostAssociatedSymbols is true.
139+
/// region if collectHostAssociatedSymbols is true. This allows gathering
140+
/// host association details of symbols particularly in nested directives
141+
/// irrespective of \p flag \p, and can be useful where host
142+
/// association details are needed in flag-agnostic manner.
140143
virtual void collectSymbolSet(
141144
pft::Evaluation &eval,
142145
llvm::SetVector<const Fortran::semantics::Symbol *> &symbolSet,

flang/lib/Lower/Bridge.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -810,7 +810,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
810810
bool collectSymbol) {
811811
if (collectSymbol && oriSymbol.test(flag))
812812
symbolSet.insert(&oriSymbol);
813-
if (checkHostAssociatedSymbols)
813+
else if (checkHostAssociatedSymbols)
814814
if (const auto *details{
815815
oriSymbol
816816
.detailsIf<Fortran::semantics::HostAssocDetails>()})

flang/lib/Lower/OpenMP/DataSharingProcessor.cpp

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -302,21 +302,38 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
302302
}
303303
}
304304

305+
void DataSharingProcessor::collectSymbolsInNestedRegions(
306+
Fortran::lower::pft::Evaluation &eval,
307+
Fortran::semantics::Symbol::Flag flag,
308+
llvm::SetVector<const Fortran::semantics::Symbol *>
309+
&symbolsInNestedRegions) {
310+
for (Fortran::lower::pft::Evaluation &nestedEval :
311+
eval.getNestedEvaluations()) {
312+
if (nestedEval.hasNestedEvaluations()) {
313+
if (nestedEval.isConstruct())
314+
// Recursively look for OpenMP constructs within `nestedEval`'s region
315+
collectSymbolsInNestedRegions(nestedEval, flag, symbolsInNestedRegions);
316+
else
317+
converter.collectSymbolSet(nestedEval, symbolsInNestedRegions, flag,
318+
/*collectSymbols=*/true,
319+
/*collectHostAssociatedSymbols=*/false);
320+
}
321+
}
322+
}
323+
324+
// Collect symbols to be default privatized in two steps.
325+
// In step 1, collect all symbols in `eval` that match `flag` into
326+
// `defaultSymbols`. In step 2, for nested constructs (if any), if and only if
327+
// the nested construct is an OpenMP construct, collect those nested
328+
// symbols skipping host associated symbols into `symbolsInNestedRegions`.
329+
// Later, in current context, all symbols in the set
330+
// `defaultSymbols` - `symbolsInNestedRegions` will be privatized.
305331
void DataSharingProcessor::collectSymbols(
306332
Fortran::semantics::Symbol::Flag flag) {
307333
converter.collectSymbolSet(eval, defaultSymbols, flag,
308334
/*collectSymbols=*/true,
309335
/*collectHostAssociatedSymbols=*/true);
310-
for (Fortran::lower::pft::Evaluation &e : eval.getNestedEvaluations()) {
311-
if (e.hasNestedEvaluations())
312-
converter.collectSymbolSet(e, symbolsInNestedRegions, flag,
313-
/*collectSymbols=*/true,
314-
/*collectHostAssociatedSymbols=*/false);
315-
else
316-
converter.collectSymbolSet(e, symbolsInParentRegions, flag,
317-
/*collectSymbols=*/false,
318-
/*collectHostAssociatedSymbols=*/true);
319-
}
336+
collectSymbolsInNestedRegions(eval, flag, symbolsInNestedRegions);
320337
}
321338

322339
void DataSharingProcessor::collectDefaultSymbols() {
@@ -367,7 +384,6 @@ void DataSharingProcessor::defaultPrivatize(
367384
!sym->GetUltimate().has<Fortran::semantics::NamelistDetails>() &&
368385
!Fortran::semantics::IsImpliedDoIndex(sym->GetUltimate()) &&
369386
!symbolsInNestedRegions.contains(sym) &&
370-
!symbolsInParentRegions.contains(sym) &&
371387
!privatizedSymbols.contains(sym))
372388
doPrivatize(sym, clauseOps, privateSyms);
373389
}

flang/lib/Lower/OpenMP/DataSharingProcessor.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ class DataSharingProcessor {
4040
llvm::SetVector<const Fortran::semantics::Symbol *> privatizedSymbols;
4141
llvm::SetVector<const Fortran::semantics::Symbol *> defaultSymbols;
4242
llvm::SetVector<const Fortran::semantics::Symbol *> symbolsInNestedRegions;
43-
llvm::SetVector<const Fortran::semantics::Symbol *> symbolsInParentRegions;
4443
llvm::DenseMap<const Fortran::semantics::Symbol *, mlir::omp::PrivateClauseOp>
4544
symToPrivatizer;
4645
Fortran::lower::AbstractConverter &converter;
@@ -52,6 +51,11 @@ class DataSharingProcessor {
5251

5352
bool needBarrier();
5453
void collectSymbols(Fortran::semantics::Symbol::Flag flag);
54+
void collectSymbolsInNestedRegions(
55+
Fortran::lower::pft::Evaluation &eval,
56+
Fortran::semantics::Symbol::Flag flag,
57+
llvm::SetVector<const Fortran::semantics::Symbol *>
58+
&symbolsInNestedRegions);
5559
void collectOmpObjectListSymbol(
5660
const omp::ObjectList &objects,
5761
llvm::SetVector<const Fortran::semantics::Symbol *> &symbolSet);

flang/test/Lower/OpenMP/default-clause-byref.f90

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,6 @@ subroutine nested_default_clause_tests
226226
!CHECK: %[[PRIVATE_Y_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Y]] {uniq_name = "_QFnested_default_clause_testsEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
227227
!CHECK: %[[PRIVATE_Z:.*]] = fir.alloca i32 {bindc_name = "z", pinned, uniq_name = "_QFnested_default_clause_testsEz"}
228228
!CHECK: %[[PRIVATE_Z_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Z]] {uniq_name = "_QFnested_default_clause_testsEz"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
229-
!CHECK: %[[PRIVATE_W:.*]] = fir.alloca i32 {bindc_name = "w", pinned, uniq_name = "_QFnested_default_clause_testsEw"}
230-
!CHECK: %[[PRIVATE_W_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_W]] {uniq_name = "_QFnested_default_clause_testsEw"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
231229
!CHECK: omp.parallel {
232230
!CHECK: %[[PRIVATE_INNER_X:.*]] = fir.alloca i32 {bindc_name = "x", pinned, uniq_name = "_QFnested_default_clause_testsEx"}
233231
!CHECK: %[[PRIVATE_INNER_X_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_INNER_X]] {uniq_name = "_QFnested_default_clause_testsEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
@@ -242,12 +240,14 @@ subroutine nested_default_clause_tests
242240
!CHECK: omp.terminator
243241
!CHECK: }
244242
!CHECK: omp.parallel {
243+
!CHECK: %[[PRIVATE_INNER_Z:.*]] = fir.alloca i32 {bindc_name = "z", pinned, uniq_name = "_QFnested_default_clause_testsEz"}
244+
!CHECK: %[[PRIVATE_INNER_Z_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_INNER_Z]] {uniq_name = "_QFnested_default_clause_testsEz"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
245245
!CHECK: %[[PRIVATE_INNER_W:.*]] = fir.alloca i32 {bindc_name = "w", pinned, uniq_name = "_QFnested_default_clause_testsEw"}
246246
!CHECK: %[[PRIVATE_INNER_W_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_INNER_W]] {uniq_name = "_QFnested_default_clause_testsEw"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
247247
!CHECK: %[[PRIVATE_INNER_X:.*]] = fir.alloca i32 {bindc_name = "x", pinned, uniq_name = "_QFnested_default_clause_testsEx"}
248248
!CHECK: %[[PRIVATE_INNER_X_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_INNER_X]] {uniq_name = "_QFnested_default_clause_testsEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
249249
!CHECK: %[[TEMP_1:.*]] = fir.load %[[PRIVATE_INNER_X_DECL]]#0 : !fir.ref<i32>
250-
!CHECK: %[[TEMP_2:.*]] = fir.load %[[PRIVATE_Z_DECL]]#0 : !fir.ref<i32>
250+
!CHECK: %[[TEMP_2:.*]] = fir.load %[[PRIVATE_INNER_Z_DECL]]#0 : !fir.ref<i32>
251251
!CHECK: %[[RESULT:.*]] = arith.addi %{{.*}}, %{{.*}} : i32
252252
!CHECK: hlfir.assign %[[RESULT]] to %[[PRIVATE_INNER_W_DECL]]#0 : i32, !fir.ref<i32>
253253
!CHECK: omp.terminator

flang/test/Lower/OpenMP/default-clause.f90

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ program default_clause_lowering
149149
end program default_clause_lowering
150150

151151
subroutine nested_default_clause_tests
152-
integer :: x, y, z, w, k, a
152+
integer :: x, y, z, w, k
153153
!CHECK: %[[K:.*]] = fir.alloca i32 {bindc_name = "k", uniq_name = "_QFnested_default_clause_testsEk"}
154154
!CHECK: %[[K_DECL:.*]]:2 = hlfir.declare %[[K]] {uniq_name = "_QFnested_default_clause_testsEk"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
155155
!CHECK: %[[W:.*]] = fir.alloca i32 {bindc_name = "w", uniq_name = "_QFnested_default_clause_testsEw"}
@@ -221,13 +221,12 @@ subroutine nested_default_clause_tests
221221

222222

223223
!CHECK: omp.parallel {
224+
!CHECK: %[[PRIVATE_X:.*]] = fir.alloca i32 {bindc_name = "x", pinned, uniq_name = "_QFnested_default_clause_testsEx"}
224225
!CHECK: %[[PRIVATE_X_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_X]] {uniq_name = "_QFnested_default_clause_testsEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
225226
!CHECK: %[[PRIVATE_Y:.*]] = fir.alloca i32 {bindc_name = "y", pinned, uniq_name = "_QFnested_default_clause_testsEy"}
226227
!CHECK: %[[PRIVATE_Y_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Y]] {uniq_name = "_QFnested_default_clause_testsEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
227228
!CHECK: %[[PRIVATE_Z:.*]] = fir.alloca i32 {bindc_name = "z", pinned, uniq_name = "_QFnested_default_clause_testsEz"}
228229
!CHECK: %[[PRIVATE_Z_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Z]] {uniq_name = "_QFnested_default_clause_testsEz"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
229-
!CHECK: %[[PRIVATE_W:.*]] = fir.alloca i32 {bindc_name = "w", pinned, uniq_name = "_QFnested_default_clause_testsEw"}
230-
!CHECK: %[[PRIVATE_W_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_W]] {uniq_name = "_QFnested_default_clause_testsEw"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
231230
!CHECK: omp.parallel {
232231
!CHECK: %[[PRIVATE_INNER_X:.*]] = fir.alloca i32 {bindc_name = "x", pinned, uniq_name = "_QFnested_default_clause_testsEx"}
233232
!CHECK: %[[PRIVATE_INNER_X_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_INNER_X]] {uniq_name = "_QFnested_default_clause_testsEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
@@ -242,12 +241,14 @@ subroutine nested_default_clause_tests
242241
!CHECK: omp.terminator
243242
!CHECK: }
244243
!CHECK: omp.parallel {
244+
!CHECK: %[[PRIVATE_INNER_Z:.*]] = fir.alloca i32 {bindc_name = "z", pinned, uniq_name = "_QFnested_default_clause_testsEz"}
245+
!CHECK: %[[PRIVATE_INNER_Z_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_INNER_Z]] {uniq_name = "_QFnested_default_clause_testsEz"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
245246
!CHECK: %[[PRIVATE_INNER_W:.*]] = fir.alloca i32 {bindc_name = "w", pinned, uniq_name = "_QFnested_default_clause_testsEw"}
246247
!CHECK: %[[PRIVATE_INNER_W_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_INNER_W]] {uniq_name = "_QFnested_default_clause_testsEw"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
247248
!CHECK: %[[PRIVATE_INNER_X:.*]] = fir.alloca i32 {bindc_name = "x", pinned, uniq_name = "_QFnested_default_clause_testsEx"}
248249
!CHECK: %[[PRIVATE_INNER_X_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_INNER_X]] {uniq_name = "_QFnested_default_clause_testsEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
249250
!CHECK: %[[TEMP_1:.*]] = fir.load %[[PRIVATE_INNER_X_DECL]]#0 : !fir.ref<i32>
250-
!CHECK: %[[TEMP_2:.*]] = fir.load %[[PRIVATE_Z_DECL]]#0 : !fir.ref<i32>
251+
!CHECK: %[[TEMP_2:.*]] = fir.load %[[PRIVATE_INNER_Z_DECL]]#0 : !fir.ref<i32>
251252
!CHECK: %[[RESULT:.*]] = arith.addi %{{.*}}, %{{.*}} : i32
252253
!CHECK: hlfir.assign %[[RESULT]] to %[[PRIVATE_INNER_W_DECL]]#0 : i32, !fir.ref<i32>
253254
!CHECK: omp.terminator
@@ -415,3 +416,49 @@ subroutine threadprivate_with_default
415416
end do
416417
!$omp end parallel do
417418
end subroutine
419+
420+
subroutine nested_constructs
421+
!CHECK: %[[I:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFnested_constructsEi"}
422+
!CHECK: %[[I_DECL:.*]]:2 = hlfir.declare %[[I]] {{.*}}
423+
!CHECK: %[[J:.*]] = fir.alloca i32 {bindc_name = "j", uniq_name = "_QFnested_constructsEj"}
424+
!CHECK: %[[J_DECL:.*]]:2 = hlfir.declare %[[J]] {{.*}}
425+
!CHECK: %[[Y:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFnested_constructsEy"}
426+
!CHECK: %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y]] {{.*}}
427+
!CHECK: %[[Z:.*]] = fir.alloca i32 {bindc_name = "z", uniq_name = "_QFnested_constructsEz"}
428+
!CHECK: %[[Z_DECL:.*]]:2 = hlfir.declare %[[Z]] {{.*}}
429+
430+
integer :: y, z
431+
!CHECK: omp.parallel {
432+
!CHECK: %[[INNER_J:.*]] = fir.alloca i32 {bindc_name = "j", pinned}
433+
!CHECK: %[[INNER_J_DECL:.*]]:2 = hlfir.declare %[[INNER_J]] {{.*}}
434+
!CHECK: %[[INNER_I:.*]] = fir.alloca i32 {bindc_name = "i", pinned}
435+
!CHECK: %[[INNER_I_DECL:.*]]:2 = hlfir.declare %[[INNER_I]] {{.*}}
436+
!CHECK: %[[INNER_Y:.*]] = fir.alloca i32 {bindc_name = "y", pinned, uniq_name = "_QFnested_constructsEy"}
437+
!CHECK: %[[INNER_Y_DECL:.*]]:2 = hlfir.declare %[[INNER_Y]] {{.*}}
438+
!CHECK: %[[TEMP:.*]] = fir.load %[[Y_DECL]]#0 : !fir.ref<i32>
439+
!CHECK: hlfir.assign %[[TEMP]] to %[[INNER_Y_DECL]]#0 temporary_lhs : i32, !fir.ref<i32>
440+
!CHECK: %[[INNER_Z:.*]] = fir.alloca i32 {bindc_name = "z", pinned, uniq_name = "_QFnested_constructsEz"}
441+
!CHECK: %[[INNER_Z_DECL:.*]]:2 = hlfir.declare %[[INNER_Z]] {{.*}}
442+
!$omp parallel default(private) firstprivate(y)
443+
!CHECK: {{.*}} = fir.do_loop {{.*}} {
444+
do i = 1, 10
445+
!CHECK: %[[CONST_1:.*]] = arith.constant 1 : i32
446+
!CHECK: hlfir.assign %[[CONST_1]] to %[[INNER_Y_DECL]]#0 : i32, !fir.ref<i32>
447+
y = 1
448+
!CHECK: {{.*}} = fir.do_loop {{.*}} {
449+
do j = 1, 10
450+
!CHECK: %[[CONST_20:.*]] = arith.constant 20 : i32
451+
!CHECK: hlfir.assign %[[CONST_20]] to %[[INNER_Z_DECL]]#0 : i32, !fir.ref<i32>
452+
z = 20
453+
!CHECK: omp.parallel {
454+
!CHECK: %[[NESTED_Y:.*]] = fir.alloca i32 {bindc_name = "y", pinned, uniq_name = "_QFnested_constructsEy"}
455+
!CHECK: %[[NESTED_Y_DECL:.*]]:2 = hlfir.declare %[[NESTED_Y]] {{.*}}
456+
!CHECK: %[[CONST_2:.*]] = arith.constant 2 : i32
457+
!CHECK: hlfir.assign %[[CONST_2]] to %[[NESTED_Y_DECL]]#0 : i32, !fir.ref<i32>
458+
!$omp parallel default(private)
459+
y = 2
460+
!$omp end parallel
461+
end do
462+
end do
463+
!$omp end parallel
464+
end subroutine

0 commit comments

Comments
 (0)