-
Notifications
You must be signed in to change notification settings - Fork 14.1k
[X86][AMX] Add AMX FP8 new APIs #115829
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[X86][AMX] Add AMX FP8 new APIs #115829
Conversation
This is a follow-up to llvm#113850. Ref.: https://cdrdv2.intel.com/v1/dl/getContent/671368
@llvm/pr-subscribers-llvm-ir @llvm/pr-subscribers-clang Author: Feng Zou (fzou1) ChangesThis is a follow-up to #113850. Ref.: https://cdrdv2.intel.com/v1/dl/getContent/671368 Full diff: https://github.com/llvm/llvm-project/pull/115829.diff 7 Files Affected:
diff --git a/clang/include/clang/Basic/BuiltinsX86_64.def b/clang/include/clang/Basic/BuiltinsX86_64.def
index 25c10d39df32e2..8653fc217bdddb 100644
--- a/clang/include/clang/Basic/BuiltinsX86_64.def
+++ b/clang/include/clang/Basic/BuiltinsX86_64.def
@@ -141,6 +141,10 @@ TARGET_BUILTIN(__builtin_ia32_tcvtrowps2phl_internal, "V32xUsUsV256iUi", "n", "a
TARGET_BUILTIN(__builtin_ia32_tilemovrow_internal, "V16iUsUsV256iUi", "n", "amx-avx512,avx10.2-512")
TARGET_BUILTIN(__builtin_ia32_tmmultf32ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-tf32")
TARGET_BUILTIN(__builtin_ia32_ttmmultf32ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-tf32,amx-transpose")
+TARGET_BUILTIN(__builtin_ia32_tdpbf8ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-fp8")
+TARGET_BUILTIN(__builtin_ia32_tdpbhf8ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-fp8")
+TARGET_BUILTIN(__builtin_ia32_tdphbf8ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-fp8")
+TARGET_BUILTIN(__builtin_ia32_tdphf8ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-fp8")
// AMX
TARGET_BUILTIN(__builtin_ia32_tile_loadconfig, "vvC*", "n", "amx-tile")
diff --git a/clang/lib/Headers/amxfp8intrin.h b/clang/lib/Headers/amxfp8intrin.h
index 0f5ddc87e5a752..4ada936a5d40af 100644
--- a/clang/lib/Headers/amxfp8intrin.h
+++ b/clang/lib/Headers/amxfp8intrin.h
@@ -15,81 +15,214 @@
#define __AMXFP8INTRIN_H
#ifdef __x86_64__
-/// Peform the dot product of a BF8 value \a a by a BF8 value \a b accumulating
-/// into a Single Precision (FP32) source/dest \a dst.
+#define __DEFAULT_FN_ATTRS_FP8 \
+ __attribute__((__always_inline__, __nodebug__, __target__("amx-fp8")))
+
+static __inline__ _tile1024i __DEFAULT_FN_ATTRS_FP8
+_tile_dpbf8ps_internal(unsigned short m, unsigned short n, unsigned short k,
+ _tile1024i dst, _tile1024i src1, _tile1024i src2) {
+ return __builtin_ia32_tdpbf8ps_internal(m, n, k, dst, src1, src2);
+}
+
+/// Perform the dot product of a BF8 value \a src1 by a BF8 value \a src2
+/// accumulating into a Single Precision (FP32) source/dest \a dst.
///
/// \headerfile <immintrin.h>
///
/// \code
-/// void _tile_dpbf8ps (__tile dst, __tile a, __tile b)
+/// void __tile_dpbf8ps (__tile1024i *dst, __tile1024i src1, __tile1024i src2)
+/// \endcode
+///
+/// \code{.operation}
+/// FOR m := 0 TO dst.rows - 1
+/// temp1[(dst.colsb / 4 - 1) : 0] = 0
+/// FOR k := 0 TO src1.colsb / 4 - 1
+/// FOR n := 0 TO dst.colsb / 4 - 1
+/// temp1[n] +=
+/// INT64(src1.row[m].float8[4*k+0]) * INT64(src2.row[k].float8[4*n+0])
+/// + INT64(src1.row[m].float8[4*k+1]) * INT64(src2.row[k].float8[4*n+1])
+/// + INT64(src1.row[m].float8[4*k+2]) * INT64(src2.row[k].float8[4*n+2])
+/// + INT64(src1.row[m].float8[4*k+3]) * INT64(src2.row[k].float8[4*n+3])
+/// ENDFOR
+/// ENDFOR
+/// FOR n := 0 TO dst.colsb / 4 - 1
+/// tmp.row[m].fp32[n] = dst.row[m].fp32[n] + FP32(temp1[n])
+/// ENDFOR
+/// write_row_and_zero(dst, m, tmp, dst.colsb)
+/// zero_upper_rows(dst, dst.rows)
+/// zero_tileconfig_start()
/// \endcode
///
/// This intrinsic corresponds to the \c TDPBF8PS instruction.
///
/// \param dst
/// The destination tile. Max size is 1024 Bytes.
-/// \param a
+/// \param src1
/// The 1st source tile. Max size is 1024 Bytes.
-/// \param b
+/// \param src2
/// The 2nd source tile. Max size is 1024 Bytes.
-#define _tile_dpbf8ps(dst, a, b) __builtin_ia32_tdpbf8ps((dst), (a), (b))
+__DEFAULT_FN_ATTRS_FP8 static void
+__tile_dpbf8ps(__tile1024i *dst, __tile1024i src1, __tile1024i src2) {
+ dst->tile = _tile_dpbf8ps_internal(src1.row, src2.col, src1.col, dst->tile,
+ src1.tile, src2.tile);
+}
+
+static __inline__ _tile1024i __DEFAULT_FN_ATTRS_FP8
+_tile_dpbhf8ps_internal(unsigned short m, unsigned short n, unsigned short k,
+ _tile1024i dst, _tile1024i src1, _tile1024i src2) {
+ return __builtin_ia32_tdpbhf8ps_internal(m, n, k, dst, src1, src2);
+}
-/// Perform the dot product of a BF8 value \a a by an HF8 value \a b
+/// Perform the dot product of a BF8 value \a src1 by an HF8 value \a src2
/// accumulating into a Single Precision (FP32) source/dest \a dst.
///
/// \headerfile <immintrin.h>
///
/// \code
-/// void _tile_dpbhf8ps (__tile dst, __tile a, __tile b)
+/// void __tile_dpbhf8ps (__tile1024i dst, __tile1024i src1, __tile1024i src2)
+/// \endcode
+///
+/// \code{.operation}
+/// FOR m := 0 TO dst.rows - 1
+/// temp1[(dst.colsb / 4 - 1) : 0] = 0
+/// FOR k := 0 TO src1.colsb / 4 - 1
+/// FOR n := 0 TO dst.colsb / 4 - 1
+/// temp1[n] +=
+/// INT64(src1.row[m].float8[4*k+0]) * INT64(src2.row[k].float8[4*n+0])
+/// + INT64(src1.row[m].float8[4*k+1]) * INT64(src2.row[k].float8[4*n+1])
+/// + INT64(src1.row[m].float8[4*k+2]) * INT64(src2.row[k].float8[4*n+2])
+/// + INT64(src1.row[m].float8[4*k+3]) * INT64(src2.row[k].float8[4*n+3])
+/// ENDFOR
+/// ENDFOR
+/// FOR n := 0 TO dst.colsb / 4 - 1
+/// tmp.row[m].fp32[n] = dst.row[m].fp32[n] + FP32(temp1[n])
+/// ENDFOR
+/// write_row_and_zero(dst, m, tmp, dst.colsb)
+/// zero_upper_rows(dst, dst.rows)
+/// zero_tileconfig_start()
/// \endcode
///
/// This intrinsic corresponds to the \c TDPBHF8PS instruction.
///
/// \param dst
/// The destination tile. Max size is 1024 Bytes.
-/// \param a
+/// \param src1
/// The 1st source tile. Max size is 1024 Bytes.
-/// \param b
+/// \param src2
/// The 2nd source tile. Max size is 1024 Bytes.
-#define _tile_dpbhf8ps(dst, a, b) __builtin_ia32_tdpbhf8ps((dst), (a), (b))
+__DEFAULT_FN_ATTRS_FP8 static void
+__tile_dpbhf8ps(__tile1024i *dst, __tile1024i src1, __tile1024i src2) {
+ dst->tile = _tile_dpbhf8ps_internal(src1.row, src2.col, src1.col, dst->tile,
+ src1.tile, src2.tile);
+}
-/// Perform the dot product of an HF8 value \a a by a BF8 value \a b
+static __inline__ _tile1024i __DEFAULT_FN_ATTRS_FP8
+_tile_dphbf8ps_internal(unsigned short m, unsigned short n, unsigned short k,
+ _tile1024i dst, _tile1024i src1, _tile1024i src2) {
+ return __builtin_ia32_tdphbf8ps_internal(m, n, k, dst, src1, src2);
+}
+
+/// Perform the dot product of an HF8 value \a src1 by a BF8 value \a src2
/// accumulating into a Single Precision (FP32) source/dest \a dst.
///
/// \headerfile <immintrin.h>
///
/// \code
-/// void _tile_dphbf8ps (__tile dst, __tile a, __tile b)
+/// void __tile_dphbf8ps (__tile1024i dst, __tile1024i src1, __tile1024i src2)
+/// \endcode
+///
+/// \code{.operation}
+/// FOR m := 0 TO dst.rows - 1
+/// temp1[(dst.colsb / 4 - 1) : 0] = 0
+/// FOR k := 0 TO src1.colsb / 4 - 1
+/// FOR n := 0 TO dst.colsb / 4 - 1
+/// temp1[n] +=
+/// INT64(src1.row[m].float8[4*k+0]) * INT64(src2.row[k].float8[4*n+0])
+/// + INT64(src1.row[m].float8[4*k+1]) * INT64(src2.row[k].float8[4*n+1])
+/// + INT64(src1.row[m].float8[4*k+2]) * INT64(src2.row[k].float8[4*n+2])
+/// + INT64(src1.row[m].float8[4*k+3]) * INT64(src2.row[k].float8[4*n+3])
+/// ENDFOR
+/// ENDFOR
+/// FOR n := 0 TO dst.colsb / 4 - 1
+/// tmp.row[m].fp32[n] = dst.row[m].fp32[n] + FP32(temp1[n])
+/// ENDFOR
+/// write_row_and_zero(dst, m, tmp, dst.colsb)
+/// zero_upper_rows(dst, dst.rows)
+/// zero_tileconfig_start()
/// \endcode
///
/// This intrinsic corresponds to the \c TDPHBF8PS instruction.
///
/// \param dst
/// The destination tile. Max size is 1024 Bytes.
-/// \param a
+/// \param src1
/// The 1st source tile. Max size is 1024 Bytes.
-/// \param b
+/// \param src2
/// The 2nd source tile. Max size is 1024 Bytes.
-#define _tile_dphbf8ps(dst, a, b) __builtin_ia32_tdphbf8ps((dst), (a), (b))
-/// Perform the dot product of an HF8 value \a a by an HF8 value \a b
+__DEFAULT_FN_ATTRS_FP8 static void
+__tile_dphbf8ps(__tile1024i *dst, __tile1024i src1, __tile1024i src2) {
+ dst->tile = _tile_dphbf8ps_internal(src1.row, src2.col, src1.col, dst->tile,
+ src1.tile, src2.tile);
+}
+
+static __inline__ _tile1024i __DEFAULT_FN_ATTRS_FP8
+_tile_dphf8ps_internal(unsigned short m, unsigned short n, unsigned short k,
+ _tile1024i dst, _tile1024i src1, _tile1024i src2) {
+ return __builtin_ia32_tdphf8ps_internal(m, n, k, dst, src1, src2);
+}
+
+/// Perform the dot product of an HF8 value \a src1 by an HF8 value \a src2
/// accumulating into a Single Precision (FP32) source/dest \a dst.
///
/// \headerfile <immintrin.h>
///
/// \code
-/// void _tile_dphf8ps (__tile dst, __tile a, __tile b)
+/// void __tile_dphf8ps (__tile1024i dst, __tile1024i src1, __tile1024i src2)
+/// \endcode
+///
+/// \code{.operation}
+/// FOR m := 0 TO dst.rows - 1
+/// temp1[(dst.colsb / 4 - 1) : 0] = 0
+/// FOR k := 0 TO src1.colsb / 4 - 1
+/// FOR n := 0 TO dst.colsb / 4 - 1
+/// temp1[n] +=
+/// INT64(src1.row[m].float8[4*k+0]) * INT64(src2.row[k].float8[4*n+0])
+/// + INT64(src1.row[m].float8[4*k+1]) * INT64(src2.row[k].float8[4*n+1])
+/// + INT64(src1.row[m].float8[4*k+2]) * INT64(src2.row[k].float8[4*n+2])
+/// + INT64(src1.row[m].float8[4*k+3]) * INT64(src2.row[k].float8[4*n+3])
+/// ENDFOR
+/// ENDFOR
+/// FOR n := 0 TO dst.colsb / 4 - 1
+/// tmp.row[m].fp32[n] = dst.row[m].fp32[n] + FP32(temp1[n])
+/// ENDFOR
+/// write_row_and_zero(dst, m, tmp, dst.colsb)
+/// zero_upper_rows(dst, dst.rows)
+/// zero_tileconfig_start()
/// \endcode
///
/// This intrinsic corresponds to the \c TDPHF8PS instruction.
///
/// \param dst
/// The destination tile. Max size is 1024 Bytes.
-/// \param a
+/// \param src1
/// The 1st source tile. Max size is 1024 Bytes.
-/// \param b
+/// \param src2
/// The 2nd source tile. Max size is 1024 Bytes.
-#define _tile_dphf8ps(dst, a, b) __builtin_ia32_tdphf8ps((dst), (a), (b))
+__DEFAULT_FN_ATTRS_FP8 static void
+__tile_dphf8ps(__tile1024i *dst, __tile1024i src1, __tile1024i src2) {
+ dst->tile = _tile_dphf8ps_internal(src1.row, src2.col, src1.col, dst->tile,
+ src1.tile, src2.tile);
+}
+
+#define _tile_dpbf8ps(dst, src1, src2) \
+ __builtin_ia32_tdpbf8ps((dst), (src1), (src2))
+#define _tile_dpbhf8ps(dst, src1, src2) \
+ __builtin_ia32_tdpbhf8ps((dst), (src1), (src2))
+#define _tile_dphbf8ps(dst, src1, src2) \
+ __builtin_ia32_tdphbf8ps((dst), (src1), (src2))
+#define _tile_dphf8ps(dst, src1, src2) \
+ __builtin_ia32_tdphf8ps((dst), (src1), (src2))
#endif /* __x86_64__ */
#endif /* __AMXFP8INTRIN_H */
diff --git a/clang/test/CodeGen/X86/amx_fp8_api.c b/clang/test/CodeGen/X86/amx_fp8_api.c
new file mode 100644
index 00000000000000..2a3af1b7f5cd9a
--- /dev/null
+++ b/clang/test/CodeGen/X86/amx_fp8_api.c
@@ -0,0 +1,36 @@
+// RUN: %clang_cc1 %s -ffreestanding -triple=x86_64-unknown-unknown -target-feature +amx-fp8 \
+// RUN: -emit-llvm -o - -Werror -pedantic | FileCheck %s
+#include <immintrin.h>
+
+void test_tdpbf8ps(__tile1024i src1, __tile1024i src2, __tile1024i dst) {
+ //CHECK-LABEL: @test_tdpbf8ps
+ //CHECK-DAG: call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> {{%.*}})
+ //CHECK-DAG: call x86_amx @llvm.x86.tdpbf8ps.internal
+ //CHECK-DAG: call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx {{%.*}})
+ __tile_dpbf8ps(&dst, src1, src2);
+}
+
+void test_tdpbhf8ps(__tile1024i src1, __tile1024i src2, __tile1024i dst) {
+ //CHECK-LABEL: @test_tdpbhf8ps
+ //CHECK-DAG: call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> {{%.*}})
+ //CHECK-DAG: call x86_amx @llvm.x86.tdpbhf8ps.internal
+ //CHECK-DAG: call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx {{%.*}})
+ __tile_dpbhf8ps(&dst, src1, src2);
+}
+
+void test_tdphbf8ps(__tile1024i src1, __tile1024i src2, __tile1024i dst) {
+ //CHECK-LABEL: @test_tdphbf8ps
+ //CHECK-DAG: call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> {{%.*}})
+ //CHECK-DAG: call x86_amx @llvm.x86.tdphbf8ps.internal
+ //CHECK-DAG: call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx {{%.*}})
+ __tile_dphbf8ps(&dst, src1, src2);
+}
+
+void test_tdphf8ps(__tile1024i src1, __tile1024i src2, __tile1024i dst) {
+ //CHECK-LABEL: @test_tdphf8ps
+ //CHECK-DAG: call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> {{%.*}})
+ //CHECK-DAG: call x86_amx @llvm.x86.tdphf8ps.internal
+ //CHECK-DAG: call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx {{%.*}})
+ __tile_dphf8ps(&dst, src1, src2);
+}
+
diff --git a/llvm/include/llvm/IR/IntrinsicsX86.td b/llvm/include/llvm/IR/IntrinsicsX86.td
index b2d6f44b7927a9..5211b82cc8d31b 100644
--- a/llvm/include/llvm/IR/IntrinsicsX86.td
+++ b/llvm/include/llvm/IR/IntrinsicsX86.td
@@ -6120,6 +6120,31 @@ let TargetPrefix = "x86" in {
Intrinsic<[llvm_x86amx_ty],
[llvm_i16_ty, llvm_i16_ty, llvm_i16_ty, llvm_x86amx_ty,
llvm_x86amx_ty, llvm_x86amx_ty], []>;
+
+ def int_x86_tdpbf8ps_internal :
+ ClangBuiltin<"__builtin_ia32_tdpbf8ps_internal">,
+ Intrinsic<[llvm_x86amx_ty],
+ [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
+ llvm_x86amx_ty, llvm_x86amx_ty, llvm_x86amx_ty],
+ []>;
+ def int_x86_tdpbhf8ps_internal :
+ ClangBuiltin<"__builtin_ia32_tdpbhf8ps_internal">,
+ Intrinsic<[llvm_x86amx_ty],
+ [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
+ llvm_x86amx_ty, llvm_x86amx_ty, llvm_x86amx_ty],
+ []>;
+ def int_x86_tdphbf8ps_internal :
+ ClangBuiltin<"__builtin_ia32_tdphbf8ps_internal">,
+ Intrinsic<[llvm_x86amx_ty],
+ [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
+ llvm_x86amx_ty, llvm_x86amx_ty, llvm_x86amx_ty],
+ []>;
+ def int_x86_tdphf8ps_internal :
+ ClangBuiltin<"__builtin_ia32_tdphf8ps_internal">,
+ Intrinsic<[llvm_x86amx_ty],
+ [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
+ llvm_x86amx_ty, llvm_x86amx_ty, llvm_x86amx_ty],
+ []>;
}
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/X86/X86ExpandPseudo.cpp b/llvm/lib/Target/X86/X86ExpandPseudo.cpp
index 4f045d78f75fb2..b673a3766a6832 100644
--- a/llvm/lib/Target/X86/X86ExpandPseudo.cpp
+++ b/llvm/lib/Target/X86/X86ExpandPseudo.cpp
@@ -757,7 +757,11 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB,
case X86::PTDPBF16PSV:
case X86::PTDPFP16PSV:
case X86::PTMMULTF32PSV:
- case X86::PTTMMULTF32PSV: {
+ case X86::PTTMMULTF32PSV:
+ case X86::PTDPBF8PSV:
+ case X86::PTDPBHF8PSV:
+ case X86::PTDPHBF8PSV:
+ case X86::PTDPHF8PSV: {
MI.untieRegOperand(4);
for (unsigned i = 3; i > 0; --i)
MI.removeOperand(i);
@@ -777,6 +781,18 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB,
case X86::PTTMMULTF32PSV:
Opc = X86::TTMMULTF32PS;
break;
+ case X86::PTDPBF8PSV:
+ Opc = X86::TDPBF8PS;
+ break;
+ case X86::PTDPBHF8PSV:
+ Opc = X86::TDPBHF8PS;
+ break;
+ case X86::PTDPHBF8PSV:
+ Opc = X86::TDPHBF8PS;
+ break;
+ case X86::PTDPHF8PSV:
+ Opc = X86::TDPHF8PS;
+ break;
default:
llvm_unreachable("Unexpected Opcode");
diff --git a/llvm/lib/Target/X86/X86InstrAMX.td b/llvm/lib/Target/X86/X86InstrAMX.td
index 04527716e31627..da0077a990242b 100644
--- a/llvm/lib/Target/X86/X86InstrAMX.td
+++ b/llvm/lib/Target/X86/X86InstrAMX.td
@@ -304,6 +304,37 @@ let Predicates = [HasAMXFP8, In64BitMode] in {
[(int_x86_tdphf8ps timm:$src1, timm:$src2,
timm:$src3)]>;
}
+
+ let Constraints = "$src4 = $dst" in {
+ def PTDPBF8PSV : PseudoI<(outs TILE:$dst),
+ (ins GR16:$src1, GR16:$src2, GR16:$src3,
+ TILE:$src4, TILE:$src5, TILE:$src6),
+ [(set TILE:$dst,
+ (int_x86_tdpbf8ps_internal GR16:$src1,
+ GR16:$src2, GR16:$src3, TILE:$src4,
+ TILE:$src5, TILE:$src6))]>;
+ def PTDPBHF8PSV : PseudoI<(outs TILE:$dst),
+ (ins GR16:$src1, GR16:$src2, GR16:$src3,
+ TILE:$src4, TILE:$src5, TILE:$src6),
+ [(set TILE:$dst,
+ (int_x86_tdpbhf8ps_internal GR16:$src1,
+ GR16:$src2, GR16:$src3, TILE:$src4,
+ TILE:$src5, TILE:$src6))]>;
+ def PTDPHBF8PSV : PseudoI<(outs TILE:$dst),
+ (ins GR16:$src1, GR16:$src2, GR16:$src3,
+ TILE:$src4, TILE:$src5, TILE:$src6),
+ [(set TILE:$dst,
+ (int_x86_tdphbf8ps_internal GR16:$src1,
+ GR16:$src2, GR16:$src3, TILE:$src4,
+ TILE:$src5, TILE:$src6))]>;
+ def PTDPHF8PSV : PseudoI<(outs TILE:$dst),
+ (ins GR16:$src1, GR16:$src2, GR16:$src3,
+ TILE:$src4, TILE:$src5, TILE:$src6),
+ [(set TILE:$dst,
+ (int_x86_tdphf8ps_internal GR16:$src1,
+ GR16:$src2, GR16:$src3, TILE:$src4,
+ TILE:$src5, TILE:$src6))]>;
+ }
}
}
diff --git a/llvm/lib/Target/X86/X86RegisterInfo.cpp b/llvm/lib/Target/X86/X86RegisterInfo.cpp
index 09418c9bb74d34..25eb90cdd6d60c 100644
--- a/llvm/lib/Target/X86/X86RegisterInfo.cpp
+++ b/llvm/lib/Target/X86/X86RegisterInfo.cpp
@@ -1078,7 +1078,11 @@ static ShapeT getTileShape(Register VirtReg, VirtRegMap *VRM,
case X86::PTCMMRLFP16PSV:
case X86::PTTRANSPOSEDV:
case X86::PTMMULTF32PSV:
- case X86::PTTMMULTF32PSV: {
+ case X86::PTTMMULTF32PSV:
+ case X86::PTDPBF8PSV:
+ case X86::PTDPBHF8PSV:
+ case X86::PTDPHBF8PSV:
+ case X86::PTDPHF8PSV: {
MachineOperand &MO1 = MI->getOperand(1);
MachineOperand &MO2 = MI->getOperand(2);
ShapeT Shape(&MO1, &MO2, MRI);
|
You can test this locally with the following command:git-clang-format --diff f109517d153609d4a8a3a3d3d3cc06da1b629364 7301f5628cc495a18c729e75dda90ca523444b9b --extensions c,h,cpp -- clang/test/CodeGen/X86/amx_fp8_api.c clang/lib/Headers/amxfp8intrin.h llvm/lib/Target/X86/X86ExpandPseudo.cpp llvm/lib/Target/X86/X86RegisterInfo.cpp View the diff from clang-format here.diff --git a/clang/lib/Headers/amxfp8intrin.h b/clang/lib/Headers/amxfp8intrin.h
index 92e7989974..41d2d3749e 100644
--- a/clang/lib/Headers/amxfp8intrin.h
+++ b/clang/lib/Headers/amxfp8intrin.h
@@ -40,9 +40,12 @@ _tile_dpbf8ps_internal(unsigned short m, unsigned short n, unsigned short k,
/// FOR n := 0 TO dst.colsb / 4 - 1
/// temp1[n] +=
/// INT64(src1.row[m].float8[4*k+0]) * INT64(src2.row[k].float8[4*n+0])
-/// + INT64(src1.row[m].float8[4*k+1]) * INT64(src2.row[k].float8[4*n+1])
-/// + INT64(src1.row[m].float8[4*k+2]) * INT64(src2.row[k].float8[4*n+2])
-/// + INT64(src1.row[m].float8[4*k+3]) * INT64(src2.row[k].float8[4*n+3])
+/// + INT64(src1.row[m].float8[4*k+1]) *
+/// INT64(src2.row[k].float8[4*n+1])
+/// + INT64(src1.row[m].float8[4*k+2]) *
+/// INT64(src2.row[k].float8[4*n+2])
+/// + INT64(src1.row[m].float8[4*k+3]) *
+/// INT64(src2.row[k].float8[4*n+3])
/// ENDFOR
/// ENDFOR
/// FOR n := 0 TO dst.colsb / 4 - 1
@@ -89,9 +92,12 @@ _tile_dpbhf8ps_internal(unsigned short m, unsigned short n, unsigned short k,
/// FOR n := 0 TO dst.colsb / 4 - 1
/// temp1[n] +=
/// INT64(src1.row[m].float8[4*k+0]) * INT64(src2.row[k].float8[4*n+0])
-/// + INT64(src1.row[m].float8[4*k+1]) * INT64(src2.row[k].float8[4*n+1])
-/// + INT64(src1.row[m].float8[4*k+2]) * INT64(src2.row[k].float8[4*n+2])
-/// + INT64(src1.row[m].float8[4*k+3]) * INT64(src2.row[k].float8[4*n+3])
+/// + INT64(src1.row[m].float8[4*k+1]) *
+/// INT64(src2.row[k].float8[4*n+1])
+/// + INT64(src1.row[m].float8[4*k+2]) *
+/// INT64(src2.row[k].float8[4*n+2])
+/// + INT64(src1.row[m].float8[4*k+3]) *
+/// INT64(src2.row[k].float8[4*n+3])
/// ENDFOR
/// ENDFOR
/// FOR n := 0 TO dst.colsb / 4 - 1
@@ -138,9 +144,12 @@ _tile_dphbf8ps_internal(unsigned short m, unsigned short n, unsigned short k,
/// FOR n := 0 TO dst.colsb / 4 - 1
/// temp1[n] +=
/// INT64(src1.row[m].float8[4*k+0]) * INT64(src2.row[k].float8[4*n+0])
-/// + INT64(src1.row[m].float8[4*k+1]) * INT64(src2.row[k].float8[4*n+1])
-/// + INT64(src1.row[m].float8[4*k+2]) * INT64(src2.row[k].float8[4*n+2])
-/// + INT64(src1.row[m].float8[4*k+3]) * INT64(src2.row[k].float8[4*n+3])
+/// + INT64(src1.row[m].float8[4*k+1]) *
+/// INT64(src2.row[k].float8[4*n+1])
+/// + INT64(src1.row[m].float8[4*k+2]) *
+/// INT64(src2.row[k].float8[4*n+2])
+/// + INT64(src1.row[m].float8[4*k+3]) *
+/// INT64(src2.row[k].float8[4*n+3])
/// ENDFOR
/// ENDFOR
/// FOR n := 0 TO dst.colsb / 4 - 1
@@ -188,9 +197,12 @@ _tile_dphf8ps_internal(unsigned short m, unsigned short n, unsigned short k,
/// FOR n := 0 TO dst.colsb / 4 - 1
/// temp1[n] +=
/// INT64(src1.row[m].float8[4*k+0]) * INT64(src2.row[k].float8[4*n+0])
-/// + INT64(src1.row[m].float8[4*k+1]) * INT64(src2.row[k].float8[4*n+1])
-/// + INT64(src1.row[m].float8[4*k+2]) * INT64(src2.row[k].float8[4*n+2])
-/// + INT64(src1.row[m].float8[4*k+3]) * INT64(src2.row[k].float8[4*n+3])
+/// + INT64(src1.row[m].float8[4*k+1]) *
+/// INT64(src2.row[k].float8[4*n+1])
+/// + INT64(src1.row[m].float8[4*k+2]) *
+/// INT64(src2.row[k].float8[4*n+2])
+/// + INT64(src1.row[m].float8[4*k+3]) *
+/// INT64(src2.row[k].float8[4*n+3])
/// ENDFOR
/// ENDFOR
/// FOR n := 0 TO dst.colsb / 4 - 1
|
#define _tile_dphbf8ps(dst, src1, src2) \ | ||
__builtin_ia32_tdphbf8ps((dst), (src1), (src2)) | ||
#define _tile_dphf8ps(dst, src1, src2) \ | ||
__builtin_ia32_tdphf8ps((dst), (src1), (src2)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#undef __DEFAULT_FN_ATTRS_FP8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Missing IR test? |
Sorry. Added. Thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
This is a follow-up to #113850.
Ref.: https://cdrdv2.intel.com/v1/dl/getContent/671368