@@ -6515,72 +6515,75 @@ bool InstCombinerImpl::OptimizeOverflowCheck(Instruction::BinaryOps BinaryOp,
65156515 llvm_unreachable (" Unexpected overflow result" );
65166516}
65176517
6518- // / Recognize and process idiom involving test for multiplication
6518+ // / Recognize and process idiom involving test for unsigned
65196519// / overflow.
65206520// /
65216521// / The caller has matched a pattern of the form:
6522+ // / I = cmp u (add(zext A, zext B), V
65226523// / I = cmp u (mul(zext A, zext B), V
65236524// / The function checks if this is a test for overflow and if so replaces
6524- // / multiplication with call to 'mul.with.overflow' intrinsic.
6525+ // / addition with call to the right intrinsic.
65256526// /
65266527// / \param I Compare instruction.
6527- // / \param MulVal Result of 'mult' instruction. It is one of the arguments of
6528+ // / \param Val Result of instruction. It is one of the arguments of
65286529// / the compare instruction. Must be of integer type.
65296530// / \param OtherVal The other argument of compare instruction.
65306531// / \returns Instruction which must replace the compare instruction, NULL if no
65316532// / replacement required.
6532- static Instruction *processUMulZExtIdiom (ICmpInst &I, Value *MulVal ,
6533- const APInt *OtherVal,
6534- InstCombinerImpl &IC) {
6533+ static Instruction *processUZExtIdiom (ICmpInst &I, Value *Val ,
6534+ const APInt *OtherVal,
6535+ InstCombinerImpl &IC) {
65356536 // Don't bother doing this transformation for pointers, don't do it for
65366537 // vectors.
6537- if (!isa<IntegerType>(MulVal ->getType ()))
6538+ if (!isa<IntegerType>(Val ->getType ()))
65386539 return nullptr ;
65396540
6540- auto *MulInstr = dyn_cast<Instruction>(MulVal );
6541- if (!MulInstr )
6541+ auto *Instr = dyn_cast<Instruction>(Val );
6542+ if (!Instr )
65426543 return nullptr ;
6543- assert (MulInstr->getOpcode () == Instruction::Mul);
65446544
6545- auto *LHS = cast<ZExtInst>(MulInstr->getOperand (0 )),
6546- *RHS = cast<ZExtInst>(MulInstr->getOperand (1 ));
6545+ unsigned Opcode = Instr->getOpcode ();
6546+ assert (Opcode == Instruction::Add || Opcode == Instruction::Mul);
6547+
6548+ auto *LHS = cast<ZExtInst>(Instr->getOperand (0 )),
6549+ *RHS = cast<ZExtInst>(Instr->getOperand (1 ));
65476550 assert (LHS->getOpcode () == Instruction::ZExt);
65486551 assert (RHS->getOpcode () == Instruction::ZExt);
65496552 Value *A = LHS->getOperand (0 ), *B = RHS->getOperand (0 );
65506553
6551- // Calculate type and width of the result produced by mul.with.overflow.
6554+ // Calculate type and width of the result produced by add/ mul.with.overflow.
65526555 Type *TyA = A->getType (), *TyB = B->getType ();
65536556 unsigned WidthA = TyA->getPrimitiveSizeInBits (),
65546557 WidthB = TyB->getPrimitiveSizeInBits ();
6555- unsigned MulWidth ;
6556- Type *MulType ;
6558+ unsigned ResultWidth ;
6559+ Type *ResultType ;
65576560 if (WidthB > WidthA) {
6558- MulWidth = WidthB;
6559- MulType = TyB;
6561+ ResultWidth = WidthB;
6562+ ResultType = TyB;
65606563 } else {
6561- MulWidth = WidthA;
6562- MulType = TyA;
6564+ ResultWidth = WidthA;
6565+ ResultType = TyA;
65636566 }
65646567
6565- // In order to replace the original mul with a narrower mul.with.overflow,
6566- // all uses must ignore upper bits of the product . The number of used low
6567- // bits must be not greater than the width of mul.with.overflow.
6568- if (MulVal ->hasNUsesOrMore (2 ))
6569- for (User *U : MulVal ->users ()) {
6568+ // In order to replace the original result with an add/ mul.with.overflow
6569+ // intrinsic, all uses must ignore upper bits of the result . The number of
6570+ // used low bits must be not greater than the width of add/ mul.with.overflow.
6571+ if (Val ->hasNUsesOrMore (2 ))
6572+ for (User *U : Val ->users ()) {
65706573 if (U == &I)
65716574 continue ;
65726575 if (TruncInst *TI = dyn_cast<TruncInst>(U)) {
6573- // Check if truncation ignores bits above MulWidth .
6576+ // Check if truncation ignores bits above ResultWidth .
65746577 unsigned TruncWidth = TI->getType ()->getPrimitiveSizeInBits ();
6575- if (TruncWidth > MulWidth )
6578+ if (TruncWidth > ResultWidth )
65766579 return nullptr ;
65776580 } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) {
6578- // Check if AND ignores bits above MulWidth .
6581+ // Check if AND ignores bits above ResultWidth .
65796582 if (BO->getOpcode () != Instruction::And)
65806583 return nullptr ;
65816584 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand (1 ))) {
65826585 const APInt &CVal = CI->getValue ();
6583- if (CVal.getBitWidth () - CVal.countl_zero () > MulWidth )
6586+ if (CVal.getBitWidth () - CVal.countl_zero () > ResultWidth )
65846587 return nullptr ;
65856588 } else {
65866589 // In this case we could have the operand of the binary operation
@@ -6598,9 +6601,9 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
65986601 switch (I.getPredicate ()) {
65996602 case ICmpInst::ICMP_UGT: {
66006603 // Recognize pattern:
6601- // mulval = mul(zext A, zext B)
6602- // cmp ugt mulval , max
6603- APInt MaxVal = APInt::getMaxValue (MulWidth );
6604+ // val = add/ mul(zext A, zext B)
6605+ // cmp ugt val , max
6606+ APInt MaxVal = APInt::getMaxValue (ResultWidth );
66046607 MaxVal = MaxVal.zext (OtherVal->getBitWidth ());
66056608 if (MaxVal.eq (*OtherVal))
66066609 break ; // Recognized
@@ -6609,9 +6612,9 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
66096612
66106613 case ICmpInst::ICMP_ULT: {
66116614 // Recognize pattern:
6612- // mulval = mul(zext A, zext B)
6613- // cmp ule mulval , max + 1
6614- APInt MaxVal = APInt::getOneBitSet (OtherVal->getBitWidth (), MulWidth );
6615+ // val = add/ mul(zext A, zext B)
6616+ // cmp ule val , max + 1
6617+ APInt MaxVal = APInt::getOneBitSet (OtherVal->getBitWidth (), ResultWidth );
66156618 if (MaxVal.eq (*OtherVal))
66166619 break ; // Recognized
66176620 return nullptr ;
@@ -6622,38 +6625,42 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
66226625 }
66236626
66246627 InstCombiner::BuilderTy &Builder = IC.Builder ;
6625- Builder.SetInsertPoint (MulInstr);
6626-
6627- // Replace: mul(zext A, zext B) --> mul.with.overflow(A, B)
6628- Value *MulA = A, *MulB = B;
6629- if (WidthA < MulWidth)
6630- MulA = Builder.CreateZExt (A, MulType);
6631- if (WidthB < MulWidth)
6632- MulB = Builder.CreateZExt (B, MulType);
6633- CallInst *Call =
6634- Builder.CreateIntrinsic (Intrinsic::umul_with_overflow, MulType,
6635- {MulA, MulB}, /* FMFSource=*/ nullptr , " umul" );
6636- IC.addToWorklist (MulInstr);
6637-
6638- // If there are uses of mul result other than the comparison, we know that
6628+ Builder.SetInsertPoint (Instr);
6629+
6630+ // Replace: add/mul(zext A, zext B) --> add/mul.with.overflow(A, B)
6631+ Value *ResultA = A, *ResultB = B;
6632+ if (WidthA < ResultWidth)
6633+ ResultA = Builder.CreateZExt (A, ResultType);
6634+ if (WidthB < ResultWidth)
6635+ ResultB = Builder.CreateZExt (B, ResultType);
6636+ CallInst *Call = Builder.CreateIntrinsic (
6637+ Opcode == Instruction::Add ? Intrinsic::uadd_with_overflow
6638+ : Intrinsic::umul_with_overflow,
6639+ ResultType, {ResultA, ResultB}, /* FMFSource=*/ nullptr ,
6640+ Intrinsic::uadd_with_overflow ? " uadd" : " umul" );
6641+ IC.addToWorklist (Instr);
6642+
6643+ // If there are uses of add result other than the comparison, we know that
66396644 // they are truncation or binary AND. Change them to use result of
6640- // mul.with.overflow and adjust properly mask/size.
6641- if (MulVal->hasNUsesOrMore (2 )) {
6642- Value *Mul = Builder.CreateExtractValue (Call, 0 , " umul.value" );
6643- for (User *U : make_early_inc_range (MulVal->users ())) {
6645+ // add/mul.with.overflow and adjust properly mask/size.
6646+ if (Val->hasNUsesOrMore (2 )) {
6647+ Value *Extract = Builder.CreateExtractValue (
6648+ Call, 0 , Instruction::Add ? " uadd.value" : " umul.value" );
6649+ for (User *U : make_early_inc_range (Val->users ())) {
66446650 if (U == &I)
66456651 continue ;
66466652 if (TruncInst *TI = dyn_cast<TruncInst>(U)) {
6647- if (TI->getType ()->getPrimitiveSizeInBits () == MulWidth )
6648- IC.replaceInstUsesWith (*TI, Mul );
6653+ if (TI->getType ()->getPrimitiveSizeInBits () == ResultWidth )
6654+ IC.replaceInstUsesWith (*TI, Extract );
66496655 else
6650- TI->setOperand (0 , Mul );
6656+ TI->setOperand (0 , Extract );
66516657 } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) {
66526658 assert (BO->getOpcode () == Instruction::And);
6653- // Replace (mul & mask) --> zext (mul.with.overflow & short_mask)
6659+ // Replace (Extract & mask) --> zext (add/mul.with.overflow &
6660+ // short_mask)
66546661 ConstantInt *CI = cast<ConstantInt>(BO->getOperand (1 ));
6655- APInt ShortMask = CI->getValue ().trunc (MulWidth );
6656- Value *ShortAnd = Builder.CreateAnd (Mul , ShortMask);
6662+ APInt ShortMask = CI->getValue ().trunc (ResultWidth );
6663+ Value *ShortAnd = Builder.CreateAnd (Extract , ShortMask);
66576664 Value *Zext = Builder.CreateZExt (ShortAnd, BO->getType ());
66586665 IC.replaceInstUsesWith (*BO, Zext);
66596666 } else {
@@ -7078,7 +7085,7 @@ Instruction *InstCombinerImpl::foldICmpUsingBoolRange(ICmpInst &I) {
70787085 // icmp eq X, (zext (icmp ne X, 0)) --> X == 0 || X == 1
70797086 // icmp ne X, (zext (icmp ne X, 0)) --> X != 0 && X != 1
70807087 // icmp eq X, (sext (icmp ne X, 0)) --> X == 0 || X == -1
7081- // icmp ne X, (sext (icmp ne X, 0)) --> X != 0 && X = = -1
7088+ // icmp ne X, (sext (icmp ne X, 0)) --> X != 0 && X ! = -1
70827089 return CreateRangeCheck ();
70837090 }
70847091 } else if (IsSExt ? C->isAllOnes () : C->isOne ()) {
@@ -7791,10 +7798,12 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
77917798 }
77927799 }
77937800
7801+ // (zext X) + (zext Y) --> llvm.uadd.with.overflow.
77947802 // (zext X) * (zext Y) --> llvm.umul.with.overflow.
7795- if (match (Op0, m_NUWMul (m_ZExt (m_Value (X)), m_ZExt (m_Value (Y)))) &&
7803+ if ((match (Op0, m_NUWAdd (m_ZExt (m_Value (X)), m_ZExt (m_Value (Y)))) ||
7804+ match (Op0, m_NUWMul (m_ZExt (m_Value (X)), m_ZExt (m_Value (Y))))) &&
77967805 match (Op1, m_APInt (C))) {
7797- if (Instruction *R = processUMulZExtIdiom (I, Op0, C, *this ))
7806+ if (Instruction *R = processUZExtIdiom (I, Op0, C, *this ))
77987807 return R;
77997808 }
78007809
0 commit comments