Skip to content

Commit 24bbcbd

Browse files
committed
Separate copy_in inlining into its own pass, add flag
Signed-off-by: Kajetan Puchalski <kajetan.puchalski@arm.com>
1 parent f34d4d5 commit 24bbcbd

File tree

7 files changed

+336
-266
lines changed

7 files changed

+336
-266
lines changed

flang/include/flang/Optimizer/HLFIR/Passes.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ def InlineHLFIRAssign : Pass<"inline-hlfir-assign"> {
6969
let summary = "Inline hlfir.assign operations";
7070
}
7171

72+
def InlineHLFIRCopyIn : Pass<"inline-hlfir-copy-in"> {
73+
let summary = "Inline hlfir.copy_in operations";
74+
}
75+
7276
def PropagateFortranVariableAttributes : Pass<"propagate-fortran-attrs"> {
7377
let summary = "Propagate FortranVariableFlagsAttr attributes through HLFIR";
7478
}

flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ add_flang_library(HLFIRTransforms
55
ConvertToFIR.cpp
66
InlineElementals.cpp
77
InlineHLFIRAssign.cpp
8+
InlineHLFIRCopyIn.cpp
89
LowerHLFIRIntrinsics.cpp
910
LowerHLFIROrderedAssignments.cpp
1011
ScheduleOrderedAssignments.cpp

flang/lib/Optimizer/HLFIR/Transforms/InlineHLFIRAssign.cpp

Lines changed: 0 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#include "flang/Optimizer/Analysis/AliasAnalysis.h"
1414
#include "flang/Optimizer/Builder/FIRBuilder.h"
1515
#include "flang/Optimizer/Builder/HLFIRTools.h"
16-
#include "flang/Optimizer/Dialect/FIRType.h"
1716
#include "flang/Optimizer/HLFIR/HLFIROps.h"
1817
#include "flang/Optimizer/HLFIR/Passes.h"
1918
#include "flang/Optimizer/OpenMP/Passes.h"
@@ -128,126 +127,6 @@ class InlineHLFIRAssignConversion
128127
}
129128
};
130129

131-
class InlineCopyInConversion : public mlir::OpRewritePattern<hlfir::CopyInOp> {
132-
public:
133-
using mlir::OpRewritePattern<hlfir::CopyInOp>::OpRewritePattern;
134-
135-
llvm::LogicalResult
136-
matchAndRewrite(hlfir::CopyInOp copyIn,
137-
mlir::PatternRewriter &rewriter) const override;
138-
};
139-
140-
llvm::LogicalResult
141-
InlineCopyInConversion::matchAndRewrite(hlfir::CopyInOp copyIn,
142-
mlir::PatternRewriter &rewriter) const {
143-
fir::FirOpBuilder builder(rewriter, copyIn.getOperation());
144-
mlir::Location loc = copyIn.getLoc();
145-
hlfir::Entity inputVariable{copyIn.getVar()};
146-
if (!fir::isa_trivial(inputVariable.getFortranElementType()))
147-
return rewriter.notifyMatchFailure(copyIn,
148-
"CopyInOp's data type is not trivial");
149-
150-
if (fir::isPointerType(inputVariable.getType()))
151-
return rewriter.notifyMatchFailure(
152-
copyIn, "CopyInOp's input variable is a pointer");
153-
154-
// There should be exactly one user of WasCopied - the corresponding
155-
// CopyOutOp.
156-
if (copyIn.getWasCopied().getUses().empty())
157-
return rewriter.notifyMatchFailure(copyIn,
158-
"CopyInOp's WasCopied has no uses");
159-
// The copy out should always be present, either to actually copy or just
160-
// deallocate memory.
161-
auto copyOut = mlir::dyn_cast<hlfir::CopyOutOp>(
162-
copyIn.getWasCopied().getUsers().begin().getCurrent().getUser());
163-
164-
if (!copyOut)
165-
return rewriter.notifyMatchFailure(copyIn,
166-
"CopyInOp has no direct CopyOut");
167-
168-
// Only inline the copy_in when copy_out does not need to be done, i.e. in
169-
// case of intent(in).
170-
if (copyOut.getVar())
171-
return rewriter.notifyMatchFailure(copyIn, "CopyIn needs a copy-out");
172-
173-
inputVariable =
174-
hlfir::derefPointersAndAllocatables(loc, builder, inputVariable);
175-
mlir::Type resultAddrType = copyIn.getCopiedIn().getType();
176-
mlir::Value isContiguous =
177-
builder.create<fir::IsContiguousBoxOp>(loc, inputVariable);
178-
mlir::Operation::result_range results =
179-
builder
180-
.genIfOp(loc, {resultAddrType, builder.getI1Type()}, isContiguous,
181-
/*withElseRegion=*/true)
182-
.genThen([&]() {
183-
mlir::Value falseVal = builder.create<mlir::arith::ConstantOp>(
184-
loc, builder.getI1Type(), builder.getBoolAttr(false));
185-
builder.create<fir::ResultOp>(
186-
loc, mlir::ValueRange{inputVariable, falseVal});
187-
})
188-
.genElse([&] {
189-
auto [temp, cleanup] =
190-
hlfir::createTempFromMold(loc, builder, inputVariable);
191-
mlir::Value shape = hlfir::genShape(loc, builder, inputVariable);
192-
llvm::SmallVector<mlir::Value> extents =
193-
hlfir::getIndexExtents(loc, builder, shape);
194-
hlfir::LoopNest loopNest = hlfir::genLoopNest(
195-
loc, builder, extents, /*isUnordered=*/true,
196-
flangomp::shouldUseWorkshareLowering(copyIn));
197-
builder.setInsertionPointToStart(loopNest.body);
198-
hlfir::Entity elem = hlfir::getElementAt(
199-
loc, builder, inputVariable, loopNest.oneBasedIndices);
200-
elem = hlfir::loadTrivialScalar(loc, builder, elem);
201-
hlfir::Entity tempElem = hlfir::getElementAt(
202-
loc, builder, temp, loopNest.oneBasedIndices);
203-
builder.create<hlfir::AssignOp>(loc, elem, tempElem);
204-
builder.setInsertionPointAfter(loopNest.outerOp);
205-
206-
mlir::Value result;
207-
// Make sure the result is always a boxed array by boxing it
208-
// ourselves if need be.
209-
if (mlir::isa<fir::BaseBoxType>(temp.getType())) {
210-
result = temp;
211-
} else {
212-
fir::ReferenceType refTy =
213-
fir::ReferenceType::get(temp.getElementOrSequenceType());
214-
mlir::Value refVal = builder.createConvert(loc, refTy, temp);
215-
result =
216-
builder.create<fir::EmboxOp>(loc, resultAddrType, refVal);
217-
}
218-
219-
builder.create<fir::ResultOp>(loc,
220-
mlir::ValueRange{result, cleanup});
221-
})
222-
.getResults();
223-
224-
mlir::OpResult addr = results[0];
225-
mlir::OpResult needsCleanup = results[1];
226-
227-
builder.setInsertionPoint(copyOut);
228-
builder.genIfOp(loc, {}, needsCleanup, /*withElseRegion=*/false).genThen([&] {
229-
auto boxAddr = builder.create<fir::BoxAddrOp>(loc, addr);
230-
fir::HeapType heapType =
231-
fir::HeapType::get(fir::BoxValue(addr).getBaseTy());
232-
mlir::Value heapVal =
233-
builder.createConvert(loc, heapType, boxAddr.getResult());
234-
builder.create<fir::FreeMemOp>(loc, heapVal);
235-
});
236-
rewriter.eraseOp(copyOut);
237-
238-
mlir::Value tempBox = copyIn.getTempBox();
239-
240-
rewriter.replaceOp(copyIn, {addr, builder.genNot(loc, isContiguous)});
241-
242-
// The TempBox is only needed for flang-rt calls which we're no longer
243-
// generating. It should have no uses left at this stage.
244-
if (!tempBox.getUses().empty())
245-
return mlir::failure();
246-
rewriter.eraseOp(tempBox.getDefiningOp());
247-
248-
return mlir::success();
249-
}
250-
251130
class InlineHLFIRAssignPass
252131
: public hlfir::impl::InlineHLFIRAssignBase<InlineHLFIRAssignPass> {
253132
public:
@@ -261,7 +140,6 @@ class InlineHLFIRAssignPass
261140

262141
mlir::RewritePatternSet patterns(context);
263142
patterns.insert<InlineHLFIRAssignConversion>(context);
264-
patterns.insert<InlineCopyInConversion>(context);
265143

266144
if (mlir::failed(mlir::applyPatternsGreedily(
267145
getOperation(), std::move(patterns), config))) {
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
//===- InlineHLFIRCopyIn.cpp - Inline hlfir.copy_in ops -------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// Transform hlfir.copy_in array operations into loop nests performing element
9+
// per element assignments. For simplicity, the inlining is done for trivial
10+
// data types when the copy_in does not require a corresponding copy_out and
11+
// when the input array is not behind a pointer. This may change in the future.
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "flang/Optimizer/Builder/FIRBuilder.h"
15+
#include "flang/Optimizer/Builder/HLFIRTools.h"
16+
#include "flang/Optimizer/Dialect/FIRType.h"
17+
#include "flang/Optimizer/HLFIR/HLFIROps.h"
18+
#include "flang/Optimizer/OpenMP/Passes.h"
19+
#include "mlir/IR/PatternMatch.h"
20+
#include "mlir/Support/LLVM.h"
21+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22+
23+
namespace hlfir {
24+
#define GEN_PASS_DEF_INLINEHLFIRCOPYIN
25+
#include "flang/Optimizer/HLFIR/Passes.h.inc"
26+
} // namespace hlfir
27+
28+
#define DEBUG_TYPE "inline-hlfir-copy-in"
29+
30+
static llvm::cl::opt<bool> noInlineHLFIRCopyIn(
31+
"no-inline-hlfir-copy-in",
32+
llvm::cl::desc("Do not inline hlfir.copy_in operations"),
33+
llvm::cl::init(false));
34+
35+
namespace {
36+
class InlineCopyInConversion : public mlir::OpRewritePattern<hlfir::CopyInOp> {
37+
public:
38+
using mlir::OpRewritePattern<hlfir::CopyInOp>::OpRewritePattern;
39+
40+
llvm::LogicalResult
41+
matchAndRewrite(hlfir::CopyInOp copyIn,
42+
mlir::PatternRewriter &rewriter) const override;
43+
};
44+
45+
llvm::LogicalResult
46+
InlineCopyInConversion::matchAndRewrite(hlfir::CopyInOp copyIn,
47+
mlir::PatternRewriter &rewriter) const {
48+
fir::FirOpBuilder builder(rewriter, copyIn.getOperation());
49+
mlir::Location loc = copyIn.getLoc();
50+
hlfir::Entity inputVariable{copyIn.getVar()};
51+
if (!fir::isa_trivial(inputVariable.getFortranElementType()))
52+
return rewriter.notifyMatchFailure(copyIn,
53+
"CopyInOp's data type is not trivial");
54+
55+
if (fir::isPointerType(inputVariable.getType()))
56+
return rewriter.notifyMatchFailure(
57+
copyIn, "CopyInOp's input variable is a pointer");
58+
59+
// There should be exactly one user of WasCopied - the corresponding
60+
// CopyOutOp.
61+
if (copyIn.getWasCopied().getUses().empty())
62+
return rewriter.notifyMatchFailure(copyIn,
63+
"CopyInOp's WasCopied has no uses");
64+
// The copy out should always be present, either to actually copy or just
65+
// deallocate memory.
66+
auto copyOut = mlir::dyn_cast<hlfir::CopyOutOp>(
67+
copyIn.getWasCopied().getUsers().begin().getCurrent().getUser());
68+
69+
if (!copyOut)
70+
return rewriter.notifyMatchFailure(copyIn,
71+
"CopyInOp has no direct CopyOut");
72+
73+
// Only inline the copy_in when copy_out does not need to be done, i.e. in
74+
// case of intent(in).
75+
if (copyOut.getVar())
76+
return rewriter.notifyMatchFailure(copyIn, "CopyIn needs a copy-out");
77+
78+
inputVariable =
79+
hlfir::derefPointersAndAllocatables(loc, builder, inputVariable);
80+
mlir::Type resultAddrType = copyIn.getCopiedIn().getType();
81+
mlir::Value isContiguous =
82+
builder.create<fir::IsContiguousBoxOp>(loc, inputVariable);
83+
mlir::Operation::result_range results =
84+
builder
85+
.genIfOp(loc, {resultAddrType, builder.getI1Type()}, isContiguous,
86+
/*withElseRegion=*/true)
87+
.genThen([&]() {
88+
mlir::Value falseVal = builder.create<mlir::arith::ConstantOp>(
89+
loc, builder.getI1Type(), builder.getBoolAttr(false));
90+
builder.create<fir::ResultOp>(
91+
loc, mlir::ValueRange{inputVariable, falseVal});
92+
})
93+
.genElse([&] {
94+
auto [temp, cleanup] =
95+
hlfir::createTempFromMold(loc, builder, inputVariable);
96+
mlir::Value shape = hlfir::genShape(loc, builder, inputVariable);
97+
llvm::SmallVector<mlir::Value> extents =
98+
hlfir::getIndexExtents(loc, builder, shape);
99+
hlfir::LoopNest loopNest = hlfir::genLoopNest(
100+
loc, builder, extents, /*isUnordered=*/true,
101+
flangomp::shouldUseWorkshareLowering(copyIn));
102+
builder.setInsertionPointToStart(loopNest.body);
103+
hlfir::Entity elem = hlfir::getElementAt(
104+
loc, builder, inputVariable, loopNest.oneBasedIndices);
105+
elem = hlfir::loadTrivialScalar(loc, builder, elem);
106+
hlfir::Entity tempElem = hlfir::getElementAt(
107+
loc, builder, temp, loopNest.oneBasedIndices);
108+
builder.create<hlfir::AssignOp>(loc, elem, tempElem);
109+
builder.setInsertionPointAfter(loopNest.outerOp);
110+
111+
mlir::Value result;
112+
// Make sure the result is always a boxed array by boxing it
113+
// ourselves if need be.
114+
if (mlir::isa<fir::BaseBoxType>(temp.getType())) {
115+
result = temp;
116+
} else {
117+
fir::ReferenceType refTy =
118+
fir::ReferenceType::get(temp.getElementOrSequenceType());
119+
mlir::Value refVal = builder.createConvert(loc, refTy, temp);
120+
result =
121+
builder.create<fir::EmboxOp>(loc, resultAddrType, refVal);
122+
}
123+
124+
builder.create<fir::ResultOp>(loc,
125+
mlir::ValueRange{result, cleanup});
126+
})
127+
.getResults();
128+
129+
mlir::OpResult addr = results[0];
130+
mlir::OpResult needsCleanup = results[1];
131+
132+
builder.setInsertionPoint(copyOut);
133+
builder.genIfOp(loc, {}, needsCleanup, /*withElseRegion=*/false).genThen([&] {
134+
auto boxAddr = builder.create<fir::BoxAddrOp>(loc, addr);
135+
fir::HeapType heapType =
136+
fir::HeapType::get(fir::BoxValue(addr).getBaseTy());
137+
mlir::Value heapVal =
138+
builder.createConvert(loc, heapType, boxAddr.getResult());
139+
builder.create<fir::FreeMemOp>(loc, heapVal);
140+
});
141+
rewriter.eraseOp(copyOut);
142+
143+
mlir::Value tempBox = copyIn.getTempBox();
144+
145+
rewriter.replaceOp(copyIn, {addr, builder.genNot(loc, isContiguous)});
146+
147+
// The TempBox is only needed for flang-rt calls which we're no longer
148+
// generating. It should have no uses left at this stage.
149+
if (!tempBox.getUses().empty())
150+
return mlir::failure();
151+
rewriter.eraseOp(tempBox.getDefiningOp());
152+
153+
return mlir::success();
154+
}
155+
156+
class InlineHLFIRCopyInPass
157+
: public hlfir::impl::InlineHLFIRCopyInBase<InlineHLFIRCopyInPass> {
158+
public:
159+
void runOnOperation() override {
160+
mlir::MLIRContext *context = &getContext();
161+
162+
mlir::GreedyRewriteConfig config;
163+
// Prevent the pattern driver from merging blocks.
164+
config.setRegionSimplificationLevel(
165+
mlir::GreedySimplifyRegionLevel::Disabled);
166+
167+
mlir::RewritePatternSet patterns(context);
168+
if (!noInlineHLFIRCopyIn) {
169+
patterns.insert<InlineCopyInConversion>(context);
170+
}
171+
172+
if (mlir::failed(mlir::applyPatternsGreedily(
173+
getOperation(), std::move(patterns), config))) {
174+
mlir::emitError(getOperation()->getLoc(),
175+
"failure in hlfir.copy_in inlining");
176+
signalPassFailure();
177+
}
178+
}
179+
};
180+
} // namespace

flang/lib/Optimizer/Passes/Pipelines.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,11 @@ void createHLFIRToFIRPassPipeline(mlir::PassManager &pm, bool enableOpenMP,
255255
pm, hlfir::createOptimizedBufferization);
256256
addNestedPassToAllTopLevelOperations<PassConstructor>(
257257
pm, hlfir::createInlineHLFIRAssign);
258+
259+
if (optLevel == llvm::OptimizationLevel::O3) {
260+
addNestedPassToAllTopLevelOperations<PassConstructor>(
261+
pm, hlfir::createInlineHLFIRCopyIn);
262+
}
258263
}
259264
pm.addPass(hlfir::createLowerHLFIROrderedAssignments());
260265
pm.addPass(hlfir::createLowerHLFIRIntrinsics());

0 commit comments

Comments
 (0)