Skip to content

Commit 9bb47f7

Browse files
authored
[Flang] Add Maxloc to fir simplify intrinsics pass (llvm#75463)
This takes the code from D144103 and extends it to maxloc, to allow the simplifyMinMaxlocReduction method to work with both min and max intrinsics by switching condition and limit/initial value.
1 parent 111a229 commit 9bb47f7

File tree

2 files changed

+293
-36
lines changed

2 files changed

+293
-36
lines changed

flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp

Lines changed: 47 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ class SimplifyIntrinsicsPass
9999
void simplifyLogicalDim1Reduction(fir::CallOp call,
100100
const fir::KindMapping &kindMap,
101101
GenReductionBodyTy genBodyFunc);
102-
void simplifyMinlocReduction(fir::CallOp call,
103-
const fir::KindMapping &kindMap);
102+
void simplifyMinMaxlocReduction(fir::CallOp call,
103+
const fir::KindMapping &kindMap, bool isMax);
104104
void simplifyReductionBody(fir::CallOp call, const fir::KindMapping &kindMap,
105105
GenReductionBodyTy genBodyFunc,
106106
fir::FirOpBuilder &builder,
@@ -353,16 +353,15 @@ genReductionLoop(fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp,
353353
// Return the reduction value from the function.
354354
builder.create<mlir::func::ReturnOp>(loc, results[resultIndex]);
355355
}
356-
using MinlocBodyOpGeneratorTy = llvm::function_ref<mlir::Value(
356+
using MinMaxlocBodyOpGeneratorTy = llvm::function_ref<mlir::Value(
357357
fir::FirOpBuilder &, mlir::Location, const mlir::Type &, mlir::Value,
358358
mlir::Value, llvm::SmallVector<mlir::Value, Fortran::common::maxRank> &)>;
359359

360-
static void
361-
genMinlocReductionLoop(fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp,
362-
InitValGeneratorTy initVal,
363-
MinlocBodyOpGeneratorTy genBody, unsigned rank,
364-
mlir::Type elementType, mlir::Location loc, bool hasMask,
365-
mlir::Type maskElemType, mlir::Value resultArr) {
360+
static void genMinMaxlocReductionLoop(
361+
fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp,
362+
InitValGeneratorTy initVal, MinMaxlocBodyOpGeneratorTy genBody,
363+
unsigned rank, mlir::Type elementType, mlir::Location loc, bool hasMask,
364+
mlir::Type maskElemType, mlir::Value resultArr) {
366365

367366
mlir::IndexType idxTy = builder.getIndexType();
368367

@@ -751,21 +750,24 @@ static mlir::FunctionType genRuntimeMinlocType(fir::FirOpBuilder &builder,
751750
{boxRefType, boxType, boxType}, {});
752751
}
753752

754-
static void genRuntimeMinlocBody(fir::FirOpBuilder &builder,
755-
mlir::func::FuncOp &funcOp, unsigned rank,
756-
int maskRank, mlir::Type elementType,
757-
mlir::Type maskElemType,
758-
mlir::Type resultElemTy) {
759-
auto init = [](fir::FirOpBuilder builder, mlir::Location loc,
760-
mlir::Type elementType) {
753+
static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder,
754+
mlir::func::FuncOp &funcOp, bool isMax,
755+
unsigned rank, int maskRank,
756+
mlir::Type elementType,
757+
mlir::Type maskElemType,
758+
mlir::Type resultElemTy) {
759+
auto init = [isMax](fir::FirOpBuilder builder, mlir::Location loc,
760+
mlir::Type elementType) {
761761
if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
762762
const llvm::fltSemantics &sem = ty.getFloatSemantics();
763763
return builder.createRealConstant(
764-
loc, elementType, llvm::APFloat::getLargest(sem, /*Negative=*/false));
764+
loc, elementType, llvm::APFloat::getLargest(sem, /*Negative=*/isMax));
765765
}
766766
unsigned bits = elementType.getIntOrFloatBitWidth();
767-
int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
768-
return builder.createIntegerConstant(loc, elementType, maxInt);
767+
int64_t initValue = (isMax ? llvm::APInt::getSignedMinValue(bits)
768+
: llvm::APInt::getSignedMaxValue(bits))
769+
.getSExtValue();
770+
return builder.createIntegerConstant(loc, elementType, initValue);
769771
};
770772

771773
mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
@@ -797,18 +799,24 @@ static void genRuntimeMinlocBody(fir::FirOpBuilder &builder,
797799
}
798800

799801
auto genBodyOp =
800-
[&rank, &resultArr](
801-
fir::FirOpBuilder builder, mlir::Location loc, mlir::Type elementType,
802-
mlir::Value elem1, mlir::Value elem2,
803-
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices)
802+
[&rank, &resultArr,
803+
isMax](fir::FirOpBuilder builder, mlir::Location loc,
804+
mlir::Type elementType, mlir::Value elem1, mlir::Value elem2,
805+
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices)
804806
-> mlir::Value {
805807
mlir::Value cmp;
806808
if (elementType.isa<mlir::FloatType>()) {
807809
cmp = builder.create<mlir::arith::CmpFOp>(
808-
loc, mlir::arith::CmpFPredicate::OLT, elem1, elem2);
810+
loc,
811+
isMax ? mlir::arith::CmpFPredicate::OGT
812+
: mlir::arith::CmpFPredicate::OLT,
813+
elem1, elem2);
809814
} else if (elementType.isa<mlir::IntegerType>()) {
810815
cmp = builder.create<mlir::arith::CmpIOp>(
811-
loc, mlir::arith::CmpIPredicate::slt, elem1, elem2);
816+
loc,
817+
isMax ? mlir::arith::CmpIPredicate::sgt
818+
: mlir::arith::CmpIPredicate::slt,
819+
elem1, elem2);
812820
} else {
813821
llvm_unreachable("unsupported type");
814822
}
@@ -875,9 +883,8 @@ static void genRuntimeMinlocBody(fir::FirOpBuilder &builder,
875883
// bit of a hack - maskRank is set to -1 for absent mask arg, so don't
876884
// generate high level mask or element by element mask.
877885
bool hasMask = maskRank > 0;
878-
879-
genMinlocReductionLoop(builder, funcOp, init, genBodyOp, rank, elementType,
880-
loc, hasMask, maskElemType, resultArr);
886+
genMinMaxlocReductionLoop(builder, funcOp, init, genBodyOp, rank, elementType,
887+
loc, hasMask, maskElemType, resultArr);
881888
}
882889

883890
/// Generate function type for the simplified version of RTNAME(DotProduct)
@@ -1150,8 +1157,8 @@ void SimplifyIntrinsicsPass::simplifyLogicalDim1Reduction(
11501157
intElementType);
11511158
}
11521159

1153-
void SimplifyIntrinsicsPass::simplifyMinlocReduction(
1154-
fir::CallOp call, const fir::KindMapping &kindMap) {
1160+
void SimplifyIntrinsicsPass::simplifyMinMaxlocReduction(
1161+
fir::CallOp call, const fir::KindMapping &kindMap, bool isMax) {
11551162

11561163
mlir::Operation::operand_range args = call.getArgs();
11571164

@@ -1217,11 +1224,11 @@ void SimplifyIntrinsicsPass::simplifyMinlocReduction(
12171224
auto typeGenerator = [rank](fir::FirOpBuilder &builder) {
12181225
return genRuntimeMinlocType(builder, rank);
12191226
};
1220-
auto bodyGenerator = [rank, maskRank, inputType, logicalElemType,
1221-
outType](fir::FirOpBuilder &builder,
1222-
mlir::func::FuncOp &funcOp) {
1223-
genRuntimeMinlocBody(builder, funcOp, rank, maskRank, inputType,
1224-
logicalElemType, outType);
1227+
auto bodyGenerator = [rank, maskRank, inputType, logicalElemType, outType,
1228+
isMax](fir::FirOpBuilder &builder,
1229+
mlir::func::FuncOp &funcOp) {
1230+
genRuntimeMinMaxlocBody(builder, funcOp, isMax, rank, maskRank, inputType,
1231+
logicalElemType, outType);
12251232
};
12261233

12271234
mlir::func::FuncOp newFunc =
@@ -1367,7 +1374,11 @@ void SimplifyIntrinsicsPass::runOnOperation() {
13671374
return;
13681375
}
13691376
if (funcName.starts_with(RTNAME_STRING(Minloc))) {
1370-
simplifyMinlocReduction(call, kindMap);
1377+
simplifyMinMaxlocReduction(call, kindMap, false);
1378+
return;
1379+
}
1380+
if (funcName.starts_with(RTNAME_STRING(Maxloc))) {
1381+
simplifyMinMaxlocReduction(call, kindMap, true);
13711382
return;
13721383
}
13731384
}

0 commit comments

Comments
 (0)