Skip to content

[NVPTX] support packed f32 instructions for sm_100+ #126337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -459,8 +459,17 @@ bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
// We only care about 16x2 as it's the only real vector type we
// need to deal with.
MVT VT = Vector.getSimpleValueType();
if (!Isv2x16VT(VT))
if (!isPackedVectorTy(VT) || VT.getVectorNumElements() != 2)
return false;

unsigned Opcode;
if (VT.is32BitVector())
Opcode = NVPTX::I32toV2I16;
else if (VT.is64BitVector())
Opcode = NVPTX::I64toV2I32;
else
llvm_unreachable("Unhandled packed type");

// Find and record all uses of this vector that extract element 0 or 1.
SmallVector<SDNode *, 4> E0, E1;
for (auto *U : Vector.getNode()->users()) {
Expand All @@ -484,11 +493,11 @@ bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
if (E0.empty() || E1.empty())
return false;

// Merge (f16 extractelt(V, 0), f16 extractelt(V,1))
// into f16,f16 SplitF16x2(V)
// Merge (EltTy extractelt(V, 0), EltTy extractelt(V,1))
// into EltTy,EltTy Split[EltTy]x2(V)
MVT EltVT = VT.getVectorElementType();
SDNode *ScatterOp =
CurDAG->getMachineNode(NVPTX::I32toV2I16, SDLoc(N), EltVT, EltVT, Vector);
CurDAG->getMachineNode(Opcode, SDLoc(N), EltVT, EltVT, Vector);
for (auto *Node : E0)
ReplaceUses(SDValue(Node, 0), SDValue(ScatterOp, 0));
for (auto *Node : E1)
Expand Down Expand Up @@ -1004,6 +1013,7 @@ pickOpcodeForVT(MVT::SimpleValueType VT, std::optional<unsigned> Opcode_i8,
case MVT::i32:
case MVT::f32:
return Opcode_i32;
case MVT::v2f32:
case MVT::i64:
case MVT::f64:
return Opcode_i64;
Expand Down
237 changes: 148 additions & 89 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Large diffs are not rendered by default.

33 changes: 32 additions & 1 deletion llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def hasHWROT32 : Predicate<"Subtarget->hasHWROT32()">;
def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">;
def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;
def hasTcgen05Instructions : Predicate<"Subtarget->hasTcgen05Instructions()">;
def hasF32x2Instructions : Predicate<"Subtarget->hasF32x2Instructions()">;

def True : Predicate<"true">;

Expand Down Expand Up @@ -220,6 +221,7 @@ def BF16RT : RegTyInfo<bf16, B16, bf16imm, fpimm, supports_imm = 0>;

def F16X2RT : RegTyInfo<v2f16, B32, ?, ?, supports_imm = 0>;
def BF16X2RT : RegTyInfo<v2bf16, B32, ?, ?, supports_imm = 0>;
def F32X2RT : RegTyInfo<v2f32, B64, ?, ?, supports_imm = 0>;


// This class provides a basic wrapper around an NVPTXInst that abstracts the
Expand Down Expand Up @@ -451,6 +453,18 @@ multiclass F3<string op_str, SDPatternOperator op_pat> {
[(set f16:$dst, (op_pat f16:$a, f16:$b))]>,
Requires<[useFP16Math]>;

def f32x2rr_ftz :
BasicNVPTXInst<(outs B64:$dst),
(ins B64:$a, B64:$b),
op_str # ".ftz.f32x2",
[(set v2f32:$dst, (op_pat v2f32:$a, v2f32:$b))]>,
Requires<[hasF32x2Instructions, doF32FTZ]>;
def f32x2rr :
BasicNVPTXInst<(outs B64:$dst),
(ins B64:$a, B64:$b),
op_str # ".f32x2",
[(set v2f32:$dst, (op_pat v2f32:$a, v2f32:$b))]>,
Requires<[hasF32x2Instructions]>;
def f16x2rr_ftz :
BasicNVPTXInst<(outs B32:$dst),
(ins B32:$a, B32:$b),
Expand Down Expand Up @@ -829,6 +843,9 @@ def : Pat<(vt (select i1:$p, vt:$a, vt:$b)),
(SELP_b32rr $a, $b, $p)>;
}

def : Pat<(v2f32 (select i1:$p, v2f32:$a, v2f32:$b)),
(SELP_b64rr $a, $b, $p)>;

//-----------------------------------
// Test Instructions
//-----------------------------------
Expand Down Expand Up @@ -1345,6 +1362,8 @@ defm BFMA16 : FMA<"fma.rn.bf16", BF16RT, [hasBF16Math]>;
defm BFMA16x2 : FMA<"fma.rn.bf16x2", BF16X2RT, [hasBF16Math]>;
defm FMA32_ftz : FMA<"fma.rn.ftz.f32", F32RT, [doF32FTZ]>;
defm FMA32 : FMA<"fma.rn.f32", F32RT>;
defm FMA32x2_ftz : FMA<"fma.rn.ftz.f32x2", F32X2RT, [hasF32x2Instructions, doF32FTZ]>;
defm FMA32x2 : FMA<"fma.rn.f32x2", F32X2RT, [hasF32x2Instructions]>;
defm FMA64 : FMA<"fma.rn.f64", F64RT>;

// sin/cos
Expand Down Expand Up @@ -2585,6 +2604,7 @@ def : Pat<(i32 (trunc (sra i64:$s, (i32 32)))), (I64toI32H $s)>;
def: Pat<(i32 (sext (extractelt v2i16:$src, 0))),
(CVT_INREG_s32_s16 $src)>;

// Handle extracting one element from the pair (32-bit types)
foreach vt = [v2f16, v2bf16, v2i16] in {
def : Pat<(extractelt vt:$src, 0), (I32toI16L_Sink $src)>, Requires<[hasPTX<71>]>;
def : Pat<(extractelt vt:$src, 1), (I32toI16H_Sink $src)>, Requires<[hasPTX<71>]>;
Expand All @@ -2596,10 +2616,21 @@ foreach vt = [v2f16, v2bf16, v2i16] in {
(V2I16toI32 $a, $b)>;
}

// Same thing for the 64-bit type v2f32.
foreach vt = [v2f32] in {
def : Pat<(extractelt vt:$src, 0), (I64toI32L_Sink $src)>, Requires<[hasPTX<71>]>;
def : Pat<(extractelt vt:$src, 1), (I64toI32H_Sink $src)>, Requires<[hasPTX<71>]>;

def : Pat<(extractelt vt:$src, 0), (I64toI32L $src)>;
def : Pat<(extractelt vt:$src, 1), (I64toI32H $src)>;

def : Pat<(vt (build_vector vt.ElementType:$a, vt.ElementType:$b)),
(V2I32toI64 $a, $b)>;
}

def: Pat<(v2i16 (scalar_to_vector i16:$a)),
(CVT_u32_u16 $a, CvtNONE)>;


def nvptx_build_vector : SDNode<"NVPTXISD::BUILD_VECTOR", SDTypeProfile<1, 2, []>, []>;

def : Pat<(i64 (nvptx_build_vector i32:$a, i32:$b)),
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def B16 : NVPTXRegClass<[i16, f16, bf16], 16, (add (sequence "RS%u", 0, 4))>;
def B32 : NVPTXRegClass<[i32, v2f16, v2bf16, v2i16, v4i8, f32], 32,
(add (sequence "R%u", 0, 4),
VRFrame32, VRFrameLocal32)>;
def B64 : NVPTXRegClass<[i64, f64], 64, (add (sequence "RL%u", 0, 4), VRFrame64, VRFrameLocal64)>;
def B64 : NVPTXRegClass<[i64, v2f32, f64], 64, (add (sequence "RL%u", 0, 4),
VRFrame64, VRFrameLocal64)>;
// 128-bit regs are not defined as general regs in NVPTX. They are used for inlineASM only.
def B128 : NVPTXRegClass<[i128], 128, (add (sequence "RQ%u", 0, 4))>;

Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXSubtarget.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {

return HasTcgen05 && PTXVersion >= 86;
}
// f32x2 instructions in Blackwell family
bool hasF32x2Instructions() const {
return SmVersion >= 100 && PTXVersion >= 86;
}

// TMA G2S copy with cta_group::1/2 support
bool hasCpAsyncBulkTensorCTAGroupSupport() const {
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,9 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
Insert = false;
}
}
if (Insert && Isv2x16VT(VT)) {
// Can be built in a single mov
if (Insert && isPackedVectorTy(VT) && VT.is32BitVector()) {
// Can be built in a single 32-bit mov (64-bit regs are emulated in SASS
// with 2x 32-bit regs)
Cost += 1;
Insert = false;
}
Expand Down
10 changes: 8 additions & 2 deletions llvm/lib/Target/NVPTX/NVPTXUtilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,14 @@ inline unsigned promoteScalarArgumentSize(unsigned size) {

bool shouldEmitPTXNoReturn(const Value *V, const TargetMachine &TM);

inline bool Isv2x16VT(EVT VT) {
return (VT == MVT::v2f16 || VT == MVT::v2bf16 || VT == MVT::v2i16);
inline bool isPackedVectorTy(EVT VT) {
return (VT == MVT::v4i8 || VT == MVT::v2f16 || VT == MVT::v2bf16 ||
VT == MVT::v2i16 || VT == MVT::v2f32);
}

inline bool isPackedElementTy(EVT VT) {
return (VT == MVT::i8 || VT == MVT::f16 || VT == MVT::bf16 ||
VT == MVT::i16 || VT == MVT::f32);
}

inline bool shouldPassAsArray(Type *Ty) {
Expand Down
111 changes: 80 additions & 31 deletions llvm/test/CodeGen/NVPTX/aggregate-return.ll
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_35 | FileCheck %s
; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_35 | %ptxas-verify %}

Expand All @@ -7,57 +8,105 @@ declare [2 x float] @bara([2 x float] %input)
declare {float, float} @bars({float, float} %input)

define void @test_v2f32(<2 x float> %input, ptr %output) {
; CHECK-LABEL: @test_v2f32
; CHECK-LABEL: test_v2f32(
; CHECK: {
; CHECK-NEXT: .reg .b64 %rd<5>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b64 %rd1, [test_v2f32_param_0];
; CHECK-NEXT: { // callseq 0, 0
; CHECK-NEXT: .param .align 8 .b8 param0[8];
; CHECK-NEXT: st.param.b64 [param0], %rd1;
; CHECK-NEXT: .param .align 8 .b8 retval0[8];
; CHECK-NEXT: call.uni (retval0), barv, (param0);
; CHECK-NEXT: ld.param.b64 %rd2, [retval0];
; CHECK-NEXT: } // callseq 0
; CHECK-NEXT: ld.param.b64 %rd4, [test_v2f32_param_1];
; CHECK-NEXT: st.b64 [%rd4], %rd2;
; CHECK-NEXT: ret;
%call = tail call <2 x float> @barv(<2 x float> %input)
; CHECK: .param .align 8 .b8 retval0[8];
; CHECK: ld.param.v2.b32 {[[E0:%r[0-9]+]], [[E1:%r[0-9]+]]}, [retval0];
store <2 x float> %call, ptr %output, align 8
; CHECK: st.v2.b32 [{{%rd[0-9]+}}], {[[E0]], [[E1]]}
ret void
}

define void @test_v3f32(<3 x float> %input, ptr %output) {
; CHECK-LABEL: @test_v3f32
;
; CHECK-LABEL: test_v3f32(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<10>;
; CHECK-NEXT: .reg .b64 %rd<2>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.v2.b32 {%r1, %r2}, [test_v3f32_param_0];
; CHECK-NEXT: ld.param.b32 %r3, [test_v3f32_param_0+8];
; CHECK-NEXT: { // callseq 1, 0
; CHECK-NEXT: .param .align 16 .b8 param0[16];
; CHECK-NEXT: st.param.v2.b32 [param0], {%r1, %r2};
; CHECK-NEXT: st.param.b32 [param0+8], %r3;
; CHECK-NEXT: .param .align 16 .b8 retval0[16];
; CHECK-NEXT: call.uni (retval0), barv3, (param0);
; CHECK-NEXT: ld.param.v2.b32 {%r4, %r5}, [retval0];
; CHECK-NEXT: ld.param.b32 %r6, [retval0+8];
; CHECK-NEXT: } // callseq 1
; CHECK-NEXT: ld.param.b64 %rd1, [test_v3f32_param_1];
; CHECK-NEXT: st.b32 [%rd1+8], %r6;
; CHECK-NEXT: st.v2.b32 [%rd1], {%r4, %r5};
; CHECK-NEXT: ret;
%call = tail call <3 x float> @barv3(<3 x float> %input)
; CHECK: .param .align 16 .b8 retval0[16];
; CHECK-DAG: ld.param.v2.b32 {[[E0:%r[0-9]+]], [[E1:%r[0-9]+]]}, [retval0];
; CHECK-DAG: ld.param.b32 [[E2:%r[0-9]+]], [retval0+8];
; Make sure we don't load more values than than we need to.
; CHECK-NOT: ld.param.b32 [[E3:%r[0-9]+]], [retval0+12];
store <3 x float> %call, ptr %output, align 8
; CHECK-DAG: st.b32 [{{%rd[0-9]}}+8],
; -- This is suboptimal. We should do st.v2.f32 instead
; of combining 2xf32 info i64.
; CHECK-DAG: st.b64 [{{%rd[0-9]}}],
; CHECK: ret;
ret void
}

define void @test_a2f32([2 x float] %input, ptr %output) {
; CHECK-LABEL: @test_a2f32
; CHECK-LABEL: test_a2f32(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<7>;
; CHECK-NEXT: .reg .b64 %rd<2>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b32 %r1, [test_a2f32_param_0];
; CHECK-NEXT: ld.param.b32 %r2, [test_a2f32_param_0+4];
; CHECK-NEXT: { // callseq 2, 0
; CHECK-NEXT: .param .align 4 .b8 param0[8];
; CHECK-NEXT: st.param.b32 [param0], %r1;
; CHECK-NEXT: st.param.b32 [param0+4], %r2;
; CHECK-NEXT: .param .align 4 .b8 retval0[8];
; CHECK-NEXT: call.uni (retval0), bara, (param0);
; CHECK-NEXT: ld.param.b32 %r3, [retval0];
; CHECK-NEXT: ld.param.b32 %r4, [retval0+4];
; CHECK-NEXT: } // callseq 2
; CHECK-NEXT: ld.param.b64 %rd1, [test_a2f32_param_1];
; CHECK-NEXT: st.b32 [%rd1+4], %r4;
; CHECK-NEXT: st.b32 [%rd1], %r3;
; CHECK-NEXT: ret;
%call = tail call [2 x float] @bara([2 x float] %input)
; CHECK: .param .align 4 .b8 retval0[8];
; CHECK-DAG: ld.param.b32 [[ELEMA1:%r[0-9]+]], [retval0];
; CHECK-DAG: ld.param.b32 [[ELEMA2:%r[0-9]+]], [retval0+4];
store [2 x float] %call, ptr %output, align 4
; CHECK: }
; CHECK-DAG: st.b32 [{{%rd[0-9]+}}], [[ELEMA1]]
; CHECK-DAG: st.b32 [{{%rd[0-9]+}}+4], [[ELEMA2]]
ret void
; CHECK: ret
}

define void @test_s2f32({float, float} %input, ptr %output) {
; CHECK-LABEL: @test_s2f32
; CHECK-LABEL: test_s2f32(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<7>;
; CHECK-NEXT: .reg .b64 %rd<2>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b32 %r1, [test_s2f32_param_0];
; CHECK-NEXT: ld.param.b32 %r2, [test_s2f32_param_0+4];
; CHECK-NEXT: { // callseq 3, 0
; CHECK-NEXT: .param .align 4 .b8 param0[8];
; CHECK-NEXT: st.param.b32 [param0], %r1;
; CHECK-NEXT: st.param.b32 [param0+4], %r2;
; CHECK-NEXT: .param .align 4 .b8 retval0[8];
; CHECK-NEXT: call.uni (retval0), bars, (param0);
; CHECK-NEXT: ld.param.b32 %r3, [retval0];
; CHECK-NEXT: ld.param.b32 %r4, [retval0+4];
; CHECK-NEXT: } // callseq 3
; CHECK-NEXT: ld.param.b64 %rd1, [test_s2f32_param_1];
; CHECK-NEXT: st.b32 [%rd1+4], %r4;
; CHECK-NEXT: st.b32 [%rd1], %r3;
; CHECK-NEXT: ret;
%call = tail call {float, float} @bars({float, float} %input)
; CHECK: .param .align 4 .b8 retval0[8];
; CHECK-DAG: ld.param.b32 [[ELEMS1:%r[0-9]+]], [retval0];
; CHECK-DAG: ld.param.b32 [[ELEMS2:%r[0-9]+]], [retval0+4];
store {float, float} %call, ptr %output, align 4
; CHECK: }
; CHECK-DAG: st.b32 [{{%rd[0-9]+}}], [[ELEMS1]]
; CHECK-DAG: st.b32 [{{%rd[0-9]+}}+4], [[ELEMS2]]
ret void
; CHECK: ret
}
Loading
Loading