Skip to content

Commit 02c2e04

Browse files
committed
[sil-optimizer] Add FP comparison support in constant folder
1 parent 031f112 commit 02c2e04

File tree

5 files changed

+746
-9
lines changed

5 files changed

+746
-9
lines changed

include/swift/SILOptimizer/Utils/ConstantFolding.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,16 @@ class SILOptFunctionBuilder;
3131
/// The \p ID must be the ID of a binary bit-operation builtin.
3232
APInt constantFoldBitOperation(APInt lhs, APInt rhs, BuiltinValueKind ID);
3333

34+
/// Evaluates the constant result of a floating point comparison.
35+
///
36+
/// The \p ID must be the ID of a floating point builtin operation.
37+
APInt constantFoldComparisonFloat(APFloat lhs, APFloat rhs,
38+
BuiltinValueKind ID);
39+
3440
/// Evaluates the constant result of an integer comparison.
3541
///
3642
/// The \p ID must be the ID of an integer builtin operation.
37-
APInt constantFoldComparison(APInt lhs, APInt rhs, BuiltinValueKind ID);
43+
APInt constantFoldComparisonInt(APInt lhs, APInt rhs, BuiltinValueKind ID);
3844

3945
/// Evaluates the constant result of a binary operation with overflow.
4046
///

lib/SILOptimizer/Utils/ConstantFolding.cpp

Lines changed: 300 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,66 @@ APInt swift::constantFoldBitOperation(APInt lhs, APInt rhs, BuiltinValueKind ID)
5050
}
5151
}
5252

53-
APInt swift::constantFoldComparison(APInt lhs, APInt rhs, BuiltinValueKind ID) {
53+
APInt swift::constantFoldComparisonFloat(APFloat lhs, APFloat rhs,
54+
BuiltinValueKind ID) {
55+
bool result;
56+
bool isOrdered = !lhs.isNaN() && !rhs.isNaN();
57+
58+
switch (ID) {
59+
default:
60+
llvm_unreachable("Invalid float compare kind");
61+
// Ordered comparisons
62+
case BuiltinValueKind::FCMP_OEQ:
63+
result = isOrdered && lhs == rhs;
64+
break;
65+
case BuiltinValueKind::FCMP_OGT:
66+
result = isOrdered && lhs > rhs;
67+
break;
68+
case BuiltinValueKind::FCMP_OGE:
69+
result = isOrdered && lhs >= rhs;
70+
break;
71+
case BuiltinValueKind::FCMP_OLT:
72+
result = isOrdered && lhs < rhs;
73+
break;
74+
case BuiltinValueKind::FCMP_OLE:
75+
result = isOrdered && lhs <= rhs;
76+
break;
77+
case BuiltinValueKind::FCMP_ONE:
78+
result = isOrdered && lhs != rhs;
79+
break;
80+
case BuiltinValueKind::FCMP_ORD:
81+
result = isOrdered;
82+
break;
83+
84+
// Unordered comparisons
85+
case BuiltinValueKind::FCMP_UEQ:
86+
result = !isOrdered || lhs == rhs;
87+
break;
88+
case BuiltinValueKind::FCMP_UGT:
89+
result = !isOrdered || lhs > rhs;
90+
break;
91+
case BuiltinValueKind::FCMP_UGE:
92+
result = !isOrdered || lhs >= rhs;
93+
break;
94+
case BuiltinValueKind::FCMP_ULT:
95+
result = !isOrdered || lhs < rhs;
96+
break;
97+
case BuiltinValueKind::FCMP_ULE:
98+
result = !isOrdered || lhs <= rhs;
99+
break;
100+
case BuiltinValueKind::FCMP_UNE:
101+
result = !isOrdered || lhs != rhs;
102+
break;
103+
case BuiltinValueKind::FCMP_UNO:
104+
result = !isOrdered;
105+
break;
106+
}
107+
108+
return APInt(1, result);
109+
}
110+
111+
APInt swift::constantFoldComparisonInt(APInt lhs, APInt rhs,
112+
BuiltinValueKind ID) {
54113
bool result;
55114
switch (ID) {
56115
default: llvm_unreachable("Invalid integer compare kind");
@@ -351,14 +410,235 @@ static SILValue constantFoldIntrinsic(BuiltinInst *BI, llvm::Intrinsic::ID ID,
351410
return nullptr;
352411
}
353412

354-
static SILValue constantFoldCompare(BuiltinInst *BI, BuiltinValueKind ID) {
413+
static SILValue constantFoldCompareFloat(BuiltinInst *BI, BuiltinValueKind ID) {
414+
static auto hasIEEEFloatNanBitRepr = [](const APInt val) -> bool {
415+
auto bitWidth = val.getBitWidth();
416+
if (bitWidth == 32) {
417+
APInt nanBitRepr =
418+
APFloat::getNaN(llvm::APFloatBase::IEEEsingle()).bitcastToAPInt();
419+
return bitWidth == nanBitRepr.getBitWidth() && val == nanBitRepr;
420+
} else {
421+
APInt nanBitRepr =
422+
APFloat::getNaN(llvm::APFloatBase::IEEEdouble()).bitcastToAPInt();
423+
return bitWidth == nanBitRepr.getBitWidth() && val == nanBitRepr;
424+
}
425+
};
426+
427+
static auto hasIEEEFloatPosInfBitRepr = [](const APInt val) -> bool {
428+
auto bitWidth = val.getBitWidth();
429+
if (bitWidth == 32) {
430+
APInt infBitRepr =
431+
APFloat::getInf(llvm::APFloatBase::IEEEsingle()).bitcastToAPInt();
432+
return bitWidth == infBitRepr.getBitWidth() && val == infBitRepr;
433+
} else {
434+
APInt infBitRepr =
435+
APFloat::getInf(llvm::APFloatBase::IEEEdouble()).bitcastToAPInt();
436+
return bitWidth == infBitRepr.getBitWidth() && val == infBitRepr;
437+
}
438+
};
439+
440+
OperandValueArrayRef Args = BI->getArguments();
441+
442+
// Fold for floating point constant arguments.
443+
auto *LHS = dyn_cast<FloatLiteralInst>(Args[0]);
444+
auto *RHS = dyn_cast<FloatLiteralInst>(Args[1]);
445+
if (LHS && RHS) {
446+
APInt Res =
447+
constantFoldComparisonFloat(LHS->getValue(), RHS->getValue(), ID);
448+
SILBuilderWithScope B(BI);
449+
return B.createIntegerLiteral(BI->getLoc(), BI->getType(), Res);
450+
}
451+
452+
using namespace swift::PatternMatch;
453+
454+
// Ordered comparisons with NaN always return false
455+
SILValue Other;
456+
IntegerLiteralInst *builtinArg;
457+
if (match(BI, m_CombineOr(
458+
m_BuiltinInst(BuiltinValueKind::FCMP_OEQ, // x == NaN
459+
m_SILValue(Other),
460+
m_BitCast(m_IntegerLiteralInst(builtinArg))),
461+
m_BuiltinInst(BuiltinValueKind::FCMP_OGT, // x > NaN
462+
m_SILValue(Other),
463+
m_BitCast(m_IntegerLiteralInst(builtinArg))),
464+
m_BuiltinInst(BuiltinValueKind::FCMP_OGE, // x >= NaN
465+
m_SILValue(Other),
466+
m_BitCast(m_IntegerLiteralInst(builtinArg))),
467+
m_BuiltinInst(BuiltinValueKind::FCMP_OLT, // x < NaN
468+
m_SILValue(Other),
469+
m_BitCast(m_IntegerLiteralInst(builtinArg))),
470+
m_BuiltinInst(BuiltinValueKind::FCMP_OLE, // x <= NaN
471+
m_SILValue(Other),
472+
m_BitCast(m_IntegerLiteralInst(builtinArg))),
473+
m_BuiltinInst(BuiltinValueKind::FCMP_ONE, // x != NaN
474+
m_SILValue(Other),
475+
m_BitCast(m_IntegerLiteralInst(builtinArg))),
476+
m_BuiltinInst(BuiltinValueKind::FCMP_OEQ, // NaN == x
477+
m_BitCast(m_IntegerLiteralInst(builtinArg)),
478+
m_SILValue(Other)),
479+
m_BuiltinInst(BuiltinValueKind::FCMP_OGT, // NaN > x
480+
m_BitCast(m_IntegerLiteralInst(builtinArg)),
481+
m_SILValue(Other)),
482+
m_BuiltinInst(BuiltinValueKind::FCMP_OGE, // NaN >= x
483+
m_BitCast(m_IntegerLiteralInst(builtinArg)),
484+
m_SILValue(Other)),
485+
m_BuiltinInst(BuiltinValueKind::FCMP_OLT, // NaN < x
486+
m_BitCast(m_IntegerLiteralInst(builtinArg)),
487+
m_SILValue(Other)),
488+
m_BuiltinInst(BuiltinValueKind::FCMP_OLE, // NaN <= x
489+
m_BitCast(m_IntegerLiteralInst(builtinArg)),
490+
m_SILValue(Other)),
491+
m_BuiltinInst(BuiltinValueKind::FCMP_ONE, // NaN != x
492+
m_BitCast(m_IntegerLiteralInst(builtinArg)),
493+
m_SILValue(Other))))) {
494+
APInt val = builtinArg->getValue();
495+
if (hasIEEEFloatNanBitRepr(val)) {
496+
SILBuilderWithScope B(BI);
497+
return B.createIntegerLiteral(BI->getLoc(), BI->getType(), APInt(1, 0));
498+
}
499+
}
500+
501+
// Unordered comparisons with NaN always return true
502+
if (match(BI, m_CombineOr(
503+
m_BuiltinInst(BuiltinValueKind::FCMP_UEQ, // x == NaN
504+
m_SILValue(Other),
505+
m_BitCast(m_IntegerLiteralInst(builtinArg))),
506+
m_BuiltinInst(BuiltinValueKind::FCMP_UGT, // x > NaN
507+
m_SILValue(Other),
508+
m_BitCast(m_IntegerLiteralInst(builtinArg))),
509+
m_BuiltinInst(BuiltinValueKind::FCMP_UGE, // x >= NaN
510+
m_SILValue(Other),
511+
m_BitCast(m_IntegerLiteralInst(builtinArg))),
512+
m_BuiltinInst(BuiltinValueKind::FCMP_ULT, // x < NaN
513+
m_SILValue(Other),
514+
m_BitCast(m_IntegerLiteralInst(builtinArg))),
515+
m_BuiltinInst(BuiltinValueKind::FCMP_ULE, // x <= NaN
516+
m_SILValue(Other),
517+
m_BitCast(m_IntegerLiteralInst(builtinArg))),
518+
m_BuiltinInst(BuiltinValueKind::FCMP_UNE, // x != NaN
519+
m_SILValue(Other),
520+
m_BitCast(m_IntegerLiteralInst(builtinArg))),
521+
m_BuiltinInst(BuiltinValueKind::FCMP_UEQ, // NaN == x
522+
m_BitCast(m_IntegerLiteralInst(builtinArg)),
523+
m_SILValue(Other)),
524+
m_BuiltinInst(BuiltinValueKind::FCMP_UGT, // NaN > x
525+
m_BitCast(m_IntegerLiteralInst(builtinArg)),
526+
m_SILValue(Other)),
527+
m_BuiltinInst(BuiltinValueKind::FCMP_UGE, // NaN >= x
528+
m_BitCast(m_IntegerLiteralInst(builtinArg)),
529+
m_SILValue(Other)),
530+
m_BuiltinInst(BuiltinValueKind::FCMP_ULT, // NaN < x
531+
m_BitCast(m_IntegerLiteralInst(builtinArg)),
532+
m_SILValue(Other)),
533+
m_BuiltinInst(BuiltinValueKind::FCMP_ULE, // NaN <= x
534+
m_BitCast(m_IntegerLiteralInst(builtinArg)),
535+
m_SILValue(Other)),
536+
m_BuiltinInst(BuiltinValueKind::FCMP_UNE, // NaN != x
537+
m_BitCast(m_IntegerLiteralInst(builtinArg)),
538+
m_SILValue(Other))))) {
539+
APInt val = builtinArg->getValue();
540+
if (hasIEEEFloatNanBitRepr(val)) {
541+
SILBuilderWithScope B(BI);
542+
return B.createIntegerLiteral(BI->getLoc(), BI->getType(), APInt(1, 1));
543+
}
544+
}
545+
546+
// Everything is less than or equal positive infinity
547+
if (match(BI,
548+
m_CombineOr(
549+
m_BuiltinInst(BuiltinValueKind::FCMP_OGT, // Inf > x
550+
m_BitCast(m_IntegerLiteralInst(builtinArg)),
551+
m_SILValue(Other)),
552+
m_BuiltinInst(BuiltinValueKind::FCMP_OGE, // Inf >= x
553+
m_BitCast(m_IntegerLiteralInst(builtinArg)),
554+
m_SILValue(Other)),
555+
m_BuiltinInst(BuiltinValueKind::FCMP_OLT, // x < Inf
556+
m_SILValue(Other),
557+
m_BitCast(m_IntegerLiteralInst(builtinArg))),
558+
m_BuiltinInst(BuiltinValueKind::FCMP_OLE, // x <= Inf
559+
m_SILValue(Other),
560+
m_BitCast(m_IntegerLiteralInst(builtinArg))),
561+
m_BuiltinInst(BuiltinValueKind::FCMP_UGT, // Inf > x
562+
m_BitCast(m_IntegerLiteralInst(builtinArg)),
563+
m_SILValue(Other)),
564+
m_BuiltinInst(BuiltinValueKind::FCMP_UGE, // Inf >= x
565+
m_BitCast(m_IntegerLiteralInst(builtinArg)),
566+
m_SILValue(Other)),
567+
m_BuiltinInst(BuiltinValueKind::FCMP_ULT, // x < Inf
568+
m_SILValue(Other),
569+
m_BitCast(m_IntegerLiteralInst(builtinArg))),
570+
m_BuiltinInst(BuiltinValueKind::FCMP_ULE, // x <= Inf
571+
m_SILValue(Other),
572+
m_BitCast(m_IntegerLiteralInst(builtinArg)))))) {
573+
APInt val = builtinArg->getValue();
574+
if (hasIEEEFloatPosInfBitRepr(val)) {
575+
SILBuilderWithScope B(BI);
576+
return B.createIntegerLiteral(BI->getLoc(), BI->getType(), APInt(1, 1));
577+
}
578+
}
579+
580+
// Positive infinity is not less than or equal to anything
581+
if (match(BI, m_CombineOr(
582+
m_BuiltinInst(BuiltinValueKind::FCMP_OGT, // x > Inf
583+
m_SILValue(Other),
584+
m_BitCast(m_IntegerLiteralInst(builtinArg))),
585+
m_BuiltinInst(BuiltinValueKind::FCMP_OGE, // x >= Inf
586+
m_SILValue(Other),
587+
m_BitCast(m_IntegerLiteralInst(builtinArg))),
588+
m_BuiltinInst(BuiltinValueKind::FCMP_OLT, // Inf < x
589+
m_BitCast(m_IntegerLiteralInst(builtinArg)),
590+
m_SILValue(Other)),
591+
m_BuiltinInst(BuiltinValueKind::FCMP_OLE, // Inf <= x
592+
m_BitCast(m_IntegerLiteralInst(builtinArg)),
593+
m_SILValue(Other)),
594+
m_BuiltinInst(BuiltinValueKind::FCMP_UGT, // x > Inf
595+
m_SILValue(Other),
596+
m_BitCast(m_IntegerLiteralInst(builtinArg))),
597+
m_BuiltinInst(BuiltinValueKind::FCMP_UGE, // x >= Inf
598+
m_SILValue(Other),
599+
m_BitCast(m_IntegerLiteralInst(builtinArg))),
600+
m_BuiltinInst(BuiltinValueKind::FCMP_ULT, // Inf < x
601+
m_BitCast(m_IntegerLiteralInst(builtinArg)),
602+
m_SILValue(Other)),
603+
m_BuiltinInst(BuiltinValueKind::FCMP_ULE, // Inf <= x
604+
m_BitCast(m_IntegerLiteralInst(builtinArg)),
605+
m_SILValue(Other))))) {
606+
APInt val = builtinArg->getValue();
607+
if (hasIEEEFloatPosInfBitRepr(val)) {
608+
SILBuilderWithScope B(BI);
609+
return B.createIntegerLiteral(BI->getLoc(), BI->getType(), APInt(1, 0));
610+
}
611+
}
612+
613+
// Everything is less than or equal to (but not necessarily less than) MAX
614+
// float
615+
FloatLiteralInst *max;
616+
if (match(BI,
617+
m_CombineOr(
618+
m_BuiltinInst(BuiltinValueKind::FCMP_OGE, // MAX >= x
619+
m_FloatLiteralInst(max), m_SILValue(Other)),
620+
m_BuiltinInst(BuiltinValueKind::FCMP_OLE, // x <= MAX
621+
m_SILValue(Other), m_FloatLiteralInst(max)),
622+
m_BuiltinInst(BuiltinValueKind::FCMP_UGE, // MAX >= x
623+
m_FloatLiteralInst(max), m_SILValue(Other)),
624+
m_BuiltinInst(BuiltinValueKind::FCMP_ULE, // x <= MAX
625+
m_SILValue(Other), m_FloatLiteralInst(max)))) &&
626+
max->getValue().isLargest()) {
627+
SILBuilderWithScope B(BI);
628+
return B.createIntegerLiteral(BI->getLoc(), BI->getType(), APInt(1, 1));
629+
}
630+
631+
return nullptr;
632+
}
633+
634+
static SILValue constantFoldCompareInt(BuiltinInst *BI, BuiltinValueKind ID) {
355635
OperandValueArrayRef Args = BI->getArguments();
356636

357637
// Fold for integer constant arguments.
358638
auto *LHS = dyn_cast<IntegerLiteralInst>(Args[0]);
359639
auto *RHS = dyn_cast<IntegerLiteralInst>(Args[1]);
360640
if (LHS && RHS) {
361-
APInt Res = constantFoldComparison(LHS->getValue(), RHS->getValue(), ID);
641+
APInt Res = constantFoldComparisonInt(LHS->getValue(), RHS->getValue(), ID);
362642
SILBuilderWithScope B(BI);
363643
return B.createIntegerLiteral(BI->getLoc(), BI->getType(), Res);
364644
}
@@ -480,6 +760,17 @@ static SILValue constantFoldCompare(BuiltinInst *BI, BuiltinValueKind ID) {
480760
return nullptr;
481761
}
482762

763+
static SILValue constantFoldCompare(BuiltinInst *BI, BuiltinValueKind ID) {
764+
// Try folding integer comparison
765+
if (auto result = constantFoldCompareInt(BI, ID))
766+
return result;
767+
// Try folding floating point comparison
768+
if (auto result = constantFoldCompareFloat(BI, ID))
769+
return result;
770+
// Else, return nullptr
771+
return nullptr;
772+
}
773+
483774
static SILValue
484775
constantFoldAndCheckDivision(BuiltinInst *BI, BuiltinValueKind ID,
485776
llvm::Optional<bool> &ResultsInError) {
@@ -1893,6 +2184,12 @@ ConstantFolder::processWorkList() {
18932184
}
18942185
}
18952186

2187+
// If the user is a bitcast, we may be able to constant
2188+
// fold its users.
2189+
if (isApplyOfBuiltin(*User, BuiltinValueKind::BitCast)) {
2190+
WorkList.insert(User);
2191+
}
2192+
18962193
// Initialize ResultsInError as a None optional.
18972194
//
18982195
// We are essentially using this optional to represent 3 states: true,

lib/SILOptimizer/Utils/PerformanceInlinerUtils.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,9 @@ case BuiltinValueKind::id:
190190
IntConst lhs = getIntConst(Args[0], depth);
191191
IntConst rhs = getIntConst(Args[1], depth);
192192
if (lhs.isValid && rhs.isValid) {
193-
return IntConst(constantFoldComparison(lhs.value, rhs.value,
194-
Builtin.ID),
195-
lhs.isFromCaller || rhs.isFromCaller);
193+
return IntConst(
194+
constantFoldComparisonInt(lhs.value, rhs.value, Builtin.ID),
195+
lhs.isFromCaller || rhs.isFromCaller);
196196
}
197197
break;
198198
}

test/AutoDiff/SILOptimizer/vjp_and_pullback_inlining.swift

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,12 @@ func caller_of_simple_vjp() -> Float {
4040
@_silgen_name("pb_with_control_flow")
4141
func pb_with_control_flow(_ x: Float) -> Float {
4242
if (x > 0) {
43-
return sin(x) * cos(x)
43+
let a = x * x;
44+
let b = x + x;
45+
let c = x * a;
46+
let d = a + b;
47+
let e = b * c;
48+
return a * b / c + d - e ;
4449
} else {
4550
return sin(x) + cos(x)
4651
}
@@ -55,7 +60,6 @@ func caller_of_pb_with_control_flow() -> Float {
5560
// CHECK: decision {{{.*}}, b=70, {{.*}}} pb_with_control_flowTJpSpSr
5661
// CHECK-NEXT: "pb_with_control_flowTJpSpSr" inlined into "caller_of_pb_with_control_flow"
5762

58-
5963
@differentiable(reverse)
6064
func double(x: Float) -> Float {
6165
return x + x

0 commit comments

Comments
 (0)