Skip to content

[RFC][mlir] Conditional support for fast-math attributes. #125620

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions flang/include/flang/Optimizer/Dialect/FIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2494,6 +2494,21 @@ def fir_CallOp : fir_Op<"call",
llvm::cast<mlir::SymbolRefAttr>(callee));
setOperand(0, llvm::cast<mlir::Value>(callee));
}

/// Always allow FastMathFlags for fir.call's.
/// It is required to be able to propagate the call site's
/// FastMathFlags to the operations resulting from inlining
/// (if any) of a fir.call (see SimplifyIntrinsics pass).
/// We could analyze the arguments' data types to see if there are
/// any floating point types, but this is unreliable. For example,
/// the runtime calls mostly take !fir.box<none> arguments,
/// and tracking them to the definitions may be not easy.
/// TODO: this should be restricted to fir.runtime calls,
/// because FastMathFlags for the user calls must come
/// from the function body, not the call site.
bool isArithFastMathApplicable() {
return true;
}
}];
}

Expand Down Expand Up @@ -2672,6 +2687,15 @@ def fir_CmpcOp : fir_Op<"cmpc",
}

static mlir::arith::CmpFPredicate getPredicateByName(llvm::StringRef name);

/// Always allow FastMathFlags on fir.cmpc.
/// It does not produce a floating point result, but
/// LLVM is currently relying on fast-math flags attached
/// to floating point comparison.
/// This can be removed whenever LLVM stops doing it.
bool isArithFastMathApplicable() {
return true;
}
}];
}

Expand Down Expand Up @@ -2735,6 +2759,8 @@ def fir_ConvertOp : fir_SimpleOneResultOp<"convert", [NoMemoryEffect]> {
static bool isPointerCompatible(mlir::Type ty);
static bool canBeConverted(mlir::Type inType, mlir::Type outType);
static bool areVectorsCompatible(mlir::Type inTy, mlir::Type outTy);

// FIXME: fir.convert should support ArithFastMathInterface.
}];
let hasCanonicalizer = 1;
}
Expand Down
5 changes: 5 additions & 0 deletions flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ bool mayHaveAllocatableComponent(mlir::Type ty);
/// Scalar integer or a sequence of integers (via boxed array or expr).
bool isFortranIntegerScalarOrArrayObject(mlir::Type type);

/// Return true iff FastMathFlagsAttr is applicable
/// to the given HLFIR dialect operation that supports
/// ArithFastMathInterface.
bool isArithFastMathApplicable(mlir::Operation *op);

} // namespace hlfir

#endif // FORTRAN_OPTIMIZER_HLFIR_HLFIRDIALECT_H
54 changes: 54 additions & 0 deletions flang/include/flang/Optimizer/HLFIR/HLFIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,12 @@ def hlfir_MaxvalOp : hlfir_Op<"maxval", [AttrSizedOperandSegments,
}];

let hasVerifier = 1;

let extraClassDeclaration = [{
bool isArithFastMathApplicable() {
return hlfir::isArithFastMathApplicable(getOperation());
}
}];
}

def hlfir_MinvalOp : hlfir_Op<"minval", [AttrSizedOperandSegments,
Expand Down Expand Up @@ -461,6 +467,12 @@ def hlfir_MinvalOp : hlfir_Op<"minval", [AttrSizedOperandSegments,
}];

let hasVerifier = 1;

let extraClassDeclaration = [{
bool isArithFastMathApplicable() {
return hlfir::isArithFastMathApplicable(getOperation());
}
}];
}

def hlfir_MinlocOp : hlfir_Op<"minloc", [AttrSizedOperandSegments,
Expand All @@ -487,6 +499,12 @@ def hlfir_MinlocOp : hlfir_Op<"minloc", [AttrSizedOperandSegments,
}];

let hasVerifier = 1;

let extraClassDeclaration = [{
bool isArithFastMathApplicable() {
return hlfir::isArithFastMathApplicable(getOperation());
}
}];
}

def hlfir_MaxlocOp : hlfir_Op<"maxloc", [AttrSizedOperandSegments,
Expand All @@ -513,6 +531,12 @@ def hlfir_MaxlocOp : hlfir_Op<"maxloc", [AttrSizedOperandSegments,
}];

let hasVerifier = 1;

let extraClassDeclaration = [{
bool isArithFastMathApplicable() {
return hlfir::isArithFastMathApplicable(getOperation());
}
}];
}

def hlfir_ProductOp : hlfir_Op<"product", [AttrSizedOperandSegments,
Expand All @@ -539,6 +563,12 @@ def hlfir_ProductOp : hlfir_Op<"product", [AttrSizedOperandSegments,
}];

let hasVerifier = 1;

let extraClassDeclaration = [{
bool isArithFastMathApplicable() {
return hlfir::isArithFastMathApplicable(getOperation());
}
}];
}

def hlfir_SetLengthOp : hlfir_Op<"set_length",
Expand Down Expand Up @@ -604,6 +634,12 @@ def hlfir_SumOp : hlfir_Op<"sum", [AttrSizedOperandSegments,
}];

let hasVerifier = 1;

let extraClassDeclaration = [{
bool isArithFastMathApplicable() {
return hlfir::isArithFastMathApplicable(getOperation());
}
}];
}

def hlfir_DotProductOp : hlfir_Op<"dot_product",
Expand All @@ -628,6 +664,12 @@ def hlfir_DotProductOp : hlfir_Op<"dot_product",
}];

let hasVerifier = 1;

let extraClassDeclaration = [{
bool isArithFastMathApplicable() {
return hlfir::isArithFastMathApplicable(getOperation());
}
}];
}

def hlfir_MatmulOp : hlfir_Op<"matmul",
Expand Down Expand Up @@ -655,6 +697,12 @@ def hlfir_MatmulOp : hlfir_Op<"matmul",
let hasCanonicalizeMethod = 1;

let hasVerifier = 1;

let extraClassDeclaration = [{
bool isArithFastMathApplicable() {
return hlfir::isArithFastMathApplicable(getOperation());
}
}];
}

def hlfir_TransposeOp : hlfir_Op<"transpose",
Expand Down Expand Up @@ -697,6 +745,12 @@ def hlfir_MatmulTransposeOp : hlfir_Op<"matmul_transpose",
}];

let hasVerifier = 1;

let extraClassDeclaration = [{
bool isArithFastMathApplicable() {
return hlfir::isArithFastMathApplicable(getOperation());
}
}];
}

def hlfir_CShiftOp
Expand Down
4 changes: 1 addition & 3 deletions flang/lib/Optimizer/Builder/FIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -786,9 +786,7 @@ mlir::Value fir::FirOpBuilder::genAbsentOp(mlir::Location loc,

void fir::FirOpBuilder::setCommonAttributes(mlir::Operation *op) const {
auto fmi = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(*op);
if (fmi) {
// TODO: use fmi.setFastMathFlagsAttr() after D137114 is merged.
// For now set the attribute by the name.
if (fmi && fmi.isArithFastMathApplicable()) {
llvm::StringRef arithFMFAttrName = fmi.getFastMathAttrName();
if (fastMathFlags != mlir::arith::FastMathFlags::none)
op->setAttr(arithFMFAttrName, mlir::arith::FastMathFlagsAttr::get(
Expand Down
9 changes: 7 additions & 2 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -589,10 +589,15 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
// Convert arith::FastMathFlagsAttr to LLVM::FastMathFlagsAttr.
mlir::arith::AttrConvertFastMathToLLVM<fir::CallOp, mlir::LLVM::CallOp>
attrConvert(call);
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
call, resultTys, adaptor.getOperands(),
auto llvmCall = rewriter.create<mlir::LLVM::CallOp>(
call.getLoc(), resultTys, adaptor.getOperands(),
addLLVMOpBundleAttrs(rewriter, attrConvert.getAttrs(),
adaptor.getOperands().size()));
auto fmi =
mlir::cast<mlir::LLVM::FastmathFlagsInterface>(llvmCall.getOperation());
if (!fmi.isFastmathApplicable())
llvmCall.setFastmathFlags(mlir::LLVM::FastmathFlags::none);
rewriter.replaceOp(call, llvmCall);
return mlir::success();
}
};
Expand Down
17 changes: 17 additions & 0 deletions flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,20 @@ bool hlfir::isFortranIntegerScalarOrArrayObject(mlir::Type type) {
mlir::Type elementType = getFortranElementType(unwrappedType);
return mlir::isa<mlir::IntegerType>(elementType);
}

bool hlfir::isArithFastMathApplicable(mlir::Operation *op) {
if (llvm::any_of(op->getResults(), [](mlir::Value v) {
mlir::Type elementType = getFortranElementType(v.getType());
return mlir::arith::ArithFastMathInterface::isCompatibleType(
elementType);
}))
return true;
if (llvm::any_of(op->getOperands(), [](mlir::Value v) {
mlir::Type elementType = getFortranElementType(v.getType());
return mlir::arith::ArithFastMathInterface::isCompatibleType(
elementType);
}))
return true;

return false;
}
2 changes: 1 addition & 1 deletion flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<i1, dense<8> : ve
%45 = llvm.call @_FortranACUFDataTransferPtrPtr(%14, %25, %2, %11, %13, %5) : (!llvm.ptr, !llvm.ptr, i64, i32, !llvm.ptr, i32) -> !llvm.struct<()>
gpu.launch_func @cuda_device_mod::@_QMmod1Psub1 blocks in (%7, %7, %7) threads in (%12, %7, %7) : i64 dynamic_shared_memory_size %11 args(%14 : !llvm.ptr)
%46 = llvm.call @_FortranACUFDataTransferPtrPtr(%25, %14, %2, %10, %13, %4) : (!llvm.ptr, !llvm.ptr, i64, i32, !llvm.ptr, i32) -> !llvm.struct<()>
%47 = llvm.call @_FortranAioBeginExternalListOutput(%9, %13, %8) {fastmathFlags = #llvm.fastmath<contract>} : (i32, !llvm.ptr, i32) -> !llvm.ptr
%47 = llvm.call @_FortranAioBeginExternalListOutput(%9, %13, %8) : (i32, !llvm.ptr, i32) -> !llvm.ptr
%48 = llvm.mlir.constant(9 : i32) : i32
%49 = llvm.mlir.zero : !llvm.ptr
%50 = llvm.getelementptr %49[1] : (!llvm.ptr) -> !llvm.ptr, i32
Expand Down
6 changes: 3 additions & 3 deletions flang/test/Fir/tbaa.fir
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ module {
// CHECK: %[[VAL_6:.*]] = llvm.mlir.constant(-1 : i32) : i32
// CHECK: %[[VAL_7:.*]] = llvm.mlir.addressof @_QFEx : !llvm.ptr
// CHECK: %[[VAL_8:.*]] = llvm.mlir.addressof @_QQclX2E2F64756D6D792E66393000 : !llvm.ptr
// CHECK: %[[VAL_10:.*]] = llvm.call @_FortranAioBeginExternalListOutput(%[[VAL_6]], %[[VAL_8]], %[[VAL_5]]) {fastmathFlags = #llvm.fastmath<contract>} : (i32, !llvm.ptr, i32) -> !llvm.ptr
// CHECK: %[[VAL_10:.*]] = llvm.call @_FortranAioBeginExternalListOutput(%[[VAL_6]], %[[VAL_8]], %[[VAL_5]]) : (i32, !llvm.ptr, i32) -> !llvm.ptr
// CHECK: %[[VAL_11:.*]] = llvm.mlir.constant(64 : i32) : i32
// CHECK: "llvm.intr.memcpy"(%[[VAL_3]], %[[VAL_7]], %[[VAL_11]]) <{isVolatile = false, tbaa = [#[[$BOXT]]]}>
// CHECK: %[[VAL_12:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, %[[VAL_4]], 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
Expand Down Expand Up @@ -188,8 +188,8 @@ module {
// CHECK: %[[VAL_59:.*]] = llvm.insertvalue %[[VAL_50]], %[[VAL_58]][7, 0, 2] : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
// CHECK: %[[VAL_61:.*]] = llvm.insertvalue %[[VAL_52]], %[[VAL_59]][0] : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
// CHECK: llvm.store %[[VAL_61]], %[[VAL_1]] {tbaa = [#[[$BOXT]]]} : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>, !llvm.ptr
// CHECK: %[[VAL_63:.*]] = llvm.call @_FortranAioOutputDescriptor(%[[VAL_10]], %[[VAL_1]]) {fastmathFlags = #llvm.fastmath<contract>} : (!llvm.ptr, !llvm.ptr) -> i1
// CHECK: %[[VAL_64:.*]] = llvm.call @_FortranAioEndIoStatement(%[[VAL_10]]) {fastmathFlags = #llvm.fastmath<contract>} : (!llvm.ptr) -> i32
// CHECK: %[[VAL_63:.*]] = llvm.call @_FortranAioOutputDescriptor(%[[VAL_10]], %[[VAL_1]]) : (!llvm.ptr, !llvm.ptr) -> i1
// CHECK: %[[VAL_64:.*]] = llvm.call @_FortranAioEndIoStatement(%[[VAL_10]]) : (!llvm.ptr) -> i32
// CHECK: llvm.return
// CHECK: }
// CHECK: llvm.func @_FortranAioBeginExternalListOutput(i32, !llvm.ptr, i32) -> !llvm.ptr attributes {fir.io, fir.runtime, sym_visibility = "private"}
Expand Down
36 changes: 31 additions & 5 deletions flang/test/HLFIR/dot_product-lowering.fir
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ func.func @_QPdot_product1(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name =
%0:2 = hlfir.declare %arg0 {uniq_name = "_QFdot_product1Elhs"} : (!fir.box<!fir.array<?xi32>>) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
%1:2 = hlfir.declare %arg2 {uniq_name = "_QFdot_product1Eres"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
%2:2 = hlfir.declare %arg1 {uniq_name = "_QFdot_product1Erhs"} : (!fir.box<!fir.array<?xi32>>) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
%3 = hlfir.dot_product %0#0 %2#0 {fastmath = #arith.fastmath<contract>} : (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>) -> i32
%3 = hlfir.dot_product %0#0 %2#0 {fastmath = #arith.fastmath<none>} : (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>) -> i32
hlfir.assign %3 to %1#0 : i32, !fir.ref<i32>
return
}
Expand All @@ -29,7 +29,7 @@ func.func @_QPdot_product2(%arg0: !fir.box<!fir.array<?x!fir.logical<4>>> {fir.b
%0:2 = hlfir.declare %arg0 {uniq_name = "_QFdot_product2Elhs"} : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> (!fir.box<!fir.array<?x!fir.logical<4>>>, !fir.box<!fir.array<?x!fir.logical<4>>>)
%1:2 = hlfir.declare %arg2 {uniq_name = "_QFdot_product2Eres"} : (!fir.ref<!fir.logical<4>>) -> (!fir.ref<!fir.logical<4>>, !fir.ref<!fir.logical<4>>)
%2:2 = hlfir.declare %arg1 {uniq_name = "_QFdot_product2Erhs"} : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> (!fir.box<!fir.array<?x!fir.logical<4>>>, !fir.box<!fir.array<?x!fir.logical<4>>>)
%3 = hlfir.dot_product %0#0 %2#0 {fastmath = #arith.fastmath<contract>} : (!fir.box<!fir.array<?x!fir.logical<4>>>, !fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.logical<4>
%3 = hlfir.dot_product %0#0 %2#0 {fastmath = #arith.fastmath<none>} : (!fir.box<!fir.array<?x!fir.logical<4>>>, !fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.logical<4>
hlfir.assign %3 to %1#0 : !fir.logical<4>, !fir.ref<!fir.logical<4>>
return
}
Expand Down Expand Up @@ -58,7 +58,7 @@ func.func @_QPdot_product2(%arg0: !fir.box<!fir.array<?x!fir.logical<4>>> {fir.b
%c5_0 = arith.constant 5 : index
%3 = fir.shape %c5_0 : (index) -> !fir.shape<1>
%4:2 = hlfir.declare %arg1(%3) {uniq_name = "_QFdot_product3Erhs"} : (!fir.ref<!fir.array<5xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<5xi32>>, !fir.ref<!fir.array<5xi32>>)
%5 = hlfir.dot_product %1#0 %4#0 {fastmath = #arith.fastmath<contract>} : (!fir.ref<!fir.array<5xi32>>, !fir.ref<!fir.array<5xi32>>) -> i32
%5 = hlfir.dot_product %1#0 %4#0 {fastmath = #arith.fastmath<none>} : (!fir.ref<!fir.array<5xi32>>, !fir.ref<!fir.array<5xi32>>) -> i32
hlfir.assign %5 to %2#0 : i32, !fir.ref<i32>
return
}
Expand Down Expand Up @@ -86,7 +86,7 @@ func.func @_QPdot_product4(%arg0: !fir.box<!fir.array<?x!fir.logical<4>>> {fir.b
%temp = fir.alloca !fir.logical<4>
%0:2 = hlfir.declare %arg0 {uniq_name = "_QFdot_product2Elhs"} : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> (!fir.box<!fir.array<?x!fir.logical<4>>>, !fir.box<!fir.array<?x!fir.logical<4>>>)
%1:2 = hlfir.declare %arg1 {uniq_name = "_QFdot_product2Erhs"} : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> (!fir.box<!fir.array<?x!fir.logical<4>>>, !fir.box<!fir.array<?x!fir.logical<4>>>)
%2 = hlfir.dot_product %0#0 %1#0 {fastmath = #arith.fastmath<contract>} : (!fir.box<!fir.array<?x!fir.logical<4>>>, !fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.logical<4>
%2 = hlfir.dot_product %0#0 %1#0 {fastmath = #arith.fastmath<none>} : (!fir.box<!fir.array<?x!fir.logical<4>>>, !fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.logical<4>
fir.store %2 to %temp : !fir.ref<!fir.logical<4>>
return
}
Expand All @@ -98,8 +98,34 @@ func.func @_QPdot_product4(%arg0: !fir.box<!fir.array<?x!fir.logical<4>>> {fir.b
// CHECK: %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_1]] {uniq_name = "_QFdot_product2Erhs"} : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> (!fir.box<!fir.array<?x!fir.logical<4>>>, !fir.box<!fir.array<?x!fir.logical<4>>>)
// CHECK: %[[VAL_9:.*]] = fir.convert %[[VAL_3]]#1 : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.box<none>
// CHECK: %[[VAL_10:.*]] = fir.convert %[[VAL_4]]#1 : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.box<none>
// CHECK: %[[VAL_12:.*]] = fir.call @_FortranADotProductLogical(%[[VAL_9]], %[[VAL_10]], %{{.*}}, %{{.*}}) fastmath<contract> : (!fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> i1
// CHECK: %[[VAL_12:.*]] = fir.call @_FortranADotProductLogical(%[[VAL_9]], %[[VAL_10]], %{{.*}}, %{{.*}}) : (!fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> i1
// CHECK: %[[VAL_13:.*]] = fir.convert %[[VAL_12]] : (i1) -> !fir.logical<4>
// CHECK: fir.store %[[VAL_13]] to %[[VAL_2]] : !fir.ref<!fir.logical<4>>
// CHECK: return
// CHECK: }

// floating point dot_product
func.func @_QPdot_product5(%arg0: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "lhs"}, %arg1: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "rhs"}, %arg2: !fir.ref<f32> {fir.bindc_name = "res"}) {
%0:2 = hlfir.declare %arg0 {uniq_name = "_QFdot_product1Elhs"} : (!fir.box<!fir.array<?xf32>>) -> (!fir.box<!fir.array<?xf32>>, !fir.box<!fir.array<?xf32>>)
%1:2 = hlfir.declare %arg2 {uniq_name = "_QFdot_product1Eres"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
%2:2 = hlfir.declare %arg1 {uniq_name = "_QFdot_product1Erhs"} : (!fir.box<!fir.array<?xf32>>) -> (!fir.box<!fir.array<?xf32>>, !fir.box<!fir.array<?xf32>>)
%3 = hlfir.dot_product %0#0 %2#0 {fastmath = #arith.fastmath<contract>} : (!fir.box<!fir.array<?xf32>>, !fir.box<!fir.array<?xf32>>) -> f32
hlfir.assign %3 to %1#0 : f32, !fir.ref<f32>
return
}
// CHECK-LABEL: func.func @_QPdot_product5(
// CHECK: %[[ARG0:.*]]: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "lhs"}
// CHECK: %[[ARG1:.*]]: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "rhs"}
// CHECK: %[[ARG2:.*]]: !fir.ref<f32> {fir.bindc_name = "res"}
// CHECK-DAG: %[[LHS_VAR:.*]]:2 = hlfir.declare %[[ARG0]]
// CHECK-DAG: %[[RHS_VAR:.*]]:2 = hlfir.declare %[[ARG1]]
// CHECK-DAG: %[[RES_VAR:.*]]:2 = hlfir.declare %[[ARG2]]

// CHECK-DAG: %[[LHS_ARG:.*]] = fir.convert %[[LHS_VAR]]#1 : (!fir.box<!fir.array<?xf32>>) -> !fir.box<none>
// CHECK-DAG: %[[RHS_ARG:.*]] = fir.convert %[[RHS_VAR]]#1 : (!fir.box<!fir.array<?xf32>>) -> !fir.box<none>

// CHECK: %[[RES_VAL:.*]] = fir.call @_FortranADotProductReal4(%[[LHS_ARG]], %[[RHS_ARG]], %[[LOC_STR:.*]], %[[LOC_N:.*]]) fastmath<contract>
// CHECK-NEXT: hlfir.assign %[[RES_VAL]] to %[[RES_VAR]]#0 : f32, !fir.ref<f32>
// CHECK-NEXT: return
// CHECK-NEXT: }

4 changes: 2 additions & 2 deletions flang/test/HLFIR/matmul-lowering.fir
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ func.func @_QPmatmul1(%arg0: !fir.box<!fir.array<?x?xi32>> {fir.bindc_name = "lh
%0:2 = hlfir.declare %arg0 {uniq_name = "_QFmatmul1Elhs"} : (!fir.box<!fir.array<?x?xi32>>) -> (!fir.box<!fir.array<?x?xi32>>, !fir.box<!fir.array<?x?xi32>>)
%1:2 = hlfir.declare %arg2 {uniq_name = "_QFmatmul1Eres"} : (!fir.box<!fir.array<?x?xi32>>) -> (!fir.box<!fir.array<?x?xi32>>, !fir.box<!fir.array<?x?xi32>>)
%2:2 = hlfir.declare %arg1 {uniq_name = "_QFmatmul1Erhs"} : (!fir.box<!fir.array<?x?xi32>>) -> (!fir.box<!fir.array<?x?xi32>>, !fir.box<!fir.array<?x?xi32>>)
%3 = hlfir.matmul %0#0 %2#0 {fastmath = #arith.fastmath<contract>} : (!fir.box<!fir.array<?x?xi32>>, !fir.box<!fir.array<?x?xi32>>) -> !hlfir.expr<?x?xi32>
%3 = hlfir.matmul %0#0 %2#0 {fastmath = #arith.fastmath<none>} : (!fir.box<!fir.array<?x?xi32>>, !fir.box<!fir.array<?x?xi32>>) -> !hlfir.expr<?x?xi32>
hlfir.assign %3 to %1#0 : !hlfir.expr<?x?xi32>, !fir.box<!fir.array<?x?xi32>>
hlfir.destroy %3 : !hlfir.expr<?x?xi32>
return
Expand All @@ -29,7 +29,7 @@ func.func @_QPmatmul1(%arg0: !fir.box<!fir.array<?x?xi32>> {fir.bindc_name = "lh
// CHECK: %[[RET_ARG:.*]] = fir.convert %[[RET_BOX]] : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> !fir.ref<!fir.box<none>>
// CHECK-DAG: %[[LHS_ARG:.*]] = fir.convert %[[LHS_VAR]]#1 : (!fir.box<!fir.array<?x?xi32>>) -> !fir.box<none>
// CHECK-DAG: %[[RHS_ARG:.*]] = fir.convert %[[RHS_VAR]]#1 : (!fir.box<!fir.array<?x?xi32>>) -> !fir.box<none>
// CHECK: fir.call @_FortranAMatmulInteger4Integer4(%[[RET_ARG]], %[[LHS_ARG]], %[[RHS_ARG]], %[[LOC_STR:.*]], %[[LOC_N:.*]]) fastmath<contract>
// CHECK: fir.call @_FortranAMatmulInteger4Integer4(%[[RET_ARG]], %[[LHS_ARG]], %[[RHS_ARG]], %[[LOC_STR:.*]], %[[LOC_N:.*]])

// CHECK: %[[RET:.*]] = fir.load %[[RET_BOX]]
// CHECK-DAG: %[[BOX_DIMS:.*]]:3 = fir.box_dims %[[RET]]
Expand Down
2 changes: 1 addition & 1 deletion flang/test/HLFIR/maxloc-elemental.fir
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func.func @_QPtest(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "array"}
%11 = fir.convert %10 : (i1) -> !fir.logical<4>
hlfir.yield_element %11 : !fir.logical<4>
}
%7 = hlfir.maxloc %0#0 mask %6 {fastmath = #arith.fastmath<contract>} : (!fir.box<!fir.array<?xi32>>, !hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<1xi32>
%7 = hlfir.maxloc %0#0 mask %6 {fastmath = #arith.fastmath<none>} : (!fir.box<!fir.array<?xi32>>, !hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<1xi32>
hlfir.assign %7 to %1#0 : !hlfir.expr<1xi32>, !fir.box<!fir.array<?xi32>>
hlfir.destroy %7 : !hlfir.expr<1xi32>
hlfir.destroy %6 : !hlfir.expr<?x!fir.logical<4>>
Expand Down
Loading