|
| 1 | +//===- InlineHLFIRCopyIn.cpp - Inline hlfir.copy_in ops -------------------===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +// Transform hlfir.copy_in array operations into loop nests performing element |
| 9 | +// per element assignments. For simplicity, the inlining is done for trivial |
| 10 | +// data types when the copy_in does not require a corresponding copy_out and |
| 11 | +// when the input array is not behind a pointer. This may change in the future. |
| 12 | +//===----------------------------------------------------------------------===// |
| 13 | + |
| 14 | +#include "flang/Optimizer/Builder/FIRBuilder.h" |
| 15 | +#include "flang/Optimizer/Builder/HLFIRTools.h" |
| 16 | +#include "flang/Optimizer/Dialect/FIRType.h" |
| 17 | +#include "flang/Optimizer/HLFIR/HLFIROps.h" |
| 18 | +#include "flang/Optimizer/OpenMP/Passes.h" |
| 19 | +#include "mlir/IR/PatternMatch.h" |
| 20 | +#include "mlir/Support/LLVM.h" |
| 21 | +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 22 | + |
| 23 | +namespace hlfir { |
| 24 | +#define GEN_PASS_DEF_INLINEHLFIRCOPYIN |
| 25 | +#include "flang/Optimizer/HLFIR/Passes.h.inc" |
| 26 | +} // namespace hlfir |
| 27 | + |
| 28 | +#define DEBUG_TYPE "inline-hlfir-copy-in" |
| 29 | + |
| 30 | +static llvm::cl::opt<bool> noInlineHLFIRCopyIn( |
| 31 | + "no-inline-hlfir-copy-in", |
| 32 | + llvm::cl::desc("Do not inline hlfir.copy_in operations"), |
| 33 | + llvm::cl::init(false)); |
| 34 | + |
| 35 | +namespace { |
| 36 | +class InlineCopyInConversion : public mlir::OpRewritePattern<hlfir::CopyInOp> { |
| 37 | +public: |
| 38 | + using mlir::OpRewritePattern<hlfir::CopyInOp>::OpRewritePattern; |
| 39 | + |
| 40 | + llvm::LogicalResult |
| 41 | + matchAndRewrite(hlfir::CopyInOp copyIn, |
| 42 | + mlir::PatternRewriter &rewriter) const override; |
| 43 | +}; |
| 44 | + |
| 45 | +llvm::LogicalResult |
| 46 | +InlineCopyInConversion::matchAndRewrite(hlfir::CopyInOp copyIn, |
| 47 | + mlir::PatternRewriter &rewriter) const { |
| 48 | + fir::FirOpBuilder builder(rewriter, copyIn.getOperation()); |
| 49 | + mlir::Location loc = copyIn.getLoc(); |
| 50 | + hlfir::Entity inputVariable{copyIn.getVar()}; |
| 51 | + if (!fir::isa_trivial(inputVariable.getFortranElementType())) |
| 52 | + return rewriter.notifyMatchFailure(copyIn, |
| 53 | + "CopyInOp's data type is not trivial"); |
| 54 | + |
| 55 | + if (fir::isPointerType(inputVariable.getType())) |
| 56 | + return rewriter.notifyMatchFailure( |
| 57 | + copyIn, "CopyInOp's input variable is a pointer"); |
| 58 | + |
| 59 | + // There should be exactly one user of WasCopied - the corresponding |
| 60 | + // CopyOutOp. |
| 61 | + if (copyIn.getWasCopied().getUses().empty()) |
| 62 | + return rewriter.notifyMatchFailure(copyIn, |
| 63 | + "CopyInOp's WasCopied has no uses"); |
| 64 | + // The copy out should always be present, either to actually copy or just |
| 65 | + // deallocate memory. |
| 66 | + auto copyOut = mlir::dyn_cast<hlfir::CopyOutOp>( |
| 67 | + copyIn.getWasCopied().getUsers().begin().getCurrent().getUser()); |
| 68 | + |
| 69 | + if (!copyOut) |
| 70 | + return rewriter.notifyMatchFailure(copyIn, |
| 71 | + "CopyInOp has no direct CopyOut"); |
| 72 | + |
| 73 | + // Only inline the copy_in when copy_out does not need to be done, i.e. in |
| 74 | + // case of intent(in). |
| 75 | + if (copyOut.getVar()) |
| 76 | + return rewriter.notifyMatchFailure(copyIn, "CopyIn needs a copy-out"); |
| 77 | + |
| 78 | + inputVariable = |
| 79 | + hlfir::derefPointersAndAllocatables(loc, builder, inputVariable); |
| 80 | + mlir::Type resultAddrType = copyIn.getCopiedIn().getType(); |
| 81 | + mlir::Value isContiguous = |
| 82 | + builder.create<fir::IsContiguousBoxOp>(loc, inputVariable); |
| 83 | + mlir::Operation::result_range results = |
| 84 | + builder |
| 85 | + .genIfOp(loc, {resultAddrType, builder.getI1Type()}, isContiguous, |
| 86 | + /*withElseRegion=*/true) |
| 87 | + .genThen([&]() { |
| 88 | + mlir::Value falseVal = builder.create<mlir::arith::ConstantOp>( |
| 89 | + loc, builder.getI1Type(), builder.getBoolAttr(false)); |
| 90 | + builder.create<fir::ResultOp>( |
| 91 | + loc, mlir::ValueRange{inputVariable, falseVal}); |
| 92 | + }) |
| 93 | + .genElse([&] { |
| 94 | + auto [temp, cleanup] = |
| 95 | + hlfir::createTempFromMold(loc, builder, inputVariable); |
| 96 | + mlir::Value shape = hlfir::genShape(loc, builder, inputVariable); |
| 97 | + llvm::SmallVector<mlir::Value> extents = |
| 98 | + hlfir::getIndexExtents(loc, builder, shape); |
| 99 | + hlfir::LoopNest loopNest = hlfir::genLoopNest( |
| 100 | + loc, builder, extents, /*isUnordered=*/true, |
| 101 | + flangomp::shouldUseWorkshareLowering(copyIn)); |
| 102 | + builder.setInsertionPointToStart(loopNest.body); |
| 103 | + hlfir::Entity elem = hlfir::getElementAt( |
| 104 | + loc, builder, inputVariable, loopNest.oneBasedIndices); |
| 105 | + elem = hlfir::loadTrivialScalar(loc, builder, elem); |
| 106 | + hlfir::Entity tempElem = hlfir::getElementAt( |
| 107 | + loc, builder, temp, loopNest.oneBasedIndices); |
| 108 | + builder.create<hlfir::AssignOp>(loc, elem, tempElem); |
| 109 | + builder.setInsertionPointAfter(loopNest.outerOp); |
| 110 | + |
| 111 | + mlir::Value result; |
| 112 | + // Make sure the result is always a boxed array by boxing it |
| 113 | + // ourselves if need be. |
| 114 | + if (mlir::isa<fir::BaseBoxType>(temp.getType())) { |
| 115 | + result = temp; |
| 116 | + } else { |
| 117 | + fir::ReferenceType refTy = |
| 118 | + fir::ReferenceType::get(temp.getElementOrSequenceType()); |
| 119 | + mlir::Value refVal = builder.createConvert(loc, refTy, temp); |
| 120 | + result = |
| 121 | + builder.create<fir::EmboxOp>(loc, resultAddrType, refVal); |
| 122 | + } |
| 123 | + |
| 124 | + builder.create<fir::ResultOp>(loc, |
| 125 | + mlir::ValueRange{result, cleanup}); |
| 126 | + }) |
| 127 | + .getResults(); |
| 128 | + |
| 129 | + mlir::OpResult addr = results[0]; |
| 130 | + mlir::OpResult needsCleanup = results[1]; |
| 131 | + |
| 132 | + builder.setInsertionPoint(copyOut); |
| 133 | + builder.genIfOp(loc, {}, needsCleanup, /*withElseRegion=*/false).genThen([&] { |
| 134 | + auto boxAddr = builder.create<fir::BoxAddrOp>(loc, addr); |
| 135 | + fir::HeapType heapType = |
| 136 | + fir::HeapType::get(fir::BoxValue(addr).getBaseTy()); |
| 137 | + mlir::Value heapVal = |
| 138 | + builder.createConvert(loc, heapType, boxAddr.getResult()); |
| 139 | + builder.create<fir::FreeMemOp>(loc, heapVal); |
| 140 | + }); |
| 141 | + rewriter.eraseOp(copyOut); |
| 142 | + |
| 143 | + mlir::Value tempBox = copyIn.getTempBox(); |
| 144 | + |
| 145 | + rewriter.replaceOp(copyIn, {addr, builder.genNot(loc, isContiguous)}); |
| 146 | + |
| 147 | + // The TempBox is only needed for flang-rt calls which we're no longer |
| 148 | + // generating. It should have no uses left at this stage. |
| 149 | + if (!tempBox.getUses().empty()) |
| 150 | + return mlir::failure(); |
| 151 | + rewriter.eraseOp(tempBox.getDefiningOp()); |
| 152 | + |
| 153 | + return mlir::success(); |
| 154 | +} |
| 155 | + |
| 156 | +class InlineHLFIRCopyInPass |
| 157 | + : public hlfir::impl::InlineHLFIRCopyInBase<InlineHLFIRCopyInPass> { |
| 158 | +public: |
| 159 | + void runOnOperation() override { |
| 160 | + mlir::MLIRContext *context = &getContext(); |
| 161 | + |
| 162 | + mlir::GreedyRewriteConfig config; |
| 163 | + // Prevent the pattern driver from merging blocks. |
| 164 | + config.setRegionSimplificationLevel( |
| 165 | + mlir::GreedySimplifyRegionLevel::Disabled); |
| 166 | + |
| 167 | + mlir::RewritePatternSet patterns(context); |
| 168 | + if (!noInlineHLFIRCopyIn) { |
| 169 | + patterns.insert<InlineCopyInConversion>(context); |
| 170 | + } |
| 171 | + |
| 172 | + if (mlir::failed(mlir::applyPatternsGreedily( |
| 173 | + getOperation(), std::move(patterns), config))) { |
| 174 | + mlir::emitError(getOperation()->getLoc(), |
| 175 | + "failure in hlfir.copy_in inlining"); |
| 176 | + signalPassFailure(); |
| 177 | + } |
| 178 | + } |
| 179 | +}; |
| 180 | +} // namespace |
0 commit comments