Skip to content

[PatternMatch] Do not accept undef elements in m_AllOnes() and friends #88217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 5 additions & 30 deletions llvm/include/llvm/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ template <int64_t Val> inline constantint_match<Val> m_ConstantInt() {

/// This helper class is used to match constant scalars, vector splats,
/// and fixed width vectors that satisfy a specified predicate.
/// For fixed width vector constants, undefined elements are ignored.
/// For fixed width vector constants, poison elements are ignored.
template <typename Predicate, typename ConstantVal>
struct cstval_pred_ty : public Predicate {
template <typename ITy> bool match(ITy *V) {
Expand All @@ -364,19 +364,19 @@ struct cstval_pred_ty : public Predicate {
// Non-splat vector constant: check each element for a match.
unsigned NumElts = FVTy->getNumElements();
assert(NumElts != 0 && "Constant vector with no elements?");
bool HasNonUndefElements = false;
bool HasNonPoisonElements = false;
for (unsigned i = 0; i != NumElts; ++i) {
Constant *Elt = C->getAggregateElement(i);
if (!Elt)
return false;
if (isa<UndefValue>(Elt))
if (isa<PoisonValue>(Elt))
continue;
auto *CV = dyn_cast<ConstantVal>(Elt);
if (!CV || !this->isValue(CV->getValue()))
return false;
HasNonUndefElements = true;
HasNonPoisonElements = true;
}
return HasNonUndefElements;
return HasNonPoisonElements;
}
}
return false;
Expand Down Expand Up @@ -2587,31 +2587,6 @@ m_Not(const ValTy &V) {
return m_c_Xor(m_AllOnes(), V);
}

template <typename ValTy> struct NotForbidUndef_match {
ValTy Val;
NotForbidUndef_match(const ValTy &V) : Val(V) {}

template <typename OpTy> bool match(OpTy *V) {
// We do not use m_c_Xor because that could match an arbitrary APInt that is
// not -1 as C and then fail to match the other operand if it is -1.
// This code should still work even when both operands are constants.
Value *X;
const APInt *C;
if (m_Xor(m_Value(X), m_APIntForbidUndef(C)).match(V) && C->isAllOnes())
return Val.match(X);
if (m_Xor(m_APIntForbidUndef(C), m_Value(X)).match(V) && C->isAllOnes())
return Val.match(X);
return false;
}
};

/// Matches a bitwise 'not' as 'xor V, -1' or 'xor -1, V'. For vectors, the
/// constant value must be composed of only -1 scalar elements.
template <typename ValTy>
inline NotForbidUndef_match<ValTy> m_NotForbidUndef(const ValTy &V) {
return NotForbidUndef_match<ValTy>(V);
}

/// Matches an SMin with LHS and RHS in either order.
template <typename LHS, typename RHS>
inline MaxMin_match<ICmpInst, LHS, RHS, smin_pred_ty, true>
Expand Down
23 changes: 10 additions & 13 deletions llvm/lib/Analysis/InstructionSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1513,7 +1513,7 @@ static Value *simplifyAShrInst(Value *Op0, Value *Op1, bool IsExact,

// -1 >>a X --> -1
// (-1 << X) a>> X --> -1
// Do not return Op0 because it may contain undef elements if it's a vector.
// We could return the original -1 constant to preserve poison elements.
if (match(Op0, m_AllOnes()) ||
match(Op0, m_Shl(m_AllOnes(), m_Specific(Op1))))
return Constant::getAllOnesValue(Op0->getType());
Expand Down Expand Up @@ -2281,7 +2281,7 @@ static Value *simplifyOrLogic(Value *X, Value *Y) {
// (B ^ ~A) | (A & B) --> B ^ ~A
// (~A ^ B) | (B & A) --> ~A ^ B
// (B ^ ~A) | (B & A) --> B ^ ~A
if (match(X, m_c_Xor(m_NotForbidUndef(m_Value(A)), m_Value(B))) &&
if (match(X, m_c_Xor(m_Not(m_Value(A)), m_Value(B))) &&
match(Y, m_c_And(m_Specific(A), m_Specific(B))))
return X;

Expand All @@ -2298,31 +2298,29 @@ static Value *simplifyOrLogic(Value *X, Value *Y) {
// (B & ~A) | ~(A | B) --> ~A
// (B & ~A) | ~(B | A) --> ~A
Value *NotA;
if (match(X,
m_c_And(m_CombineAnd(m_Value(NotA), m_NotForbidUndef(m_Value(A))),
m_Value(B))) &&
if (match(X, m_c_And(m_CombineAnd(m_Value(NotA), m_Not(m_Value(A))),
m_Value(B))) &&
match(Y, m_Not(m_c_Or(m_Specific(A), m_Specific(B)))))
return NotA;
// The same is true of Logical And
// TODO: This could share the logic of the version above if there was a
// version of LogicalAnd that allowed more than just i1 types.
if (match(X, m_c_LogicalAnd(
m_CombineAnd(m_Value(NotA), m_NotForbidUndef(m_Value(A))),
m_Value(B))) &&
if (match(X, m_c_LogicalAnd(m_CombineAnd(m_Value(NotA), m_Not(m_Value(A))),
m_Value(B))) &&
match(Y, m_Not(m_c_LogicalOr(m_Specific(A), m_Specific(B)))))
return NotA;

// ~(A ^ B) | (A & B) --> ~(A ^ B)
// ~(A ^ B) | (B & A) --> ~(A ^ B)
Value *NotAB;
if (match(X, m_CombineAnd(m_NotForbidUndef(m_Xor(m_Value(A), m_Value(B))),
if (match(X, m_CombineAnd(m_Not(m_Xor(m_Value(A), m_Value(B))),
m_Value(NotAB))) &&
match(Y, m_c_And(m_Specific(A), m_Specific(B))))
return NotAB;

// ~(A & B) | (A ^ B) --> ~(A & B)
// ~(A & B) | (B ^ A) --> ~(A & B)
if (match(X, m_CombineAnd(m_NotForbidUndef(m_And(m_Value(A), m_Value(B))),
if (match(X, m_CombineAnd(m_Not(m_And(m_Value(A), m_Value(B))),
m_Value(NotAB))) &&
match(Y, m_c_Xor(m_Specific(A), m_Specific(B))))
return NotAB;
Expand Down Expand Up @@ -2552,9 +2550,8 @@ static Value *simplifyXorInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
// The 'not' op must contain a complete -1 operand (no undef elements for
// vector) for the transform to be safe.
Value *NotA;
if (match(X,
m_c_Or(m_CombineAnd(m_NotForbidUndef(m_Value(A)), m_Value(NotA)),
m_Value(B))) &&
if (match(X, m_c_Or(m_CombineAnd(m_Not(m_Value(A)), m_Value(NotA)),
m_Value(B))) &&
match(Y, m_c_And(m_Specific(A), m_Specific(B))))
return NotA;

Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/IR/Constants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ bool Constant::isElementWiseEqual(Value *Y) const {
Constant *C0 = ConstantExpr::getBitCast(const_cast<Constant *>(this), IntTy);
Constant *C1 = ConstantExpr::getBitCast(cast<Constant>(Y), IntTy);
Constant *CmpEq = ConstantExpr::getICmp(ICmpInst::ICMP_EQ, C0, C1);
return isa<UndefValue>(CmpEq) || match(CmpEq, m_One());
return isa<PoisonValue>(CmpEq) || match(CmpEq, m_One());
}

static bool
Expand Down
12 changes: 3 additions & 9 deletions llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2538,6 +2538,8 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
}
}

// and(shl(zext(X), Y), SignMask) -> and(sext(X), SignMask)
// where Y is a valid shift amount.
if (match(&I, m_And(m_OneUse(m_Shl(m_ZExt(m_Value(X)), m_Value(Y))),
m_SignMask())) &&
match(Y, m_SpecificInt_ICMP(
Expand All @@ -2546,15 +2548,7 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
Ty->getScalarSizeInBits() -
X->getType()->getScalarSizeInBits())))) {
auto *SExt = Builder.CreateSExt(X, Ty, X->getName() + ".signext");
auto *SanitizedSignMask = cast<Constant>(Op1);
// We must be careful with the undef elements of the sign bit mask, however:
// the mask elt can be undef iff the shift amount for that lane was undef,
// otherwise we need to sanitize undef masks to zero.
SanitizedSignMask = Constant::replaceUndefsWith(
SanitizedSignMask, ConstantInt::getNullValue(Ty->getScalarType()));
SanitizedSignMask =
Constant::mergeUndefsWith(SanitizedSignMask, cast<Constant>(Y));
return BinaryOperator::CreateAnd(SExt, SanitizedSignMask);
return BinaryOperator::CreateAnd(SExt, Op1);
}

if (Instruction *Z = narrowMaskedBinOp(I))
Expand Down
30 changes: 15 additions & 15 deletions llvm/test/Transforms/InstCombine/X86/x86-vector-shifts.ll
Original file line number Diff line number Diff line change
Expand Up @@ -2032,23 +2032,23 @@ define <4 x i64> @avx2_psrlv_q_256_allbig(<4 x i64> %v) {
ret <4 x i64> %1
}

; The shift amount is 0 (the undef lane could be 0), so we return the unshifted input.
; The shift amount is 0 (the poison lane could be 0), so we return the unshifted input.

define <2 x i64> @avx2_psrlv_q_128_undef(<2 x i64> %v) {
; CHECK-LABEL: @avx2_psrlv_q_128_undef(
define <2 x i64> @avx2_psrlv_q_128_poison(<2 x i64> %v) {
; CHECK-LABEL: @avx2_psrlv_q_128_poison(
; CHECK-NEXT: ret <2 x i64> [[V:%.*]]
;
%1 = insertelement <2 x i64> <i64 0, i64 8>, i64 undef, i64 1
%1 = insertelement <2 x i64> <i64 0, i64 8>, i64 poison, i64 1
%2 = tail call <2 x i64> @llvm.x86.avx2.psrlv.q(<2 x i64> %v, <2 x i64> %1)
ret <2 x i64> %2
}

define <4 x i64> @avx2_psrlv_q_256_undef(<4 x i64> %v) {
; CHECK-LABEL: @avx2_psrlv_q_256_undef(
; CHECK-NEXT: [[TMP1:%.*]] = lshr <4 x i64> [[V:%.*]], <i64 undef, i64 8, i64 16, i64 31>
define <4 x i64> @avx2_psrlv_q_256_poison(<4 x i64> %v) {
; CHECK-LABEL: @avx2_psrlv_q_256_poison(
; CHECK-NEXT: [[TMP1:%.*]] = lshr <4 x i64> [[V:%.*]], <i64 poison, i64 8, i64 16, i64 31>
; CHECK-NEXT: ret <4 x i64> [[TMP1]]
;
%1 = insertelement <4 x i64> <i64 0, i64 8, i64 16, i64 31>, i64 undef, i64 0
%1 = insertelement <4 x i64> <i64 0, i64 8, i64 16, i64 31>, i64 poison, i64 0
%2 = tail call <4 x i64> @llvm.x86.avx2.psrlv.q.256(<4 x i64> %v, <4 x i64> %1)
ret <4 x i64> %2
}
Expand Down Expand Up @@ -2435,21 +2435,21 @@ define <4 x i64> @avx2_psllv_q_256_allbig(<4 x i64> %v) {

; The shift amount is 0 (the undef lane could be 0), so we return the unshifted input.

define <2 x i64> @avx2_psllv_q_128_undef(<2 x i64> %v) {
; CHECK-LABEL: @avx2_psllv_q_128_undef(
define <2 x i64> @avx2_psllv_q_128_poison(<2 x i64> %v) {
; CHECK-LABEL: @avx2_psllv_q_128_poison(
; CHECK-NEXT: ret <2 x i64> [[V:%.*]]
;
%1 = insertelement <2 x i64> <i64 0, i64 8>, i64 undef, i64 1
%1 = insertelement <2 x i64> <i64 0, i64 8>, i64 poison, i64 1
%2 = tail call <2 x i64> @llvm.x86.avx2.psllv.q(<2 x i64> %v, <2 x i64> %1)
ret <2 x i64> %2
}

define <4 x i64> @avx2_psllv_q_256_undef(<4 x i64> %v) {
; CHECK-LABEL: @avx2_psllv_q_256_undef(
; CHECK-NEXT: [[TMP1:%.*]] = shl <4 x i64> [[V:%.*]], <i64 undef, i64 8, i64 16, i64 31>
define <4 x i64> @avx2_psllv_q_256_poison(<4 x i64> %v) {
; CHECK-LABEL: @avx2_psllv_q_256_poison(
; CHECK-NEXT: [[TMP1:%.*]] = shl <4 x i64> [[V:%.*]], <i64 poison, i64 8, i64 16, i64 31>
; CHECK-NEXT: ret <4 x i64> [[TMP1]]
;
%1 = insertelement <4 x i64> <i64 0, i64 8, i64 16, i64 31>, i64 undef, i64 0
%1 = insertelement <4 x i64> <i64 0, i64 8, i64 16, i64 31>, i64 poison, i64 0
%2 = tail call <4 x i64> @llvm.x86.avx2.psllv.q.256(<4 x i64> %v, <4 x i64> %1)
ret <4 x i64> %2
}
Expand Down
16 changes: 8 additions & 8 deletions llvm/test/Transforms/InstCombine/abs-1.ll
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,14 @@ define <2 x i8> @abs_canonical_2(<2 x i8> %x) {
ret <2 x i8> %abs
}

; Even if a constant has undef elements.
; Even if a constant has poison elements.

define <2 x i8> @abs_canonical_2_vec_undef_elts(<2 x i8> %x) {
; CHECK-LABEL: @abs_canonical_2_vec_undef_elts(
define <2 x i8> @abs_canonical_2_vec_poison_elts(<2 x i8> %x) {
; CHECK-LABEL: @abs_canonical_2_vec_poison_elts(
; CHECK-NEXT: [[ABS:%.*]] = call <2 x i8> @llvm.abs.v2i8(<2 x i8> [[X:%.*]], i1 false)
; CHECK-NEXT: ret <2 x i8> [[ABS]]
;
%cmp = icmp sgt <2 x i8> %x, <i8 undef, i8 -1>
%cmp = icmp sgt <2 x i8> %x, <i8 poison, i8 -1>
%neg = sub <2 x i8> zeroinitializer, %x
%abs = select <2 x i1> %cmp, <2 x i8> %x, <2 x i8> %neg
ret <2 x i8> %abs
Expand Down Expand Up @@ -208,15 +208,15 @@ define <2 x i8> @nabs_canonical_2(<2 x i8> %x) {
ret <2 x i8> %abs
}

; Even if a constant has undef elements.
; Even if a constant has poison elements.

define <2 x i8> @nabs_canonical_2_vec_undef_elts(<2 x i8> %x) {
; CHECK-LABEL: @nabs_canonical_2_vec_undef_elts(
define <2 x i8> @nabs_canonical_2_vec_poison_elts(<2 x i8> %x) {
; CHECK-LABEL: @nabs_canonical_2_vec_poison_elts(
; CHECK-NEXT: [[TMP1:%.*]] = call <2 x i8> @llvm.abs.v2i8(<2 x i8> [[X:%.*]], i1 false)
; CHECK-NEXT: [[ABS:%.*]] = sub <2 x i8> zeroinitializer, [[TMP1]]
; CHECK-NEXT: ret <2 x i8> [[ABS]]
;
%cmp = icmp sgt <2 x i8> %x, <i8 -1, i8 undef>
%cmp = icmp sgt <2 x i8> %x, <i8 -1, i8 poison>
%neg = sub <2 x i8> zeroinitializer, %x
%abs = select <2 x i1> %cmp, <2 x i8> %neg, <2 x i8> %x
ret <2 x i8> %abs
Expand Down
6 changes: 3 additions & 3 deletions llvm/test/Transforms/InstCombine/add-mask-neg.ll
Original file line number Diff line number Diff line change
Expand Up @@ -89,16 +89,16 @@ define <2 x i32> @dec_mask_neg_v2i32(<2 x i32> %X) {
ret <2 x i32> %dec
}

define <2 x i32> @dec_mask_neg_v2i32_undef(<2 x i32> %X) {
; CHECK-LABEL: @dec_mask_neg_v2i32_undef(
define <2 x i32> @dec_mask_neg_v2i32_poison(<2 x i32> %X) {
; CHECK-LABEL: @dec_mask_neg_v2i32_poison(
; CHECK-NEXT: [[TMP1:%.*]] = add <2 x i32> [[X:%.*]], <i32 -1, i32 -1>
; CHECK-NEXT: [[TMP2:%.*]] = xor <2 x i32> [[X]], <i32 -1, i32 -1>
; CHECK-NEXT: [[DEC:%.*]] = and <2 x i32> [[TMP1]], [[TMP2]]
; CHECK-NEXT: ret <2 x i32> [[DEC]]
;
%neg = sub <2 x i32> zeroinitializer, %X
%mask = and <2 x i32> %neg, %X
%dec = add <2 x i32> %mask, <i32 -1, i32 undef>
%dec = add <2 x i32> %mask, <i32 -1, i32 poison>
ret <2 x i32> %dec
}

Expand Down
28 changes: 14 additions & 14 deletions llvm/test/Transforms/InstCombine/add.ll
Original file line number Diff line number Diff line change
Expand Up @@ -150,24 +150,24 @@ define i32 @test5_add_nsw(i32 %A, i32 %B) {
ret i32 %D
}

define <2 x i8> @neg_op0_vec_undef_elt(<2 x i8> %a, <2 x i8> %b) {
; CHECK-LABEL: @neg_op0_vec_undef_elt(
define <2 x i8> @neg_op0_vec_poison_elt(<2 x i8> %a, <2 x i8> %b) {
; CHECK-LABEL: @neg_op0_vec_poison_elt(
; CHECK-NEXT: [[R:%.*]] = sub <2 x i8> [[B:%.*]], [[A:%.*]]
; CHECK-NEXT: ret <2 x i8> [[R]]
;
%nega = sub <2 x i8> <i8 0, i8 undef>, %a
%nega = sub <2 x i8> <i8 0, i8 poison>, %a
%r = add <2 x i8> %nega, %b
ret <2 x i8> %r
}

define <2 x i8> @neg_neg_vec_undef_elt(<2 x i8> %a, <2 x i8> %b) {
; CHECK-LABEL: @neg_neg_vec_undef_elt(
define <2 x i8> @neg_neg_vec_poison_elt(<2 x i8> %a, <2 x i8> %b) {
; CHECK-LABEL: @neg_neg_vec_poison_elt(
; CHECK-NEXT: [[TMP1:%.*]] = add <2 x i8> [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[R:%.*]] = sub <2 x i8> zeroinitializer, [[TMP1]]
; CHECK-NEXT: ret <2 x i8> [[R]]
;
%nega = sub <2 x i8> <i8 undef, i8 0>, %a
%negb = sub <2 x i8> <i8 undef, i8 0>, %b
%nega = sub <2 x i8> <i8 poison, i8 0>, %a
%negb = sub <2 x i8> <i8 poison, i8 0>, %b
%r = add <2 x i8> %nega, %negb
ret <2 x i8> %r
}
Expand Down Expand Up @@ -1196,14 +1196,14 @@ define <2 x i32> @test44_vec_non_matching(<2 x i32> %A) {
ret <2 x i32> %C
}

define <2 x i32> @test44_vec_undef(<2 x i32> %A) {
; CHECK-LABEL: @test44_vec_undef(
; CHECK-NEXT: [[B:%.*]] = or <2 x i32> [[A:%.*]], <i32 123, i32 undef>
; CHECK-NEXT: [[C:%.*]] = add <2 x i32> [[B]], <i32 -123, i32 undef>
define <2 x i32> @test44_vec_poison(<2 x i32> %A) {
; CHECK-LABEL: @test44_vec_poison(
; CHECK-NEXT: [[B:%.*]] = or <2 x i32> [[A:%.*]], <i32 123, i32 poison>
; CHECK-NEXT: [[C:%.*]] = add nsw <2 x i32> [[B]], <i32 -123, i32 poison>
; CHECK-NEXT: ret <2 x i32> [[C]]
;
%B = or <2 x i32> %A, <i32 123, i32 undef>
%C = add <2 x i32> %B, <i32 -123, i32 undef>
%B = or <2 x i32> %A, <i32 123, i32 poison>
%C = add <2 x i32> %B, <i32 -123, i32 poison>
ret <2 x i32> %C
}

Expand Down Expand Up @@ -2983,7 +2983,7 @@ define i8 @signum_i8_i8_use3(i8 %x) {
ret i8 %r
}

; poison/undef is ok to propagate in shift amount
; poison is ok to propagate in shift amount
; complexity canonicalization guarantees that shift is op0 of add

define <2 x i5> @signum_v2i5_v2i5(<2 x i5> %x) {
Expand Down
Loading