-
Notifications
You must be signed in to change notification settings - Fork 13.3k
[NVPTX] support packed f32 instructions for sm_100+ #126337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
1e97b9a
to
fb057c7
Compare
@llvm/pr-subscribers-backend-nvptx Author: Princeton Ferro (Prince781) ChangesThis adds support for lowering In this PR I didn't implement support for alternative rounding modes, as that was lower priority. If there's sufficient demand, I can add that to this PR. Otherwise we can leave that for later. Patch is 127.41 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/126337.diff 10 Files Affected:
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 42a5fbec95174e..394428594b9870 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -115,6 +115,9 @@ def SDTPtrAddOp : SDTypeProfile<1, 2, [ // ptradd
def SDTIntBinOp : SDTypeProfile<1, 2, [ // add, and, or, xor, udiv, etc.
SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisInt<0>
]>;
+def SDTIntTernaryOp : SDTypeProfile<1, 3, [ // fma32x2
+ SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisSameAs<0, 3>, SDTCisInt<0>
+]>;
def SDTIntShiftOp : SDTypeProfile<1, 2, [ // shl, sra, srl
SDTCisSameAs<0, 1>, SDTCisInt<0>, SDTCisInt<2>
]>;
@@ -818,6 +821,10 @@ def step_vector : SDNode<"ISD::STEP_VECTOR", SDTypeProfile<1, 1,
def scalar_to_vector : SDNode<"ISD::SCALAR_TO_VECTOR", SDTypeProfile<1, 1, []>,
[]>;
+def build_pair : SDNode<"ISD::BUILD_PAIR", SDTypeProfile<1, 2,
+ [SDTCisInt<0>, SDTCisInt<1>, SDTCisInt<2>]>, []>;
+
+
// vector_extract/vector_insert are deprecated. extractelt/insertelt
// are preferred.
def vector_extract : SDNode<"ISD::EXTRACT_VECTOR_ELT",
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index ec654e0f3f200f..3a39f6dab0c85f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -190,6 +190,12 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
SelectI128toV2I64(N);
return;
}
+ if (N->getOperand(1).getValueType() == MVT::i64 &&
+ N->getValueType(0) == MVT::f32 && N->getValueType(1) == MVT::f32) {
+ // {f32,f32} = mov i64
+ SelectI64ToV2F32(N);
+ return;
+ }
break;
}
case ISD::FADD:
@@ -2765,6 +2771,19 @@ void NVPTXDAGToDAGISel::SelectI128toV2I64(SDNode *N) {
ReplaceNode(N, Mov);
}
+void NVPTXDAGToDAGISel::SelectI64ToV2F32(SDNode *N) {
+ SDValue Ch = N->getOperand(0);
+ SDValue Src = N->getOperand(1);
+ assert(N->getValueType(0) == MVT::f32 && N->getValueType(1) == MVT::f32 &&
+ "expected {f32,f32} = CopyFromReg i64");
+ SDLoc DL(N);
+
+ SDNode *Mov = CurDAG->getMachineNode(NVPTX::I64toV2F32, DL,
+ {MVT::f32, MVT::f32, Ch.getValueType()},
+ {Src, Ch});
+ ReplaceNode(N, Mov);
+}
+
/// GetConvertOpcode - Returns the CVT_ instruction opcode that implements a
/// conversion from \p SrcTy to \p DestTy.
unsigned NVPTXDAGToDAGISel::GetConvertOpcode(MVT DestTy, MVT SrcTy,
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index 8dc6bc86c68281..703a80f74e90c7 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -91,6 +91,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
bool tryEXTRACT_VECTOR_ELEMENT(SDNode *N);
void SelectV2I64toI128(SDNode *N);
void SelectI128toV2I64(SDNode *N);
+ void SelectI64ToV2F32(SDNode *N);
void SelectCpAsyncBulkG2S(SDNode *N);
void SelectCpAsyncBulkS2G(SDNode *N);
void SelectCpAsyncBulkPrefetchL2(SDNode *N);
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 58ad92a8934a66..1e417f23fdb099 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -866,6 +866,24 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setBF16OperationAction(ISD::FNEG, MVT::v2bf16, Legal, Expand);
// (would be) Library functions.
+ if (STI.hasF32x2Instructions()) {
+ // Handle custom lowering for: v2f32 = OP v2f32, v2f32
+ for (const auto &Op : {ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FMA})
+ setOperationAction(Op, MVT::v2f32, Custom);
+ // Handle custom lowering for: f32 = extract_vector_elt v2f32
+ setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v2f32, Custom);
+ // Combine:
+ // i64 = or (i64 = zero_extend X, i64 = shl (i64 = any_extend Y, 32))
+ // -> i64 = build_pair (X, Y)
+ setTargetDAGCombine(ISD::OR);
+ // i32 = truncate (i64 = srl (i64 = build_pair (X, Y), 32))
+ // -> i32 Y
+ setTargetDAGCombine(ISD::TRUNCATE);
+ // i64 = build_pair ({i32, i32} = CopyFromReg (CopyToReg (i64 X)))
+ // -> i64 X
+ setTargetDAGCombine(ISD::BUILD_PAIR);
+ }
+
// These map to conversion instructions for scalar FP types.
for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
ISD::FROUNDEVEN, ISD::FTRUNC}) {
@@ -1066,6 +1084,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(NVPTXISD::STACKSAVE)
MAKE_CASE(NVPTXISD::SETP_F16X2)
MAKE_CASE(NVPTXISD::SETP_BF16X2)
+ MAKE_CASE(NVPTXISD::FADD_F32X2)
+ MAKE_CASE(NVPTXISD::FSUB_F32X2)
+ MAKE_CASE(NVPTXISD::FMUL_F32X2)
+ MAKE_CASE(NVPTXISD::FMA_F32X2)
MAKE_CASE(NVPTXISD::Dummy)
MAKE_CASE(NVPTXISD::MUL_WIDE_SIGNED)
MAKE_CASE(NVPTXISD::MUL_WIDE_UNSIGNED)
@@ -2207,6 +2229,30 @@ SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
return DAG.getAnyExtOrTrunc(BFE, DL, Op->getValueType(0));
}
+ if (VectorVT == MVT::v2f32) {
+ auto GetOperand = [&DAG, &DL](SDValue Op, SDValue Index) {
+ if (const auto *ConstIdx = dyn_cast<ConstantSDNode>(Index))
+ return Op.getOperand(ConstIdx->getZExtValue());
+ SDValue E0 = Op.getOperand(0);
+ SDValue E1 = Op.getOperand(1);
+ return DAG.getSelectCC(DL, Index, DAG.getIntPtrConstant(0, DL), E0, E1,
+ ISD::CondCode::SETEQ);
+ };
+ if (SDValue Pair = Vector.getOperand(0);
+ Vector.getOpcode() == ISD::BITCAST &&
+ Pair.getOpcode() == ISD::BUILD_PAIR) {
+ // peek through v2f32 = bitcast (i64 = build_pair (i32 A, i32 B))
+ // where A:i32, B:i32 = CopyFromReg (i64 = F32X2 Operation ...)
+ return DAG.getNode(ISD::BITCAST, DL, Op.getValueType(),
+ GetOperand(Pair, Index));
+ }
+ if (Vector.getOpcode() == ISD::BUILD_VECTOR)
+ return GetOperand(Vector, Index);
+
+ // Otherwise, let SelectionDAG expand the operand.
+ return SDValue();
+ }
+
// Constant index will be matched by tablegen.
if (isa<ConstantSDNode>(Index.getNode()))
return Op;
@@ -4573,26 +4619,109 @@ PerformFADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
return SDValue();
}
+// If {Lo, Hi} = <packed f32x2 val>, returns that value
+static SDValue peekThroughF32x2Copy(const SDValue &Lo, const SDValue &Hi) {
+ if (Lo.getValueType() != MVT::f32 || Lo.getOpcode() != ISD::CopyFromReg ||
+ Lo.getNode() != Hi.getNode() || Lo == Hi)
+ return SDValue();
+
+ SDNode *CopyF = Lo.getNode();
+ SDNode *CopyT = CopyF->getOperand(0).getNode();
+ if (CopyT->getOpcode() != ISD::CopyToReg)
+ return SDValue();
+
+ // check the two registers are the same
+ if (cast<RegisterSDNode>(CopyF->getOperand(1))->getReg() !=
+ cast<RegisterSDNode>(CopyT->getOperand(1))->getReg())
+ return SDValue();
+
+ SDValue OrigV = CopyT->getOperand(2);
+ if (OrigV.getValueType() != MVT::i64)
+ return SDValue();
+ return OrigV;
+}
+
+static SDValue
+PerformPackedF32StoreCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+ CodeGenOptLevel OptLevel) {
+ if (OptLevel == CodeGenOptLevel::None)
+ return SDValue();
+
+ // rewrite stores of packed f32 values
+ auto *MemN = cast<MemSDNode>(N);
+ if (MemN->getMemoryVT() == MVT::f32) {
+ std::optional<NVPTXISD::NodeType> NewOpcode;
+ switch (MemN->getOpcode()) {
+ case NVPTXISD::StoreRetvalV2:
+ NewOpcode = NVPTXISD::StoreRetval;
+ break;
+ case NVPTXISD::StoreRetvalV4:
+ NewOpcode = NVPTXISD::StoreRetvalV2;
+ break;
+ case NVPTXISD::StoreParamV2:
+ NewOpcode = NVPTXISD::StoreParam;
+ break;
+ case NVPTXISD::StoreParamV4:
+ NewOpcode = NVPTXISD::StoreParamV2;
+ break;
+ }
+
+ if (NewOpcode) {
+ SmallVector<SDValue> NewOps = {N->getOperand(0), N->getOperand(1)};
+ unsigned NumPacked = 0;
+
+ // gather all packed operands
+ for (unsigned I = 2, E = MemN->getNumOperands(); I < E; I += 2) {
+ if (SDValue Packed = peekThroughF32x2Copy(MemN->getOperand(I),
+ MemN->getOperand(I + 1))) {
+ NewOps.push_back(Packed);
+ ++NumPacked;
+ } else {
+ NumPacked = 0;
+ break;
+ }
+ }
+
+ if (NumPacked) {
+ return DCI.DAG.getMemIntrinsicNode(
+ *NewOpcode, SDLoc(N), N->getVTList(), NewOps, MVT::i64,
+ MemN->getPointerInfo(), MemN->getAlign(),
+ MachineMemOperand::MOStore);
+ }
+ }
+ }
+ return SDValue();
+}
+
static SDValue PerformStoreCombineHelper(SDNode *N, std::size_t Front,
- std::size_t Back) {
+ std::size_t Back,
+ TargetLowering::DAGCombinerInfo &DCI,
+ CodeGenOptLevel OptLevel) {
if (all_of(N->ops().drop_front(Front).drop_back(Back),
[](const SDUse &U) { return U.get()->isUndef(); }))
// Operand 0 is the previous value in the chain. Cannot return EntryToken
// as the previous value will become unused and eliminated later.
return N->getOperand(0);
+ if (SDValue V = PerformPackedF32StoreCombine(N, DCI, OptLevel))
+ return V;
+
return SDValue();
}
-static SDValue PerformStoreParamCombine(SDNode *N) {
+static SDValue PerformStoreParamCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ CodeGenOptLevel OptLevel) {
// Operands from the 3rd to the 2nd last one are the values to be stored.
// {Chain, ArgID, Offset, Val, Glue}
- return PerformStoreCombineHelper(N, 3, 1);
+ return PerformStoreCombineHelper(N, 3, 1, DCI, OptLevel);
}
-static SDValue PerformStoreRetvalCombine(SDNode *N) {
+static SDValue PerformStoreRetvalCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ CodeGenOptLevel OptLevel) {
// Operands from the 2nd to the last one are the values to be stored
- return PerformStoreCombineHelper(N, 2, 0);
+ return PerformStoreCombineHelper(N, 2, 0, DCI, OptLevel);
}
/// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
@@ -5055,10 +5184,10 @@ static SDValue PerformEXTRACTCombine(SDNode *N,
IsPTXVectorType(VectorVT.getSimpleVT()))
return SDValue(); // Native vector loads already combine nicely w/
// extract_vector_elt.
- // Don't mess with singletons or v2*16, v4i8 and v8i8 types, we already
+ // Don't mess with singletons or v2*16, v4i8, v8i8, or v2f32 types, we already
// handle them OK.
if (VectorVT.getVectorNumElements() == 1 || Isv2x16VT(VectorVT) ||
- VectorVT == MVT::v4i8 || VectorVT == MVT::v8i8)
+ VectorVT == MVT::v4i8 || VectorVT == MVT::v8i8 || VectorVT == MVT::v2f32)
return SDValue();
// Don't mess with undef values as sra may be simplified to 0, not undef.
@@ -5188,6 +5317,78 @@ PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
return DAG.getNode(ISD::BITCAST, DL, VT, PRMT);
}
+static SDValue PerformORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+ CodeGenOptLevel OptLevel) {
+ if (OptLevel == CodeGenOptLevel::None)
+ return SDValue();
+
+ SDValue Op0 = N->getOperand(0);
+ SDValue Op1 = N->getOperand(1);
+
+ // i64 = or (i64 = zero_extend A, i64 = shl (i64 = any_extend B, 32))
+ // -> i64 = build_pair (A, B)
+ if (N->getValueType(0) == MVT::i64 && Op0.getOpcode() == ISD::ZERO_EXTEND &&
+ Op1.getOpcode() == ISD::SHL) {
+ SDValue SHLOp0 = Op1.getOperand(0);
+ SDValue SHLOp1 = Op1.getOperand(1);
+ if (const auto *Const = dyn_cast<ConstantSDNode>(SHLOp1);
+ Const && Const->getZExtValue() == 32 &&
+ SHLOp0.getOpcode() == ISD::ANY_EXTEND) {
+ SDLoc DL(N);
+ return DCI.DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i64,
+ {Op0.getOperand(0), SHLOp0.getOperand(0)});
+ }
+ }
+ return SDValue();
+}
+
+static SDValue PerformTRUNCATECombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ CodeGenOptLevel OptLevel) {
+ if (OptLevel == CodeGenOptLevel::None)
+ return SDValue();
+
+ SDValue Op = N->getOperand(0);
+ if (Op.getOpcode() == ISD::SRL) {
+ SDValue SrlOp = Op.getOperand(0);
+ SDValue SrlSh = Op.getOperand(1);
+ // i32 = truncate (i64 = srl (i64 build_pair (A, B), 32))
+ // -> i32 A
+ if (const auto *Const = dyn_cast<ConstantSDNode>(SrlSh);
+ Const && Const->getZExtValue() == 32) {
+ if (SrlOp.getOpcode() == ISD::BUILD_PAIR)
+ return SrlOp.getOperand(1);
+ }
+ }
+
+ return SDValue();
+}
+
+static SDValue PerformBUILD_PAIRCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ CodeGenOptLevel OptLevel) {
+ if (OptLevel == CodeGenOptLevel::None)
+ return SDValue();
+
+ EVT ToVT = N->getValueType(0);
+ SDValue Op0 = N->getOperand(0);
+ SDValue Op1 = N->getOperand(1);
+ // i64 = build_pair ({i32, i32} = CopyFromReg (CopyToReg (i64 X)))
+ // -> i64 X
+ if (ToVT == MVT::i64 && Op0.getOpcode() == ISD::CopyFromReg &&
+ Op1.getNode() == Op0.getNode() && Op0 != Op1) {
+ SDValue CFRChain = Op0.getOperand(0);
+ Register Reg = cast<RegisterSDNode>(Op0.getOperand(1))->getReg();
+ if (CFRChain.getOpcode() == ISD::CopyToReg &&
+ cast<RegisterSDNode>(CFRChain.getOperand(1))->getReg() == Reg) {
+ SDValue Value = CFRChain.getOperand(2);
+ return Value;
+ }
+ }
+
+ return SDValue();
+}
+
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5211,17 +5412,23 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
case NVPTXISD::StoreRetval:
case NVPTXISD::StoreRetvalV2:
case NVPTXISD::StoreRetvalV4:
- return PerformStoreRetvalCombine(N);
+ return PerformStoreRetvalCombine(N, DCI, OptLevel);
case NVPTXISD::StoreParam:
case NVPTXISD::StoreParamV2:
case NVPTXISD::StoreParamV4:
- return PerformStoreParamCombine(N);
+ return PerformStoreParamCombine(N, DCI, OptLevel);
case ISD::EXTRACT_VECTOR_ELT:
return PerformEXTRACTCombine(N, DCI);
case ISD::VSELECT:
return PerformVSELECTCombine(N, DCI);
case ISD::BUILD_VECTOR:
return PerformBUILD_VECTORCombine(N, DCI);
+ case ISD::OR:
+ return PerformORCombine(N, DCI, OptLevel);
+ case ISD::TRUNCATE:
+ return PerformTRUNCATECombine(N, DCI, OptLevel);
+ case ISD::BUILD_PAIR:
+ return PerformBUILD_PAIRCombine(N, DCI, OptLevel);
}
return SDValue();
}
@@ -5478,6 +5685,59 @@ static void ReplaceCopyFromReg_128(SDNode *N, SelectionDAG &DAG,
Results.push_back(NewValue.getValue(3));
}
+static void ReplaceF32x2Op(SDNode *N, SelectionDAG &DAG,
+ SmallVectorImpl<SDValue> &Results) {
+ SDLoc DL(N);
+ EVT OldResultTy = N->getValueType(0); // <2 x float>
+ assert(OldResultTy == MVT::v2f32 && "Unexpected result type for F32x2 op!");
+
+ SmallVector<SDValue> NewOps;
+
+ // whether we use FTZ (TODO)
+
+ // replace with NVPTX F32x2 op:
+ unsigned Opcode;
+ switch (N->getOpcode()) {
+ case ISD::FADD:
+ Opcode = NVPTXISD::FADD_F32X2;
+ break;
+ case ISD::FSUB:
+ Opcode = NVPTXISD::FSUB_F32X2;
+ break;
+ case ISD::FMUL:
+ Opcode = NVPTXISD::FMUL_F32X2;
+ break;
+ case ISD::FMA:
+ Opcode = NVPTXISD::FMA_F32X2;
+ break;
+ default:
+ llvm_unreachable("Unexpected opcode");
+ }
+
+ // bitcast operands: <2 x float> -> i64
+ for (const SDValue &Op : N->ops())
+ NewOps.push_back(DAG.getNode(ISD::BITCAST, DL, MVT::i64, Op));
+
+ SDValue Chain = DAG.getEntryNode();
+
+ // break packed result into two f32 registers for later instructions that may
+ // access element #0 or #1
+ SDValue NewValue = DAG.getNode(Opcode, DL, MVT::i64, NewOps);
+ MachineRegisterInfo &RegInfo = DAG.getMachineFunction().getRegInfo();
+ Register DestReg = RegInfo.createVirtualRegister(
+ DAG.getTargetLoweringInfo().getRegClassFor(MVT::i64));
+ SDValue RegCopy = DAG.getCopyToReg(Chain, DL, DestReg, NewValue);
+ SDValue Explode = DAG.getNode(ISD::CopyFromReg, DL,
+ {MVT::f32, MVT::f32, Chain.getValueType()},
+ {RegCopy, DAG.getRegister(DestReg, MVT::i64)});
+ // cast i64 result of new op back to <2 x float>
+ Results.push_back(DAG.getBitcast(
+ OldResultTy,
+ DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i64,
+ {DAG.getBitcast(MVT::i32, Explode.getValue(0)),
+ DAG.getBitcast(MVT::i32, Explode.getValue(1))})));
+}
+
void NVPTXTargetLowering::ReplaceNodeResults(
SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
switch (N->getOpcode()) {
@@ -5495,6 +5755,12 @@ void NVPTXTargetLowering::ReplaceNodeResults(
case ISD::CopyFromReg:
ReplaceCopyFromReg_128(N, DAG, Results);
return;
+ case ISD::FADD:
+ case ISD::FSUB:
+ case ISD::FMUL:
+ case ISD::FMA:
+ ReplaceF32x2Op(N, DAG, Results);
+ return;
}
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 5adf69d621552f..8fd4ded42a238a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -55,6 +55,10 @@ enum NodeType : unsigned {
FSHR_CLAMP,
MUL_WIDE_SIGNED,
MUL_WIDE_UNSIGNED,
+ FADD_F32X2,
+ FMUL_F32X2,
+ FSUB_F32X2,
+ FMA_F32X2,
SETP_F16X2,
SETP_BF16X2,
BFE,
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 7d9697e40e6aba..b0eb9bbbb2456a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -165,6 +165,7 @@ def hasHWROT32 : Predicate<"Subtarget->hasHWROT32()">;
def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">;
def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;
def hasTcgen05Instructions : Predicate<"Subtarget->hasTcgen05Instructions()">;
+def hasF32x2Instructions : Predicate<"Subtarget->hasF32x2Instructions()">;
def True : Predicate<"true">;
def False : Predicate<"false">;
@@ -2638,13 +2639,13 @@ class LastCallArgInstVT<NVPTXRegClass regclass, ValueType vt> :
NVPTXInst<(outs), (ins regclass:$a), "$a",
[(LastCallArg (i32 0), vt:$a)]>;
-def CallArgI64 : CallArgInst<Int64Regs>;
+def CallArgI64 : CallArgInstVT<Int64Regs, i64>;
def CallArgI32 : CallArgInstVT<Int32Regs, i32>;
def CallArgI16 : CallArgInstVT<Int16Regs, i16>;
def CallArgF64 : CallArgInst<Float64Regs>;
def CallArgF32 : CallArgInst<Float32Regs>;
-def LastCallArgI64 : LastCallArgInst<Int64Regs>;
+def LastCallArgI64 : LastCallArgInstVT<Int64Regs, i64>;
def LastCallArgI32 : LastCallArgInstVT<Int32Regs, i32>;
def LastCallArgI16 : LastCallArgInstVT<Int16Regs, i16>;
def LastCallArgF64 : LastCallArgInst<Float64Regs>;
@@ -3371,6 +3372,9 @@ let hasSideEffects = false in {
def V2F32toF64 : NVPTXInst<(outs Float64Regs:$d),
(ins Float32Regs:$s1, Float32Regs:$s2),
"mov.b64 \t$d, {{$s1, $s2}};", []>;
+ def V2F32toI64 : NVPTXInst<(outs Int64Regs:$d),
+ (ins Float32Regs:$s1, Float32Regs:$s2),
+ "mov.b64 \t$d, {{$s1, $s2}};", []>;
// unpack a larger int register to a set of smaller int registers
def I64toV4I16 : NVPTXInst<(outs Int16Regs:$d1, Int16Regs:$d2,
@@ -3383,6 +3387,9 @@ let hasSideEffects = false in {
def I64toV2I32 : NVPTXInst<(outs Int32Regs:$d1, Int32Regs:$d2),
(ins Int64Regs:$s),
"mov.b64 \t{{$d1, $d2}}, $...
[truncated]
Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than legalizing v2f32, we handle these four instructions ad hoc, so that codegen remains the same unless these instructions are present.
Supporting v2f32
(similar to how we support v2f16
for example) would be a cleaner and more extensible way to implement this change.
@AlexMaclean what led me to this implementation (and I did try it the other way) is that v2f16 and v2bf16 are supported by many kinds of instructions, so it makes more sense to legalize these types than v2f32. My concern is whether this feature should change code that uses f32 vectors but avoids these operations. Legalizing this type requires me to change some things in how we lower instructions, like loads and stores (for example, if we don't want See the test cases for more examples. If this is not a concern, then I can go with the other implementation. |
Okay, so I'm in the process of reworking this PR to legalize |
f46de86
to
a00289e
Compare
Hi, I've updated the patch to legalize The patch probably needs more coverage and I'd appreciate suggestions in that direction. |
774cfd0
to
520da79
Compare
520da79
to
07e7869
Compare
ad79e84
to
cd1ca40
Compare
Pinging reviewers: @AlexMaclean, @Artem-B, @arsenm, @justinfargnoli, @durga4github |
cd1ca40
to
c6ed6a8
Compare
Requires us to lower EXTRACT_VECTOR_ELT as well.
Also update the test cases.
Now that v2f32 is legal, this node will go straight to instruction selection. Instead, we want to break it up into two nodes, which can be handled better in instruction selection, since the final instruction (cvt.[b]f16x2.f32) takes two f32 arguments.
Fixes test/CodeGen/Generic/vector.ll
Ensures ld.b64 and st.b64 for v2f32. Also remove -O3 in f32x2-instructions.ll test.
Do this to reduce the amount of packing movs.
…ParamV2 To reduce the number of unpacking movs when the element type is i64 but all uses are of unpacked f32s.
Handle more loads, including ones with multiple proxy registers: - i64 = LOAD - i64 = LoadParam - v2f32,v2f32 = LoadParamV2 Also update the test cases. Because this is an optimization, it is not triggered for some of these tests that compile with no optimizations.
Support ld.global.nc.b64/ldu.global.b64 for v2f32 and ld.global.nc.b32/ldu.global.b32 for v2f16/v2bf16/v2i16/v4i8 Update test cases.
Fold i64->v2f32 bitcasts on the results of a NVPTXISD::Load* op.
Split unaligned stores and loads of v2f32. Add DAGCombiner rules for: - target-independent stores that store a v2f32 BUILD_VECTOR. We scalarize the value and rewrite the store Fix test cases.
for fp-contract: - test folding of fma.f32x2 - bump SM version to 100 for ldg-invariant: - test proper splitting of loads on vectors of f32
Change selp.u16 -> selp.b16 and
c6ed6a8
to
2b31795
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some comments but will need to review more thoroughly. This change is quite large and complex, would it be possible to break it up in any way? Such as adding some of the DAG combine rules in a separate change?
MVT VT = Vector.getSimpleValueType(); | ||
if (!Isv2x16VT(VT)) | ||
unsigned NewOpcode = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can the = 0
be removed here?
switch (StoreVT.getSimpleVT().SimpleTy) { | ||
case MVT::v2f16: | ||
case MVT::v2bf16: | ||
case MVT::v2i16: | ||
case MVT::v4i8: | ||
case MVT::v2f32: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm seeing something like this in lots of place (Isv2x16VT(VT) || VT == MVT::v4i8 || VT == MVT::v2f32
). Lets add a helper function in NVPTXUtilities.h and use that throughout.
@@ -2536,6 +2559,10 @@ foreach vt = [v2f16, v2bf16, v2i16, v4i8] in { | |||
def: Pat<(vt (ProxyReg vt:$src)), (ProxyRegI32 $src)>; | |||
} | |||
|
|||
def: Pat<(v2f32 (ProxyReg v2f32:$src)), (ProxyRegI64 $src)>; | |||
|
|||
def: Pat<(v2f32 (bitconvert i64:$src)), (ProxyRegI64 $src)>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just use (v2f32 Int64Regs:$src)
here?
@@ -2439,13 +2462,13 @@ class LastCallArgInstVT<NVPTXRegClass regclass, ValueType vt> : | |||
NVPTXInst<(outs), (ins regclass:$a), "$a", | |||
[(LastCallArg (i32 0), vt:$a)]>; | |||
|
|||
def CallArgI64 : CallArgInst<Int64Regs>; | |||
def CallArgI64 : CallArgInstVT<Int64Regs, i64>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we may as well just remove the non-VT variants of these instructions. It would also be nice to use the RegTyInfo
wrapper here.
(ins Int64Regs:$s), | ||
"{{ .reg .b32 tmp; mov.b64 {$low, tmp}, $s; }}", | ||
[]>; | ||
def I64toF32HS : NVPTXInst<(outs Float32Regs:$high), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets be consistent with the rest of these and use "_Sink
"
if (N->getOpcode() == ISD::LOAD || | ||
N->getOpcode() == ISD::INTRINSIC_W_CHAIN) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why special case these? Can we just use the untyped variant in all cases?
static SDValue PerformLoadCombine(SDNode *N, | ||
TargetLowering::DAGCombinerInfo &DCI) { | ||
auto *MemN = cast<MemSDNode>(N); | ||
// only operate on vectors of f32s / i64s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this?
return {{NewLD, LoadChain}}; | ||
} | ||
|
||
static SDValue PerformLoadCombine(SDNode *N, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a comment here explaining what this function is attempting to accomplish?
// v2[b]f16 = fp_round (v2f32 A) | ||
// -> v2[b]f16 = (build_vector ([b]f16 = fp_round (extractelt A, 0)), | ||
// ([b]f16 = fp_round (extractelt A, 1))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this required?
This adds support for lowering
fadd
,fsub
,fmul
, andfma
to sm_100+ packed-f32 instructions1 (e.g.add.rn.f32x2 Int64Reg, Int64Reg
).Rather than legalizingv2f32
, we handle these four instructions ad hoc, so that codegen remains the same unless these instructions are present. We also introduce some DAGCombiner rules to simplify bitwise packing/unpacking to usemov
, and to reduce redundantmov
s.In this PR I didn't implement support for alternative rounding modes, as that was lower priority. If there's sufficient demand, I can add that to this PR. Otherwise we can leave that for later.
Footnotes
Introduced in PTX 8.6: https://docs.nvidia.com/cuda/parallel-thread-execution/#changes-in-ptx-isa-version-8-6 ↩