Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVPTX] Prefer ValueType when defining DAG patterns (NFC) #120161

Merged
merged 1 commit into from
Dec 17, 2024

Conversation

AlexMaclean
Copy link
Member

Replace uses of register class in dag patterns with value types. These types are much more concise and in cases where a single register class maps to multiple types, they avoid the need for both.

@llvmbot
Copy link
Member

llvmbot commented Dec 16, 2024

@llvm/pr-subscribers-backend-nvptx

Author: Alex MacLean (AlexMaclean)

Changes

Replace uses of register class in dag patterns with value types. These types are much more concise and in cases where a single register class maps to multiple types, they avoid the need for both.


Patch is 176.81 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/120161.diff

2 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+474-474)
  • (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+282-282)
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index a7836ccc45f476..abaf8e0b0ec1f8 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -213,33 +213,33 @@ multiclass I3<string OpcStr, SDNode OpNode> {
   def i64rr :
     NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$a, Int64Regs:$b),
               !strconcat(OpcStr, "64 \t$dst, $a, $b;"),
-              [(set Int64Regs:$dst, (OpNode Int64Regs:$a, Int64Regs:$b))]>;
+              [(set i64:$dst, (OpNode i64:$a, i64:$b))]>;
   def i64ri :
     NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$a, i64imm:$b),
               !strconcat(OpcStr, "64 \t$dst, $a, $b;"),
-              [(set Int64Regs:$dst, (OpNode Int64Regs:$a, imm:$b))]>;
+              [(set i64:$dst, (OpNode i64:$a, imm:$b))]>;
   def i32rr :
     NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a, Int32Regs:$b),
               !strconcat(OpcStr, "32 \t$dst, $a, $b;"),
-              [(set Int32Regs:$dst, (OpNode (i32 Int32Regs:$a), (i32 Int32Regs:$b)))]>;
+              [(set i32:$dst, (OpNode i32:$a, i32:$b))]>;
   def i32ri :
     NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a, i32imm:$b),
               !strconcat(OpcStr, "32 \t$dst, $a, $b;"),
-              [(set Int32Regs:$dst, (OpNode (i32 Int32Regs:$a), imm:$b))]>;
+              [(set i32:$dst, (OpNode i32:$a, imm:$b))]>;
   def i16rr :
     NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b),
               !strconcat(OpcStr, "16 \t$dst, $a, $b;"),
-              [(set Int16Regs:$dst, (OpNode Int16Regs:$a, Int16Regs:$b))]>;
+              [(set i16:$dst, (OpNode i16:$a, i16:$b))]>;
   def i16ri :
     NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, i16imm:$b),
               !strconcat(OpcStr, "16 \t$dst, $a, $b;"),
-              [(set Int16Regs:$dst, (OpNode Int16Regs:$a, (imm):$b))]>;
+              [(set i16:$dst, (OpNode i16:$a, (imm):$b))]>;
 }
 
 class I16x2<string OpcStr, SDNode OpNode> :
  NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a, Int32Regs:$b),
               !strconcat(OpcStr, "16x2 \t$dst, $a, $b;"),
-              [(set Int32Regs:$dst, (OpNode (v2i16 Int32Regs:$a), (v2i16 Int32Regs:$b)))]>,
+              [(set v2i16:$dst, (OpNode v2i16:$a, v2i16:$b))]>,
               Requires<[hasPTX<80>, hasSM<90>]>;
 
 // Template for instructions which take 3 int args.  The instructions are
@@ -249,20 +249,20 @@ multiclass ADD_SUB_INT_CARRY<string OpcStr, SDNode OpNode> {
     def i32rr :
       NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a, Int32Regs:$b),
                 !strconcat(OpcStr, ".s32 \t$dst, $a, $b;"),
-                [(set Int32Regs:$dst, (OpNode (i32 Int32Regs:$a), (i32 Int32Regs:$b)))]>;
+                [(set i32:$dst, (OpNode i32:$a, i32:$b))]>;
     def i32ri :
       NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a, i32imm:$b),
                 !strconcat(OpcStr, ".s32 \t$dst, $a, $b;"),
-                [(set Int32Regs:$dst, (OpNode (i32 Int32Regs:$a), imm:$b))]>;
+                [(set i32:$dst, (OpNode i32:$a, imm:$b))]>;
     def i64rr :
       NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$a, Int64Regs:$b),
                 !strconcat(OpcStr, ".s64 \t$dst, $a, $b;"),
-                [(set Int64Regs:$dst, (OpNode Int64Regs:$a, Int64Regs:$b))]>,
+                [(set i64:$dst, (OpNode i64:$a, i64:$b))]>,
       Requires<[hasPTX<43>]>;
     def i64ri :
       NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$a, i64imm:$b),
                 !strconcat(OpcStr, ".s64 \t$dst, $a, $b;"),
-                [(set Int64Regs:$dst, (OpNode Int64Regs:$a, imm:$b))]>,
+                [(set i64:$dst, (OpNode i64:$a, imm:$b))]>,
       Requires<[hasPTX<43>]>;
   }
 }
@@ -277,72 +277,72 @@ multiclass FMINIMUMMAXIMUM<string OpcStr, bit NaN, SDNode OpNode> {
      NVPTXInst<(outs Float64Regs:$dst),
                (ins Float64Regs:$a, Float64Regs:$b),
                !strconcat(OpcStr, ".f64 \t$dst, $a, $b;"),
-               [(set Float64Regs:$dst, (OpNode Float64Regs:$a, Float64Regs:$b))]>;
+               [(set f64:$dst, (OpNode f64:$a, f64:$b))]>;
    def f64ri :
      NVPTXInst<(outs Float64Regs:$dst),
                (ins Float64Regs:$a, f64imm:$b),
                !strconcat(OpcStr, ".f64 \t$dst, $a, $b;"),
-               [(set Float64Regs:$dst, (OpNode Float64Regs:$a, fpimm:$b))]>;
+               [(set f64:$dst, (OpNode f64:$a, fpimm:$b))]>;
   }
    def f32rr_ftz :
      NVPTXInst<(outs Float32Regs:$dst),
                (ins Float32Regs:$a, Float32Regs:$b),
                !strconcat(OpcStr, ".ftz.f32 \t$dst, $a, $b;"),
-               [(set Float32Regs:$dst, (OpNode Float32Regs:$a, Float32Regs:$b))]>,
+               [(set f32:$dst, (OpNode f32:$a, f32:$b))]>,
                Requires<[doF32FTZ]>;
    def f32ri_ftz :
      NVPTXInst<(outs Float32Regs:$dst),
                (ins Float32Regs:$a, f32imm:$b),
                !strconcat(OpcStr, ".ftz.f32 \t$dst, $a, $b;"),
-               [(set Float32Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>,
+               [(set f32:$dst, (OpNode f32:$a, fpimm:$b))]>,
                Requires<[doF32FTZ]>;
    def f32rr :
      NVPTXInst<(outs Float32Regs:$dst),
                (ins Float32Regs:$a, Float32Regs:$b),
                !strconcat(OpcStr, ".f32 \t$dst, $a, $b;"),
-               [(set Float32Regs:$dst, (OpNode Float32Regs:$a, Float32Regs:$b))]>;
+               [(set f32:$dst, (OpNode f32:$a, f32:$b))]>;
    def f32ri :
      NVPTXInst<(outs Float32Regs:$dst),
                (ins Float32Regs:$a, f32imm:$b),
                !strconcat(OpcStr, ".f32 \t$dst, $a, $b;"),
-               [(set Float32Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>;
+               [(set f32:$dst, (OpNode f32:$a, fpimm:$b))]>;
 
    def f16rr_ftz :
      NVPTXInst<(outs Int16Regs:$dst),
                (ins Int16Regs:$a, Int16Regs:$b),
                !strconcat(OpcStr, ".ftz.f16 \t$dst, $a, $b;"),
-               [(set Int16Regs:$dst, (OpNode (f16 Int16Regs:$a), (f16 Int16Regs:$b)))]>,
+               [(set f16:$dst, (OpNode f16:$a, f16:$b))]>,
                Requires<[useFP16Math, doF32FTZ]>;
    def f16rr :
      NVPTXInst<(outs Int16Regs:$dst),
                (ins Int16Regs:$a, Int16Regs:$b),
                !strconcat(OpcStr, ".f16 \t$dst, $a, $b;"),
-               [(set Int16Regs:$dst, (OpNode (f16 Int16Regs:$a), (f16 Int16Regs:$b)))]>,
+               [(set f16:$dst, (OpNode f16:$a, f16:$b))]>,
                Requires<[useFP16Math, hasSM<80>, hasPTX<70>]>;
 
    def f16x2rr_ftz :
      NVPTXInst<(outs Int32Regs:$dst),
                (ins Int32Regs:$a, Int32Regs:$b),
                !strconcat(OpcStr, ".ftz.f16x2 \t$dst, $a, $b;"),
-               [(set Int32Regs:$dst, (OpNode (v2f16 Int32Regs:$a), (v2f16 Int32Regs:$b)))]>,
+               [(set v2f16:$dst, (OpNode v2f16:$a, v2f16:$b))]>,
                Requires<[useFP16Math, hasSM<80>, hasPTX<70>, doF32FTZ]>;
    def f16x2rr :
      NVPTXInst<(outs Int32Regs:$dst),
                (ins Int32Regs:$a, Int32Regs:$b),
                !strconcat(OpcStr, ".f16x2 \t$dst, $a, $b;"),
-               [(set Int32Regs:$dst, (OpNode (v2f16 Int32Regs:$a), (v2f16 Int32Regs:$b)))]>,
+               [(set v2f16:$dst, (OpNode v2f16:$a, v2f16:$b))]>,
                Requires<[useFP16Math, hasSM<80>, hasPTX<70>]>;
    def bf16rr :
      NVPTXInst<(outs Int16Regs:$dst),
                (ins Int16Regs:$a, Int16Regs:$b),
                !strconcat(OpcStr, ".bf16 \t$dst, $a, $b;"),
-               [(set Int16Regs:$dst, (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b)))]>,
+               [(set bf16:$dst, (OpNode bf16:$a, bf16:$b))]>,
                Requires<[hasBF16Math, hasSM<80>, hasPTX<70>]>;
    def bf16x2rr :
      NVPTXInst<(outs Int32Regs:$dst),
                (ins Int32Regs:$a, Int32Regs:$b),
                !strconcat(OpcStr, ".bf16x2 \t$dst, $a, $b;"),
-               [(set Int32Regs:$dst, (OpNode (v2bf16 Int32Regs:$a), (v2bf16 Int32Regs:$b)))]>,
+               [(set v2bf16:$dst, (OpNode v2bf16:$a, v2bf16:$b))]>,
                Requires<[hasBF16Math, hasSM<80>, hasPTX<70>]>;
 }
 
@@ -360,161 +360,161 @@ multiclass F3_fma_component<string OpcStr, SDNode OpNode> {
      NVPTXInst<(outs Float64Regs:$dst),
                (ins Float64Regs:$a, Float64Regs:$b),
                !strconcat(OpcStr, ".f64 \t$dst, $a, $b;"),
-               [(set Float64Regs:$dst, (OpNode Float64Regs:$a, Float64Regs:$b))]>,
+               [(set f64:$dst, (OpNode f64:$a, f64:$b))]>,
                Requires<[allowFMA]>;
    def f64ri :
      NVPTXInst<(outs Float64Regs:$dst),
                (ins Float64Regs:$a, f64imm:$b),
                !strconcat(OpcStr, ".f64 \t$dst, $a, $b;"),
-               [(set Float64Regs:$dst, (OpNode Float64Regs:$a, fpimm:$b))]>,
+               [(set f64:$dst, (OpNode f64:$a, fpimm:$b))]>,
                Requires<[allowFMA]>;
    def f32rr_ftz :
      NVPTXInst<(outs Float32Regs:$dst),
                (ins Float32Regs:$a, Float32Regs:$b),
                !strconcat(OpcStr, ".ftz.f32 \t$dst, $a, $b;"),
-               [(set Float32Regs:$dst, (OpNode Float32Regs:$a, Float32Regs:$b))]>,
+               [(set f32:$dst, (OpNode f32:$a, f32:$b))]>,
                Requires<[allowFMA, doF32FTZ]>;
    def f32ri_ftz :
      NVPTXInst<(outs Float32Regs:$dst),
                (ins Float32Regs:$a, f32imm:$b),
                !strconcat(OpcStr, ".ftz.f32 \t$dst, $a, $b;"),
-               [(set Float32Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>,
+               [(set f32:$dst, (OpNode f32:$a, fpimm:$b))]>,
                Requires<[allowFMA, doF32FTZ]>;
    def f32rr :
      NVPTXInst<(outs Float32Regs:$dst),
                (ins Float32Regs:$a, Float32Regs:$b),
                !strconcat(OpcStr, ".f32 \t$dst, $a, $b;"),
-               [(set Float32Regs:$dst, (OpNode Float32Regs:$a, Float32Regs:$b))]>,
+               [(set f32:$dst, (OpNode f32:$a, f32:$b))]>,
                Requires<[allowFMA]>;
    def f32ri :
      NVPTXInst<(outs Float32Regs:$dst),
                (ins Float32Regs:$a, f32imm:$b),
                !strconcat(OpcStr, ".f32 \t$dst, $a, $b;"),
-               [(set Float32Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>,
+               [(set f32:$dst, (OpNode f32:$a, fpimm:$b))]>,
                Requires<[allowFMA]>;
 
    def f16rr_ftz :
      NVPTXInst<(outs Int16Regs:$dst),
                (ins Int16Regs:$a, Int16Regs:$b),
                !strconcat(OpcStr, ".ftz.f16 \t$dst, $a, $b;"),
-               [(set Int16Regs:$dst, (OpNode (f16 Int16Regs:$a), (f16 Int16Regs:$b)))]>,
+               [(set f16:$dst, (OpNode f16:$a, f16:$b))]>,
                Requires<[useFP16Math, allowFMA, doF32FTZ]>;
    def f16rr :
      NVPTXInst<(outs Int16Regs:$dst),
                (ins Int16Regs:$a, Int16Regs:$b),
                !strconcat(OpcStr, ".f16 \t$dst, $a, $b;"),
-               [(set Int16Regs:$dst, (OpNode (f16 Int16Regs:$a), (f16 Int16Regs:$b)))]>,
+               [(set f16:$dst, (OpNode f16:$a, f16:$b))]>,
                Requires<[useFP16Math, allowFMA]>;
 
    def f16x2rr_ftz :
      NVPTXInst<(outs Int32Regs:$dst),
                (ins Int32Regs:$a, Int32Regs:$b),
                !strconcat(OpcStr, ".ftz.f16x2 \t$dst, $a, $b;"),
-               [(set (v2f16 Int32Regs:$dst), (OpNode (v2f16 Int32Regs:$a), (v2f16 Int32Regs:$b)))]>,
+               [(set v2f16:$dst, (OpNode v2f16:$a, v2f16:$b))]>,
                Requires<[useFP16Math, allowFMA, doF32FTZ]>;
    def f16x2rr :
      NVPTXInst<(outs Int32Regs:$dst),
                (ins Int32Regs:$a, Int32Regs:$b),
                !strconcat(OpcStr, ".f16x2 \t$dst, $a, $b;"),
-               [(set Int32Regs:$dst, (OpNode (v2f16 Int32Regs:$a), (v2f16 Int32Regs:$b)))]>,
+               [(set v2f16:$dst, (OpNode v2f16:$a, v2f16:$b))]>,
                Requires<[useFP16Math, allowFMA]>;
    def bf16rr :
      NVPTXInst<(outs Int16Regs:$dst),
                (ins Int16Regs:$a, Int16Regs:$b),
                !strconcat(OpcStr, ".bf16 \t$dst, $a, $b;"),
-               [(set Int16Regs:$dst, (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b)))]>,
+               [(set bf16:$dst, (OpNode bf16:$a, bf16:$b))]>,
                Requires<[hasBF16Math, allowFMA]>;
 
    def bf16x2rr :
      NVPTXInst<(outs Int32Regs:$dst),
                (ins Int32Regs:$a, Int32Regs:$b),
                !strconcat(OpcStr, ".bf16x2 \t$dst, $a, $b;"),
-               [(set Int32Regs:$dst, (OpNode (v2bf16 Int32Regs:$a), (v2bf16 Int32Regs:$b)))]>,
+               [(set v2bf16:$dst, (OpNode v2bf16:$a, v2bf16:$b))]>,
                Requires<[hasBF16Math, allowFMA]>;
    // These have strange names so we don't perturb existing mir tests.
    def _rnf64rr :
      NVPTXInst<(outs Float64Regs:$dst),
                (ins Float64Regs:$a, Float64Regs:$b),
                !strconcat(OpcStr, ".rn.f64 \t$dst, $a, $b;"),
-               [(set Float64Regs:$dst, (OpNode Float64Regs:$a, Float64Regs:$b))]>,
+               [(set f64:$dst, (OpNode f64:$a, f64:$b))]>,
                Requires<[noFMA]>;
    def _rnf64ri :
      NVPTXInst<(outs Float64Regs:$dst),
                (ins Float64Regs:$a, f64imm:$b),
                !strconcat(OpcStr, ".rn.f64 \t$dst, $a, $b;"),
-               [(set Float64Regs:$dst, (OpNode Float64Regs:$a, fpimm:$b))]>,
+               [(set f64:$dst, (OpNode f64:$a, fpimm:$b))]>,
                Requires<[noFMA]>;
    def _rnf32rr_ftz :
      NVPTXInst<(outs Float32Regs:$dst),
                (ins Float32Regs:$a, Float32Regs:$b),
                !strconcat(OpcStr, ".rn.ftz.f32 \t$dst, $a, $b;"),
-               [(set Float32Regs:$dst, (OpNode Float32Regs:$a, Float32Regs:$b))]>,
+               [(set f32:$dst, (OpNode f32:$a, Float32Regs:$b))]>,
                Requires<[noFMA, doF32FTZ]>;
    def _rnf32ri_ftz :
      NVPTXInst<(outs Float32Regs:$dst),
                (ins Float32Regs:$a, f32imm:$b),
                !strconcat(OpcStr, ".rn.ftz.f32 \t$dst, $a, $b;"),
-               [(set Float32Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>,
+               [(set f32:$dst, (OpNode f32:$a, fpimm:$b))]>,
                Requires<[noFMA, doF32FTZ]>;
    def _rnf32rr :
      NVPTXInst<(outs Float32Regs:$dst),
                (ins Float32Regs:$a, Float32Regs:$b),
                !strconcat(OpcStr, ".rn.f32 \t$dst, $a, $b;"),
-               [(set Float32Regs:$dst, (OpNode Float32Regs:$a, Float32Regs:$b))]>,
+               [(set f32:$dst, (OpNode f32:$a, f32:$b))]>,
                Requires<[noFMA]>;
    def _rnf32ri :
      NVPTXInst<(outs Float32Regs:$dst),
                (ins Float32Regs:$a, f32imm:$b),
                !strconcat(OpcStr, ".rn.f32 \t$dst, $a, $b;"),
-               [(set Float32Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>,
+               [(set f32:$dst, (OpNode f32:$a, fpimm:$b))]>,
                Requires<[noFMA]>;
    def _rnf16rr_ftz :
      NVPTXInst<(outs Int16Regs:$dst),
                (ins Int16Regs:$a, Int16Regs:$b),
                !strconcat(OpcStr, ".rn.ftz.f16 \t$dst, $a, $b;"),
-               [(set Int16Regs:$dst, (OpNode (f16 Int16Regs:$a), (f16 Int16Regs:$b)))]>,
+               [(set f16:$dst, (OpNode f16:$a, f16:$b))]>,
                Requires<[useFP16Math, noFMA, doF32FTZ]>;
    def _rnf16rr :
      NVPTXInst<(outs Int16Regs:$dst),
                (ins Int16Regs:$a, Int16Regs:$b),
                !strconcat(OpcStr, ".rn.f16 \t$dst, $a, $b;"),
-               [(set Int16Regs:$dst, (OpNode (f16 Int16Regs:$a), (f16 Int16Regs:$b)))]>,
+               [(set f16:$dst, (OpNode f16:$a, f16:$b))]>,
                Requires<[useFP16Math, noFMA]>;
    def _rnf16x2rr_ftz :
      NVPTXInst<(outs Int32Regs:$dst),
                (ins Int32Regs:$a, Int32Regs:$b),
                !strconcat(OpcStr, ".rn.ftz.f16x2 \t$dst, $a, $b;"),
-               [(set Int32Regs:$dst, (OpNode (v2f16 Int32Regs:$a), (v2f16 Int32Regs:$b)))]>,
+               [(set v2f16:$dst, (OpNode v2f16:$a, v2f16:$b))]>,
                Requires<[useFP16Math, noFMA, doF32FTZ]>;
    def _rnf16x2rr :
      NVPTXInst<(outs Int32Regs:$dst),
                (ins Int32Regs:$a, Int32Regs:$b),
                !strconcat(OpcStr, ".rn.f16x2 \t$dst, $a, $b;"),
-               [(set Int32Regs:$dst, (OpNode (v2f16 Int32Regs:$a), (v2f16 Int32Regs:$b)))]>,
+               [(set v2f16:$dst, (OpNode v2f16:$a, v2f16:$b))]>,
                Requires<[useFP16Math, noFMA]>;
   def _rnbf16rr_ftz :
      NVPTXInst<(outs Int16Regs:$dst),
                (ins Int16Regs:$a, Int16Regs:$b),
                !strconcat(OpcStr, ".rn.ftz.bf16 \t$dst, $a, $b;"),
-               [(set Int16Regs:$dst, (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b)))]>,
+               [(set bf16:$dst, (OpNode bf16:$a, bf16:$b))]>,
                Requires<[hasBF16Math, noFMA, doF32FTZ]>;
    def _rnbf16rr :
      NVPTXInst<(outs Int16Regs:$dst),
                (ins Int16Regs:$a, Int16Regs:$b),
                !strconcat(OpcStr, ".rn.bf16 \t$dst, $a, $b;"),
-               [(set Int16Regs:$dst, (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b)))]>,
+               [(set bf16:$dst, (OpNode bf16:$a, bf16:$b))]>,
                Requires<[hasBF16Math, noFMA]>;
    def _rnbf16x2rr_ftz :
      NVPTXInst<(outs Int32Regs:$dst),
                (ins Int32Regs:$a, Int32Regs:$b),
                !strconcat(OpcStr, ".rn.ftz.bf16x2 \t$dst, $a, $b;"),
-               [(set Int32Regs:$dst, (OpNode (v2bf16 Int32Regs:$a), (v2bf16 Int32Regs:$b)))]>,
+               [(set v2bf16:$dst, (OpNode v2bf16:$a, v2bf16:$b))]>,
                Requires<[hasBF16Math, noFMA, doF32FTZ]>;
    def _rnbf16x2rr :
      NVPTXInst<(outs Int32Regs:$dst),
                (ins Int32Regs:$a, Int32Regs:$b),
                !strconcat(OpcStr, ".rn.bf16x2 \t$dst, $a, $b;"),
-               [(set Int32Regs:$dst, (OpNode (v2bf16 Int32Regs:$a), (v2bf16 Int32Regs:$b)))]>,
+               [(set v2bf16:$dst, (OpNode v2bf16:$a, v2bf16:$b))]>,
                Requires<[hasBF16Math, noFMA]>;
 }
 
@@ -524,40 +524,40 @@ multiclass F3_fma_component<string OpcStr, SDNode OpNode> {
 multiclass F2<string OpcStr, SDNode OpNode> {
    def f64 :     NVPTXInst<(outs Float64Regs:$dst), (ins Float64Regs:$a),
                            !strconcat(OpcStr, ".f64 \t$dst, $a;"),
-                           [(set Float64Regs:$dst, (OpNode Float64Regs:$a))]>;
+                           [(set f64:$dst, (OpNode f64:$a))]>;
    def f32_ftz : NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$a),
                            !strconcat(OpcStr, ".ftz.f32 \t$dst, $a;"),
-                           [(set Float32Regs:$dst, (OpNode Float32Regs:$a))]>,
+                           [(set f32:$dst, (OpNode f32:$a))]>,
                            Requires<[doF32FTZ]>;
    def f32 :     NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$a),
                            !strconcat(OpcStr, ".f32 \t$dst, $a;"),
-                           [(set Float32Regs:$dst, (OpNode Float32Regs:$a))]>;
+                           [(set f32:$dst, (OpNode f32:$a))]>;
 }
 
 multiclass F2_Support_Half<string OpcStr, SDNode OpNode> {
    def bf16 :      NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a),
                            !strconcat(OpcStr, ".bf16 \t$dst, $a;"),
-                           [(set Int16Regs:$dst, (OpNode (bf16 Int16Regs:$a)))]>,
+                           [(set bf16:$dst, (OpNode bf16:$a))]>,
                            Requires<[hasSM<80>, hasPTX<70>]>;
    def bf16x2 :    NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a),
                            !strconcat(OpcStr, ".bf16x2 \t$dst, $a;"),
-                           [(set Int32Regs:$dst, (OpNode (v2bf16 Int32Regs:$a)))]>,
+                           [(set v2bf16:$dst, (OpNode v2bf16:$a))]>,
                            Requires<[hasSM<80>, hasPTX<70>]>;
    def f16_ftz :   NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a),
                            !strconcat(OpcStr, ".ftz.f16 \t$dst, $a;"),
-                           [(set Int16Regs:$dst, (OpNode (f16 Int16Regs:$a)))]>,
+                           [(set f16:$dst, (OpNode f16:$a))]>,
                            Requires<[hasSM<53>, hasPTX<65>, doF32FTZ]>;
    def f16x2_ftz : NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a),
                            !strconcat(OpcStr, ".ftz.f16x2 \t$dst, $a;"),
-                           [(set Int32Regs:$dst, (OpNode (v2f16 Int32Regs:$a)))]>,
+                           [(set v2f16:$dst, (OpNode v2f16:$a))]>,
              ...
[truncated]

Copy link
Member

@Artem-B Artem-B left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM in principle.

I'm not 100% convinced that it's a NFC change. Will matching the type allow the pattern to match a constant value of the same type? If so, that would be a functional change. I think it would be mostly positive (less move-const-to-register for constants that may be used directly), but I vaguely recall that I did see ptxas not accepting constants/symbols as an argument for some instructions and required moving them into a register first.

I don't think we have robust enough test coverage for that. I'm OK with the patch, but keep an eye on it after it lands, in case it breaks something. It should be easy to undo the change for the affected instructions, if we find any.

Copy link
Contributor

@s-barannikov s-barannikov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense and looks neater, LGTM

@s-barannikov
Copy link
Contributor

I'm not 100% convinced that it's a NFC change. Will matching the type allow the pattern to match a constant value of the same type?

This shouldn't happen, but it can be easily verified by comparing the generated NVPTXGenDAGISel.inc files before / after the change.

Copy link
Contributor

@kalxr kalxr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@AlexMaclean AlexMaclean merged commit 9f231a8 into llvm:main Dec 17, 2024
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants