@@ -99,8 +99,8 @@ class SimplifyIntrinsicsPass
99
99
void simplifyLogicalDim1Reduction (fir::CallOp call,
100
100
const fir::KindMapping &kindMap,
101
101
GenReductionBodyTy genBodyFunc);
102
- void simplifyMinlocReduction (fir::CallOp call,
103
- const fir::KindMapping &kindMap);
102
+ void simplifyMinMaxlocReduction (fir::CallOp call,
103
+ const fir::KindMapping &kindMap, bool isMax );
104
104
void simplifyReductionBody (fir::CallOp call, const fir::KindMapping &kindMap,
105
105
GenReductionBodyTy genBodyFunc,
106
106
fir::FirOpBuilder &builder,
@@ -353,16 +353,15 @@ genReductionLoop(fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp,
353
353
// Return the reduction value from the function.
354
354
builder.create <mlir::func::ReturnOp>(loc, results[resultIndex]);
355
355
}
356
- using MinlocBodyOpGeneratorTy = llvm::function_ref<mlir::Value(
356
+ using MinMaxlocBodyOpGeneratorTy = llvm::function_ref<mlir::Value(
357
357
fir::FirOpBuilder &, mlir::Location, const mlir::Type &, mlir::Value,
358
358
mlir::Value, llvm::SmallVector<mlir::Value, Fortran::common::maxRank> &)>;
359
359
360
- static void
361
- genMinlocReductionLoop (fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp,
362
- InitValGeneratorTy initVal,
363
- MinlocBodyOpGeneratorTy genBody, unsigned rank,
364
- mlir::Type elementType, mlir::Location loc, bool hasMask,
365
- mlir::Type maskElemType, mlir::Value resultArr) {
360
+ static void genMinMaxlocReductionLoop (
361
+ fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp,
362
+ InitValGeneratorTy initVal, MinMaxlocBodyOpGeneratorTy genBody,
363
+ unsigned rank, mlir::Type elementType, mlir::Location loc, bool hasMask,
364
+ mlir::Type maskElemType, mlir::Value resultArr) {
366
365
367
366
mlir::IndexType idxTy = builder.getIndexType ();
368
367
@@ -751,21 +750,24 @@ static mlir::FunctionType genRuntimeMinlocType(fir::FirOpBuilder &builder,
751
750
{boxRefType, boxType, boxType}, {});
752
751
}
753
752
754
- static void genRuntimeMinlocBody (fir::FirOpBuilder &builder,
755
- mlir::func::FuncOp &funcOp, unsigned rank,
756
- int maskRank, mlir::Type elementType,
757
- mlir::Type maskElemType,
758
- mlir::Type resultElemTy) {
759
- auto init = [](fir::FirOpBuilder builder, mlir::Location loc,
760
- mlir::Type elementType) {
753
+ static void genRuntimeMinMaxlocBody (fir::FirOpBuilder &builder,
754
+ mlir::func::FuncOp &funcOp, bool isMax,
755
+ unsigned rank, int maskRank,
756
+ mlir::Type elementType,
757
+ mlir::Type maskElemType,
758
+ mlir::Type resultElemTy) {
759
+ auto init = [isMax](fir::FirOpBuilder builder, mlir::Location loc,
760
+ mlir::Type elementType) {
761
761
if (auto ty = elementType.dyn_cast <mlir::FloatType>()) {
762
762
const llvm::fltSemantics &sem = ty.getFloatSemantics ();
763
763
return builder.createRealConstant (
764
- loc, elementType, llvm::APFloat::getLargest (sem, /* Negative=*/ false ));
764
+ loc, elementType, llvm::APFloat::getLargest (sem, /* Negative=*/ isMax ));
765
765
}
766
766
unsigned bits = elementType.getIntOrFloatBitWidth ();
767
- int64_t maxInt = llvm::APInt::getSignedMaxValue (bits).getSExtValue ();
768
- return builder.createIntegerConstant (loc, elementType, maxInt);
767
+ int64_t initValue = (isMax ? llvm::APInt::getSignedMinValue (bits)
768
+ : llvm::APInt::getSignedMaxValue (bits))
769
+ .getSExtValue ();
770
+ return builder.createIntegerConstant (loc, elementType, initValue);
769
771
};
770
772
771
773
mlir::Location loc = mlir::UnknownLoc::get (builder.getContext ());
@@ -797,18 +799,24 @@ static void genRuntimeMinlocBody(fir::FirOpBuilder &builder,
797
799
}
798
800
799
801
auto genBodyOp =
800
- [&rank, &resultArr](
801
- fir::FirOpBuilder builder, mlir::Location loc, mlir::Type elementType ,
802
- mlir::Value elem1, mlir::Value elem2,
803
- llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices)
802
+ [&rank, &resultArr,
803
+ isMax]( fir::FirOpBuilder builder, mlir::Location loc,
804
+ mlir::Type elementType, mlir::Value elem1, mlir::Value elem2,
805
+ llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices)
804
806
-> mlir::Value {
805
807
mlir::Value cmp;
806
808
if (elementType.isa <mlir::FloatType>()) {
807
809
cmp = builder.create <mlir::arith::CmpFOp>(
808
- loc, mlir::arith::CmpFPredicate::OLT, elem1, elem2);
810
+ loc,
811
+ isMax ? mlir::arith::CmpFPredicate::OGT
812
+ : mlir::arith::CmpFPredicate::OLT,
813
+ elem1, elem2);
809
814
} else if (elementType.isa <mlir::IntegerType>()) {
810
815
cmp = builder.create <mlir::arith::CmpIOp>(
811
- loc, mlir::arith::CmpIPredicate::slt, elem1, elem2);
816
+ loc,
817
+ isMax ? mlir::arith::CmpIPredicate::sgt
818
+ : mlir::arith::CmpIPredicate::slt,
819
+ elem1, elem2);
812
820
} else {
813
821
llvm_unreachable (" unsupported type" );
814
822
}
@@ -875,9 +883,8 @@ static void genRuntimeMinlocBody(fir::FirOpBuilder &builder,
875
883
// bit of a hack - maskRank is set to -1 for absent mask arg, so don't
876
884
// generate high level mask or element by element mask.
877
885
bool hasMask = maskRank > 0 ;
878
-
879
- genMinlocReductionLoop (builder, funcOp, init, genBodyOp, rank, elementType,
880
- loc, hasMask, maskElemType, resultArr);
886
+ genMinMaxlocReductionLoop (builder, funcOp, init, genBodyOp, rank, elementType,
887
+ loc, hasMask, maskElemType, resultArr);
881
888
}
882
889
883
890
// / Generate function type for the simplified version of RTNAME(DotProduct)
@@ -1150,8 +1157,8 @@ void SimplifyIntrinsicsPass::simplifyLogicalDim1Reduction(
1150
1157
intElementType);
1151
1158
}
1152
1159
1153
- void SimplifyIntrinsicsPass::simplifyMinlocReduction (
1154
- fir::CallOp call, const fir::KindMapping &kindMap) {
1160
+ void SimplifyIntrinsicsPass::simplifyMinMaxlocReduction (
1161
+ fir::CallOp call, const fir::KindMapping &kindMap, bool isMax ) {
1155
1162
1156
1163
mlir::Operation::operand_range args = call.getArgs ();
1157
1164
@@ -1217,11 +1224,11 @@ void SimplifyIntrinsicsPass::simplifyMinlocReduction(
1217
1224
auto typeGenerator = [rank](fir::FirOpBuilder &builder) {
1218
1225
return genRuntimeMinlocType (builder, rank);
1219
1226
};
1220
- auto bodyGenerator = [rank, maskRank, inputType, logicalElemType,
1221
- outType ](fir::FirOpBuilder &builder,
1222
- mlir::func::FuncOp &funcOp) {
1223
- genRuntimeMinlocBody (builder, funcOp, rank, maskRank, inputType,
1224
- logicalElemType, outType);
1227
+ auto bodyGenerator = [rank, maskRank, inputType, logicalElemType, outType,
1228
+ isMax ](fir::FirOpBuilder &builder,
1229
+ mlir::func::FuncOp &funcOp) {
1230
+ genRuntimeMinMaxlocBody (builder, funcOp, isMax , rank, maskRank, inputType,
1231
+ logicalElemType, outType);
1225
1232
};
1226
1233
1227
1234
mlir::func::FuncOp newFunc =
@@ -1367,7 +1374,11 @@ void SimplifyIntrinsicsPass::runOnOperation() {
1367
1374
return ;
1368
1375
}
1369
1376
if (funcName.starts_with (RTNAME_STRING (Minloc))) {
1370
- simplifyMinlocReduction (call, kindMap);
1377
+ simplifyMinMaxlocReduction (call, kindMap, false );
1378
+ return ;
1379
+ }
1380
+ if (funcName.starts_with (RTNAME_STRING (Maxloc))) {
1381
+ simplifyMinMaxlocReduction (call, kindMap, true );
1371
1382
return ;
1372
1383
}
1373
1384
}
0 commit comments