Skip to content

[NVPTX] Misc table-gen cleanup (NFC) #142877

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 82 additions & 114 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;
def hasTcgen05Instructions : Predicate<"Subtarget->hasTcgen05Instructions()">;

def True : Predicate<"true">;
def False : Predicate<"false">;

class hasPTX<int version>: Predicate<"Subtarget->getPTXVersion() >= " # version>;
class hasSM<int version>: Predicate<"Subtarget->getSmVersion() >= " # version>;
Expand Down Expand Up @@ -257,6 +256,11 @@ def BF16X2RT : RegTyInfo<v2bf16, Int32Regs, ?, ?, supports_imm = 0>;
// "prmt.b32${mode}">;
// ---> "prmt.b32${mode} \t$d, $a, $b, $c;"
//
// * BasicFlagsNVPTXInst<(outs Int64Regs:$state),
// (ins ADDR:$addr),
// "mbarrier.arrive.b64">;
// ---> "mbarrier.arrive.b64 \t$state, [$addr];"
//
class BasicFlagsNVPTXInst<dag outs_dag, dag ins_dag, dag flags_dag, string asmstr,
list<dag> pattern = []>
: NVPTXInst<
Expand All @@ -274,7 +278,11 @@ class BasicFlagsNVPTXInst<dag outs_dag, dag ins_dag, dag flags_dag, string asmst
!if(!or(!empty(ins_dag), !empty(outs_dag)), "", ", "),
!interleave(
!foreach(i, !range(!size(ins_dag)),
"$" # !getdagname(ins_dag, i)),
!if(!eq(!cast<string>(!getdagarg<DAGOperand>(ins_dag, i)), "ADDR"),
"[$" # !getdagname(ins_dag, i) # "]",
"$" # !getdagname(ins_dag, i)
)
),
", "))),
";"),
pattern>;
Expand Down Expand Up @@ -956,31 +964,17 @@ def mul_wide_signed : SDNode<"NVPTXISD::MUL_WIDE_SIGNED", SDTMulWide>;
def mul_wide_unsigned : SDNode<"NVPTXISD::MUL_WIDE_UNSIGNED", SDTMulWide>;

// Matchers for signed, unsigned mul.wide ISD nodes.
def : Pat<(i32 (mul_wide_signed i16:$a, i16:$b)),
(MULWIDES32 $a, $b)>,
Requires<[doMulWide]>;
def : Pat<(i32 (mul_wide_signed i16:$a, imm:$b)),
(MULWIDES32Imm $a, imm:$b)>,
Requires<[doMulWide]>;
def : Pat<(i32 (mul_wide_unsigned i16:$a, i16:$b)),
(MULWIDEU32 $a, $b)>,
Requires<[doMulWide]>;
def : Pat<(i32 (mul_wide_unsigned i16:$a, imm:$b)),
(MULWIDEU32Imm $a, imm:$b)>,
Requires<[doMulWide]>;
let Predicates = [doMulWide] in {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice.

def : Pat<(i32 (mul_wide_signed i16:$a, i16:$b)), (MULWIDES32 $a, $b)>;
def : Pat<(i32 (mul_wide_signed i16:$a, imm:$b)), (MULWIDES32Imm $a, imm:$b)>;
def : Pat<(i32 (mul_wide_unsigned i16:$a, i16:$b)), (MULWIDEU32 $a, $b)>;
def : Pat<(i32 (mul_wide_unsigned i16:$a, imm:$b)), (MULWIDEU32Imm $a, imm:$b)>;

def : Pat<(i64 (mul_wide_signed i32:$a, i32:$b)),
(MULWIDES64 $a, $b)>,
Requires<[doMulWide]>;
def : Pat<(i64 (mul_wide_signed i32:$a, imm:$b)),
(MULWIDES64Imm $a, imm:$b)>,
Requires<[doMulWide]>;
def : Pat<(i64 (mul_wide_unsigned i32:$a, i32:$b)),
(MULWIDEU64 $a, $b)>,
Requires<[doMulWide]>;
def : Pat<(i64 (mul_wide_unsigned i32:$a, imm:$b)),
(MULWIDEU64Imm $a, imm:$b)>,
Requires<[doMulWide]>;
def : Pat<(i64 (mul_wide_signed i32:$a, i32:$b)), (MULWIDES64 $a, $b)>;
def : Pat<(i64 (mul_wide_signed i32:$a, imm:$b)), (MULWIDES64Imm $a, imm:$b)>;
def : Pat<(i64 (mul_wide_unsigned i32:$a, i32:$b)), (MULWIDEU64 $a, $b)>;
def : Pat<(i64 (mul_wide_unsigned i32:$a, imm:$b)), (MULWIDEU64Imm $a, imm:$b)>;
}

// Predicates used for converting some patterns to mul.wide.
def SInt32Const : PatLeaf<(imm), [{
Expand Down Expand Up @@ -1106,18 +1100,12 @@ defm MAD32 : MAD<"mad.lo.s32", i32, Int32Regs, i32imm>;
defm MAD64 : MAD<"mad.lo.s64", i64, Int64Regs, i64imm>;
}

def INEG16 :
BasicNVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$src),
"neg.s16",
[(set i16:$dst, (ineg i16:$src))]>;
def INEG32 :
BasicNVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src),
"neg.s32",
[(set i32:$dst, (ineg i32:$src))]>;
def INEG64 :
BasicNVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src),
"neg.s64",
[(set i64:$dst, (ineg i64:$src))]>;
foreach t = [I16RT, I32RT, I64RT] in {
def NEG_S # t.Size :
BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src),
"neg.s" # t.Size,
[(set t.Ty:$dst, (ineg t.Ty:$src))]>;
}

//-----------------------------------
// Floating Point Arithmetic
Expand Down Expand Up @@ -1538,7 +1526,7 @@ def bfi : SDNode<"NVPTXISD::BFI", SDTBFI>;

def SDTPRMT :
SDTypeProfile<1, 4, [SDTCisVT<0, i32>, SDTCisVT<1, i32>,
SDTCisVT<2, i32>, SDTCisVT<3, i32>, SDTCisVT<4, i32>,]>;
SDTCisVT<2, i32>, SDTCisVT<3, i32>, SDTCisVT<4, i32>]>;
def prmt : SDNode<"NVPTXISD::PRMT", SDTPRMT>;

multiclass BFE<string Instr, ValueType T, RegisterClass RC> {
Expand Down Expand Up @@ -1961,15 +1949,15 @@ multiclass FSET_FORMAT<PatFrag OpNode, PatLeaf Mode, PatLeaf ModeFTZ> {
// f16 -> pred
def : Pat<(i1 (OpNode f16:$a, f16:$b)),
(SETP_f16rr $a, $b, ModeFTZ)>,
Requires<[useFP16Math,doF32FTZ]>;
Requires<[useFP16Math, doF32FTZ]>;
def : Pat<(i1 (OpNode f16:$a, f16:$b)),
(SETP_f16rr $a, $b, Mode)>,
Requires<[useFP16Math]>;

// bf16 -> pred
def : Pat<(i1 (OpNode bf16:$a, bf16:$b)),
(SETP_bf16rr $a, $b, ModeFTZ)>,
Requires<[hasBF16Math,doF32FTZ]>;
Requires<[hasBF16Math, doF32FTZ]>;
def : Pat<(i1 (OpNode bf16:$a, bf16:$b)),
(SETP_bf16rr $a, $b, Mode)>,
Requires<[hasBF16Math]>;
Expand Down Expand Up @@ -2497,24 +2485,20 @@ def : Pat<(f16 (uint_to_fp i32:$a)), (CVT_f16_u32 $a, CvtRN)>;
def : Pat<(f16 (uint_to_fp i64:$a)), (CVT_f16_u64 $a, CvtRN)>;

// sint -> bf16
def : Pat<(bf16 (sint_to_fp i1:$a)), (CVT_bf16_s32 (SELP_b32ii 1, 0, $a), CvtRN)>,
Requires<[hasPTX<78>, hasSM<90>]>;
def : Pat<(bf16 (sint_to_fp i16:$a)), (CVT_bf16_s16 $a, CvtRN)>,
Requires<[hasPTX<78>, hasSM<90>]>;
def : Pat<(bf16 (sint_to_fp i32:$a)), (CVT_bf16_s32 $a, CvtRN)>,
Requires<[hasPTX<78>, hasSM<90>]>;
def : Pat<(bf16 (sint_to_fp i64:$a)), (CVT_bf16_s64 $a, CvtRN)>,
Requires<[hasPTX<78>, hasSM<90>]>;
let Predicates = [hasPTX<78>, hasSM<90>] in {
def : Pat<(bf16 (sint_to_fp i1:$a)), (CVT_bf16_s32 (SELP_b32ii 1, 0, $a), CvtRN)>;
def : Pat<(bf16 (sint_to_fp i16:$a)), (CVT_bf16_s16 $a, CvtRN)>;
def : Pat<(bf16 (sint_to_fp i32:$a)), (CVT_bf16_s32 $a, CvtRN)>;
def : Pat<(bf16 (sint_to_fp i64:$a)), (CVT_bf16_s64 $a, CvtRN)>;
}

// uint -> bf16
def : Pat<(bf16 (uint_to_fp i1:$a)), (CVT_bf16_u32 (SELP_b32ii 1, 0, $a), CvtRN)>,
Requires<[hasPTX<78>, hasSM<90>]>;
def : Pat<(bf16 (uint_to_fp i16:$a)), (CVT_bf16_u16 $a, CvtRN)>,
Requires<[hasPTX<78>, hasSM<90>]>;
def : Pat<(bf16 (uint_to_fp i32:$a)), (CVT_bf16_u32 $a, CvtRN)>,
Requires<[hasPTX<78>, hasSM<90>]>;
def : Pat<(bf16 (uint_to_fp i64:$a)), (CVT_bf16_u64 $a, CvtRN)>,
Requires<[hasPTX<78>, hasSM<90>]>;
let Predicates = [hasPTX<78>, hasSM<90>] in {
def : Pat<(bf16 (uint_to_fp i1:$a)), (CVT_bf16_u32 (SELP_b32ii 1, 0, $a), CvtRN)>;
def : Pat<(bf16 (uint_to_fp i16:$a)), (CVT_bf16_u16 $a, CvtRN)>;
def : Pat<(bf16 (uint_to_fp i32:$a)), (CVT_bf16_u32 $a, CvtRN)>;
def : Pat<(bf16 (uint_to_fp i64:$a)), (CVT_bf16_u64 $a, CvtRN)>;
}

// sint -> f32
def : Pat<(f32 (sint_to_fp i1:$a)), (CVT_f32_s32 (SELP_b32ii -1, 0, $a), CvtRN)>;
Expand Down Expand Up @@ -2565,27 +2549,25 @@ def : Pat<(i16 (fp_to_uint bf16:$a)), (CVT_u16_bf16 $a, CvtRZI)>;
def : Pat<(i32 (fp_to_uint bf16:$a)), (CVT_u32_bf16 $a, CvtRZI)>;
def : Pat<(i64 (fp_to_uint bf16:$a)), (CVT_u64_bf16 $a, CvtRZI)>;
// f32 -> sint
def : Pat<(i1 (fp_to_sint f32:$a)), (SETP_b32ri $a, 0, CmpEQ)>;
def : Pat<(i16 (fp_to_sint f32:$a)), (CVT_s16_f32 $a, CvtRZI_FTZ)>,
Requires<[doF32FTZ]>;
let Predicates = [doF32FTZ] in {
def : Pat<(i16 (fp_to_sint f32:$a)), (CVT_s16_f32 $a, CvtRZI_FTZ)>;
def : Pat<(i32 (fp_to_sint f32:$a)), (CVT_s32_f32 $a, CvtRZI_FTZ)>;
def : Pat<(i64 (fp_to_sint f32:$a)), (CVT_s64_f32 $a, CvtRZI_FTZ)>;
}
def : Pat<(i1 (fp_to_sint f32:$a)), (SETP_b32ri $a, 0, CmpEQ)>;
def : Pat<(i16 (fp_to_sint f32:$a)), (CVT_s16_f32 $a, CvtRZI)>;
def : Pat<(i32 (fp_to_sint f32:$a)), (CVT_s32_f32 $a, CvtRZI_FTZ)>,
Requires<[doF32FTZ]>;
def : Pat<(i32 (fp_to_sint f32:$a)), (CVT_s32_f32 $a, CvtRZI)>;
def : Pat<(i64 (fp_to_sint f32:$a)), (CVT_s64_f32 $a, CvtRZI_FTZ)>,
Requires<[doF32FTZ]>;
def : Pat<(i64 (fp_to_sint f32:$a)), (CVT_s64_f32 $a, CvtRZI)>;

// f32 -> uint
let Predicates = [doF32FTZ] in {
def : Pat<(i16 (fp_to_uint f32:$a)), (CVT_u16_f32 $a, CvtRZI_FTZ)>;
def : Pat<(i32 (fp_to_uint f32:$a)), (CVT_u32_f32 $a, CvtRZI_FTZ)>;
def : Pat<(i64 (fp_to_uint f32:$a)), (CVT_u64_f32 $a, CvtRZI_FTZ)>;
}
def : Pat<(i1 (fp_to_uint f32:$a)), (SETP_b32ri $a, 0, CmpEQ)>;
def : Pat<(i16 (fp_to_uint f32:$a)), (CVT_u16_f32 $a, CvtRZI_FTZ)>,
Requires<[doF32FTZ]>;
def : Pat<(i16 (fp_to_uint f32:$a)), (CVT_u16_f32 $a, CvtRZI)>;
def : Pat<(i32 (fp_to_uint f32:$a)), (CVT_u32_f32 $a, CvtRZI_FTZ)>,
Requires<[doF32FTZ]>;
def : Pat<(i32 (fp_to_uint f32:$a)), (CVT_u32_f32 $a, CvtRZI)>;
def : Pat<(i64 (fp_to_uint f32:$a)), (CVT_u64_f32 $a, CvtRZI_FTZ)>,
Requires<[doF32FTZ]>;
def : Pat<(i64 (fp_to_uint f32:$a)), (CVT_u64_f32 $a, CvtRZI)>;

// f64 -> sint
Expand Down Expand Up @@ -2707,28 +2689,24 @@ let hasSideEffects = false in {

// PTX 7.1 lets you avoid a temp register and just use _ as a "sink" for the
// unused high/low part.
def I32toI16H_Sink : NVPTXInst<(outs Int16Regs:$high),
(ins Int32Regs:$s),
"mov.b32 \t{{_, $high}}, $s;",
[]>, Requires<[hasPTX<71>]>;
def I32toI16L_Sink : NVPTXInst<(outs Int16Regs:$low),
(ins Int32Regs:$s),
"mov.b32 \t{{$low, _}}, $s;",
[]>, Requires<[hasPTX<71>]>;
def I64toI32H_Sink : NVPTXInst<(outs Int32Regs:$high),
(ins Int64Regs:$s),
"mov.b64 \t{{_, $high}}, $s;",
[]>, Requires<[hasPTX<71>]>;
def I64toI32L_Sink : NVPTXInst<(outs Int32Regs:$low),
(ins Int64Regs:$s),
"mov.b64 \t{{$low, _}}, $s;",
[]>, Requires<[hasPTX<71>]>;
let Predicates = [hasPTX<71>] in {
def I32toI16H_Sink : NVPTXInst<(outs Int16Regs:$high), (ins Int32Regs:$s),
"mov.b32 \t{{_, $high}}, $s;", []>;
def I32toI16L_Sink : NVPTXInst<(outs Int16Regs:$low), (ins Int32Regs:$s),
"mov.b32 \t{{$low, _}}, $s;", []>;
def I64toI32H_Sink : NVPTXInst<(outs Int32Regs:$high), (ins Int64Regs:$s),
"mov.b64 \t{{_, $high}}, $s;", []>;
def I64toI32L_Sink : NVPTXInst<(outs Int32Regs:$low), (ins Int64Regs:$s),
"mov.b64 \t{{$low, _}}, $s;", []>;
}
}

def : Pat<(i16 (trunc (srl i32:$s, (i32 16)))), (I32toI16H_Sink i32:$s)>, Requires<[hasPTX<71>]>;
def : Pat<(i16 (trunc (sra i32:$s, (i32 16)))), (I32toI16H_Sink i32:$s)>, Requires<[hasPTX<71>]>;
def : Pat<(i32 (trunc (srl i64:$s, (i32 32)))), (I64toI32H_Sink i64:$s)>, Requires<[hasPTX<71>]>;
def : Pat<(i32 (trunc (sra i64:$s, (i32 32)))), (I64toI32H_Sink i64:$s)>, Requires<[hasPTX<71>]>;
let Predicates = [hasPTX<71>] in {
def : Pat<(i16 (trunc (srl i32:$s, (i32 16)))), (I32toI16H_Sink i32:$s)>;
def : Pat<(i16 (trunc (sra i32:$s, (i32 16)))), (I32toI16H_Sink i32:$s)>;
def : Pat<(i32 (trunc (srl i64:$s, (i32 32)))), (I64toI32H_Sink i64:$s)>;
def : Pat<(i32 (trunc (sra i64:$s, (i32 32)))), (I64toI32H_Sink i64:$s)>;
}

// Fall back to the old way if we don't have PTX 7.1.
def : Pat<(i16 (trunc (srl i32:$s, (i32 16)))), (I32toI16H $s)>;
Expand Down Expand Up @@ -3061,29 +3039,19 @@ def stacksave :
SDNode<"NVPTXISD::STACKSAVE", SDTIntLeaf,
[SDNPHasChain, SDNPSideEffect]>;

def STACKRESTORE_32 :
BasicNVPTXInst<(outs), (ins Int32Regs:$ptr),
"stackrestore.u32",
[(stackrestore i32:$ptr)]>,
Requires<[hasPTX<73>, hasSM<52>]>;

def STACKSAVE_32 :
BasicNVPTXInst<(outs Int32Regs:$dst), (ins),
"stacksave.u32",
[(set i32:$dst, (i32 stacksave))]>,
Requires<[hasPTX<73>, hasSM<52>]>;

def STACKRESTORE_64 :
BasicNVPTXInst<(outs), (ins Int64Regs:$ptr),
"stackrestore.u64",
[(stackrestore i64:$ptr)]>,
Requires<[hasPTX<73>, hasSM<52>]>;

def STACKSAVE_64 :
BasicNVPTXInst<(outs Int64Regs:$dst), (ins),
"stacksave.u64",
[(set i64:$dst, (i64 stacksave))]>,
Requires<[hasPTX<73>, hasSM<52>]>;
let Predicates = [hasPTX<73>, hasSM<52>] in {
foreach t = [I32RT, I64RT] in {
def STACKRESTORE_ # t.Size :
BasicNVPTXInst<(outs), (ins t.RC:$ptr),
"stackrestore.u" # t.Size,
[(stackrestore t.Ty:$ptr)]>;

def STACKSAVE_ # t.Size :
BasicNVPTXInst<(outs t.RC:$dst), (ins),
"stacksave.u" # t.Size,
[(set t.Ty:$dst, (t.Ty stacksave))]>;
}
}

include "NVPTXIntrinsics.td"

Expand Down Expand Up @@ -3124,7 +3092,7 @@ def : Pat <
////////////////////////////////////////////////////////////////////////////////

class NVPTXFenceInst<string scope, string sem, Predicate ptx>:
NVPTXInst<(outs), (ins), "fence."#sem#"."#scope#";", []>,
BasicNVPTXInst<(outs), (ins), "fence."#sem#"."#scope>,
Requires<[ptx, hasSM<70>]>;

foreach scope = ["sys", "gpu", "cluster", "cta"] in {
Expand Down
Loading
Loading