Skip to content

[RISCV] Initial codegen support for zvqdotq extension #137039

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 121 additions & 3 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6971,7 +6971,7 @@ static bool hasPassthruOp(unsigned Opcode) {
Opcode <= RISCVISD::LAST_STRICTFP_OPCODE &&
"not a RISC-V target specific op");
static_assert(
RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 134 &&
RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 139 &&
RISCVISD::LAST_STRICTFP_OPCODE - RISCVISD::FIRST_STRICTFP_OPCODE == 21 &&
"adding target specific op should update this function");
if (Opcode >= RISCVISD::ADD_VL && Opcode <= RISCVISD::VFMAX_VL)
Expand All @@ -6995,7 +6995,7 @@ static bool hasMaskOp(unsigned Opcode) {
Opcode <= RISCVISD::LAST_STRICTFP_OPCODE &&
"not a RISC-V target specific op");
static_assert(
RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 134 &&
RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 139 &&
RISCVISD::LAST_STRICTFP_OPCODE - RISCVISD::FIRST_STRICTFP_OPCODE == 21 &&
"adding target specific op should update this function");
if (Opcode >= RISCVISD::TRUNCATE_VECTOR_VL && Opcode <= RISCVISD::SETCC_VL)
Expand Down Expand Up @@ -18101,6 +18101,118 @@ static SDValue performBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG,
DAG.getBuildVector(VT, DL, RHSOps));
}

static SDValue lowerVQDOT(unsigned Opc, SDValue Op0, SDValue Op1,
const SDLoc &DL, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
assert(RISCVISD::VQDOT_VL == Opc || RISCVISD::VQDOTU_VL == Opc ||
RISCVISD::VQDOTSU_VL == Opc);
MVT VT = Op0.getSimpleValueType();
assert(VT == Op1.getSimpleValueType() &&
VT.getVectorElementType() == MVT::i32);

assert(VT.isFixedLengthVector());
MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
SDValue Passthru = convertToScalableVector(
ContainerVT, DAG.getConstant(0, DL, VT), DAG, Subtarget);
Op0 = convertToScalableVector(ContainerVT, Op0, DAG, Subtarget);
Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);

auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
const unsigned Policy = RISCVVType::TAIL_AGNOSTIC | RISCVVType::MASK_AGNOSTIC;
SDValue PolicyOp = DAG.getTargetConstant(Policy, DL, Subtarget.getXLenVT());
SDValue LocalAccum = DAG.getNode(Opc, DL, ContainerVT,
{Op0, Op1, Passthru, Mask, VL, PolicyOp});
return convertFromScalableVector(VT, LocalAccum, DAG, Subtarget);
}

static MVT getQDOTXResultType(MVT OpVT) {
ElementCount OpEC = OpVT.getVectorElementCount();
assert(OpEC.isKnownMultipleOf(4) && OpVT.getVectorElementType() == MVT::i8);
return MVT::getVectorVT(MVT::i32, OpEC.divideCoefficientBy(4));
}

static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
SelectionDAG &DAG,
const RISCVSubtarget &Subtarget,
const RISCVTargetLowering &TLI) {
// Note: We intentionally do not check the legality of the reduction type.
// We want to handle the m4/m8 *src* types, and thus need to let illegal
// intermediate types flow through here.
if (InVec.getValueType().getVectorElementType() != MVT::i32 ||
!InVec.getValueType().getVectorElementCount().isKnownMultipleOf(4))
return SDValue();

// reduce (zext a) <--> reduce (mul zext a. zext 1)
// reduce (sext a) <--> reduce (mul sext a. sext 1)
if (InVec.getOpcode() == ISD::ZERO_EXTEND ||
InVec.getOpcode() == ISD::SIGN_EXTEND) {
SDValue A = InVec.getOperand(0);
if (A.getValueType().getVectorElementType() != MVT::i8 ||
!TLI.isTypeLegal(A.getValueType()))
return SDValue();

MVT ResVT = getQDOTXResultType(A.getSimpleValueType());
A = DAG.getBitcast(ResVT, A);
SDValue B = DAG.getConstant(0x01010101, DL, ResVT);

bool IsSigned = InVec.getOpcode() == ISD::SIGN_EXTEND;
unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
}

// mul (sext, sext) -> vqdot
// mul (zext, zext) -> vqdotu
// mul (sext, zext) -> vqdotsu
// mul (zext, sext) -> vqdotsu (swapped)
// TODO: Improve .vx handling - we end up with a sub-vector insert
// which confuses the splat pattern matching. Also, match vqdotus.vx
if (InVec.getOpcode() != ISD::MUL)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious, what about left shifts?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Annoyingly complicated, possible future work.

The problem is that we have to expand the shift as a multiply by 2^N, and the range of shift amounts we can handle is very limited due to the input being an i8.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

p.s. I don't have a real case which benefits from this, if you do, please share and I'll re-prioritize.

return SDValue();

SDValue A = InVec.getOperand(0);
SDValue B = InVec.getOperand(1);
unsigned Opc = 0;
if (A.getOpcode() == B.getOpcode()) {
if (A.getOpcode() == ISD::SIGN_EXTEND)
Opc = RISCVISD::VQDOT_VL;
else if (A.getOpcode() == ISD::ZERO_EXTEND)
Opc = RISCVISD::VQDOTU_VL;
else
return SDValue();
} else {
if (B.getOpcode() != ISD::ZERO_EXTEND)
std::swap(A, B);
if (A.getOpcode() != ISD::SIGN_EXTEND || B.getOpcode() != ISD::ZERO_EXTEND)
return SDValue();
Opc = RISCVISD::VQDOTSU_VL;
}
assert(Opc);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this ever got triggered?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't with the current code structure, but asserts exist to check assumptions?


if (A.getOperand(0).getValueType().getVectorElementType() != MVT::i8 ||
A.getOperand(0).getValueType() != B.getOperand(0).getValueType() ||
!TLI.isTypeLegal(A.getValueType()))
return SDValue();

MVT ResVT = getQDOTXResultType(A.getOperand(0).getSimpleValueType());
A = DAG.getBitcast(ResVT, A.getOperand(0));
B = DAG.getBitcast(ResVT, B.getOperand(0));
return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
}

static SDValue performVECREDUCECombine(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget,
const RISCVTargetLowering &TLI) {
if (!Subtarget.hasStdExtZvqdotq())
return SDValue();

SDLoc DL(N);
EVT VT = N->getValueType(0);
SDValue InVec = N->getOperand(0);
if (SDValue V = foldReduceOperandViaVQDOT(InVec, DL, DAG, Subtarget, TLI))
return DAG.getNode(ISD::VECREDUCE_ADD, DL, VT, V);
return SDValue();
}

static SDValue performINSERT_VECTOR_ELTCombine(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget,
const RISCVTargetLowering &TLI) {
Expand Down Expand Up @@ -19878,8 +19990,11 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,

return SDValue();
}
case ISD::CTPOP:
case ISD::VECREDUCE_ADD:
if (SDValue V = performVECREDUCECombine(N, DAG, Subtarget, *this))
return V;
[[fallthrough]];
case ISD::CTPOP:
if (SDValue V = combineToVCPOP(N, DAG, Subtarget))
return V;
break;
Expand Down Expand Up @@ -22401,6 +22516,9 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(RI_VUNZIP2A_VL)
NODE_NAME_CASE(RI_VUNZIP2B_VL)
NODE_NAME_CASE(RI_VEXTRACT)
NODE_NAME_CASE(VQDOT_VL)
NODE_NAME_CASE(VQDOTU_VL)
NODE_NAME_CASE(VQDOTSU_VL)
NODE_NAME_CASE(READ_CSR)
NODE_NAME_CASE(WRITE_CSR)
NODE_NAME_CASE(SWAP_CSR)
Expand Down
7 changes: 6 additions & 1 deletion llvm/lib/Target/RISCV/RISCVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,12 @@ enum NodeType : unsigned {
RI_VUNZIP2A_VL,
RI_VUNZIP2B_VL,

LAST_VL_VECTOR_OP = RI_VUNZIP2B_VL,
// zvqdot instructions with additional passthru, mask and VL operands
VQDOT_VL,
VQDOTU_VL,
VQDOTSU_VL,

LAST_VL_VECTOR_OP = VQDOTSU_VL,

// XRivosVisni
// VEXTRACT matches the semantics of ri.vextract.x.v. The result is always
Expand Down
31 changes: 31 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfoZvqdotq.td
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,34 @@ let Predicates = [HasStdExtZvqdotq] in {
def VQDOTSU_VX : VALUVX<0b101010, OPMVX, "vqdotsu.vx">;
def VQDOTUS_VX : VALUVX<0b101110, OPMVX, "vqdotus.vx">;
} // Predicates = [HasStdExtZvqdotq]


def riscv_vqdot_vl : SDNode<"RISCVISD::VQDOT_VL", SDT_RISCVIntBinOp_VL>;
def riscv_vqdotu_vl : SDNode<"RISCVISD::VQDOTU_VL", SDT_RISCVIntBinOp_VL>;
def riscv_vqdotsu_vl : SDNode<"RISCVISD::VQDOTSU_VL", SDT_RISCVIntBinOp_VL>;

multiclass VPseudoVQDOT_VV_VX {
foreach m = MxSet<32>.m in {
defm "" : VPseudoBinaryV_VV<m>,
SchedBinary<"WriteVIALUV", "ReadVIALUV", "ReadVIALUV", m.MX,
forcePassthruRead=true>;
defm "" : VPseudoBinaryV_VX<m>,
SchedBinary<"WriteVIALUX", "ReadVIALUV", "ReadVIALUX", m.MX,
forcePassthruRead=true>;
}
}

// TODO: Add pseudo and patterns for vqdotus.vx
// TODO: Add isCommutable for VQDOT and VQDOTU
let Predicates = [HasStdExtZvqdotq], mayLoad = 0, mayStore = 0,
hasSideEffects = 0 in {
defm PseudoVQDOT : VPseudoVQDOT_VV_VX;
defm PseudoVQDOTU : VPseudoVQDOT_VV_VX;
defm PseudoVQDOTSU : VPseudoVQDOT_VV_VX;
}

defvar AllE32Vectors = [VI32MF2, VI32M1, VI32M2, VI32M4, VI32M8];
defm : VPatBinaryVL_VV_VX<riscv_vqdot_vl, "PseudoVQDOT", AllE32Vectors>;
defm : VPatBinaryVL_VV_VX<riscv_vqdotu_vl, "PseudoVQDOTU", AllE32Vectors>;
defm : VPatBinaryVL_VV_VX<riscv_vqdotsu_vl, "PseudoVQDOTSU", AllE32Vectors>;

Loading
Loading