-
Notifications
You must be signed in to change notification settings - Fork 13k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[X86] combinePMULH
- combine mulhu
+ srl
#132548
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-backend-x86 Author: Abhishek Kaushik (abhishek-kaushik22) ChangesFixes #132166 Full diff: https://github.com/llvm/llvm-project/pull/132548.diff 2 Files Affected:
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 02398923ebc90..ec0af8d53b76e 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -54021,7 +54021,7 @@ static SDValue combineTruncatedArithmetic(SDNode *N, SelectionDAG &DAG,
}
// Try to form a MULHU or MULHS node by looking for
-// (trunc (srl (mul ext, ext), 16))
+// (trunc (srl (mul ext, ext), >= 16))
// TODO: This is X86 specific because we want to be able to handle wide types
// before type legalization. But we can only do it if the vector will be
// legalized via widening/splitting. Type legalization can't handle promotion
@@ -54046,10 +54046,16 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
// First instruction should be a right shift by 16 of a multiply.
SDValue LHS, RHS;
+ APInt ShiftAmt;
if (!sd_match(Src,
- m_Srl(m_Mul(m_Value(LHS), m_Value(RHS)), m_SpecificInt(16))))
+ m_Srl(m_Mul(m_Value(LHS), m_Value(RHS)), m_ConstInt(ShiftAmt))))
+ return SDValue();
+
+ if (ShiftAmt.ult(16))
return SDValue();
+ APInt AdditionalShift = (ShiftAmt - 16).trunc(16);
+
// Count leading sign/zero bits on both inputs - if there are enough then
// truncation back to vXi16 will be cheap - either as a pack/shuffle
// sequence or using AVX512 truncations. If the inputs are sext/zext then the
@@ -54087,7 +54093,9 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
InVT.getSizeInBits() / 16);
SDValue Res = DAG.getNode(ISD::MULHU, DL, BCVT, DAG.getBitcast(BCVT, LHS),
DAG.getBitcast(BCVT, RHS));
- return DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getBitcast(InVT, Res));
+ Res = DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getBitcast(InVT, Res));
+ return DAG.getNode(ISD::SRL, DL, VT, Res,
+ DAG.getConstant(AdditionalShift, DL, VT));
}
// Truncate back to source type.
@@ -54095,7 +54103,9 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
RHS = DAG.getNode(ISD::TRUNCATE, DL, VT, RHS);
unsigned Opc = IsSigned ? ISD::MULHS : ISD::MULHU;
- return DAG.getNode(Opc, DL, VT, LHS, RHS);
+ SDValue Res = DAG.getNode(Opc, DL, VT, LHS, RHS);
+ return DAG.getNode(ISD::SRL, DL, VT, Res,
+ DAG.getConstant(AdditionalShift, DL, VT));
}
// Attempt to match PMADDUBSW, which multiplies corresponding unsigned bytes
diff --git a/llvm/test/CodeGen/X86/pmulh.ll b/llvm/test/CodeGen/X86/pmulh.ll
index 300da68d9a3b3..8ecc3c1575367 100644
--- a/llvm/test/CodeGen/X86/pmulh.ll
+++ b/llvm/test/CodeGen/X86/pmulh.ll
@@ -2166,3 +2166,23 @@ define <8 x i16> @sse2_pmulhu_w_const(<8 x i16> %a0, <8 x i16> %a1) {
}
declare <8 x i16> @llvm.x86.sse2.pmulhu.w(<8 x i16>, <8 x i16>)
+define <8 x i16> @mul_and_shift17(<8 x i16> %a, <8 x i16> %b) {
+; SSE-LABEL: mul_and_shift17:
+; SSE: # %bb.0:
+; SSE-NEXT: pmulhuw %xmm1, %xmm0
+; SSE-NEXT: psrlw $1, %xmm0
+; SSE-NEXT: retq
+;
+; AVX-LABEL: mul_and_shift17:
+; AVX: # %bb.0:
+; AVX-NEXT: vpmulhuw %xmm1, %xmm0, %xmm0
+; AVX-NEXT: vpsrlw $1, %xmm0, %xmm0
+; AVX-NEXT: retq
+ %a.ext = zext <8 x i16> %a to <8 x i32>
+ %b.ext = zext <8 x i16> %b to <8 x i32>
+ %mul = mul <8 x i32> %a.ext, %b.ext
+ %shift = lshr <8 x i32> %mul, splat(i32 17)
+ %trunc = trunc <8 x i32> %shift to <8 x i16>
+ ret <8 x i16> %trunc
+}
+
|
m_Srl(m_Mul(m_Value(LHS), m_Value(RHS)), m_ConstInt(ShiftAmt)))) | ||
return SDValue(); | ||
|
||
if (ShiftAmt.ult(16)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe worth checking for ShiftAmt.uge(InVT.getScalarSizeInBits())
as well?
return DAG.getNode(Opc, DL, VT, LHS, RHS); | ||
SDValue Res = DAG.getNode(Opc, DL, VT, LHS, RHS); | ||
return DAG.getNode(ISD::SRL, DL, VT, Res, | ||
DAG.getConstant(AdditionalShift, DL, VT)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should IsSigned be ISD::SRA?
Use getShiftAmountConstant
return DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getBitcast(InVT, Res)); | ||
Res = DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getBitcast(InVT, Res)); | ||
return DAG.getNode(ISD::SRL, DL, VT, Res, | ||
DAG.getConstant(AdditionalShift, DL, VT)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getShiftAmountConstant
return SDValue(); | ||
|
||
APInt AdditionalShift = (ShiftAmt - 16).trunc(16); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need to do some testing to see what happens with MULHS - probably limit AdditionalShift != 0 to just the IsUnsigned case for starters?
@@ -54087,15 +54093,19 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL, | |||
InVT.getSizeInBits() / 16); | |||
SDValue Res = DAG.getNode(ISD::MULHU, DL, BCVT, DAG.getBitcast(BCVT, LHS), | |||
DAG.getBitcast(BCVT, RHS)); | |||
return DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getBitcast(InVT, Res)); | |||
Res = DAG.getNode(ISD::TRUNCATE, DL, VT, DAG.getBitcast(InVT, Res)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this path tested?
Fixes #132166