Skip to content

[X86][AMX] Add AMX FP8 new APIs #115829

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 12, 2024
Merged

[X86][AMX] Add AMX FP8 new APIs #115829

merged 2 commits into from
Nov 12, 2024

Conversation

fzou1
Copy link
Contributor

@fzou1 fzou1 commented Nov 12, 2024

@llvmbot llvmbot added clang Clang issues not falling into any other category backend:X86 clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics llvm:ir labels Nov 12, 2024
@llvmbot
Copy link
Member

llvmbot commented Nov 12, 2024

@llvm/pr-subscribers-llvm-ir
@llvm/pr-subscribers-backend-x86

@llvm/pr-subscribers-clang

Author: Feng Zou (fzou1)

Changes

This is a follow-up to #113850.

Ref.: https://cdrdv2.intel.com/v1/dl/getContent/671368


Full diff: https://github.com/llvm/llvm-project/pull/115829.diff

7 Files Affected:

  • (modified) clang/include/clang/Basic/BuiltinsX86_64.def (+4)
  • (modified) clang/lib/Headers/amxfp8intrin.h (+154-21)
  • (added) clang/test/CodeGen/X86/amx_fp8_api.c (+36)
  • (modified) llvm/include/llvm/IR/IntrinsicsX86.td (+25)
  • (modified) llvm/lib/Target/X86/X86ExpandPseudo.cpp (+17-1)
  • (modified) llvm/lib/Target/X86/X86InstrAMX.td (+31)
  • (modified) llvm/lib/Target/X86/X86RegisterInfo.cpp (+5-1)
diff --git a/clang/include/clang/Basic/BuiltinsX86_64.def b/clang/include/clang/Basic/BuiltinsX86_64.def
index 25c10d39df32e2..8653fc217bdddb 100644
--- a/clang/include/clang/Basic/BuiltinsX86_64.def
+++ b/clang/include/clang/Basic/BuiltinsX86_64.def
@@ -141,6 +141,10 @@ TARGET_BUILTIN(__builtin_ia32_tcvtrowps2phl_internal, "V32xUsUsV256iUi", "n", "a
 TARGET_BUILTIN(__builtin_ia32_tilemovrow_internal, "V16iUsUsV256iUi", "n", "amx-avx512,avx10.2-512")
 TARGET_BUILTIN(__builtin_ia32_tmmultf32ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-tf32")
 TARGET_BUILTIN(__builtin_ia32_ttmmultf32ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-tf32,amx-transpose")
+TARGET_BUILTIN(__builtin_ia32_tdpbf8ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-fp8")
+TARGET_BUILTIN(__builtin_ia32_tdpbhf8ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-fp8")
+TARGET_BUILTIN(__builtin_ia32_tdphbf8ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-fp8")
+TARGET_BUILTIN(__builtin_ia32_tdphf8ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-fp8")
 
 // AMX
 TARGET_BUILTIN(__builtin_ia32_tile_loadconfig, "vvC*", "n", "amx-tile")
diff --git a/clang/lib/Headers/amxfp8intrin.h b/clang/lib/Headers/amxfp8intrin.h
index 0f5ddc87e5a752..4ada936a5d40af 100644
--- a/clang/lib/Headers/amxfp8intrin.h
+++ b/clang/lib/Headers/amxfp8intrin.h
@@ -15,81 +15,214 @@
 #define __AMXFP8INTRIN_H
 #ifdef __x86_64__
 
-/// Peform the dot product of a BF8 value \a a by a BF8 value \a b accumulating
-/// into a Single Precision (FP32) source/dest \a dst.
+#define __DEFAULT_FN_ATTRS_FP8                                                 \
+  __attribute__((__always_inline__, __nodebug__, __target__("amx-fp8")))
+
+static __inline__ _tile1024i __DEFAULT_FN_ATTRS_FP8
+_tile_dpbf8ps_internal(unsigned short m, unsigned short n, unsigned short k,
+                       _tile1024i dst, _tile1024i src1, _tile1024i src2) {
+  return __builtin_ia32_tdpbf8ps_internal(m, n, k, dst, src1, src2);
+}
+
+/// Perform the dot product of a BF8 value \a src1 by a BF8 value \a src2
+/// accumulating into a Single Precision (FP32) source/dest \a dst.
 ///
 /// \headerfile <immintrin.h>
 ///
 /// \code
-/// void _tile_dpbf8ps (__tile dst, __tile a, __tile b)
+/// void __tile_dpbf8ps (__tile1024i *dst, __tile1024i src1, __tile1024i src2)
+/// \endcode
+///
+/// \code{.operation}
+/// FOR m := 0 TO dst.rows - 1
+///   temp1[(dst.colsb / 4 - 1) : 0] = 0
+///   FOR k := 0 TO src1.colsb / 4 - 1
+///     FOR n := 0 TO dst.colsb / 4 - 1
+///       temp1[n] +=
+///         INT64(src1.row[m].float8[4*k+0]) * INT64(src2.row[k].float8[4*n+0])
+///         + INT64(src1.row[m].float8[4*k+1]) * INT64(src2.row[k].float8[4*n+1])
+///         + INT64(src1.row[m].float8[4*k+2]) * INT64(src2.row[k].float8[4*n+2])
+///         + INT64(src1.row[m].float8[4*k+3]) * INT64(src2.row[k].float8[4*n+3])
+///     ENDFOR
+///   ENDFOR
+///   FOR n := 0 TO dst.colsb / 4 - 1
+///     tmp.row[m].fp32[n] = dst.row[m].fp32[n] + FP32(temp1[n])
+///   ENDFOR
+/// write_row_and_zero(dst, m, tmp, dst.colsb)
+/// zero_upper_rows(dst, dst.rows)
+/// zero_tileconfig_start()
 /// \endcode
 ///
 /// This intrinsic corresponds to the \c TDPBF8PS instruction.
 ///
 /// \param dst
 ///    The destination tile. Max size is 1024 Bytes.
-/// \param a
+/// \param src1
 ///    The 1st source tile. Max size is 1024 Bytes.
-/// \param b
+/// \param src2
 ///    The 2nd source tile. Max size is 1024 Bytes.
-#define _tile_dpbf8ps(dst, a, b) __builtin_ia32_tdpbf8ps((dst), (a), (b))
+__DEFAULT_FN_ATTRS_FP8 static void
+__tile_dpbf8ps(__tile1024i *dst, __tile1024i src1, __tile1024i src2) {
+  dst->tile = _tile_dpbf8ps_internal(src1.row, src2.col, src1.col, dst->tile,
+                                     src1.tile, src2.tile);
+}
+
+static __inline__ _tile1024i __DEFAULT_FN_ATTRS_FP8
+_tile_dpbhf8ps_internal(unsigned short m, unsigned short n, unsigned short k,
+                        _tile1024i dst, _tile1024i src1, _tile1024i src2) {
+  return __builtin_ia32_tdpbhf8ps_internal(m, n, k, dst, src1, src2);
+}
 
-/// Perform the dot product of a BF8 value \a a by an HF8 value \a b
+/// Perform the dot product of a BF8 value \a src1 by an HF8 value \a src2
 /// accumulating into a Single Precision (FP32) source/dest \a dst.
 ///
 /// \headerfile <immintrin.h>
 ///
 /// \code
-/// void _tile_dpbhf8ps (__tile dst, __tile a, __tile b)
+/// void __tile_dpbhf8ps (__tile1024i dst, __tile1024i src1, __tile1024i src2)
+/// \endcode
+///
+/// \code{.operation}
+/// FOR m := 0 TO dst.rows - 1
+///   temp1[(dst.colsb / 4 - 1) : 0] = 0
+///   FOR k := 0 TO src1.colsb / 4 - 1
+///     FOR n := 0 TO dst.colsb / 4 - 1
+///       temp1[n] +=
+///         INT64(src1.row[m].float8[4*k+0]) * INT64(src2.row[k].float8[4*n+0])
+///         + INT64(src1.row[m].float8[4*k+1]) * INT64(src2.row[k].float8[4*n+1])
+///         + INT64(src1.row[m].float8[4*k+2]) * INT64(src2.row[k].float8[4*n+2])
+///         + INT64(src1.row[m].float8[4*k+3]) * INT64(src2.row[k].float8[4*n+3])
+///     ENDFOR
+///   ENDFOR
+///   FOR n := 0 TO dst.colsb / 4 - 1
+///     tmp.row[m].fp32[n] = dst.row[m].fp32[n] + FP32(temp1[n])
+///   ENDFOR
+/// write_row_and_zero(dst, m, tmp, dst.colsb)
+/// zero_upper_rows(dst, dst.rows)
+/// zero_tileconfig_start()
 /// \endcode
 ///
 /// This intrinsic corresponds to the \c TDPBHF8PS instruction.
 ///
 /// \param dst
 ///    The destination tile. Max size is 1024 Bytes.
-/// \param a
+/// \param src1
 ///    The 1st source tile. Max size is 1024 Bytes.
-/// \param b
+/// \param src2
 ///    The 2nd source tile. Max size is 1024 Bytes.
-#define _tile_dpbhf8ps(dst, a, b) __builtin_ia32_tdpbhf8ps((dst), (a), (b))
+__DEFAULT_FN_ATTRS_FP8 static void
+__tile_dpbhf8ps(__tile1024i *dst, __tile1024i src1, __tile1024i src2) {
+  dst->tile = _tile_dpbhf8ps_internal(src1.row, src2.col, src1.col, dst->tile,
+                                      src1.tile, src2.tile);
+}
 
-/// Perform the dot product of an HF8 value \a a by a BF8 value \a b
+static __inline__ _tile1024i __DEFAULT_FN_ATTRS_FP8
+_tile_dphbf8ps_internal(unsigned short m, unsigned short n, unsigned short k,
+                        _tile1024i dst, _tile1024i src1, _tile1024i src2) {
+  return __builtin_ia32_tdphbf8ps_internal(m, n, k, dst, src1, src2);
+}
+
+/// Perform the dot product of an HF8 value \a src1 by a BF8 value \a src2
 /// accumulating into a Single Precision (FP32) source/dest \a dst.
 ///
 /// \headerfile <immintrin.h>
 ///
 /// \code
-/// void _tile_dphbf8ps (__tile dst, __tile a, __tile b)
+/// void __tile_dphbf8ps (__tile1024i dst, __tile1024i src1, __tile1024i src2)
+/// \endcode
+///
+/// \code{.operation}
+/// FOR m := 0 TO dst.rows - 1
+///   temp1[(dst.colsb / 4 - 1) : 0] = 0
+///   FOR k := 0 TO src1.colsb / 4 - 1
+///     FOR n := 0 TO dst.colsb / 4 - 1
+///       temp1[n] +=
+///         INT64(src1.row[m].float8[4*k+0]) * INT64(src2.row[k].float8[4*n+0])
+///         + INT64(src1.row[m].float8[4*k+1]) * INT64(src2.row[k].float8[4*n+1])
+///         + INT64(src1.row[m].float8[4*k+2]) * INT64(src2.row[k].float8[4*n+2])
+///         + INT64(src1.row[m].float8[4*k+3]) * INT64(src2.row[k].float8[4*n+3])
+///     ENDFOR
+///   ENDFOR
+///   FOR n := 0 TO dst.colsb / 4 - 1
+///     tmp.row[m].fp32[n] = dst.row[m].fp32[n] + FP32(temp1[n])
+///   ENDFOR
+/// write_row_and_zero(dst, m, tmp, dst.colsb)
+/// zero_upper_rows(dst, dst.rows)
+/// zero_tileconfig_start()
 /// \endcode
 ///
 /// This intrinsic corresponds to the \c TDPHBF8PS instruction.
 ///
 /// \param dst
 ///    The destination tile. Max size is 1024 Bytes.
-/// \param a
+/// \param src1
 ///    The 1st source tile. Max size is 1024 Bytes.
-/// \param b
+/// \param src2
 ///    The 2nd source tile. Max size is 1024 Bytes.
-#define _tile_dphbf8ps(dst, a, b) __builtin_ia32_tdphbf8ps((dst), (a), (b))
 
-/// Perform the dot product of an HF8 value \a a by an HF8 value \a b
+__DEFAULT_FN_ATTRS_FP8 static void
+__tile_dphbf8ps(__tile1024i *dst, __tile1024i src1, __tile1024i src2) {
+  dst->tile = _tile_dphbf8ps_internal(src1.row, src2.col, src1.col, dst->tile,
+                                      src1.tile, src2.tile);
+}
+
+static __inline__ _tile1024i __DEFAULT_FN_ATTRS_FP8
+_tile_dphf8ps_internal(unsigned short m, unsigned short n, unsigned short k,
+                       _tile1024i dst, _tile1024i src1, _tile1024i src2) {
+  return __builtin_ia32_tdphf8ps_internal(m, n, k, dst, src1, src2);
+}
+
+/// Perform the dot product of an HF8 value \a src1 by an HF8 value \a src2
 /// accumulating into a Single Precision (FP32) source/dest \a dst.
 ///
 /// \headerfile <immintrin.h>
 ///
 /// \code
-/// void _tile_dphf8ps (__tile dst, __tile a, __tile b)
+/// void __tile_dphf8ps (__tile1024i dst, __tile1024i src1, __tile1024i src2)
+/// \endcode
+///
+/// \code{.operation}
+/// FOR m := 0 TO dst.rows - 1
+///   temp1[(dst.colsb / 4 - 1) : 0] = 0
+///   FOR k := 0 TO src1.colsb / 4 - 1
+///     FOR n := 0 TO dst.colsb / 4 - 1
+///       temp1[n] +=
+///         INT64(src1.row[m].float8[4*k+0]) * INT64(src2.row[k].float8[4*n+0])
+///         + INT64(src1.row[m].float8[4*k+1]) * INT64(src2.row[k].float8[4*n+1])
+///         + INT64(src1.row[m].float8[4*k+2]) * INT64(src2.row[k].float8[4*n+2])
+///         + INT64(src1.row[m].float8[4*k+3]) * INT64(src2.row[k].float8[4*n+3])
+///     ENDFOR
+///   ENDFOR
+///   FOR n := 0 TO dst.colsb / 4 - 1
+///     tmp.row[m].fp32[n] = dst.row[m].fp32[n] + FP32(temp1[n])
+///   ENDFOR
+/// write_row_and_zero(dst, m, tmp, dst.colsb)
+/// zero_upper_rows(dst, dst.rows)
+/// zero_tileconfig_start()
 /// \endcode
 ///
 /// This intrinsic corresponds to the \c TDPHF8PS instruction.
 ///
 /// \param dst
 ///    The destination tile. Max size is 1024 Bytes.
-/// \param a
+/// \param src1
 ///    The 1st source tile. Max size is 1024 Bytes.
-/// \param b
+/// \param src2
 ///    The 2nd source tile. Max size is 1024 Bytes.
-#define _tile_dphf8ps(dst, a, b) __builtin_ia32_tdphf8ps((dst), (a), (b))
+__DEFAULT_FN_ATTRS_FP8 static void
+__tile_dphf8ps(__tile1024i *dst, __tile1024i src1, __tile1024i src2) {
+  dst->tile = _tile_dphf8ps_internal(src1.row, src2.col, src1.col, dst->tile,
+                                     src1.tile, src2.tile);
+}
+
+#define _tile_dpbf8ps(dst, src1, src2)                                         \
+  __builtin_ia32_tdpbf8ps((dst), (src1), (src2))
+#define _tile_dpbhf8ps(dst, src1, src2)                                        \
+  __builtin_ia32_tdpbhf8ps((dst), (src1), (src2))
+#define _tile_dphbf8ps(dst, src1, src2)                                        \
+  __builtin_ia32_tdphbf8ps((dst), (src1), (src2))
+#define _tile_dphf8ps(dst, src1, src2)                                         \
+  __builtin_ia32_tdphf8ps((dst), (src1), (src2))
 
 #endif /* __x86_64__ */
 #endif /* __AMXFP8INTRIN_H */
diff --git a/clang/test/CodeGen/X86/amx_fp8_api.c b/clang/test/CodeGen/X86/amx_fp8_api.c
new file mode 100644
index 00000000000000..2a3af1b7f5cd9a
--- /dev/null
+++ b/clang/test/CodeGen/X86/amx_fp8_api.c
@@ -0,0 +1,36 @@
+// RUN: %clang_cc1 %s -ffreestanding -triple=x86_64-unknown-unknown  -target-feature +amx-fp8  \
+// RUN: -emit-llvm -o - -Werror -pedantic | FileCheck %s
+#include <immintrin.h>
+
+void test_tdpbf8ps(__tile1024i src1, __tile1024i src2, __tile1024i dst) {
+  //CHECK-LABEL: @test_tdpbf8ps
+  //CHECK-DAG: call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> {{%.*}})
+  //CHECK-DAG: call x86_amx @llvm.x86.tdpbf8ps.internal
+  //CHECK-DAG: call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx {{%.*}})
+  __tile_dpbf8ps(&dst, src1, src2);
+}
+
+void test_tdpbhf8ps(__tile1024i src1, __tile1024i src2, __tile1024i dst) {
+  //CHECK-LABEL: @test_tdpbhf8ps
+  //CHECK-DAG: call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> {{%.*}})
+  //CHECK-DAG: call x86_amx @llvm.x86.tdpbhf8ps.internal
+  //CHECK-DAG: call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx {{%.*}})
+  __tile_dpbhf8ps(&dst, src1, src2);
+}
+
+void test_tdphbf8ps(__tile1024i src1, __tile1024i src2, __tile1024i dst) {
+  //CHECK-LABEL: @test_tdphbf8ps
+  //CHECK-DAG: call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> {{%.*}})
+  //CHECK-DAG: call x86_amx @llvm.x86.tdphbf8ps.internal
+  //CHECK-DAG: call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx {{%.*}})
+  __tile_dphbf8ps(&dst, src1, src2);
+}
+
+void test_tdphf8ps(__tile1024i src1, __tile1024i src2, __tile1024i dst) {
+  //CHECK-LABEL: @test_tdphf8ps
+  //CHECK-DAG: call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> {{%.*}})
+  //CHECK-DAG: call x86_amx @llvm.x86.tdphf8ps.internal
+  //CHECK-DAG: call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx {{%.*}})
+  __tile_dphf8ps(&dst, src1, src2);
+}
+
diff --git a/llvm/include/llvm/IR/IntrinsicsX86.td b/llvm/include/llvm/IR/IntrinsicsX86.td
index b2d6f44b7927a9..5211b82cc8d31b 100644
--- a/llvm/include/llvm/IR/IntrinsicsX86.td
+++ b/llvm/include/llvm/IR/IntrinsicsX86.td
@@ -6120,6 +6120,31 @@ let TargetPrefix = "x86" in {
               Intrinsic<[llvm_x86amx_ty],
                         [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty, llvm_x86amx_ty,
                          llvm_x86amx_ty, llvm_x86amx_ty], []>;
+
+  def int_x86_tdpbf8ps_internal :
+                ClangBuiltin<"__builtin_ia32_tdpbf8ps_internal">,
+                Intrinsic<[llvm_x86amx_ty],
+                          [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
+                           llvm_x86amx_ty, llvm_x86amx_ty, llvm_x86amx_ty],
+                          []>;
+  def int_x86_tdpbhf8ps_internal :
+                ClangBuiltin<"__builtin_ia32_tdpbhf8ps_internal">,
+                Intrinsic<[llvm_x86amx_ty],
+                          [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
+                           llvm_x86amx_ty, llvm_x86amx_ty, llvm_x86amx_ty],
+                          []>;
+  def int_x86_tdphbf8ps_internal :
+                ClangBuiltin<"__builtin_ia32_tdphbf8ps_internal">,
+                Intrinsic<[llvm_x86amx_ty],
+                          [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
+                           llvm_x86amx_ty, llvm_x86amx_ty, llvm_x86amx_ty],
+                          []>;
+  def int_x86_tdphf8ps_internal :
+                ClangBuiltin<"__builtin_ia32_tdphf8ps_internal">,
+                Intrinsic<[llvm_x86amx_ty],
+                          [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
+                           llvm_x86amx_ty, llvm_x86amx_ty, llvm_x86amx_ty],
+                          []>;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/X86/X86ExpandPseudo.cpp b/llvm/lib/Target/X86/X86ExpandPseudo.cpp
index 4f045d78f75fb2..b673a3766a6832 100644
--- a/llvm/lib/Target/X86/X86ExpandPseudo.cpp
+++ b/llvm/lib/Target/X86/X86ExpandPseudo.cpp
@@ -757,7 +757,11 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB,
   case X86::PTDPBF16PSV:
   case X86::PTDPFP16PSV:
   case X86::PTMMULTF32PSV:
-  case X86::PTTMMULTF32PSV: {
+  case X86::PTTMMULTF32PSV:
+  case X86::PTDPBF8PSV:
+  case X86::PTDPBHF8PSV:
+  case X86::PTDPHBF8PSV:
+  case X86::PTDPHF8PSV: {
     MI.untieRegOperand(4);
     for (unsigned i = 3; i > 0; --i)
       MI.removeOperand(i);
@@ -777,6 +781,18 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB,
     case X86::PTTMMULTF32PSV:
       Opc = X86::TTMMULTF32PS;
       break;
+    case X86::PTDPBF8PSV:
+      Opc = X86::TDPBF8PS;
+      break;
+    case X86::PTDPBHF8PSV:
+      Opc = X86::TDPBHF8PS;
+      break;
+    case X86::PTDPHBF8PSV:
+      Opc = X86::TDPHBF8PS;
+      break;
+    case X86::PTDPHF8PSV:
+      Opc = X86::TDPHF8PS;
+      break;
 
     default:
       llvm_unreachable("Unexpected Opcode");
diff --git a/llvm/lib/Target/X86/X86InstrAMX.td b/llvm/lib/Target/X86/X86InstrAMX.td
index 04527716e31627..da0077a990242b 100644
--- a/llvm/lib/Target/X86/X86InstrAMX.td
+++ b/llvm/lib/Target/X86/X86InstrAMX.td
@@ -304,6 +304,37 @@ let Predicates = [HasAMXFP8, In64BitMode] in {
                               [(int_x86_tdphf8ps timm:$src1, timm:$src2,
                                 timm:$src3)]>;
     }
+
+    let Constraints = "$src4 = $dst" in {
+      def PTDPBF8PSV : PseudoI<(outs TILE:$dst),
+                               (ins GR16:$src1, GR16:$src2, GR16:$src3,
+                                    TILE:$src4, TILE:$src5, TILE:$src6),
+                               [(set TILE:$dst,
+                                (int_x86_tdpbf8ps_internal GR16:$src1,
+                                 GR16:$src2, GR16:$src3, TILE:$src4,
+                                 TILE:$src5, TILE:$src6))]>;
+      def PTDPBHF8PSV : PseudoI<(outs TILE:$dst),
+                               (ins GR16:$src1, GR16:$src2, GR16:$src3,
+                                    TILE:$src4, TILE:$src5, TILE:$src6),
+                               [(set TILE:$dst,
+                                (int_x86_tdpbhf8ps_internal GR16:$src1,
+                                 GR16:$src2, GR16:$src3, TILE:$src4,
+                                 TILE:$src5, TILE:$src6))]>;
+      def PTDPHBF8PSV : PseudoI<(outs TILE:$dst),
+                               (ins GR16:$src1, GR16:$src2, GR16:$src3,
+                                    TILE:$src4, TILE:$src5, TILE:$src6),
+                               [(set TILE:$dst,
+                                (int_x86_tdphbf8ps_internal GR16:$src1,
+                                 GR16:$src2, GR16:$src3, TILE:$src4,
+                                 TILE:$src5, TILE:$src6))]>;
+      def PTDPHF8PSV : PseudoI<(outs TILE:$dst),
+                               (ins GR16:$src1, GR16:$src2, GR16:$src3,
+                                    TILE:$src4, TILE:$src5, TILE:$src6),
+                               [(set TILE:$dst,
+                                (int_x86_tdphf8ps_internal GR16:$src1,
+                                 GR16:$src2, GR16:$src3, TILE:$src4,
+                                 TILE:$src5, TILE:$src6))]>;
+    }
   }
 }
 
diff --git a/llvm/lib/Target/X86/X86RegisterInfo.cpp b/llvm/lib/Target/X86/X86RegisterInfo.cpp
index 09418c9bb74d34..25eb90cdd6d60c 100644
--- a/llvm/lib/Target/X86/X86RegisterInfo.cpp
+++ b/llvm/lib/Target/X86/X86RegisterInfo.cpp
@@ -1078,7 +1078,11 @@ static ShapeT getTileShape(Register VirtReg, VirtRegMap *VRM,
   case X86::PTCMMRLFP16PSV:
   case X86::PTTRANSPOSEDV:
   case X86::PTMMULTF32PSV:
-  case X86::PTTMMULTF32PSV: {
+  case X86::PTTMMULTF32PSV:
+  case X86::PTDPBF8PSV:
+  case X86::PTDPBHF8PSV:
+  case X86::PTDPHBF8PSV:
+  case X86::PTDPHF8PSV: {
     MachineOperand &MO1 = MI->getOperand(1);
     MachineOperand &MO2 = MI->getOperand(2);
     ShapeT Shape(&MO1, &MO2, MRI);

Copy link

github-actions bot commented Nov 12, 2024

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff f109517d153609d4a8a3a3d3d3cc06da1b629364 7301f5628cc495a18c729e75dda90ca523444b9b --extensions c,h,cpp -- clang/test/CodeGen/X86/amx_fp8_api.c clang/lib/Headers/amxfp8intrin.h llvm/lib/Target/X86/X86ExpandPseudo.cpp llvm/lib/Target/X86/X86RegisterInfo.cpp
View the diff from clang-format here.
diff --git a/clang/lib/Headers/amxfp8intrin.h b/clang/lib/Headers/amxfp8intrin.h
index 92e7989974..41d2d3749e 100644
--- a/clang/lib/Headers/amxfp8intrin.h
+++ b/clang/lib/Headers/amxfp8intrin.h
@@ -40,9 +40,12 @@ _tile_dpbf8ps_internal(unsigned short m, unsigned short n, unsigned short k,
 ///     FOR n := 0 TO dst.colsb / 4 - 1
 ///       temp1[n] +=
 ///         INT64(src1.row[m].float8[4*k+0]) * INT64(src2.row[k].float8[4*n+0])
-///         + INT64(src1.row[m].float8[4*k+1]) * INT64(src2.row[k].float8[4*n+1])
-///         + INT64(src1.row[m].float8[4*k+2]) * INT64(src2.row[k].float8[4*n+2])
-///         + INT64(src1.row[m].float8[4*k+3]) * INT64(src2.row[k].float8[4*n+3])
+///         + INT64(src1.row[m].float8[4*k+1]) *
+///         INT64(src2.row[k].float8[4*n+1])
+///         + INT64(src1.row[m].float8[4*k+2]) *
+///         INT64(src2.row[k].float8[4*n+2])
+///         + INT64(src1.row[m].float8[4*k+3]) *
+///         INT64(src2.row[k].float8[4*n+3])
 ///     ENDFOR
 ///   ENDFOR
 ///   FOR n := 0 TO dst.colsb / 4 - 1
@@ -89,9 +92,12 @@ _tile_dpbhf8ps_internal(unsigned short m, unsigned short n, unsigned short k,
 ///     FOR n := 0 TO dst.colsb / 4 - 1
 ///       temp1[n] +=
 ///         INT64(src1.row[m].float8[4*k+0]) * INT64(src2.row[k].float8[4*n+0])
-///         + INT64(src1.row[m].float8[4*k+1]) * INT64(src2.row[k].float8[4*n+1])
-///         + INT64(src1.row[m].float8[4*k+2]) * INT64(src2.row[k].float8[4*n+2])
-///         + INT64(src1.row[m].float8[4*k+3]) * INT64(src2.row[k].float8[4*n+3])
+///         + INT64(src1.row[m].float8[4*k+1]) *
+///         INT64(src2.row[k].float8[4*n+1])
+///         + INT64(src1.row[m].float8[4*k+2]) *
+///         INT64(src2.row[k].float8[4*n+2])
+///         + INT64(src1.row[m].float8[4*k+3]) *
+///         INT64(src2.row[k].float8[4*n+3])
 ///     ENDFOR
 ///   ENDFOR
 ///   FOR n := 0 TO dst.colsb / 4 - 1
@@ -138,9 +144,12 @@ _tile_dphbf8ps_internal(unsigned short m, unsigned short n, unsigned short k,
 ///     FOR n := 0 TO dst.colsb / 4 - 1
 ///       temp1[n] +=
 ///         INT64(src1.row[m].float8[4*k+0]) * INT64(src2.row[k].float8[4*n+0])
-///         + INT64(src1.row[m].float8[4*k+1]) * INT64(src2.row[k].float8[4*n+1])
-///         + INT64(src1.row[m].float8[4*k+2]) * INT64(src2.row[k].float8[4*n+2])
-///         + INT64(src1.row[m].float8[4*k+3]) * INT64(src2.row[k].float8[4*n+3])
+///         + INT64(src1.row[m].float8[4*k+1]) *
+///         INT64(src2.row[k].float8[4*n+1])
+///         + INT64(src1.row[m].float8[4*k+2]) *
+///         INT64(src2.row[k].float8[4*n+2])
+///         + INT64(src1.row[m].float8[4*k+3]) *
+///         INT64(src2.row[k].float8[4*n+3])
 ///     ENDFOR
 ///   ENDFOR
 ///   FOR n := 0 TO dst.colsb / 4 - 1
@@ -188,9 +197,12 @@ _tile_dphf8ps_internal(unsigned short m, unsigned short n, unsigned short k,
 ///     FOR n := 0 TO dst.colsb / 4 - 1
 ///       temp1[n] +=
 ///         INT64(src1.row[m].float8[4*k+0]) * INT64(src2.row[k].float8[4*n+0])
-///         + INT64(src1.row[m].float8[4*k+1]) * INT64(src2.row[k].float8[4*n+1])
-///         + INT64(src1.row[m].float8[4*k+2]) * INT64(src2.row[k].float8[4*n+2])
-///         + INT64(src1.row[m].float8[4*k+3]) * INT64(src2.row[k].float8[4*n+3])
+///         + INT64(src1.row[m].float8[4*k+1]) *
+///         INT64(src2.row[k].float8[4*n+1])
+///         + INT64(src1.row[m].float8[4*k+2]) *
+///         INT64(src2.row[k].float8[4*n+2])
+///         + INT64(src1.row[m].float8[4*k+3]) *
+///         INT64(src2.row[k].float8[4*n+3])
 ///     ENDFOR
 ///   ENDFOR
 ///   FOR n := 0 TO dst.colsb / 4 - 1

#define _tile_dphbf8ps(dst, src1, src2) \
__builtin_ia32_tdphbf8ps((dst), (src1), (src2))
#define _tile_dphf8ps(dst, src1, src2) \
__builtin_ia32_tdphf8ps((dst), (src1), (src2))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#undef __DEFAULT_FN_ATTRS_FP8

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@phoebewang
Copy link
Contributor

Missing IR test?

@fzou1
Copy link
Contributor Author

fzou1 commented Nov 12, 2024

Missing IR test?

Sorry. Added. Thanks.

Copy link
Contributor

@phoebewang phoebewang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@phoebewang phoebewang merged commit 1b63f47 into llvm:main Nov 12, 2024
7 of 8 checks passed
@fzou1 fzou1 deleted the amx-fp8-new-api branch November 12, 2024 15:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:X86 clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics clang Clang issues not falling into any other category llvm:ir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants