Skip to content

Commit 0881304

Browse files
committed
Fix complex abs with nnan/ninf.
The current logic tests for inf/inf and 0/0 inputs using a NaN check. This doesn't work with fastmath flags. With nnan and ninf, we can just check for a 0 maximum. With only nnan, we have to check for both cases separately.
1 parent bfa8150 commit 0881304

File tree

2 files changed

+94
-23
lines changed

2 files changed

+94
-23
lines changed

mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,14 @@ Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
4343
Value ratio = b.create<arith::DivFOp>(min, max, fmf);
4444
Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmf);
4545
Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmf);
46-
Value result;
4746

4847
if (fn == AbsFn::rsqrt) {
4948
ratioSqPlusOne = b.create<math::RsqrtOp>(ratioSqPlusOne, fmf);
5049
min = b.create<math::RsqrtOp>(min, fmf);
5150
max = b.create<math::RsqrtOp>(max, fmf);
5251
}
5352

53+
Value result;
5454
if (fn == AbsFn::sqrt) {
5555
Value quarter = b.create<arith::ConstantOp>(
5656
real.getType(), b.getFloatAttr(real.getType(), 0.25));
@@ -63,6 +63,40 @@ Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
6363
result = b.create<arith::MulFOp>(max, sqrt, fmf);
6464
}
6565

66+
if (arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
67+
arith::FastMathFlags::ninf)) {
68+
// We only need to handle the 0/0 case here.
69+
Value zero = b.create<arith::ConstantOp>(
70+
real.getType(), b.getFloatAttr(real.getType(), 0.0));
71+
Value maxIsZero =
72+
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, max, zero);
73+
return b.create<arith::SelectOp>(maxIsZero, min, result);
74+
}
75+
76+
if (arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan)) {
77+
Value zero = b.create<arith::ConstantOp>(
78+
real.getType(), b.getFloatAttr(real.getType(), 0.0));
79+
Value inf = b.create<arith::ConstantOp>(
80+
real.getType(),
81+
b.getFloatAttr(
82+
real.getType(),
83+
APFloat::getInf(
84+
cast<FloatType>(real.getType()).getFloatSemantics())));
85+
Value maxIsInf =
86+
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, max, inf, fmf);
87+
Value minIsInf =
88+
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, min, inf, fmf);
89+
// We need to handle inf/inf and 0/0 specially. The former is inf, the
90+
// latter is 0. Both produce poison in the division.
91+
Value resultIsInf = b.create<arith::AndIOp>(maxIsInf, minIsInf);
92+
Value resultIsZero =
93+
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, max, zero);
94+
result = b.create<arith::SelectOp>(resultIsInf, inf, result);
95+
result = b.create<arith::SelectOp>(resultIsZero, zero, result);
96+
return result;
97+
}
98+
99+
// This handles both inf/inf and 0/0.
66100
Value isNaN =
67101
b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result, result, fmf);
68102
return b.create<arith::SelectOp>(isNaN, min, result);

mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir

Lines changed: 59 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// RUN: mlir-opt %s --convert-complex-to-standard --split-input-file |\
2-
// RUN: FileCheck %s --dump-input=always
2+
// RUN: FileCheck %s
33

44
// CHECK-LABEL: func @complex_abs
55
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
@@ -709,9 +709,10 @@ func.func @complex_sqrt_nnan_ninf(%arg: complex<f32>) -> complex<f32> {
709709
// CHECK: %[[QUARTER:.*]] = arith.constant 2.500000e-01 : f32
710710
// CHECK: %[[SQRT_MAX:.*]] = math.sqrt %[[MAX]] fastmath<nnan,ninf> : f32
711711
// CHECK: %[[POW:.*]] = math.powf %[[RATIO_SQ_PLUS_ONE]], %[[QUARTER]] fastmath<nnan,ninf> : f32
712-
// CHECK: %[[SQRT_ABS_OR_NAN:.*]] = arith.mulf %[[SQRT_MAX]], %[[POW]] fastmath<nnan,ninf> : f32
713-
// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[SQRT_ABS_OR_NAN]], %[[SQRT_ABS_OR_NAN]] fastmath<nnan,ninf> : f32
714-
// CHECK: %[[SQRT_ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[SQRT_ABS_OR_NAN]] : f32
712+
// CHECK: %[[SQRT_ABS_OR_POISON:.*]] = arith.mulf %[[SQRT_MAX]], %[[POW]] fastmath<nnan,ninf> : f32
713+
// CHECK: %[[ZERO_2:.*]] = arith.constant 0.000000e+00 : f32
714+
// CHECK: %[[IS_POISON:.*]] = arith.cmpf oeq, %[[MAX]], %[[ZERO_2]] fastmath<nnan,ninf> : f32
715+
// CHECK: %[[SQRT_ABS:.*]] = arith.select %[[IS_POISON]], %[[MIN]], %[[SQRT_ABS_OR_POISON]] : f32
715716
// CHECK: %[[ARGARG:.*]] = math.atan2 %[[IM]], %[[RE]] fastmath<nnan,ninf> : f32
716717
// CHECK: %[[SQRTARG:.*]] = arith.mulf %[[ARGARG]], %[[HALF]] fastmath<nnan,ninf> : f32
717718
// CHECK: %[[COS:.*]] = math.cos %[[SQRTARG]] fastmath<nnan,ninf> : f32
@@ -823,9 +824,15 @@ func.func @complex_abs_with_fmf(%arg: complex<f32>) -> f32 {
823824
// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] fastmath<nnan,contract> : f32
824825
// CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
825826
// CHECK: %[[SQRT:.*]] = math.sqrt %[[RATIO_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
826-
// CHECK: %[[ABS_OR_NAN:.*]] = arith.mulf %[[MAX]], %[[SQRT]] fastmath<nnan,contract> : f32
827-
// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[ABS_OR_NAN]], %[[ABS_OR_NAN]] fastmath<nnan,contract> : f32
828-
// CHECK: %[[ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[ABS_OR_NAN]] : f32
827+
// CHECK: %[[ABS_OR_POISON:.*]] = arith.mulf %[[MAX]], %[[SQRT]] fastmath<nnan,contract> : f32
828+
// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00
829+
// CHECK: %[[INF:.*]] = arith.constant 0x7F800000
830+
// CHECK: %[[MAX_IS_INF:.*]] = arith.cmpf oeq, %[[MAX]], %[[INF]] fastmath<nnan,contract> : f32
831+
// CHECK: %[[MIN_IS_INF:.*]] = arith.cmpf oeq, %[[MIN]], %[[INF]] fastmath<nnan,contract> : f32
832+
// CHECK: %[[RESULT_IS_INF:.*]] = arith.andi %[[MAX_IS_INF]], %[[MIN_IS_INF]]
833+
// CHECK: %[[RESULT_IS_ZERO:.*]] = arith.cmpf oeq, %[[MAX]], %[[ZERO]] fastmath<nnan,contract> : f32
834+
// CHECK: %[[ABS_OR_INF:.*]] = arith.select %[[RESULT_IS_INF]], %[[INF]], %[[ABS_OR_POISON]] : f32
835+
// CHECK: %[[ABS:.*]] = arith.select %[[RESULT_IS_ZERO]], %[[ZERO]], %[[ABS_OR_INF]] : f32
829836
// CHECK: return %[[ABS]] : f32
830837

831838
// -----
@@ -922,9 +929,15 @@ func.func @complex_log_with_fmf(%arg: complex<f32>) -> complex<f32> {
922929
// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] fastmath<nnan,contract> : f32
923930
// CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
924931
// CHECK: %[[SQRT:.*]] = math.sqrt %[[RATIO_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
925-
// CHECK: %[[ABS_OR_NAN:.*]] = arith.mulf %[[MAX]], %[[SQRT]] fastmath<nnan,contract> : f32
926-
// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[ABS_OR_NAN]], %[[ABS_OR_NAN]] fastmath<nnan,contract> : f32
927-
// CHECK: %[[ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[ABS_OR_NAN]] : f32
932+
// CHECK: %[[ABS_OR_POISON:.*]] = arith.mulf %[[MAX]], %[[SQRT]] fastmath<nnan,contract> : f32
933+
// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00
934+
// CHECK: %[[INF:.*]] = arith.constant 0x7F800000
935+
// CHECK: %[[MAX_IS_INF:.*]] = arith.cmpf oeq, %[[MAX]], %[[INF]] fastmath<nnan,contract> : f32
936+
// CHECK: %[[MIN_IS_INF:.*]] = arith.cmpf oeq, %[[MIN]], %[[INF]] fastmath<nnan,contract> : f32
937+
// CHECK: %[[RESULT_IS_INF:.*]] = arith.andi %[[MAX_IS_INF]], %[[MIN_IS_INF]]
938+
// CHECK: %[[RESULT_IS_ZERO:.*]] = arith.cmpf oeq, %[[MAX]], %[[ZERO]] fastmath<nnan,contract> : f32
939+
// CHECK: %[[ABS_OR_INF:.*]] = arith.select %[[RESULT_IS_INF]], %[[INF]], %[[ABS_OR_POISON]] : f32
940+
// CHECK: %[[ABS:.*]] = arith.select %[[RESULT_IS_ZERO]], %[[ZERO]], %[[ABS_OR_INF]] : f32
928941
// CHECK: %[[RESULT_REAL:.*]] = math.log %[[ABS]] fastmath<nnan,contract> : f32
929942
// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
930943
// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
@@ -1304,9 +1317,15 @@ func.func @complex_atan2_with_fmf(%lhs: complex<f32>,
13041317
// CHECK: %[[QUARTER:.*]] = arith.constant 2.500000e-01 : f32
13051318
// CHECK: %[[SQRT_MAX:.*]] = math.sqrt %[[MAX]] fastmath<nnan,contract> : f32
13061319
// CHECK: %[[POW:.*]] = math.powf %[[RATIO_SQ_PLUS_ONE]], %[[QUARTER]] fastmath<nnan,contract> : f32
1307-
// CHECK: %[[SQRT_ABS_OR_NAN:.*]] = arith.mulf %[[SQRT_MAX]], %[[POW]] fastmath<nnan,contract> : f32
1308-
// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[SQRT_ABS_OR_NAN]], %[[SQRT_ABS_OR_NAN]] fastmath<nnan,contract> : f32
1309-
// CHECK: %[[SQRT_ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[SQRT_ABS_OR_NAN]] : f32
1320+
// CHECK: %[[SQRT_ABS_OR_POISON:.*]] = arith.mulf %[[SQRT_MAX]], %[[POW]] fastmath<nnan,contract> : f32
1321+
// CHECK: %[[ZERO_2:.*]] = arith.constant 0.000000e+00
1322+
// CHECK: %[[INF:.*]] = arith.constant 0x7F800000
1323+
// CHECK: %[[MAX_IS_INF:.*]] = arith.cmpf oeq, %[[MAX]], %[[INF]] fastmath<nnan,contract> : f32
1324+
// CHECK: %[[MIN_IS_INF:.*]] = arith.cmpf oeq, %[[MIN]], %[[INF]] fastmath<nnan,contract> : f32
1325+
// CHECK: %[[RESULT_IS_INF:.*]] = arith.andi %[[MAX_IS_INF]], %[[MIN_IS_INF]]
1326+
// CHECK: %[[RESULT_IS_ZERO:.*]] = arith.cmpf oeq, %[[MAX]], %[[ZERO_2]] fastmath<nnan,contract> : f32
1327+
// CHECK: %[[ABS_OR_INF:.*]] = arith.select %[[RESULT_IS_INF]], %[[INF]], %[[SQRT_ABS_OR_POISON]] : f32
1328+
// CHECK: %[[SQRT_ABS:.*]] = arith.select %[[RESULT_IS_ZERO]], %[[ZERO_2]], %[[ABS_OR_INF]] : f32
13101329
// CHECK: %[[ARGARG:.*]] = math.atan2 %[[IM]], %[[RE]] fastmath<nnan,contract> : f32
13111330
// CHECK: %[[SQRTARG:.*]] = arith.mulf %[[ARGARG]], %[[HALF]] fastmath<nnan,contract> : f32
13121331
// CHECK: %[[COS:.*]] = math.cos %[[SQRTARG]] fastmath<nnan,contract> : f32
@@ -1543,9 +1562,15 @@ func.func @complex_atan2_with_fmf(%lhs: complex<f32>,
15431562
// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] fastmath<nnan,contract> : f32
15441563
// CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
15451564
// CHECK: %[[SQRT:.*]] = math.sqrt %[[RATIO_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
1546-
// CHECK: %[[ABS_OR_NAN:.*]] = arith.mulf %[[MAX]], %[[SQRT]] fastmath<nnan,contract> : f32
1547-
// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[ABS_OR_NAN]], %[[ABS_OR_NAN]] fastmath<nnan,contract> : f32
1548-
// CHECK: %[[ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[ABS_OR_NAN]] : f32
1565+
// CHECK: %[[ABS_OR_POISON:.*]] = arith.mulf %[[MAX]], %[[SQRT]] fastmath<nnan,contract> : f32
1566+
// CHECK: %[[ZERO_3:.*]] = arith.constant 0.000000e+00
1567+
// CHECK: %[[INF:.*]] = arith.constant 0x7F800000
1568+
// CHECK: %[[MAX_IS_INF:.*]] = arith.cmpf oeq, %[[MAX]], %[[INF]] fastmath<nnan,contract> : f32
1569+
// CHECK: %[[MIN_IS_INF:.*]] = arith.cmpf oeq, %[[MIN]], %[[INF]] fastmath<nnan,contract> : f32
1570+
// CHECK: %[[RESULT_IS_INF:.*]] = arith.andi %[[MAX_IS_INF]], %[[MIN_IS_INF]]
1571+
// CHECK: %[[RESULT_IS_ZERO:.*]] = arith.cmpf oeq, %[[MAX]], %[[ZERO_3]] fastmath<nnan,contract> : f32
1572+
// CHECK: %[[ABS_OR_INF:.*]] = arith.select %[[RESULT_IS_INF]], %[[INF]], %[[ABS_OR_POISON]] : f32
1573+
// CHECK: %[[ABS:.*]] = arith.select %[[RESULT_IS_ZERO]], %[[ZERO_3]], %[[ABS_OR_INF]] : f32
15491574
// CHECK: %[[VAR436:.*]] = math.log %[[ABS]] fastmath<nnan,contract> : f32
15501575
// CHECK: %[[VAR437:.*]] = complex.re %[[VAR415]] : complex<f32>
15511576
// CHECK: %[[VAR438:.*]] = complex.im %[[VAR415]] : complex<f32>
@@ -1784,9 +1809,15 @@ func.func @complex_sqrt_with_fmf(%arg: complex<f32>) -> complex<f32> {
17841809
// CHECK: %[[QUARTER:.*]] = arith.constant 2.500000e-01 : f32
17851810
// CHECK: %[[SQRT_MAX:.*]] = math.sqrt %[[MAX]] fastmath<nnan,contract> : f32
17861811
// CHECK: %[[POW:.*]] = math.powf %[[RATIO_SQ_PLUS_ONE]], %[[QUARTER]] fastmath<nnan,contract> : f32
1787-
// CHECK: %[[SQRT_ABS_OR_NAN:.*]] = arith.mulf %[[SQRT_MAX]], %[[POW]] fastmath<nnan,contract> : f32
1788-
// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[SQRT_ABS_OR_NAN]], %[[SQRT_ABS_OR_NAN]] fastmath<nnan,contract> : f32
1789-
// CHECK: %[[SQRT_ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[SQRT_ABS_OR_NAN]] : f32
1812+
// CHECK: %[[SQRT_ABS_OR_POISON:.*]] = arith.mulf %[[SQRT_MAX]], %[[POW]] fastmath<nnan,contract> : f32
1813+
// CHECK: %[[ZERO_2:.*]] = arith.constant 0.000000e+00
1814+
// CHECK: %[[INF:.*]] = arith.constant 0x7F800000
1815+
// CHECK: %[[MAX_IS_INF:.*]] = arith.cmpf oeq, %[[MAX]], %[[INF]] fastmath<nnan,contract> : f32
1816+
// CHECK: %[[MIN_IS_INF:.*]] = arith.cmpf oeq, %[[MIN]], %[[INF]] fastmath<nnan,contract> : f32
1817+
// CHECK: %[[RESULT_IS_INF:.*]] = arith.andi %[[MAX_IS_INF]], %[[MIN_IS_INF]]
1818+
// CHECK: %[[RESULT_IS_ZERO:.*]] = arith.cmpf oeq, %[[MAX]], %[[ZERO_2]] fastmath<nnan,contract> : f32
1819+
// CHECK: %[[SQRT_ABS_OR_INF:.*]] = arith.select %[[RESULT_IS_INF]], %[[INF]], %[[SQRT_ABS_OR_POISON]] : f32
1820+
// CHECK: %[[SQRT_ABS:.*]] = arith.select %[[RESULT_IS_ZERO]], %[[ZERO_2]], %[[SQRT_ABS_OR_INF]] : f32
17901821
// CHECK: %[[ARGARG:.*]] = math.atan2 %[[IM]], %[[RE]] fastmath<nnan,contract> : f32
17911822
// CHECK: %[[SQRTARG:.*]] = arith.mulf %[[ARGARG]], %[[HALF]] fastmath<nnan,contract> : f32
17921823
// CHECK: %[[COS:.*]] = math.cos %[[SQRTARG]] fastmath<nnan,contract> : f32
@@ -1890,9 +1921,15 @@ func.func @complex_sign_with_fmf(%arg: complex<f32>) -> complex<f32> {
18901921
// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] fastmath<nnan,contract> : f32
18911922
// CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
18921923
// CHECK: %[[SQRT:.*]] = math.sqrt %[[RATIO_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
1893-
// CHECK: %[[ABS_OR_NAN:.*]] = arith.mulf %[[MAX]], %[[SQRT]] fastmath<nnan,contract> : f32
1894-
// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[ABS_OR_NAN]], %[[ABS_OR_NAN]] fastmath<nnan,contract> : f32
1895-
// CHECK: %[[ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[ABS_OR_NAN]] : f32
1924+
// CHECK: %[[ABS_OR_POISON:.*]] = arith.mulf %[[MAX]], %[[SQRT]] fastmath<nnan,contract> : f32
1925+
// CHECK: %[[ZERO_2:.*]] = arith.constant 0.000000e+00
1926+
// CHECK: %[[INF:.*]] = arith.constant 0x7F800000
1927+
// CHECK: %[[MAX_IS_INF:.*]] = arith.cmpf oeq, %[[MAX]], %[[INF]] fastmath<nnan,contract> : f32
1928+
// CHECK: %[[MIN_IS_INF:.*]] = arith.cmpf oeq, %[[MIN]], %[[INF]] fastmath<nnan,contract> : f32
1929+
// CHECK: %[[RESULT_IS_INF:.*]] = arith.andi %[[MAX_IS_INF]], %[[MIN_IS_INF]]
1930+
// CHECK: %[[RESULT_IS_ZERO:.*]] = arith.cmpf oeq, %[[MAX]], %[[ZERO_2]] fastmath<nnan,contract> : f32
1931+
// CHECK: %[[ABS_OR_INF:.*]] = arith.select %[[RESULT_IS_INF]], %[[INF]], %[[ABS_OR_POISON]] : f32
1932+
// CHECK: %[[ABS:.*]] = arith.select %[[RESULT_IS_ZERO]], %[[ZERO_2]], %[[ABS_OR_INF]] : f32
18961933
// CHECK: %[[REAL_SIGN:.*]] = arith.divf %[[REAL]], %[[ABS]] fastmath<nnan,contract> : f32
18971934
// CHECK: %[[IMAG_SIGN:.*]] = arith.divf %[[IMAG]], %[[ABS]] fastmath<nnan,contract> : f32
18981935
// CHECK: %[[SIGN:.*]] = complex.create %[[REAL_SIGN]], %[[IMAG_SIGN]] : complex<f32>

0 commit comments

Comments
 (0)