Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 107 additions & 2 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
ISD::SIGN_EXTEND_INREG, ISD::CONCAT_VECTORS,
ISD::EXTRACT_SUBVECTOR, ISD::INSERT_SUBVECTOR,
ISD::STORE, ISD::BUILD_VECTOR});
setTargetDAGCombine(ISD::SMIN);
setTargetDAGCombine(ISD::TRUNCATE);
setTargetDAGCombine(ISD::LOAD);

Expand Down Expand Up @@ -2392,6 +2393,15 @@ static bool isIntImmediate(const SDNode *N, uint64_t &Imm) {
return false;
}

bool isVectorizedBinOp(unsigned Opcode) {
switch (Opcode) {
case AArch64ISD::SQDMULH:
return true;
default:
return false;
}
}

// isOpcWithIntImmediate - This method tests to see if the node is a specific
// opcode and that it has a immediate integer right operand.
// If so Imm will receive the value.
Expand Down Expand Up @@ -20126,8 +20136,9 @@ static SDValue performConcatVectorsCombine(SDNode *N,
// size, combine into an binop of two contacts of the source vectors. eg:
// concat(uhadd(a,b), uhadd(c, d)) -> uhadd(concat(a, c), concat(b, d))
if (N->getNumOperands() == 2 && N0Opc == N1Opc && VT.is128BitVector() &&
DAG.getTargetLoweringInfo().isBinOp(N0Opc) && N0->hasOneUse() &&
N1->hasOneUse()) {
(DAG.getTargetLoweringInfo().isBinOp(N0Opc) ||
isVectorizedBinOp(N0Opc)) &&
N0->hasOneUse() && N1->hasOneUse()) {
SDValue N00 = N0->getOperand(0);
SDValue N01 = N0->getOperand(1);
SDValue N10 = N1->getOperand(0);
Expand Down Expand Up @@ -20986,6 +20997,98 @@ static SDValue performBuildVectorCombine(SDNode *N,
return SDValue();
}

// A special combine for the sqdmulh family of instructions.
// smin( sra ( mul( sext v0, sext v1 ) ), SHIFT_AMOUNT ),
// SATURATING_VAL ) can be reduced to sqdmulh(...)
static SDValue trySQDMULHCombine(SDNode *N, SelectionDAG &DAG) {

if (N->getOpcode() != ISD::SMIN)
return SDValue();

EVT DestVT = N->getValueType(0);

if (!DestVT.isVector() || DestVT.getScalarSizeInBits() > 64 ||
DestVT.isScalableVector())
return SDValue();

ConstantSDNode *Clamp = isConstOrConstSplat(N->getOperand(1));

if (!Clamp)
return SDValue();

MVT ScalarType;
unsigned ShiftAmt = 0;
switch (Clamp->getSExtValue()) {
case (1ULL << 15) - 1:
ScalarType = MVT::i16;
ShiftAmt = 16;
break;
case (1ULL << 31) - 1:
ScalarType = MVT::i32;
ShiftAmt = 32;
break;
default:
return SDValue();
}

SDValue Sra = N->getOperand(0);
if (Sra.getOpcode() != ISD::SRA || !Sra.hasOneUse())
return SDValue();

ConstantSDNode *RightShiftVec = isConstOrConstSplat(Sra.getOperand(1));
if (!RightShiftVec)
return SDValue();
unsigned SExtValue = RightShiftVec->getSExtValue();

if (SExtValue != (ShiftAmt - 1))
return SDValue();

SDValue Mul = Sra.getOperand(0);
if (Mul.getOpcode() != ISD::MUL)
return SDValue();

SDValue SExt0 = Mul.getOperand(0);
SDValue SExt1 = Mul.getOperand(1);

if (SExt0.getOpcode() != ISD::SIGN_EXTEND ||
SExt1.getOpcode() != ISD::SIGN_EXTEND)
return SDValue();

EVT SExt0Type = SExt0.getOperand(0).getValueType();
EVT SExt1Type = SExt1.getOperand(0).getValueType();

if (SExt0Type != SExt1Type || SExt0Type.getScalarType() != ScalarType ||
SExt0Type.getFixedSizeInBits() > 128 || !SExt0Type.isPow2VectorType() ||
SExt0Type.getVectorNumElements() == 1)
return SDValue();

SDLoc DL(N);
SDValue V0 = SExt0.getOperand(0);
SDValue V1 = SExt1.getOperand(0);

// Ensure input vectors are extended to legal types
if (SExt0Type.getFixedSizeInBits() < 64) {
unsigned VecNumElements = SExt0Type.getVectorNumElements();
EVT ExtVecVT = MVT::getVectorVT(MVT::getIntegerVT(64 / VecNumElements),
VecNumElements);
V0 = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtVecVT, V0);
V1 = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtVecVT, V1);
}

SDValue SQDMULH =
DAG.getNode(AArch64ISD::SQDMULH, DL, V0.getValueType(), V0, V1);

return DAG.getNode(ISD::SIGN_EXTEND, DL, DestVT, SQDMULH);
}

static SDValue performSMINCombine(SDNode *N, SelectionDAG &DAG) {
if (SDValue V = trySQDMULHCombine(N, DAG)) {
return V;
}

return SDValue();
}

static SDValue performTruncateCombine(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI) {
SDLoc DL(N);
Expand Down Expand Up @@ -26737,6 +26840,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
return performAddSubCombine(N, DCI);
case ISD::BUILD_VECTOR:
return performBuildVectorCombine(N, DCI, DAG);
case ISD::SMIN:
return performSMINCombine(N, DAG);
case ISD::TRUNCATE:
return performTruncateCombine(N, DAG, DCI);
case AArch64ISD::ANDS:
Expand Down
10 changes: 10 additions & 0 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -1022,6 +1022,7 @@ def AArch64smull : SDNode<"AArch64ISD::SMULL", SDT_AArch64mull,
[SDNPCommutative]>;
def AArch64umull : SDNode<"AArch64ISD::UMULL", SDT_AArch64mull,
[SDNPCommutative]>;
def AArch64sqdmulh : SDNode<"AArch64ISD::SQDMULH", SDT_AArch64mull>;

// Reciprocal estimates and steps.
def AArch64frecpe : SDNode<"AArch64ISD::FRECPE", SDTFPUnaryOp>;
Expand Down Expand Up @@ -9439,6 +9440,15 @@ def : Pat<(v4i32 (mulhu V128:$Rn, V128:$Rm)),
(EXTRACT_SUBREG V128:$Rm, dsub)),
(UMULLv4i32_v2i64 V128:$Rn, V128:$Rm))>;

def : Pat<(v4i16 (AArch64sqdmulh (v4i16 V64:$Rn), (v4i16 V64:$Rm))),
(SQDMULHv4i16 V64:$Rn, V64:$Rm)>;
def : Pat<(v2i32 (AArch64sqdmulh (v2i32 V64:$Rn), (v2i32 V64:$Rm))),
(SQDMULHv2i32 V64:$Rn, V64:$Rm)>;
def : Pat<(v8i16 (AArch64sqdmulh (v8i16 V128:$Rn), (v8i16 V128:$Rm))),
(SQDMULHv8i16 V128:$Rn, V128:$Rm)>;
def : Pat<(v4i32 (AArch64sqdmulh (v4i32 V128:$Rn), (v4i32 V128:$Rm))),
(SQDMULHv4i32 V128:$Rn, V128:$Rm)>;

// Conversions within AdvSIMD types in the same register size are free.
// But because we need a consistent lane ordering, in big endian many
// conversions require one or more REV instructions.
Expand Down
223 changes: 223 additions & 0 deletions llvm/test/CodeGen/AArch64/saturating-vec-smull.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc -mtriple=aarch64-none-elf < %s | FileCheck %s


define <2 x i16> @saturating_2xi16(<2 x i16> %a, <2 x i16> %b) {
; CHECK-LABEL: saturating_2xi16:
; CHECK: // %bb.0:
; CHECK-NEXT: shl v0.2s, v0.2s, #16
; CHECK-NEXT: shl v1.2s, v1.2s, #16
; CHECK-NEXT: sshr v0.2s, v0.2s, #16
; CHECK-NEXT: sshr v1.2s, v1.2s, #16
; CHECK-NEXT: sqdmulh v0.2s, v1.2s, v0.2s
; CHECK-NEXT: ret
%as = sext <2 x i16> %a to <2 x i32>
%bs = sext <2 x i16> %b to <2 x i32>
%m = mul <2 x i32> %bs, %as
%sh = ashr <2 x i32> %m, splat (i32 15)
%ma = tail call <2 x i32> @llvm.smin.v4i32(<2 x i32> %sh, <2 x i32> splat (i32 32767))
%t = trunc <2 x i32> %ma to <2 x i16>
ret <2 x i16> %t
}

define <4 x i16> @saturating_4xi16(<4 x i16> %a, <4 x i16> %b) {
; CHECK-LABEL: saturating_4xi16:
; CHECK: // %bb.0:
; CHECK-NEXT: sqdmulh v0.4h, v1.4h, v0.4h
; CHECK-NEXT: ret
%as = sext <4 x i16> %a to <4 x i32>
%bs = sext <4 x i16> %b to <4 x i32>
%m = mul <4 x i32> %bs, %as
%sh = ashr <4 x i32> %m, splat (i32 15)
%ma = tail call <4 x i32> @llvm.smin.v4i32(<4 x i32> %sh, <4 x i32> splat (i32 32767))
%t = trunc <4 x i32> %ma to <4 x i16>
ret <4 x i16> %t
}

define <8 x i16> @saturating_8xi16(<8 x i16> %a, <8 x i16> %b) {
; CHECK-LABEL: saturating_8xi16:
; CHECK: // %bb.0:
; CHECK-NEXT: sqdmulh v0.8h, v1.8h, v0.8h
; CHECK-NEXT: ret
%as = sext <8 x i16> %a to <8 x i32>
%bs = sext <8 x i16> %b to <8 x i32>
%m = mul <8 x i32> %bs, %as
%sh = ashr <8 x i32> %m, splat (i32 15)
%ma = tail call <8 x i32> @llvm.smin.v8i32(<8 x i32> %sh, <8 x i32> splat (i32 32767))
%t = trunc <8 x i32> %ma to <8 x i16>
ret <8 x i16> %t
}

define <2 x i32> @saturating_2xi32(<2 x i32> %a, <2 x i32> %b) {
; CHECK-LABEL: saturating_2xi32:
; CHECK: // %bb.0:
; CHECK-NEXT: sqdmulh v0.2s, v1.2s, v0.2s
; CHECK-NEXT: ret
%as = sext <2 x i32> %a to <2 x i64>
%bs = sext <2 x i32> %b to <2 x i64>
%m = mul <2 x i64> %bs, %as
%sh = ashr <2 x i64> %m, splat (i64 31)
%ma = tail call <2 x i64> @llvm.smin.v8i64(<2 x i64> %sh, <2 x i64> splat (i64 2147483647))
%t = trunc <2 x i64> %ma to <2 x i32>
ret <2 x i32> %t
}

define <4 x i32> @saturating_4xi32(<4 x i32> %a, <4 x i32> %b) {
; CHECK-LABEL: saturating_4xi32:
; CHECK: // %bb.0:
; CHECK-NEXT: sqdmulh v0.4s, v1.4s, v0.4s
; CHECK-NEXT: ret
%as = sext <4 x i32> %a to <4 x i64>
%bs = sext <4 x i32> %b to <4 x i64>
%m = mul <4 x i64> %bs, %as
%sh = ashr <4 x i64> %m, splat (i64 31)
%ma = tail call <4 x i64> @llvm.smin.v4i64(<4 x i64> %sh, <4 x i64> splat (i64 2147483647))
%t = trunc <4 x i64> %ma to <4 x i32>
ret <4 x i32> %t
}

define <8 x i32> @saturating_8xi32(<8 x i32> %a, <8 x i32> %b) {
; CHECK-LABEL: saturating_8xi32:
; CHECK: // %bb.0:
; CHECK-NEXT: sqdmulh v1.4s, v3.4s, v1.4s
; CHECK-NEXT: sqdmulh v0.4s, v2.4s, v0.4s
; CHECK-NEXT: ret
%as = sext <8 x i32> %a to <8 x i64>
%bs = sext <8 x i32> %b to <8 x i64>
%m = mul <8 x i64> %bs, %as
%sh = ashr <8 x i64> %m, splat (i64 31)
%ma = tail call <8 x i64> @llvm.smin.v8i64(<8 x i64> %sh, <8 x i64> splat (i64 2147483647))
%t = trunc <8 x i64> %ma to <8 x i32>
ret <8 x i32> %t
}

define <2 x i64> @saturating_2xi32_2xi64(<2 x i32> %a, <2 x i32> %b) {
; CHECK-LABEL: saturating_2xi32_2xi64:
; CHECK: // %bb.0:
; CHECK-NEXT: sqdmulh v0.2s, v1.2s, v0.2s
; CHECK-NEXT: sshll v0.2d, v0.2s, #0
; CHECK-NEXT: ret
%as = sext <2 x i32> %a to <2 x i64>
%bs = sext <2 x i32> %b to <2 x i64>
%m = mul <2 x i64> %bs, %as
%sh = ashr <2 x i64> %m, splat (i64 31)
%ma = tail call <2 x i64> @llvm.smin.v8i64(<2 x i64> %sh, <2 x i64> splat (i64 2147483647))
ret <2 x i64> %ma
}

define <6 x i16> @saturating_6xi16(<6 x i16> %a, <6 x i16> %b) {
; CHECK-LABEL: saturating_6xi16:
; CHECK: // %bb.0:
; CHECK-NEXT: smull2 v3.4s, v1.8h, v0.8h
; CHECK-NEXT: movi v2.4s, #127, msl #8
; CHECK-NEXT: sqdmulh v0.4h, v1.4h, v0.4h
; CHECK-NEXT: sshr v3.4s, v3.4s, #15
; CHECK-NEXT: smin v2.4s, v3.4s, v2.4s
; CHECK-NEXT: xtn2 v0.8h, v2.4s
; CHECK-NEXT: ret
%as = sext <6 x i16> %a to <6 x i32>
%bs = sext <6 x i16> %b to <6 x i32>
%m = mul <6 x i32> %bs, %as
%sh = ashr <6 x i32> %m, splat (i32 15)
%ma = tail call <6 x i32> @llvm.smin.v6i32(<6 x i32> %sh, <6 x i32> splat (i32 32767))
%t = trunc <6 x i32> %ma to <6 x i16>
ret <6 x i16> %t
}

define <4 x i16> @unsupported_saturation_value_v4i16(<4 x i16> %a, <4 x i16> %b) {
; CHECK-LABEL: unsupported_saturation_value_v4i16:
; CHECK: // %bb.0:
; CHECK-NEXT: smull v0.4s, v1.4h, v0.4h
; CHECK-NEXT: movi v1.4s, #42
; CHECK-NEXT: sshr v0.4s, v0.4s, #15
; CHECK-NEXT: smin v0.4s, v0.4s, v1.4s
; CHECK-NEXT: xtn v0.4h, v0.4s
; CHECK-NEXT: ret
%as = sext <4 x i16> %a to <4 x i32>
%bs = sext <4 x i16> %b to <4 x i32>
%m = mul <4 x i32> %bs, %as
%sh = ashr <4 x i32> %m, splat (i32 15)
%ma = tail call <4 x i32> @llvm.smin.v4i32(<4 x i32> %sh, <4 x i32> splat (i32 42))
%t = trunc <4 x i32> %ma to <4 x i16>
ret <4 x i16> %t
}

define <4 x i16> @unsupported_shift_value_v4i16(<4 x i16> %a, <4 x i16> %b) {
; CHECK-LABEL: unsupported_shift_value_v4i16:
; CHECK: // %bb.0:
; CHECK-NEXT: smull v0.4s, v1.4h, v0.4h
; CHECK-NEXT: movi v1.4s, #127, msl #8
; CHECK-NEXT: sshr v0.4s, v0.4s, #3
; CHECK-NEXT: smin v0.4s, v0.4s, v1.4s
; CHECK-NEXT: xtn v0.4h, v0.4s
; CHECK-NEXT: ret
%as = sext <4 x i16> %a to <4 x i32>
%bs = sext <4 x i16> %b to <4 x i32>
%m = mul <4 x i32> %bs, %as
%sh = ashr <4 x i32> %m, splat (i32 3)
%ma = tail call <4 x i32> @llvm.smin.v4i32(<4 x i32> %sh, <4 x i32> splat (i32 32767))
%t = trunc <4 x i32> %ma to <4 x i16>
ret <4 x i16> %t
}

define <2 x i16> @extend_to_illegal_type(<2 x i16> %a, <2 x i16> %b) {
; CHECK-LABEL: extend_to_illegal_type:
; CHECK: // %bb.0:
; CHECK-NEXT: shl v0.2s, v0.2s, #16
; CHECK-NEXT: shl v1.2s, v1.2s, #16
; CHECK-NEXT: sshr v0.2s, v0.2s, #16
; CHECK-NEXT: sshr v1.2s, v1.2s, #16
; CHECK-NEXT: sqdmulh v0.2s, v1.2s, v0.2s
; CHECK-NEXT: ret
%as = sext <2 x i16> %a to <2 x i48>
%bs = sext <2 x i16> %b to <2 x i48>
%m = mul <2 x i48> %bs, %as
%sh = ashr <2 x i48> %m, splat (i48 15)
%ma = tail call <2 x i48> @llvm.smin.v4i32(<2 x i48> %sh, <2 x i48> splat (i48 32767))
%t = trunc <2 x i48> %ma to <2 x i16>
ret <2 x i16> %t
}

define <2 x i11> @illegal_source(<2 x i11> %a, <2 x i11> %b) {
; CHECK-LABEL: illegal_source:
; CHECK: // %bb.0:
; CHECK-NEXT: shl v0.2s, v0.2s, #21
; CHECK-NEXT: shl v1.2s, v1.2s, #21
; CHECK-NEXT: sshr v0.2s, v0.2s, #21
; CHECK-NEXT: sshr v1.2s, v1.2s, #21
; CHECK-NEXT: mul v0.2s, v1.2s, v0.2s
; CHECK-NEXT: movi v1.2s, #127, msl #8
; CHECK-NEXT: sshr v0.2s, v0.2s, #15
; CHECK-NEXT: smin v0.2s, v0.2s, v1.2s
; CHECK-NEXT: ret
%as = sext <2 x i11> %a to <2 x i32>
%bs = sext <2 x i11> %b to <2 x i32>
%m = mul <2 x i32> %bs, %as
%sh = ashr <2 x i32> %m, splat (i32 15)
%ma = tail call <2 x i32> @llvm.smin.v2i32(<2 x i32> %sh, <2 x i32> splat (i32 32767))
%t = trunc <2 x i32> %ma to <2 x i11>
ret <2 x i11> %t
}
define <1 x i16> @saturating_1xi16(<1 x i16> %a, <1 x i16> %b) {
; CHECK-LABEL: saturating_1xi16:
; CHECK: // %bb.0:
; CHECK-NEXT: zip1 v0.4h, v0.4h, v0.4h
; CHECK-NEXT: zip1 v1.4h, v1.4h, v0.4h
; CHECK-NEXT: shl v0.2s, v0.2s, #16
; CHECK-NEXT: sshr v0.2s, v0.2s, #16
; CHECK-NEXT: shl v1.2s, v1.2s, #16
; CHECK-NEXT: sshr v1.2s, v1.2s, #16
; CHECK-NEXT: mul v0.2s, v1.2s, v0.2s
; CHECK-NEXT: movi v1.2s, #127, msl #8
; CHECK-NEXT: sshr v0.2s, v0.2s, #15
; CHECK-NEXT: smin v0.2s, v0.2s, v1.2s
; CHECK-NEXT: uzp1 v0.4h, v0.4h, v0.4h
; CHECK-NEXT: ret
%as = sext <1 x i16> %a to <1 x i32>
%bs = sext <1 x i16> %b to <1 x i32>
%m = mul <1 x i32> %bs, %as
%sh = ashr <1 x i32> %m, splat (i32 15)
%ma = tail call <1 x i32> @llvm.smin.v1i32(<1 x i32> %sh, <1 x i32> splat (i32 32767))
%t = trunc <1 x i32> %ma to <1 x i16>
ret <1 x i16> %t
}