Skip to content

Commit 29aed2c

Browse files
committed
[InstCombine] Detect uadd with overflow idiom
Change processUMulZExtIdiom to also support adds, since the idiom is the same, except with add instead of mul. Alive2: https://alive2.llvm.org/ce/z/SsB4AK
1 parent 2f3f0f6 commit 29aed2c

File tree

3 files changed

+78
-80
lines changed

3 files changed

+78
-80
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp

Lines changed: 70 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -6515,72 +6515,75 @@ bool InstCombinerImpl::OptimizeOverflowCheck(Instruction::BinaryOps BinaryOp,
65156515
llvm_unreachable("Unexpected overflow result");
65166516
}
65176517

6518-
/// Recognize and process idiom involving test for multiplication
6518+
/// Recognize and process idiom involving test for unsigned
65196519
/// overflow.
65206520
///
65216521
/// The caller has matched a pattern of the form:
6522+
/// I = cmp u (add(zext A, zext B), V
65226523
/// I = cmp u (mul(zext A, zext B), V
65236524
/// The function checks if this is a test for overflow and if so replaces
6524-
/// multiplication with call to 'mul.with.overflow' intrinsic.
6525+
/// addition with call to the right intrinsic.
65256526
///
65266527
/// \param I Compare instruction.
6527-
/// \param MulVal Result of 'mult' instruction. It is one of the arguments of
6528+
/// \param Val Result of instruction. It is one of the arguments of
65286529
/// the compare instruction. Must be of integer type.
65296530
/// \param OtherVal The other argument of compare instruction.
65306531
/// \returns Instruction which must replace the compare instruction, NULL if no
65316532
/// replacement required.
6532-
static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
6533-
const APInt *OtherVal,
6534-
InstCombinerImpl &IC) {
6533+
static Instruction *processUZExtIdiom(ICmpInst &I, Value *Val,
6534+
const APInt *OtherVal,
6535+
InstCombinerImpl &IC) {
65356536
// Don't bother doing this transformation for pointers, don't do it for
65366537
// vectors.
6537-
if (!isa<IntegerType>(MulVal->getType()))
6538+
if (!isa<IntegerType>(Val->getType()))
65386539
return nullptr;
65396540

6540-
auto *MulInstr = dyn_cast<Instruction>(MulVal);
6541-
if (!MulInstr)
6541+
auto *Instr = dyn_cast<Instruction>(Val);
6542+
if (!Instr)
65426543
return nullptr;
6543-
assert(MulInstr->getOpcode() == Instruction::Mul);
65446544

6545-
auto *LHS = cast<ZExtInst>(MulInstr->getOperand(0)),
6546-
*RHS = cast<ZExtInst>(MulInstr->getOperand(1));
6545+
unsigned Opcode = Instr->getOpcode();
6546+
assert(Opcode == Instruction::Add || Opcode == Instruction::Mul);
6547+
6548+
auto *LHS = cast<ZExtInst>(Instr->getOperand(0)),
6549+
*RHS = cast<ZExtInst>(Instr->getOperand(1));
65476550
assert(LHS->getOpcode() == Instruction::ZExt);
65486551
assert(RHS->getOpcode() == Instruction::ZExt);
65496552
Value *A = LHS->getOperand(0), *B = RHS->getOperand(0);
65506553

6551-
// Calculate type and width of the result produced by mul.with.overflow.
6554+
// Calculate type and width of the result produced by add/mul.with.overflow.
65526555
Type *TyA = A->getType(), *TyB = B->getType();
65536556
unsigned WidthA = TyA->getPrimitiveSizeInBits(),
65546557
WidthB = TyB->getPrimitiveSizeInBits();
6555-
unsigned MulWidth;
6556-
Type *MulType;
6558+
unsigned ResultWidth;
6559+
Type *ResultType;
65576560
if (WidthB > WidthA) {
6558-
MulWidth = WidthB;
6559-
MulType = TyB;
6561+
ResultWidth = WidthB;
6562+
ResultType = TyB;
65606563
} else {
6561-
MulWidth = WidthA;
6562-
MulType = TyA;
6564+
ResultWidth = WidthA;
6565+
ResultType = TyA;
65636566
}
65646567

6565-
// In order to replace the original mul with a narrower mul.with.overflow,
6566-
// all uses must ignore upper bits of the product. The number of used low
6567-
// bits must be not greater than the width of mul.with.overflow.
6568-
if (MulVal->hasNUsesOrMore(2))
6569-
for (User *U : MulVal->users()) {
6568+
// In order to replace the original result with an add/mul.with.overflow
6569+
// intrinsic, all uses must ignore upper bits of the result. The number of
6570+
// used low bits must be not greater than the width of add/mul.with.overflow.
6571+
if (Val->hasNUsesOrMore(2))
6572+
for (User *U : Val->users()) {
65706573
if (U == &I)
65716574
continue;
65726575
if (TruncInst *TI = dyn_cast<TruncInst>(U)) {
6573-
// Check if truncation ignores bits above MulWidth.
6576+
// Check if truncation ignores bits above ResultWidth.
65746577
unsigned TruncWidth = TI->getType()->getPrimitiveSizeInBits();
6575-
if (TruncWidth > MulWidth)
6578+
if (TruncWidth > ResultWidth)
65766579
return nullptr;
65776580
} else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) {
6578-
// Check if AND ignores bits above MulWidth.
6581+
// Check if AND ignores bits above ResultWidth.
65796582
if (BO->getOpcode() != Instruction::And)
65806583
return nullptr;
65816584
if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(1))) {
65826585
const APInt &CVal = CI->getValue();
6583-
if (CVal.getBitWidth() - CVal.countl_zero() > MulWidth)
6586+
if (CVal.getBitWidth() - CVal.countl_zero() > ResultWidth)
65846587
return nullptr;
65856588
} else {
65866589
// In this case we could have the operand of the binary operation
@@ -6598,9 +6601,9 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
65986601
switch (I.getPredicate()) {
65996602
case ICmpInst::ICMP_UGT: {
66006603
// Recognize pattern:
6601-
// mulval = mul(zext A, zext B)
6602-
// cmp ugt mulval, max
6603-
APInt MaxVal = APInt::getMaxValue(MulWidth);
6604+
// val = add/mul(zext A, zext B)
6605+
// cmp ugt val, max
6606+
APInt MaxVal = APInt::getMaxValue(ResultWidth);
66046607
MaxVal = MaxVal.zext(OtherVal->getBitWidth());
66056608
if (MaxVal.eq(*OtherVal))
66066609
break; // Recognized
@@ -6609,9 +6612,9 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
66096612

66106613
case ICmpInst::ICMP_ULT: {
66116614
// Recognize pattern:
6612-
// mulval = mul(zext A, zext B)
6613-
// cmp ule mulval, max + 1
6614-
APInt MaxVal = APInt::getOneBitSet(OtherVal->getBitWidth(), MulWidth);
6615+
// val = add/mul(zext A, zext B)
6616+
// cmp ule val, max + 1
6617+
APInt MaxVal = APInt::getOneBitSet(OtherVal->getBitWidth(), ResultWidth);
66156618
if (MaxVal.eq(*OtherVal))
66166619
break; // Recognized
66176620
return nullptr;
@@ -6622,38 +6625,42 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
66226625
}
66236626

66246627
InstCombiner::BuilderTy &Builder = IC.Builder;
6625-
Builder.SetInsertPoint(MulInstr);
6626-
6627-
// Replace: mul(zext A, zext B) --> mul.with.overflow(A, B)
6628-
Value *MulA = A, *MulB = B;
6629-
if (WidthA < MulWidth)
6630-
MulA = Builder.CreateZExt(A, MulType);
6631-
if (WidthB < MulWidth)
6632-
MulB = Builder.CreateZExt(B, MulType);
6633-
CallInst *Call =
6634-
Builder.CreateIntrinsic(Intrinsic::umul_with_overflow, MulType,
6635-
{MulA, MulB}, /*FMFSource=*/nullptr, "umul");
6636-
IC.addToWorklist(MulInstr);
6637-
6638-
// If there are uses of mul result other than the comparison, we know that
6628+
Builder.SetInsertPoint(Instr);
6629+
6630+
// Replace: add/mul(zext A, zext B) --> add/mul.with.overflow(A, B)
6631+
Value *ResultA = A, *ResultB = B;
6632+
if (WidthA < ResultWidth)
6633+
ResultA = Builder.CreateZExt(A, ResultType);
6634+
if (WidthB < ResultWidth)
6635+
ResultB = Builder.CreateZExt(B, ResultType);
6636+
CallInst *Call = Builder.CreateIntrinsic(
6637+
Opcode == Instruction::Add ? Intrinsic::uadd_with_overflow
6638+
: Intrinsic::umul_with_overflow,
6639+
ResultType, {ResultA, ResultB}, /*FMFSource=*/nullptr,
6640+
Intrinsic::uadd_with_overflow ? "uadd" : "umul");
6641+
IC.addToWorklist(Instr);
6642+
6643+
// If there are uses of add result other than the comparison, we know that
66396644
// they are truncation or binary AND. Change them to use result of
6640-
// mul.with.overflow and adjust properly mask/size.
6641-
if (MulVal->hasNUsesOrMore(2)) {
6642-
Value *Mul = Builder.CreateExtractValue(Call, 0, "umul.value");
6643-
for (User *U : make_early_inc_range(MulVal->users())) {
6645+
// add/mul.with.overflow and adjust properly mask/size.
6646+
if (Val->hasNUsesOrMore(2)) {
6647+
Value *Extract = Builder.CreateExtractValue(
6648+
Call, 0, Instruction::Add ? "uadd.value" : "umul.value");
6649+
for (User *U : make_early_inc_range(Val->users())) {
66446650
if (U == &I)
66456651
continue;
66466652
if (TruncInst *TI = dyn_cast<TruncInst>(U)) {
6647-
if (TI->getType()->getPrimitiveSizeInBits() == MulWidth)
6648-
IC.replaceInstUsesWith(*TI, Mul);
6653+
if (TI->getType()->getPrimitiveSizeInBits() == ResultWidth)
6654+
IC.replaceInstUsesWith(*TI, Extract);
66496655
else
6650-
TI->setOperand(0, Mul);
6656+
TI->setOperand(0, Extract);
66516657
} else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) {
66526658
assert(BO->getOpcode() == Instruction::And);
6653-
// Replace (mul & mask) --> zext (mul.with.overflow & short_mask)
6659+
// Replace (Extract & mask) --> zext (add/mul.with.overflow &
6660+
// short_mask)
66546661
ConstantInt *CI = cast<ConstantInt>(BO->getOperand(1));
6655-
APInt ShortMask = CI->getValue().trunc(MulWidth);
6656-
Value *ShortAnd = Builder.CreateAnd(Mul, ShortMask);
6662+
APInt ShortMask = CI->getValue().trunc(ResultWidth);
6663+
Value *ShortAnd = Builder.CreateAnd(Extract, ShortMask);
66576664
Value *Zext = Builder.CreateZExt(ShortAnd, BO->getType());
66586665
IC.replaceInstUsesWith(*BO, Zext);
66596666
} else {
@@ -7078,7 +7085,7 @@ Instruction *InstCombinerImpl::foldICmpUsingBoolRange(ICmpInst &I) {
70787085
// icmp eq X, (zext (icmp ne X, 0)) --> X == 0 || X == 1
70797086
// icmp ne X, (zext (icmp ne X, 0)) --> X != 0 && X != 1
70807087
// icmp eq X, (sext (icmp ne X, 0)) --> X == 0 || X == -1
7081-
// icmp ne X, (sext (icmp ne X, 0)) --> X != 0 && X == -1
7088+
// icmp ne X, (sext (icmp ne X, 0)) --> X != 0 && X != -1
70827089
return CreateRangeCheck();
70837090
}
70847091
} else if (IsSExt ? C->isAllOnes() : C->isOne()) {
@@ -7791,10 +7798,12 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
77917798
}
77927799
}
77937800

7801+
// (zext X) + (zext Y) --> llvm.uadd.with.overflow.
77947802
// (zext X) * (zext Y) --> llvm.umul.with.overflow.
7795-
if (match(Op0, m_NUWMul(m_ZExt(m_Value(X)), m_ZExt(m_Value(Y)))) &&
7803+
if ((match(Op0, m_NUWAdd(m_ZExt(m_Value(X)), m_ZExt(m_Value(Y)))) ||
7804+
match(Op0, m_NUWMul(m_ZExt(m_Value(X)), m_ZExt(m_Value(Y))))) &&
77967805
match(Op1, m_APInt(C))) {
7797-
if (Instruction *R = processUMulZExtIdiom(I, Op0, C, *this))
7806+
if (Instruction *R = processUZExtIdiom(I, Op0, C, *this))
77987807
return R;
77997808
}
78007809

llvm/test/Transforms/InstCombine/saturating-add-sub.ll

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2352,11 +2352,7 @@ define i8 @fold_add_umax_to_usub_multiuse(i8 %a) {
23522352

23532353
define i32 @uadd_with_zext(i32 %x, i32 %y) {
23542354
; CHECK-LABEL: @uadd_with_zext(
2355-
; CHECK-NEXT: [[CONV:%.*]] = zext i32 [[X:%.*]] to i64
2356-
; CHECK-NEXT: [[CONV1:%.*]] = zext i32 [[Y:%.*]] to i64
2357-
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i64 [[CONV]], [[CONV1]]
2358-
; CHECK-NEXT: [[COND1:%.*]] = call i64 @llvm.umin.i64(i64 [[ADD]], i64 4294967295)
2359-
; CHECK-NEXT: [[COND:%.*]] = trunc nuw i64 [[COND1]] to i32
2355+
; CHECK-NEXT: [[COND:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[X:%.*]], i32 [[Y:%.*]])
23602356
; CHECK-NEXT: ret i32 [[COND]]
23612357
;
23622358
%conv = zext i32 %x to i64
@@ -2370,13 +2366,9 @@ define i32 @uadd_with_zext(i32 %x, i32 %y) {
23702366

23712367
define i32 @uadd_with_zext_multi_use(i32 %x, i32 %y) {
23722368
; CHECK-LABEL: @uadd_with_zext_multi_use(
2373-
; CHECK-NEXT: [[CONV:%.*]] = zext i32 [[X:%.*]] to i64
2374-
; CHECK-NEXT: [[CONV1:%.*]] = zext i32 [[Y:%.*]] to i64
2375-
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i64 [[CONV]], [[CONV1]]
2376-
; CHECK-NEXT: [[TRUNCADD:%.*]] = trunc i64 [[ADD]] to i32
2369+
; CHECK-NEXT: [[TRUNCADD:%.*]] = add i32 [[X:%.*]], [[Y:%.*]]
23772370
; CHECK-NEXT: call void @usei32(i32 [[TRUNCADD]])
2378-
; CHECK-NEXT: [[COND1:%.*]] = call i64 @llvm.umin.i64(i64 [[ADD]], i64 4294967295)
2379-
; CHECK-NEXT: [[COND:%.*]] = trunc nuw i64 [[COND1]] to i32
2371+
; CHECK-NEXT: [[COND:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[X]], i32 [[Y]])
23802372
; CHECK-NEXT: ret i32 [[COND]]
23812373
;
23822374
%conv = zext i32 %x to i64

llvm/test/Transforms/InstCombine/uadd-with-overflow.ll

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,8 @@ define { <2 x i32>, <2 x i1> } @fold_simple_splat_constant_with_or_fail(<2 x i32
150150

151151
define i32 @uadd_with_zext(i32 %x, i32 %y) {
152152
; CHECK-LABEL: @uadd_with_zext(
153-
; CHECK-NEXT: [[CONV:%.*]] = zext i32 [[X:%.*]] to i64
154-
; CHECK-NEXT: [[CONV1:%.*]] = zext i32 [[Y:%.*]] to i64
155-
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i64 [[CONV]], [[CONV1]]
156-
; CHECK-NEXT: [[CMP:%.*]] = icmp samesign ugt i64 [[ADD]], 4294967295
153+
; CHECK-NEXT: [[UADD:%.*]] = call { i32, i1 } @llvm.uadd.with.overflow.i32(i32 [[X:%.*]], i32 [[Y:%.*]])
154+
; CHECK-NEXT: [[CMP:%.*]] = extractvalue { i32, i1 } [[UADD]], 1
157155
; CHECK-NEXT: [[COND:%.*]] = zext i1 [[CMP]] to i32
158156
; CHECK-NEXT: ret i32 [[COND]]
159157
;
@@ -167,10 +165,9 @@ define i32 @uadd_with_zext(i32 %x, i32 %y) {
167165

168166
define i32 @uadd_with_zext_inverse(i32 %x, i32 %y) {
169167
; CHECK-LABEL: @uadd_with_zext_inverse(
170-
; CHECK-NEXT: [[CONV:%.*]] = zext i32 [[X:%.*]] to i64
171-
; CHECK-NEXT: [[CONV1:%.*]] = zext i32 [[Y:%.*]] to i64
172-
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i64 [[CONV]], [[CONV1]]
173-
; CHECK-NEXT: [[CMP:%.*]] = icmp samesign ult i64 [[ADD]], 4294967296
168+
; CHECK-NEXT: [[UADD:%.*]] = call { i32, i1 } @llvm.uadd.with.overflow.i32(i32 [[X:%.*]], i32 [[Y:%.*]])
169+
; CHECK-NEXT: [[TMP1:%.*]] = extractvalue { i32, i1 } [[UADD]], 1
170+
; CHECK-NEXT: [[CMP:%.*]] = xor i1 [[TMP1]], true
174171
; CHECK-NEXT: [[COND:%.*]] = zext i1 [[CMP]] to i32
175172
; CHECK-NEXT: ret i32 [[COND]]
176173
;

0 commit comments

Comments
 (0)