@@ -50,7 +50,66 @@ APInt swift::constantFoldBitOperation(APInt lhs, APInt rhs, BuiltinValueKind ID)
50
50
}
51
51
}
52
52
53
- APInt swift::constantFoldComparison (APInt lhs, APInt rhs, BuiltinValueKind ID) {
53
+ APInt swift::constantFoldComparisonFloat (APFloat lhs, APFloat rhs,
54
+ BuiltinValueKind ID) {
55
+ bool result;
56
+ bool isOrdered = !lhs.isNaN () && !rhs.isNaN ();
57
+
58
+ switch (ID) {
59
+ default :
60
+ llvm_unreachable (" Invalid float compare kind" );
61
+ // Ordered comparisons
62
+ case BuiltinValueKind::FCMP_OEQ:
63
+ result = isOrdered && lhs == rhs;
64
+ break ;
65
+ case BuiltinValueKind::FCMP_OGT:
66
+ result = isOrdered && lhs > rhs;
67
+ break ;
68
+ case BuiltinValueKind::FCMP_OGE:
69
+ result = isOrdered && lhs >= rhs;
70
+ break ;
71
+ case BuiltinValueKind::FCMP_OLT:
72
+ result = isOrdered && lhs < rhs;
73
+ break ;
74
+ case BuiltinValueKind::FCMP_OLE:
75
+ result = isOrdered && lhs <= rhs;
76
+ break ;
77
+ case BuiltinValueKind::FCMP_ONE:
78
+ result = isOrdered && lhs != rhs;
79
+ break ;
80
+ case BuiltinValueKind::FCMP_ORD:
81
+ result = isOrdered;
82
+ break ;
83
+
84
+ // Unordered comparisons
85
+ case BuiltinValueKind::FCMP_UEQ:
86
+ result = !isOrdered || lhs == rhs;
87
+ break ;
88
+ case BuiltinValueKind::FCMP_UGT:
89
+ result = !isOrdered || lhs > rhs;
90
+ break ;
91
+ case BuiltinValueKind::FCMP_UGE:
92
+ result = !isOrdered || lhs >= rhs;
93
+ break ;
94
+ case BuiltinValueKind::FCMP_ULT:
95
+ result = !isOrdered || lhs < rhs;
96
+ break ;
97
+ case BuiltinValueKind::FCMP_ULE:
98
+ result = !isOrdered || lhs <= rhs;
99
+ break ;
100
+ case BuiltinValueKind::FCMP_UNE:
101
+ result = !isOrdered || lhs != rhs;
102
+ break ;
103
+ case BuiltinValueKind::FCMP_UNO:
104
+ result = !isOrdered;
105
+ break ;
106
+ }
107
+
108
+ return APInt (1 , result);
109
+ }
110
+
111
+ APInt swift::constantFoldComparisonInt (APInt lhs, APInt rhs,
112
+ BuiltinValueKind ID) {
54
113
bool result;
55
114
switch (ID) {
56
115
default : llvm_unreachable (" Invalid integer compare kind" );
@@ -351,14 +410,235 @@ static SILValue constantFoldIntrinsic(BuiltinInst *BI, llvm::Intrinsic::ID ID,
351
410
return nullptr ;
352
411
}
353
412
354
- static SILValue constantFoldCompare (BuiltinInst *BI, BuiltinValueKind ID) {
413
+ static SILValue constantFoldCompareFloat (BuiltinInst *BI, BuiltinValueKind ID) {
414
+ static auto hasIEEEFloatNanBitRepr = [](const APInt val) -> bool {
415
+ auto bitWidth = val.getBitWidth ();
416
+ if (bitWidth == 32 ) {
417
+ APInt nanBitRepr =
418
+ APFloat::getNaN (llvm::APFloatBase::IEEEsingle ()).bitcastToAPInt ();
419
+ return bitWidth == nanBitRepr.getBitWidth () && val == nanBitRepr;
420
+ } else {
421
+ APInt nanBitRepr =
422
+ APFloat::getNaN (llvm::APFloatBase::IEEEdouble ()).bitcastToAPInt ();
423
+ return bitWidth == nanBitRepr.getBitWidth () && val == nanBitRepr;
424
+ }
425
+ };
426
+
427
+ static auto hasIEEEFloatPosInfBitRepr = [](const APInt val) -> bool {
428
+ auto bitWidth = val.getBitWidth ();
429
+ if (bitWidth == 32 ) {
430
+ APInt infBitRepr =
431
+ APFloat::getInf (llvm::APFloatBase::IEEEsingle ()).bitcastToAPInt ();
432
+ return bitWidth == infBitRepr.getBitWidth () && val == infBitRepr;
433
+ } else {
434
+ APInt infBitRepr =
435
+ APFloat::getInf (llvm::APFloatBase::IEEEdouble ()).bitcastToAPInt ();
436
+ return bitWidth == infBitRepr.getBitWidth () && val == infBitRepr;
437
+ }
438
+ };
439
+
440
+ OperandValueArrayRef Args = BI->getArguments ();
441
+
442
+ // Fold for floating point constant arguments.
443
+ auto *LHS = dyn_cast<FloatLiteralInst>(Args[0 ]);
444
+ auto *RHS = dyn_cast<FloatLiteralInst>(Args[1 ]);
445
+ if (LHS && RHS) {
446
+ APInt Res =
447
+ constantFoldComparisonFloat (LHS->getValue (), RHS->getValue (), ID);
448
+ SILBuilderWithScope B (BI);
449
+ return B.createIntegerLiteral (BI->getLoc (), BI->getType (), Res);
450
+ }
451
+
452
+ using namespace swift ::PatternMatch;
453
+
454
+ // Ordered comparisons with NaN always return false
455
+ SILValue Other;
456
+ IntegerLiteralInst *builtinArg;
457
+ if (match (BI, m_CombineOr (
458
+ m_BuiltinInst (BuiltinValueKind::FCMP_OEQ, // x == NaN
459
+ m_SILValue (Other),
460
+ m_BitCast (m_IntegerLiteralInst (builtinArg))),
461
+ m_BuiltinInst (BuiltinValueKind::FCMP_OGT, // x > NaN
462
+ m_SILValue (Other),
463
+ m_BitCast (m_IntegerLiteralInst (builtinArg))),
464
+ m_BuiltinInst (BuiltinValueKind::FCMP_OGE, // x >= NaN
465
+ m_SILValue (Other),
466
+ m_BitCast (m_IntegerLiteralInst (builtinArg))),
467
+ m_BuiltinInst (BuiltinValueKind::FCMP_OLT, // x < NaN
468
+ m_SILValue (Other),
469
+ m_BitCast (m_IntegerLiteralInst (builtinArg))),
470
+ m_BuiltinInst (BuiltinValueKind::FCMP_OLE, // x <= NaN
471
+ m_SILValue (Other),
472
+ m_BitCast (m_IntegerLiteralInst (builtinArg))),
473
+ m_BuiltinInst (BuiltinValueKind::FCMP_ONE, // x != NaN
474
+ m_SILValue (Other),
475
+ m_BitCast (m_IntegerLiteralInst (builtinArg))),
476
+ m_BuiltinInst (BuiltinValueKind::FCMP_OEQ, // NaN == x
477
+ m_BitCast (m_IntegerLiteralInst (builtinArg)),
478
+ m_SILValue (Other)),
479
+ m_BuiltinInst (BuiltinValueKind::FCMP_OGT, // NaN > x
480
+ m_BitCast (m_IntegerLiteralInst (builtinArg)),
481
+ m_SILValue (Other)),
482
+ m_BuiltinInst (BuiltinValueKind::FCMP_OGE, // NaN >= x
483
+ m_BitCast (m_IntegerLiteralInst (builtinArg)),
484
+ m_SILValue (Other)),
485
+ m_BuiltinInst (BuiltinValueKind::FCMP_OLT, // NaN < x
486
+ m_BitCast (m_IntegerLiteralInst (builtinArg)),
487
+ m_SILValue (Other)),
488
+ m_BuiltinInst (BuiltinValueKind::FCMP_OLE, // NaN <= x
489
+ m_BitCast (m_IntegerLiteralInst (builtinArg)),
490
+ m_SILValue (Other)),
491
+ m_BuiltinInst (BuiltinValueKind::FCMP_ONE, // NaN != x
492
+ m_BitCast (m_IntegerLiteralInst (builtinArg)),
493
+ m_SILValue (Other))))) {
494
+ APInt val = builtinArg->getValue ();
495
+ if (hasIEEEFloatNanBitRepr (val)) {
496
+ SILBuilderWithScope B (BI);
497
+ return B.createIntegerLiteral (BI->getLoc (), BI->getType (), APInt (1 , 0 ));
498
+ }
499
+ }
500
+
501
+ // Unordered comparisons with NaN always return true
502
+ if (match (BI, m_CombineOr (
503
+ m_BuiltinInst (BuiltinValueKind::FCMP_UEQ, // x == NaN
504
+ m_SILValue (Other),
505
+ m_BitCast (m_IntegerLiteralInst (builtinArg))),
506
+ m_BuiltinInst (BuiltinValueKind::FCMP_UGT, // x > NaN
507
+ m_SILValue (Other),
508
+ m_BitCast (m_IntegerLiteralInst (builtinArg))),
509
+ m_BuiltinInst (BuiltinValueKind::FCMP_UGE, // x >= NaN
510
+ m_SILValue (Other),
511
+ m_BitCast (m_IntegerLiteralInst (builtinArg))),
512
+ m_BuiltinInst (BuiltinValueKind::FCMP_ULT, // x < NaN
513
+ m_SILValue (Other),
514
+ m_BitCast (m_IntegerLiteralInst (builtinArg))),
515
+ m_BuiltinInst (BuiltinValueKind::FCMP_ULE, // x <= NaN
516
+ m_SILValue (Other),
517
+ m_BitCast (m_IntegerLiteralInst (builtinArg))),
518
+ m_BuiltinInst (BuiltinValueKind::FCMP_UNE, // x != NaN
519
+ m_SILValue (Other),
520
+ m_BitCast (m_IntegerLiteralInst (builtinArg))),
521
+ m_BuiltinInst (BuiltinValueKind::FCMP_UEQ, // NaN == x
522
+ m_BitCast (m_IntegerLiteralInst (builtinArg)),
523
+ m_SILValue (Other)),
524
+ m_BuiltinInst (BuiltinValueKind::FCMP_UGT, // NaN > x
525
+ m_BitCast (m_IntegerLiteralInst (builtinArg)),
526
+ m_SILValue (Other)),
527
+ m_BuiltinInst (BuiltinValueKind::FCMP_UGE, // NaN >= x
528
+ m_BitCast (m_IntegerLiteralInst (builtinArg)),
529
+ m_SILValue (Other)),
530
+ m_BuiltinInst (BuiltinValueKind::FCMP_ULT, // NaN < x
531
+ m_BitCast (m_IntegerLiteralInst (builtinArg)),
532
+ m_SILValue (Other)),
533
+ m_BuiltinInst (BuiltinValueKind::FCMP_ULE, // NaN <= x
534
+ m_BitCast (m_IntegerLiteralInst (builtinArg)),
535
+ m_SILValue (Other)),
536
+ m_BuiltinInst (BuiltinValueKind::FCMP_UNE, // NaN != x
537
+ m_BitCast (m_IntegerLiteralInst (builtinArg)),
538
+ m_SILValue (Other))))) {
539
+ APInt val = builtinArg->getValue ();
540
+ if (hasIEEEFloatNanBitRepr (val)) {
541
+ SILBuilderWithScope B (BI);
542
+ return B.createIntegerLiteral (BI->getLoc (), BI->getType (), APInt (1 , 1 ));
543
+ }
544
+ }
545
+
546
+ // Everything is less than or equal positive infinity
547
+ if (match (BI,
548
+ m_CombineOr (
549
+ m_BuiltinInst (BuiltinValueKind::FCMP_OGT, // Inf > x
550
+ m_BitCast (m_IntegerLiteralInst (builtinArg)),
551
+ m_SILValue (Other)),
552
+ m_BuiltinInst (BuiltinValueKind::FCMP_OGE, // Inf >= x
553
+ m_BitCast (m_IntegerLiteralInst (builtinArg)),
554
+ m_SILValue (Other)),
555
+ m_BuiltinInst (BuiltinValueKind::FCMP_OLT, // x < Inf
556
+ m_SILValue (Other),
557
+ m_BitCast (m_IntegerLiteralInst (builtinArg))),
558
+ m_BuiltinInst (BuiltinValueKind::FCMP_OLE, // x <= Inf
559
+ m_SILValue (Other),
560
+ m_BitCast (m_IntegerLiteralInst (builtinArg))),
561
+ m_BuiltinInst (BuiltinValueKind::FCMP_UGT, // Inf > x
562
+ m_BitCast (m_IntegerLiteralInst (builtinArg)),
563
+ m_SILValue (Other)),
564
+ m_BuiltinInst (BuiltinValueKind::FCMP_UGE, // Inf >= x
565
+ m_BitCast (m_IntegerLiteralInst (builtinArg)),
566
+ m_SILValue (Other)),
567
+ m_BuiltinInst (BuiltinValueKind::FCMP_ULT, // x < Inf
568
+ m_SILValue (Other),
569
+ m_BitCast (m_IntegerLiteralInst (builtinArg))),
570
+ m_BuiltinInst (BuiltinValueKind::FCMP_ULE, // x <= Inf
571
+ m_SILValue (Other),
572
+ m_BitCast (m_IntegerLiteralInst (builtinArg)))))) {
573
+ APInt val = builtinArg->getValue ();
574
+ if (hasIEEEFloatPosInfBitRepr (val)) {
575
+ SILBuilderWithScope B (BI);
576
+ return B.createIntegerLiteral (BI->getLoc (), BI->getType (), APInt (1 , 1 ));
577
+ }
578
+ }
579
+
580
+ // Positive infinity is not less than or equal to anything
581
+ if (match (BI, m_CombineOr (
582
+ m_BuiltinInst (BuiltinValueKind::FCMP_OGT, // x > Inf
583
+ m_SILValue (Other),
584
+ m_BitCast (m_IntegerLiteralInst (builtinArg))),
585
+ m_BuiltinInst (BuiltinValueKind::FCMP_OGE, // x >= Inf
586
+ m_SILValue (Other),
587
+ m_BitCast (m_IntegerLiteralInst (builtinArg))),
588
+ m_BuiltinInst (BuiltinValueKind::FCMP_OLT, // Inf < x
589
+ m_BitCast (m_IntegerLiteralInst (builtinArg)),
590
+ m_SILValue (Other)),
591
+ m_BuiltinInst (BuiltinValueKind::FCMP_OLE, // Inf <= x
592
+ m_BitCast (m_IntegerLiteralInst (builtinArg)),
593
+ m_SILValue (Other)),
594
+ m_BuiltinInst (BuiltinValueKind::FCMP_UGT, // x > Inf
595
+ m_SILValue (Other),
596
+ m_BitCast (m_IntegerLiteralInst (builtinArg))),
597
+ m_BuiltinInst (BuiltinValueKind::FCMP_UGE, // x >= Inf
598
+ m_SILValue (Other),
599
+ m_BitCast (m_IntegerLiteralInst (builtinArg))),
600
+ m_BuiltinInst (BuiltinValueKind::FCMP_ULT, // Inf < x
601
+ m_BitCast (m_IntegerLiteralInst (builtinArg)),
602
+ m_SILValue (Other)),
603
+ m_BuiltinInst (BuiltinValueKind::FCMP_ULE, // Inf <= x
604
+ m_BitCast (m_IntegerLiteralInst (builtinArg)),
605
+ m_SILValue (Other))))) {
606
+ APInt val = builtinArg->getValue ();
607
+ if (hasIEEEFloatPosInfBitRepr (val)) {
608
+ SILBuilderWithScope B (BI);
609
+ return B.createIntegerLiteral (BI->getLoc (), BI->getType (), APInt (1 , 0 ));
610
+ }
611
+ }
612
+
613
+ // Everything is less than or equal to (but not necessarily less than) MAX
614
+ // float
615
+ FloatLiteralInst *max;
616
+ if (match (BI,
617
+ m_CombineOr (
618
+ m_BuiltinInst (BuiltinValueKind::FCMP_OGE, // MAX >= x
619
+ m_FloatLiteralInst (max), m_SILValue (Other)),
620
+ m_BuiltinInst (BuiltinValueKind::FCMP_OLE, // x <= MAX
621
+ m_SILValue (Other), m_FloatLiteralInst (max)),
622
+ m_BuiltinInst (BuiltinValueKind::FCMP_UGE, // MAX >= x
623
+ m_FloatLiteralInst (max), m_SILValue (Other)),
624
+ m_BuiltinInst (BuiltinValueKind::FCMP_ULE, // x <= MAX
625
+ m_SILValue (Other), m_FloatLiteralInst (max)))) &&
626
+ max->getValue ().isLargest ()) {
627
+ SILBuilderWithScope B (BI);
628
+ return B.createIntegerLiteral (BI->getLoc (), BI->getType (), APInt (1 , 1 ));
629
+ }
630
+
631
+ return nullptr ;
632
+ }
633
+
634
+ static SILValue constantFoldCompareInt (BuiltinInst *BI, BuiltinValueKind ID) {
355
635
OperandValueArrayRef Args = BI->getArguments ();
356
636
357
637
// Fold for integer constant arguments.
358
638
auto *LHS = dyn_cast<IntegerLiteralInst>(Args[0 ]);
359
639
auto *RHS = dyn_cast<IntegerLiteralInst>(Args[1 ]);
360
640
if (LHS && RHS) {
361
- APInt Res = constantFoldComparison (LHS->getValue (), RHS->getValue (), ID);
641
+ APInt Res = constantFoldComparisonInt (LHS->getValue (), RHS->getValue (), ID);
362
642
SILBuilderWithScope B (BI);
363
643
return B.createIntegerLiteral (BI->getLoc (), BI->getType (), Res);
364
644
}
@@ -480,6 +760,17 @@ static SILValue constantFoldCompare(BuiltinInst *BI, BuiltinValueKind ID) {
480
760
return nullptr ;
481
761
}
482
762
763
+ static SILValue constantFoldCompare (BuiltinInst *BI, BuiltinValueKind ID) {
764
+ // Try folding integer comparison
765
+ if (auto result = constantFoldCompareInt (BI, ID))
766
+ return result;
767
+ // Try folding floating point comparison
768
+ if (auto result = constantFoldCompareFloat (BI, ID))
769
+ return result;
770
+ // Else, return nullptr
771
+ return nullptr ;
772
+ }
773
+
483
774
static SILValue
484
775
constantFoldAndCheckDivision (BuiltinInst *BI, BuiltinValueKind ID,
485
776
llvm::Optional<bool > &ResultsInError) {
@@ -1893,6 +2184,12 @@ ConstantFolder::processWorkList() {
1893
2184
}
1894
2185
}
1895
2186
2187
+ // If the user is a bitcast, we may be able to constant
2188
+ // fold its users.
2189
+ if (isApplyOfBuiltin (*User, BuiltinValueKind::BitCast)) {
2190
+ WorkList.insert (User);
2191
+ }
2192
+
1896
2193
// Initialize ResultsInError as a None optional.
1897
2194
//
1898
2195
// We are essentially using this optional to represent 3 states: true,
0 commit comments