Skip to content

[flang][hlfir] optimize hlfir.eval_in_mem bufferization #118069

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions flang/include/flang/Optimizer/Analysis/AliasAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,12 @@ struct AliasAnalysis {
/// Return the modify-reference behavior of `op` on `location`.
mlir::ModRefResult getModRef(mlir::Operation *op, mlir::Value location);

/// Return the modify-reference behavior of operations inside `region` on
/// `location`. Contrary to getModRef(operation, location), this will visit
/// nested regions recursively according to the HasRecursiveMemoryEffects
/// trait.
mlir::ModRefResult getModRef(mlir::Region &region, mlir::Value location);

/// Return the memory source of a value.
/// If getLastInstantiationPoint is true, the search for the source
/// will stop at [hl]fir.declare if it represents a dummy
Expand Down
41 changes: 40 additions & 1 deletion flang/lib/Optimizer/Analysis/AliasAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,13 @@ bool AliasAnalysis::Source::isDummyArgument() const {
return false;
}

static bool isEvaluateInMemoryBlockArg(mlir::Value v) {
if (auto evalInMem = llvm::dyn_cast_or_null<hlfir::EvaluateInMemoryOp>(
v.getParentRegion()->getParentOp()))
return evalInMem.getMemory() == v;
return false;
}

bool AliasAnalysis::Source::isData() const { return origin.isData; }
bool AliasAnalysis::Source::isBoxData() const {
return mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(valueType)) &&
Expand Down Expand Up @@ -457,6 +464,33 @@ ModRefResult AliasAnalysis::getModRef(Operation *op, Value location) {
return result;
}

ModRefResult AliasAnalysis::getModRef(mlir::Region &region,
mlir::Value location) {
ModRefResult result = ModRefResult::getNoModRef();
for (mlir::Operation &op : region.getOps()) {
if (op.hasTrait<mlir::OpTrait::HasRecursiveMemoryEffects>()) {
for (mlir::Region &subRegion : op.getRegions()) {
result = result.merge(getModRef(subRegion, location));
// Fast return is already mod and ref.
if (result.isModAndRef())
return result;
}
// In MLIR, RecursiveMemoryEffects can be combined with
// MemoryEffectOpInterface to describe extra effects on top of the
// effects of the nested operations. However, the presence of
// RecursiveMemoryEffects and the absence of MemoryEffectOpInterface
// implies the operation has no other memory effects than the one of its
// nested operations.
if (!mlir::isa<mlir::MemoryEffectOpInterface>(op))
continue;
}
result = result.merge(getModRef(&op, location));
if (result.isModAndRef())
return result;
}
return result;
}

AliasAnalysis::Source::Attributes
getAttrsFromVariable(fir::FortranVariableOpInterface var) {
AliasAnalysis::Source::Attributes attrs;
Expand Down Expand Up @@ -698,7 +732,7 @@ AliasAnalysis::Source AliasAnalysis::getSource(mlir::Value v,
breakFromLoop = true;
});
}
if (!defOp && type == SourceKind::Unknown)
if (!defOp && type == SourceKind::Unknown) {
// Check if the memory source is coming through a dummy argument.
if (isDummyArgument(v)) {
type = SourceKind::Argument;
Expand All @@ -708,7 +742,12 @@ AliasAnalysis::Source AliasAnalysis::getSource(mlir::Value v,

if (isPointerReference(ty))
attributes.set(Attribute::Pointer);
} else if (isEvaluateInMemoryBlockArg(v)) {
// hlfir.eval_in_mem block operands is allocated by the operation.
type = SourceKind::Allocate;
ty = v.getType();
}
}

if (type == SourceKind::Global) {
return {{global, instantiationPoint, followingData},
Expand Down
95 changes: 95 additions & 0 deletions flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1108,6 +1108,100 @@ class ReductionMaskConversion : public mlir::OpRewritePattern<Op> {
}
};

class EvaluateIntoMemoryAssignBufferization
: public mlir::OpRewritePattern<hlfir::EvaluateInMemoryOp> {

public:
using mlir::OpRewritePattern<hlfir::EvaluateInMemoryOp>::OpRewritePattern;

llvm::LogicalResult
matchAndRewrite(hlfir::EvaluateInMemoryOp,
mlir::PatternRewriter &rewriter) const override;
};

static llvm::LogicalResult
tryUsingAssignLhsDirectly(hlfir::EvaluateInMemoryOp evalInMem,
mlir::PatternRewriter &rewriter) {
mlir::Location loc = evalInMem.getLoc();
hlfir::DestroyOp destroy;
hlfir::AssignOp assign;
for (auto user : llvm::enumerate(evalInMem->getUsers())) {
if (user.index() > 2)
return mlir::failure();
mlir::TypeSwitch<mlir::Operation *, void>(user.value())
.Case([&](hlfir::AssignOp op) { assign = op; })
.Case([&](hlfir::DestroyOp op) { destroy = op; });
}
if (!assign || !destroy || destroy.mustFinalizeExpr() ||
assign.isAllocatableAssignment())
return mlir::failure();

hlfir::Entity lhs{assign.getLhs()};
// EvaluateInMemoryOp memory is contiguous, so in general, it can only be
// replace by the LHS if the LHS is contiguous.
if (!lhs.isSimplyContiguous())
return mlir::failure();
// Character assignment may involves truncation/padding, so the LHS
// cannot be used to evaluate RHS in place without proving the LHS and
// RHS lengths are the same.
if (lhs.isCharacter())
return mlir::failure();
fir::AliasAnalysis aliasAnalysis;
// The region must not read or write the LHS.
// Note that getModRef is used instead of mlir::MemoryEffects because
// EvaluateInMemoryOp is typically expected to hold fir.calls and that
// Fortran calls cannot be modeled in a useful way with mlir::MemoryEffects:
// it is hard/impossible to list all the read/written SSA values in a call,
// but it is often possible to tell that an SSA value cannot be accessed,
// hence getModRef is needed here and below. Also note that getModRef uses
// mlir::MemoryEffects for operations that do not have special handling in
// getModRef.
if (aliasAnalysis.getModRef(evalInMem.getBody(), lhs).isModOrRef())
return mlir::failure();
// Any variables affected between the hlfir.evalInMem and assignment must not
// be read or written inside the region since it will be moved at the
// assignment insertion point.
auto effects = getEffectsBetween(evalInMem->getNextNode(), assign);
if (!effects) {
LLVM_DEBUG(
llvm::dbgs()
<< "operation with unknown effects between eval_in_mem and assign\n");
return mlir::failure();
}
for (const mlir::MemoryEffects::EffectInstance &effect : *effects) {
mlir::Value affected = effect.getValue();
if (!affected ||
aliasAnalysis.getModRef(evalInMem.getBody(), affected).isModOrRef())
return mlir::failure();
}

rewriter.setInsertionPoint(assign);
fir::FirOpBuilder builder(rewriter, evalInMem.getOperation());
mlir::Value rawLhs = hlfir::genVariableRawAddress(loc, builder, lhs);
hlfir::computeEvaluateOpIn(loc, builder, evalInMem, rawLhs);
rewriter.eraseOp(assign);
rewriter.eraseOp(destroy);
rewriter.eraseOp(evalInMem);
return mlir::success();
}

llvm::LogicalResult EvaluateIntoMemoryAssignBufferization::matchAndRewrite(
hlfir::EvaluateInMemoryOp evalInMem,
mlir::PatternRewriter &rewriter) const {
if (mlir::succeeded(tryUsingAssignLhsDirectly(evalInMem, rewriter)))
return mlir::success();
// Rewrite to temp + as_expr here so that the assign + as_expr pattern can
// kick-in for simple types and at least implement the assignment inline
// instead of call Assign runtime.
fir::FirOpBuilder builder(rewriter, evalInMem.getOperation());
mlir::Location loc = evalInMem.getLoc();
auto [temp, isHeapAllocated] = hlfir::computeEvaluateOpInNewTemp(
loc, builder, evalInMem, evalInMem.getShape(), evalInMem.getTypeparams());
rewriter.replaceOpWithNewOp<hlfir::AsExprOp>(
evalInMem, temp, /*mustFree=*/builder.createBool(loc, isHeapAllocated));
return mlir::success();
}

class OptimizedBufferizationPass
: public hlfir::impl::OptimizedBufferizationBase<
OptimizedBufferizationPass> {
Expand All @@ -1130,6 +1224,7 @@ class OptimizedBufferizationPass
patterns.insert<ElementalAssignBufferization>(context);
patterns.insert<BroadcastAssignBufferization>(context);
patterns.insert<VariableAssignBufferization>(context);
patterns.insert<EvaluateIntoMemoryAssignBufferization>(context);
patterns.insert<ReductionConversion<hlfir::CountOp>>(context);
patterns.insert<ReductionConversion<hlfir::AnyOp>>(context);
patterns.insert<ReductionConversion<hlfir::AllOp>>(context);
Expand Down
67 changes: 67 additions & 0 deletions flang/test/HLFIR/opt-bufferization-eval_in_mem.fir
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// RUN: fir-opt --opt-bufferization %s | FileCheck %s

// Fortran F2023 15.5.2.14 point 4. ensures that _QPfoo cannot access _QFtestEx
// and the temporary storage for the result can be avoided.
func.func @_QPtest(%arg0: !fir.ref<!fir.array<10xf32>> {fir.bindc_name = "x"}) {
%c10 = arith.constant 10 : index
%0 = fir.dummy_scope : !fir.dscope
%1 = fir.shape %c10 : (index) -> !fir.shape<1>
%2:2 = hlfir.declare %arg0(%1) dummy_scope %0 {uniq_name = "_QFtestEx"} : (!fir.ref<!fir.array<10xf32>>, !fir.shape<1>, !fir.dscope) -> (!fir.ref<!fir.array<10xf32>>, !fir.ref<!fir.array<10xf32>>)
%3 = hlfir.eval_in_mem shape %1 : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
^bb0(%arg1: !fir.ref<!fir.array<10xf32>>):
%4 = fir.call @_QPfoo() fastmath<contract> : () -> !fir.array<10xf32>
fir.save_result %4 to %arg1(%1) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
}
hlfir.assign %3 to %2#0 : !hlfir.expr<10xf32>, !fir.ref<!fir.array<10xf32>>
hlfir.destroy %3 : !hlfir.expr<10xf32>
return
}
func.func private @_QPfoo() -> !fir.array<10xf32>

// CHECK-LABEL: func.func @_QPtest(
// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.array<10xf32>> {fir.bindc_name = "x"}) {
// CHECK: %[[VAL_1:.*]] = arith.constant 10 : index
// CHECK: %[[VAL_2:.*]] = fir.dummy_scope : !fir.dscope
// CHECK: %[[VAL_3:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1>
// CHECK: %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_3]]) dummy_scope %[[VAL_2]] {uniq_name = "_QFtestEx"} : (!fir.ref<!fir.array<10xf32>>, !fir.shape<1>, !fir.dscope) -> (!fir.ref<!fir.array<10xf32>>, !fir.ref<!fir.array<10xf32>>)
// CHECK: %[[VAL_5:.*]] = fir.call @_QPfoo() fastmath<contract> : () -> !fir.array<10xf32>
// CHECK: fir.save_result %[[VAL_5]] to %[[VAL_4]]#1(%[[VAL_3]]) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
// CHECK: return
// CHECK: }


// Temporary storage cannot be avoided in this case since
// _QFnegative_test_is_targetEx has the TARGET attribute.
func.func @_QPnegative_test_is_target(%arg0: !fir.ref<!fir.array<10xf32>> {fir.bindc_name = "x", fir.target}) {
%c10 = arith.constant 10 : index
%0 = fir.dummy_scope : !fir.dscope
%1 = fir.shape %c10 : (index) -> !fir.shape<1>
%2:2 = hlfir.declare %arg0(%1) dummy_scope %0 {fortran_attrs = #fir.var_attrs<target>, uniq_name = "_QFnegative_test_is_targetEx"} : (!fir.ref<!fir.array<10xf32>>, !fir.shape<1>, !fir.dscope) -> (!fir.ref<!fir.array<10xf32>>, !fir.ref<!fir.array<10xf32>>)
%3 = hlfir.eval_in_mem shape %1 : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
^bb0(%arg1: !fir.ref<!fir.array<10xf32>>):
%4 = fir.call @_QPfoo() fastmath<contract> : () -> !fir.array<10xf32>
fir.save_result %4 to %arg1(%1) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
}
hlfir.assign %3 to %2#0 : !hlfir.expr<10xf32>, !fir.ref<!fir.array<10xf32>>
hlfir.destroy %3 : !hlfir.expr<10xf32>
return
}
// CHECK-LABEL: func.func @_QPnegative_test_is_target(
// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.array<10xf32>> {fir.bindc_name = "x", fir.target}) {
// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_2:.*]] = arith.constant false
// CHECK: %[[VAL_3:.*]] = arith.constant 10 : index
// CHECK: %[[VAL_4:.*]] = fir.alloca !fir.array<10xf32>
// CHECK: %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_0]]{{.*}}
// CHECK: %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_4]]{{.*}}
// CHECK: %[[VAL_9:.*]] = fir.call @_QPfoo() fastmath<contract> : () -> !fir.array<10xf32>
// CHECK: fir.save_result %[[VAL_9]] to %[[VAL_8]]#1{{.*}}
// CHECK: %[[VAL_10:.*]] = hlfir.as_expr %[[VAL_8]]#0 move %[[VAL_2]] : (!fir.ref<!fir.array<10xf32>>, i1) -> !hlfir.expr<10xf32>
// CHECK: fir.do_loop %[[VAL_11:.*]] = %[[VAL_1]] to %[[VAL_3]] step %[[VAL_1]] unordered {
// CHECK: %[[VAL_12:.*]] = hlfir.apply %[[VAL_10]], %[[VAL_11]] : (!hlfir.expr<10xf32>, index) -> f32
// CHECK: %[[VAL_13:.*]] = hlfir.designate %[[VAL_7]]#0 (%[[VAL_11]]) : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
// CHECK: hlfir.assign %[[VAL_12]] to %[[VAL_13]] : f32, !fir.ref<f32>
// CHECK: }
// CHECK: hlfir.destroy %[[VAL_10]] : !hlfir.expr<10xf32>
// CHECK: return
// CHECK: }
Loading