Skip to content

[RISCV] Merge GPRPair and GPRF64Pair #116094

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,16 +497,10 @@ struct RISCVOperand final : public MCParsedAsmOperand {
RISCVMCRegisterClasses[RISCV::GPRF32RegClassID].contains(Reg.RegNum);
}

bool isGPRF64Pair() const {
return Kind == KindTy::Register &&
RISCVMCRegisterClasses[RISCV::GPRF64PairRegClassID].contains(
Reg.RegNum);
}

bool isGPRAsFPR() const { return isGPR() && Reg.IsGPRAsFPR; }
bool isGPRAsFPR16() const { return isGPRF16() && Reg.IsGPRAsFPR; }
bool isGPRAsFPR32() const { return isGPRF32() && Reg.IsGPRAsFPR; }
bool isGPRPairAsFPR64() const { return isGPRF64Pair() && Reg.IsGPRAsFPR; }
bool isGPRPairAsFPR64() const { return isGPRPair() && Reg.IsGPRAsFPR; }

static bool evaluateConstantImm(const MCExpr *Expr, int64_t &Imm,
RISCVMCExpr::VariantKind &VK) {
Expand Down Expand Up @@ -2405,7 +2399,7 @@ ParseStatus RISCVAsmParser::parseGPRPairAsFPR64(OperandVector &Operands) {
const MCRegisterInfo *RI = getContext().getRegisterInfo();
MCRegister Pair = RI->getMatchingSuperReg(
Reg, RISCV::sub_gpr_even,
&RISCVMCRegisterClasses[RISCV::GPRF64PairRegClassID]);
&RISCVMCRegisterClasses[RISCV::GPRPairRegClassID]);
Operands.push_back(RISCVOperand::createReg(Pair, S, E, /*isGPRAsFPR=*/true));
return ParseStatus::Success;
}
Expand Down
17 changes: 6 additions & 11 deletions llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -958,20 +958,14 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
assert((!Subtarget->is64Bit() || Opcode == RISCVISD::BuildGPRPair) &&
"BuildPairF64 only handled here on rv32i_zdinx");

int RegClassID = (Opcode == RISCVISD::BuildGPRPair)
? RISCV::GPRPairRegClassID
: RISCV::GPRF64PairRegClassID;
MVT OutType = (Opcode == RISCVISD::BuildGPRPair) ? MVT::Untyped : MVT::f64;

SDValue Ops[] = {
CurDAG->getTargetConstant(RegClassID, DL, MVT::i32),
CurDAG->getTargetConstant(RISCV::GPRPairRegClassID, DL, MVT::i32),
Node->getOperand(0),
CurDAG->getTargetConstant(RISCV::sub_gpr_even, DL, MVT::i32),
Node->getOperand(1),
CurDAG->getTargetConstant(RISCV::sub_gpr_odd, DL, MVT::i32)};

SDNode *N =
CurDAG->getMachineNode(TargetOpcode::REG_SEQUENCE, DL, OutType, Ops);
SDNode *N = CurDAG->getMachineNode(TargetOpcode::REG_SEQUENCE, DL, VT, Ops);
ReplaceNode(Node, N);
return;
}
Expand All @@ -982,14 +976,15 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
"SplitF64 only handled here on rv32i_zdinx");

if (!SDValue(Node, 0).use_empty()) {
SDValue Lo = CurDAG->getTargetExtractSubreg(RISCV::sub_gpr_even, DL, VT,
SDValue Lo = CurDAG->getTargetExtractSubreg(RISCV::sub_gpr_even, DL,
Node->getValueType(0),
Node->getOperand(0));
ReplaceUses(SDValue(Node, 0), Lo);
}

if (!SDValue(Node, 1).use_empty()) {
SDValue Hi = CurDAG->getTargetExtractSubreg(RISCV::sub_gpr_odd, DL, VT,
Node->getOperand(0));
SDValue Hi = CurDAG->getTargetExtractSubreg(
RISCV::sub_gpr_odd, DL, Node->getValueType(1), Node->getOperand(0));
ReplaceUses(SDValue(Node, 1), Hi);
}

Expand Down
51 changes: 34 additions & 17 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
if (Subtarget.is64Bit())
addRegisterClass(MVT::f64, &RISCV::GPRRegClass);
else
addRegisterClass(MVT::f64, &RISCV::GPRF64PairRegClass);
addRegisterClass(MVT::f64, &RISCV::GPRPairRegClass);
}

static const MVT::SimpleValueType BoolVecVTs[] = {
Expand Down Expand Up @@ -20507,7 +20507,7 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
if (VT == MVT::f32 && Subtarget.hasStdExtZfinx())
return std::make_pair(0U, &RISCV::GPRF32NoX0RegClass);
if (VT == MVT::f64 && Subtarget.hasStdExtZdinx() && !Subtarget.is64Bit())
return std::make_pair(0U, &RISCV::GPRF64PairNoX0RegClass);
return std::make_pair(0U, &RISCV::GPRPairNoX0RegClass);
return std::make_pair(0U, &RISCV::GPRNoX0RegClass);
case 'f':
if (VT == MVT::f16) {
Expand All @@ -20524,14 +20524,14 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
if (Subtarget.hasStdExtD())
return std::make_pair(0U, &RISCV::FPR64RegClass);
if (Subtarget.hasStdExtZdinx() && !Subtarget.is64Bit())
return std::make_pair(0U, &RISCV::GPRF64PairNoX0RegClass);
return std::make_pair(0U, &RISCV::GPRPairNoX0RegClass);
if (Subtarget.hasStdExtZdinx() && Subtarget.is64Bit())
return std::make_pair(0U, &RISCV::GPRNoX0RegClass);
}
break;
case 'R':
if (VT == MVT::f64 && !Subtarget.is64Bit() && Subtarget.hasStdExtZdinx())
return std::make_pair(0U, &RISCV::GPRF64PairNoX0RegClass);
return std::make_pair(0U, &RISCV::GPRPairNoX0RegClass);
return std::make_pair(0U, &RISCV::GPRPairNoX0RegClass);
default:
break;
Expand Down Expand Up @@ -20570,7 +20570,7 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
if (VT == MVT::f32 && Subtarget.hasStdExtZfinx())
return std::make_pair(0U, &RISCV::GPRF32CRegClass);
if (VT == MVT::f64 && Subtarget.hasStdExtZdinx() && !Subtarget.is64Bit())
return std::make_pair(0U, &RISCV::GPRF64PairCRegClass);
return std::make_pair(0U, &RISCV::GPRPairCRegClass);
if (!VT.isVector())
return std::make_pair(0U, &RISCV::GPRCRegClass);
} else if (Constraint == "cf") {
Expand All @@ -20588,7 +20588,7 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
if (Subtarget.hasStdExtD())
return std::make_pair(0U, &RISCV::FPR64CRegClass);
if (Subtarget.hasStdExtZdinx() && !Subtarget.is64Bit())
return std::make_pair(0U, &RISCV::GPRF64PairCRegClass);
return std::make_pair(0U, &RISCV::GPRPairCRegClass);
if (Subtarget.hasStdExtZdinx() && Subtarget.is64Bit())
return std::make_pair(0U, &RISCV::GPRCRegClass);
}
Expand Down Expand Up @@ -20752,7 +20752,7 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
// Subtarget into account.
if (Res.second == &RISCV::GPRF16RegClass ||
Res.second == &RISCV::GPRF32RegClass ||
Res.second == &RISCV::GPRF64PairRegClass)
Res.second == &RISCV::GPRPairRegClass)
return std::make_pair(Res.first, &RISCV::GPRRegClass);

return Res;
Expand Down Expand Up @@ -21379,12 +21379,19 @@ bool RISCVTargetLowering::splitValueIntoRegisterParts(
bool IsABIRegCopy = CC.has_value();
EVT ValueVT = Val.getValueType();

if (ValueVT == (Subtarget.is64Bit() ? MVT::i128 : MVT::i64) &&
MVT PairVT = Subtarget.is64Bit() ? MVT::i128 : MVT::i64;
if ((ValueVT == PairVT ||
(!Subtarget.is64Bit() && Subtarget.hasStdExtZdinx() &&
ValueVT == MVT::f64)) &&
NumParts == 1 && PartVT == MVT::Untyped) {
// Pairs in Inline Assembly
// Pairs in Inline Assembly, f64 in Inline assembly on rv32_zdinx
MVT XLenVT = Subtarget.getXLenVT();
if (ValueVT == MVT::f64)
Val = DAG.getBitcast(MVT::i64, Val);
auto [Lo, Hi] = DAG.SplitScalar(Val, DL, XLenVT, XLenVT);
Parts[0] = DAG.getNode(RISCVISD::BuildGPRPair, DL, MVT::Untyped, Lo, Hi);
// Always creating an MVT::Untyped part, so always use
// RISCVISD::BuildGPRPair.
Parts[0] = DAG.getNode(RISCVISD::BuildGPRPair, DL, PartVT, Lo, Hi);
return true;
}

Expand All @@ -21396,7 +21403,7 @@ bool RISCVTargetLowering::splitValueIntoRegisterParts(
Val = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Val);
Val = DAG.getNode(ISD::OR, DL, MVT::i32, Val,
DAG.getConstant(0xFFFF0000, DL, MVT::i32));
Val = DAG.getNode(ISD::BITCAST, DL, MVT::f32, Val);
Val = DAG.getNode(ISD::BITCAST, DL, PartVT, Val);
Parts[0] = Val;
return true;
}
Expand Down Expand Up @@ -21465,14 +21472,24 @@ SDValue RISCVTargetLowering::joinRegisterPartsIntoValue(
MVT PartVT, EVT ValueVT, std::optional<CallingConv::ID> CC) const {
bool IsABIRegCopy = CC.has_value();

if (ValueVT == (Subtarget.is64Bit() ? MVT::i128 : MVT::i64) &&
MVT PairVT = Subtarget.is64Bit() ? MVT::i128 : MVT::i64;
if ((ValueVT == PairVT ||
(!Subtarget.is64Bit() && Subtarget.hasStdExtZdinx() &&
ValueVT == MVT::f64)) &&
NumParts == 1 && PartVT == MVT::Untyped) {
// Pairs in Inline Assembly
// Pairs in Inline Assembly, f64 in Inline assembly on rv32_zdinx
MVT XLenVT = Subtarget.getXLenVT();
SDValue Res = DAG.getNode(RISCVISD::SplitGPRPair, DL,
DAG.getVTList(XLenVT, XLenVT), Parts[0]);
return DAG.getNode(ISD::BUILD_PAIR, DL, ValueVT, Res.getValue(0),
Res.getValue(1));

SDValue Val = Parts[0];
// Always starting with an MVT::Untyped part, so always use
// RISCVISD::SplitGPRPair
Val = DAG.getNode(RISCVISD::SplitGPRPair, DL, DAG.getVTList(XLenVT, XLenVT),
Val);
Val = DAG.getNode(ISD::BUILD_PAIR, DL, PairVT, Val.getValue(0),
Val.getValue(1));
if (ValueVT == MVT::f64)
Val = DAG.getBitcast(ValueVT, Val);
return Val;
}

if (IsABIRegCopy && (ValueVT == MVT::f16 || ValueVT == MVT::bf16) &&
Expand Down
18 changes: 9 additions & 9 deletions llvm/lib/Target/RISCV/RISCVInstrInfoD.td
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def FPR64INX : RegisterOperand<GPR> {
let DecoderMethod = "DecodeGPRRegisterClass";
}

def FPR64IN32X : RegisterOperand<GPRF64Pair> {
def FPR64IN32X : RegisterOperand<GPRPair> {
let ParserMatchClass = GPRPairAsFPR;
}

Expand Down Expand Up @@ -457,16 +457,16 @@ def : PatSetCC<FPR64INX, any_fsetccs, SETOLE, FLE_D_INX, f64>;

let Predicates = [HasStdExtZdinx, IsRV32] in {
// Match signaling FEQ_D
def : Pat<(XLenVT (strict_fsetccs FPR64IN32X:$rs1, FPR64IN32X:$rs2, SETEQ)),
def : Pat<(XLenVT (strict_fsetccs (f64 FPR64IN32X:$rs1), FPR64IN32X:$rs2, SETEQ)),
(AND (XLenVT (FLE_D_IN32X $rs1, $rs2)),
(XLenVT (FLE_D_IN32X $rs2, $rs1)))>;
def : Pat<(XLenVT (strict_fsetccs FPR64IN32X:$rs1, FPR64IN32X:$rs2, SETOEQ)),
def : Pat<(XLenVT (strict_fsetccs (f64 FPR64IN32X:$rs1), FPR64IN32X:$rs2, SETOEQ)),
(AND (XLenVT (FLE_D_IN32X $rs1, $rs2)),
(XLenVT (FLE_D_IN32X $rs2, $rs1)))>;
// If both operands are the same, use a single FLE.
def : Pat<(XLenVT (strict_fsetccs FPR64IN32X:$rs1, FPR64IN32X:$rs1, SETEQ)),
def : Pat<(XLenVT (strict_fsetccs (f64 FPR64IN32X:$rs1), FPR64IN32X:$rs1, SETEQ)),
(FLE_D_IN32X $rs1, $rs1)>;
def : Pat<(XLenVT (strict_fsetccs FPR64IN32X:$rs1, FPR64IN32X:$rs1, SETOEQ)),
def : Pat<(XLenVT (strict_fsetccs (f64 FPR64IN32X:$rs1), FPR64IN32X:$rs1, SETOEQ)),
(FLE_D_IN32X $rs1, $rs1)>;

def : PatSetCC<FPR64IN32X, any_fsetccs, SETLT, FLT_D_IN32X, f64>;
Expand Down Expand Up @@ -523,15 +523,15 @@ def PseudoFROUND_D_IN32X : PseudoFROUND<FPR64IN32X, f64>;

/// Loads
let isCall = 0, mayLoad = 1, mayStore = 0, Size = 8, isCodeGenOnly = 1 in
def PseudoRV32ZdinxLD : Pseudo<(outs GPRF64Pair:$dst), (ins GPR:$rs1, simm12:$imm12), []>;
def PseudoRV32ZdinxLD : Pseudo<(outs GPRPair:$dst), (ins GPR:$rs1, simm12:$imm12), []>;
def : Pat<(f64 (load (AddrRegImmINX (XLenVT GPR:$rs1), simm12:$imm12))),
(PseudoRV32ZdinxLD GPR:$rs1, simm12:$imm12)>;

/// Stores
let isCall = 0, mayLoad = 0, mayStore = 1, Size = 8, isCodeGenOnly = 1 in
def PseudoRV32ZdinxSD : Pseudo<(outs), (ins GPRF64Pair:$rs2, GPRNoX0:$rs1, simm12:$imm12), []>;
def : Pat<(store (f64 GPRF64Pair:$rs2), (AddrRegImmINX (XLenVT GPR:$rs1), simm12:$imm12)),
(PseudoRV32ZdinxSD GPRF64Pair:$rs2, GPR:$rs1, simm12:$imm12)>;
def PseudoRV32ZdinxSD : Pseudo<(outs), (ins GPRPair:$rs2, GPRNoX0:$rs1, simm12:$imm12), []>;
def : Pat<(store (f64 GPRPair:$rs2), (AddrRegImmINX (XLenVT GPR:$rs1), simm12:$imm12)),
(PseudoRV32ZdinxSD GPRPair:$rs2, GPR:$rs1, simm12:$imm12)>;
} // Predicates = [HasStdExtZdinx, IsRV32]

let Predicates = [HasStdExtD, IsRV32] in {
Expand Down
22 changes: 3 additions & 19 deletions llvm/lib/Target/RISCV/RISCVRegisterInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ let RegAltNameIndices = [ABIRegAltName] in {

let RegInfos = XLenPairRI,
DecoderMethod = "DecodeGPRPairRegisterClass" in {
def GPRPair : RISCVRegisterClass<[XLenPairVT], 64, (add
def GPRPair : RISCVRegisterClass<[XLenPairVT, XLenPairFVT], 64, (add
X10_X11, X12_X13, X14_X15, X16_X17,
X6_X7,
X28_X29, X30_X31,
Expand All @@ -334,11 +334,11 @@ def GPRPair : RISCVRegisterClass<[XLenPairVT], 64, (add
X0_Pair, X2_X3, X4_X5
)>;

def GPRPairNoX0 : RISCVRegisterClass<[XLenPairVT], 64, (sub GPRPair, X0_Pair)>;
def GPRPairNoX0 : RISCVRegisterClass<[XLenPairVT, XLenPairFVT], 64, (sub GPRPair, X0_Pair)>;
} // let RegInfos = XLenPairRI, DecoderMethod = "DecodeGPRPairRegisterClass"

let RegInfos = XLenPairRI in
def GPRPairC : RISCVRegisterClass<[XLenPairVT], 64, (add
def GPRPairC : RISCVRegisterClass<[XLenPairVT, XLenPairFVT], 64, (add
X10_X11, X12_X13, X14_X15, X8_X9
)>;

Expand Down Expand Up @@ -464,22 +464,6 @@ def GPRF32C : RISCVRegisterClass<[f32], 32, (add (sequence "X%u_W", 10, 15),
(sequence "X%u_W", 8, 9))>;
def GPRF32NoX0 : RISCVRegisterClass<[f32], 32, (sub GPRF32, X0_W)>;

let DecoderMethod = "DecodeGPRPairRegisterClass" in
def GPRF64Pair : RISCVRegisterClass<[XLenPairFVT], 64, (add
X10_X11, X12_X13, X14_X15, X16_X17,
X6_X7,
X28_X29, X30_X31,
X8_X9,
X18_X19, X20_X21, X22_X23, X24_X25, X26_X27,
X0_Pair, X2_X3, X4_X5
)>;

def GPRF64PairC : RISCVRegisterClass<[XLenPairFVT], 64, (add
X10_X11, X12_X13, X14_X15, X8_X9
)>;

def GPRF64PairNoX0 : RISCVRegisterClass<[XLenPairFVT], 64, (sub GPRF64Pair, X0_Pair)>;

//===----------------------------------------------------------------------===//
// Vector type mapping to LLVM types.
//===----------------------------------------------------------------------===//
Expand Down
36 changes: 36 additions & 0 deletions llvm/test/CodeGen/RISCV/zdinx-asm-constraint.ll
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,42 @@ entry:
ret void
}

define dso_local void @zdinx_asm_inout(ptr nocapture noundef writeonly %a, double noundef %b) nounwind {
; CHECK-LABEL: zdinx_asm_inout:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: mv a3, a2
; CHECK-NEXT: mv a2, a1
; CHECK-NEXT: #APP
; CHECK-NEXT: fmv.d a2, a2
; CHECK-NEXT: #NO_APP
; CHECK-NEXT: sw a2, 8(a0)
; CHECK-NEXT: sw a3, 12(a0)
; CHECK-NEXT: ret
entry:
%arrayidx = getelementptr inbounds double, ptr %a, i32 1
%0 = tail call double asm "fsgnj.d $0, $1, $1", "=r,0"(double %b)
store double %0, ptr %arrayidx, align 8
ret void
}

define dso_local void @zdinx_asm_Pr_inout(ptr nocapture noundef writeonly %a, double noundef %b) nounwind {
; CHECK-LABEL: zdinx_asm_Pr_inout:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: mv a3, a2
; CHECK-NEXT: mv a2, a1
; CHECK-NEXT: #APP
; CHECK-NEXT: fabs.d a2, a2
; CHECK-NEXT: #NO_APP
; CHECK-NEXT: sw a2, 8(a0)
; CHECK-NEXT: sw a3, 12(a0)
; CHECK-NEXT: ret
entry:
%arrayidx = getelementptr inbounds double, ptr %a, i32 1
%0 = tail call double asm "fsgnjx.d $0, $1, $1", "=R,0"(double %b)
store double %0, ptr %arrayidx, align 8
ret void
}

define dso_local void @zfinx_asm(ptr nocapture noundef writeonly %a, float noundef %b, float noundef %c) nounwind {
; CHECK-LABEL: zfinx_asm:
; CHECK: # %bb.0: # %entry
Expand Down
Loading