Skip to content

[X86][AMX] Add AMX FP8 new APIs #115829

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions clang/include/clang/Basic/BuiltinsX86_64.def
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ TARGET_BUILTIN(__builtin_ia32_tcvtrowps2phl_internal, "V32xUsUsV256iUi", "n", "a
TARGET_BUILTIN(__builtin_ia32_tilemovrow_internal, "V16iUsUsV256iUi", "n", "amx-avx512,avx10.2-512")
TARGET_BUILTIN(__builtin_ia32_tmmultf32ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-tf32")
TARGET_BUILTIN(__builtin_ia32_ttmmultf32ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-tf32,amx-transpose")
TARGET_BUILTIN(__builtin_ia32_tdpbf8ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-fp8")
TARGET_BUILTIN(__builtin_ia32_tdpbhf8ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-fp8")
TARGET_BUILTIN(__builtin_ia32_tdphbf8ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-fp8")
TARGET_BUILTIN(__builtin_ia32_tdphf8ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-fp8")

// AMX
TARGET_BUILTIN(__builtin_ia32_tile_loadconfig, "vvC*", "n", "amx-tile")
Expand Down
177 changes: 156 additions & 21 deletions clang/lib/Headers/amxfp8intrin.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,81 +15,216 @@
#define __AMXFP8INTRIN_H
#ifdef __x86_64__

/// Peform the dot product of a BF8 value \a a by a BF8 value \a b accumulating
/// into a Single Precision (FP32) source/dest \a dst.
#define __DEFAULT_FN_ATTRS_FP8 \
__attribute__((__always_inline__, __nodebug__, __target__("amx-fp8")))

static __inline__ _tile1024i __DEFAULT_FN_ATTRS_FP8
_tile_dpbf8ps_internal(unsigned short m, unsigned short n, unsigned short k,
_tile1024i dst, _tile1024i src1, _tile1024i src2) {
return __builtin_ia32_tdpbf8ps_internal(m, n, k, dst, src1, src2);
}

/// Perform the dot product of a BF8 value \a src1 by a BF8 value \a src2
/// accumulating into a Single Precision (FP32) source/dest \a dst.
///
/// \headerfile <immintrin.h>
///
/// \code
/// void _tile_dpbf8ps (__tile dst, __tile a, __tile b)
/// void __tile_dpbf8ps (__tile1024i *dst, __tile1024i src1, __tile1024i src2)
/// \endcode
///
/// \code{.operation}
/// FOR m := 0 TO dst.rows - 1
/// temp1[(dst.colsb / 4 - 1) : 0] = 0
/// FOR k := 0 TO src1.colsb / 4 - 1
/// FOR n := 0 TO dst.colsb / 4 - 1
/// temp1[n] +=
/// INT64(src1.row[m].float8[4*k+0]) * INT64(src2.row[k].float8[4*n+0])
/// + INT64(src1.row[m].float8[4*k+1]) * INT64(src2.row[k].float8[4*n+1])
/// + INT64(src1.row[m].float8[4*k+2]) * INT64(src2.row[k].float8[4*n+2])
/// + INT64(src1.row[m].float8[4*k+3]) * INT64(src2.row[k].float8[4*n+3])
/// ENDFOR
/// ENDFOR
/// FOR n := 0 TO dst.colsb / 4 - 1
/// tmp.row[m].fp32[n] = dst.row[m].fp32[n] + FP32(temp1[n])
/// ENDFOR
/// write_row_and_zero(dst, m, tmp, dst.colsb)
/// zero_upper_rows(dst, dst.rows)
/// zero_tileconfig_start()
/// \endcode
///
/// This intrinsic corresponds to the \c TDPBF8PS instruction.
///
/// \param dst
/// The destination tile. Max size is 1024 Bytes.
/// \param a
/// \param src1
/// The 1st source tile. Max size is 1024 Bytes.
/// \param b
/// \param src2
/// The 2nd source tile. Max size is 1024 Bytes.
#define _tile_dpbf8ps(dst, a, b) __builtin_ia32_tdpbf8ps((dst), (a), (b))
__DEFAULT_FN_ATTRS_FP8 static void
__tile_dpbf8ps(__tile1024i *dst, __tile1024i src1, __tile1024i src2) {
dst->tile = _tile_dpbf8ps_internal(src1.row, src2.col, src1.col, dst->tile,
src1.tile, src2.tile);
}

/// Perform the dot product of a BF8 value \a a by an HF8 value \a b
static __inline__ _tile1024i __DEFAULT_FN_ATTRS_FP8
_tile_dpbhf8ps_internal(unsigned short m, unsigned short n, unsigned short k,
_tile1024i dst, _tile1024i src1, _tile1024i src2) {
return __builtin_ia32_tdpbhf8ps_internal(m, n, k, dst, src1, src2);
}

/// Perform the dot product of a BF8 value \a src1 by an HF8 value \a src2
/// accumulating into a Single Precision (FP32) source/dest \a dst.
///
/// \headerfile <immintrin.h>
///
/// \code
/// void _tile_dpbhf8ps (__tile dst, __tile a, __tile b)
/// void __tile_dpbhf8ps (__tile1024i dst, __tile1024i src1, __tile1024i src2)
/// \endcode
///
/// \code{.operation}
/// FOR m := 0 TO dst.rows - 1
/// temp1[(dst.colsb / 4 - 1) : 0] = 0
/// FOR k := 0 TO src1.colsb / 4 - 1
/// FOR n := 0 TO dst.colsb / 4 - 1
/// temp1[n] +=
/// INT64(src1.row[m].float8[4*k+0]) * INT64(src2.row[k].float8[4*n+0])
/// + INT64(src1.row[m].float8[4*k+1]) * INT64(src2.row[k].float8[4*n+1])
/// + INT64(src1.row[m].float8[4*k+2]) * INT64(src2.row[k].float8[4*n+2])
/// + INT64(src1.row[m].float8[4*k+3]) * INT64(src2.row[k].float8[4*n+3])
/// ENDFOR
/// ENDFOR
/// FOR n := 0 TO dst.colsb / 4 - 1
/// tmp.row[m].fp32[n] = dst.row[m].fp32[n] + FP32(temp1[n])
/// ENDFOR
/// write_row_and_zero(dst, m, tmp, dst.colsb)
/// zero_upper_rows(dst, dst.rows)
/// zero_tileconfig_start()
/// \endcode
///
/// This intrinsic corresponds to the \c TDPBHF8PS instruction.
///
/// \param dst
/// The destination tile. Max size is 1024 Bytes.
/// \param a
/// \param src1
/// The 1st source tile. Max size is 1024 Bytes.
/// \param b
/// \param src2
/// The 2nd source tile. Max size is 1024 Bytes.
#define _tile_dpbhf8ps(dst, a, b) __builtin_ia32_tdpbhf8ps((dst), (a), (b))
__DEFAULT_FN_ATTRS_FP8 static void
__tile_dpbhf8ps(__tile1024i *dst, __tile1024i src1, __tile1024i src2) {
dst->tile = _tile_dpbhf8ps_internal(src1.row, src2.col, src1.col, dst->tile,
src1.tile, src2.tile);
}

static __inline__ _tile1024i __DEFAULT_FN_ATTRS_FP8
_tile_dphbf8ps_internal(unsigned short m, unsigned short n, unsigned short k,
_tile1024i dst, _tile1024i src1, _tile1024i src2) {
return __builtin_ia32_tdphbf8ps_internal(m, n, k, dst, src1, src2);
}

/// Perform the dot product of an HF8 value \a a by a BF8 value \a b
/// Perform the dot product of an HF8 value \a src1 by a BF8 value \a src2
/// accumulating into a Single Precision (FP32) source/dest \a dst.
///
/// \headerfile <immintrin.h>
///
/// \code
/// void _tile_dphbf8ps (__tile dst, __tile a, __tile b)
/// void __tile_dphbf8ps (__tile1024i dst, __tile1024i src1, __tile1024i src2)
/// \endcode
///
/// \code{.operation}
/// FOR m := 0 TO dst.rows - 1
/// temp1[(dst.colsb / 4 - 1) : 0] = 0
/// FOR k := 0 TO src1.colsb / 4 - 1
/// FOR n := 0 TO dst.colsb / 4 - 1
/// temp1[n] +=
/// INT64(src1.row[m].float8[4*k+0]) * INT64(src2.row[k].float8[4*n+0])
/// + INT64(src1.row[m].float8[4*k+1]) * INT64(src2.row[k].float8[4*n+1])
/// + INT64(src1.row[m].float8[4*k+2]) * INT64(src2.row[k].float8[4*n+2])
/// + INT64(src1.row[m].float8[4*k+3]) * INT64(src2.row[k].float8[4*n+3])
/// ENDFOR
/// ENDFOR
/// FOR n := 0 TO dst.colsb / 4 - 1
/// tmp.row[m].fp32[n] = dst.row[m].fp32[n] + FP32(temp1[n])
/// ENDFOR
/// write_row_and_zero(dst, m, tmp, dst.colsb)
/// zero_upper_rows(dst, dst.rows)
/// zero_tileconfig_start()
/// \endcode
///
/// This intrinsic corresponds to the \c TDPHBF8PS instruction.
///
/// \param dst
/// The destination tile. Max size is 1024 Bytes.
/// \param a
/// \param src1
/// The 1st source tile. Max size is 1024 Bytes.
/// \param b
/// \param src2
/// The 2nd source tile. Max size is 1024 Bytes.
#define _tile_dphbf8ps(dst, a, b) __builtin_ia32_tdphbf8ps((dst), (a), (b))

/// Perform the dot product of an HF8 value \a a by an HF8 value \a b
__DEFAULT_FN_ATTRS_FP8 static void
__tile_dphbf8ps(__tile1024i *dst, __tile1024i src1, __tile1024i src2) {
dst->tile = _tile_dphbf8ps_internal(src1.row, src2.col, src1.col, dst->tile,
src1.tile, src2.tile);
}

static __inline__ _tile1024i __DEFAULT_FN_ATTRS_FP8
_tile_dphf8ps_internal(unsigned short m, unsigned short n, unsigned short k,
_tile1024i dst, _tile1024i src1, _tile1024i src2) {
return __builtin_ia32_tdphf8ps_internal(m, n, k, dst, src1, src2);
}

/// Perform the dot product of an HF8 value \a src1 by an HF8 value \a src2
/// accumulating into a Single Precision (FP32) source/dest \a dst.
///
/// \headerfile <immintrin.h>
///
/// \code
/// void _tile_dphf8ps (__tile dst, __tile a, __tile b)
/// void __tile_dphf8ps (__tile1024i dst, __tile1024i src1, __tile1024i src2)
/// \endcode
///
/// \code{.operation}
/// FOR m := 0 TO dst.rows - 1
/// temp1[(dst.colsb / 4 - 1) : 0] = 0
/// FOR k := 0 TO src1.colsb / 4 - 1
/// FOR n := 0 TO dst.colsb / 4 - 1
/// temp1[n] +=
/// INT64(src1.row[m].float8[4*k+0]) * INT64(src2.row[k].float8[4*n+0])
/// + INT64(src1.row[m].float8[4*k+1]) * INT64(src2.row[k].float8[4*n+1])
/// + INT64(src1.row[m].float8[4*k+2]) * INT64(src2.row[k].float8[4*n+2])
/// + INT64(src1.row[m].float8[4*k+3]) * INT64(src2.row[k].float8[4*n+3])
/// ENDFOR
/// ENDFOR
/// FOR n := 0 TO dst.colsb / 4 - 1
/// tmp.row[m].fp32[n] = dst.row[m].fp32[n] + FP32(temp1[n])
/// ENDFOR
/// write_row_and_zero(dst, m, tmp, dst.colsb)
/// zero_upper_rows(dst, dst.rows)
/// zero_tileconfig_start()
/// \endcode
///
/// This intrinsic corresponds to the \c TDPHF8PS instruction.
///
/// \param dst
/// The destination tile. Max size is 1024 Bytes.
/// \param a
/// \param src1
/// The 1st source tile. Max size is 1024 Bytes.
/// \param b
/// \param src2
/// The 2nd source tile. Max size is 1024 Bytes.
#define _tile_dphf8ps(dst, a, b) __builtin_ia32_tdphf8ps((dst), (a), (b))
__DEFAULT_FN_ATTRS_FP8 static void
__tile_dphf8ps(__tile1024i *dst, __tile1024i src1, __tile1024i src2) {
dst->tile = _tile_dphf8ps_internal(src1.row, src2.col, src1.col, dst->tile,
src1.tile, src2.tile);
}

#define _tile_dpbf8ps(dst, src1, src2) \
__builtin_ia32_tdpbf8ps((dst), (src1), (src2))
#define _tile_dpbhf8ps(dst, src1, src2) \
__builtin_ia32_tdpbhf8ps((dst), (src1), (src2))
#define _tile_dphbf8ps(dst, src1, src2) \
__builtin_ia32_tdphbf8ps((dst), (src1), (src2))
#define _tile_dphf8ps(dst, src1, src2) \
__builtin_ia32_tdphf8ps((dst), (src1), (src2))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#undef __DEFAULT_FN_ATTRS_FP8

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

#undef __DEFAULT_FN_ATTRS_FP8

#endif /* __x86_64__ */
#endif /* __AMXFP8INTRIN_H */
36 changes: 36 additions & 0 deletions clang/test/CodeGen/X86/amx_fp8_api.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// RUN: %clang_cc1 %s -ffreestanding -triple=x86_64-unknown-unknown -target-feature +amx-fp8 \
// RUN: -emit-llvm -o - -Werror -pedantic | FileCheck %s
#include <immintrin.h>

void test_tdpbf8ps(__tile1024i src1, __tile1024i src2, __tile1024i dst) {
//CHECK-LABEL: @test_tdpbf8ps
//CHECK-DAG: call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> {{%.*}})
//CHECK-DAG: call x86_amx @llvm.x86.tdpbf8ps.internal
//CHECK-DAG: call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx {{%.*}})
__tile_dpbf8ps(&dst, src1, src2);
}

void test_tdpbhf8ps(__tile1024i src1, __tile1024i src2, __tile1024i dst) {
//CHECK-LABEL: @test_tdpbhf8ps
//CHECK-DAG: call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> {{%.*}})
//CHECK-DAG: call x86_amx @llvm.x86.tdpbhf8ps.internal
//CHECK-DAG: call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx {{%.*}})
__tile_dpbhf8ps(&dst, src1, src2);
}

void test_tdphbf8ps(__tile1024i src1, __tile1024i src2, __tile1024i dst) {
//CHECK-LABEL: @test_tdphbf8ps
//CHECK-DAG: call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> {{%.*}})
//CHECK-DAG: call x86_amx @llvm.x86.tdphbf8ps.internal
//CHECK-DAG: call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx {{%.*}})
__tile_dphbf8ps(&dst, src1, src2);
}

void test_tdphf8ps(__tile1024i src1, __tile1024i src2, __tile1024i dst) {
//CHECK-LABEL: @test_tdphf8ps
//CHECK-DAG: call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> {{%.*}})
//CHECK-DAG: call x86_amx @llvm.x86.tdphf8ps.internal
//CHECK-DAG: call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx {{%.*}})
__tile_dphf8ps(&dst, src1, src2);
}

25 changes: 25 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsX86.td
Original file line number Diff line number Diff line change
Expand Up @@ -6120,6 +6120,31 @@ let TargetPrefix = "x86" in {
Intrinsic<[llvm_x86amx_ty],
[llvm_i16_ty, llvm_i16_ty, llvm_i16_ty, llvm_x86amx_ty,
llvm_x86amx_ty, llvm_x86amx_ty], []>;

def int_x86_tdpbf8ps_internal :
ClangBuiltin<"__builtin_ia32_tdpbf8ps_internal">,
Intrinsic<[llvm_x86amx_ty],
[llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
llvm_x86amx_ty, llvm_x86amx_ty, llvm_x86amx_ty],
[]>;
def int_x86_tdpbhf8ps_internal :
ClangBuiltin<"__builtin_ia32_tdpbhf8ps_internal">,
Intrinsic<[llvm_x86amx_ty],
[llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
llvm_x86amx_ty, llvm_x86amx_ty, llvm_x86amx_ty],
[]>;
def int_x86_tdphbf8ps_internal :
ClangBuiltin<"__builtin_ia32_tdphbf8ps_internal">,
Intrinsic<[llvm_x86amx_ty],
[llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
llvm_x86amx_ty, llvm_x86amx_ty, llvm_x86amx_ty],
[]>;
def int_x86_tdphf8ps_internal :
ClangBuiltin<"__builtin_ia32_tdphf8ps_internal">,
Intrinsic<[llvm_x86amx_ty],
[llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
llvm_x86amx_ty, llvm_x86amx_ty, llvm_x86amx_ty],
[]>;
}

//===----------------------------------------------------------------------===//
Expand Down
18 changes: 17 additions & 1 deletion llvm/lib/Target/X86/X86ExpandPseudo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,11 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB,
case X86::PTDPBF16PSV:
case X86::PTDPFP16PSV:
case X86::PTMMULTF32PSV:
case X86::PTTMMULTF32PSV: {
case X86::PTTMMULTF32PSV:
case X86::PTDPBF8PSV:
case X86::PTDPBHF8PSV:
case X86::PTDPHBF8PSV:
case X86::PTDPHF8PSV: {
MI.untieRegOperand(4);
for (unsigned i = 3; i > 0; --i)
MI.removeOperand(i);
Expand All @@ -777,6 +781,18 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB,
case X86::PTTMMULTF32PSV:
Opc = X86::TTMMULTF32PS;
break;
case X86::PTDPBF8PSV:
Opc = X86::TDPBF8PS;
break;
case X86::PTDPBHF8PSV:
Opc = X86::TDPBHF8PS;
break;
case X86::PTDPHBF8PSV:
Opc = X86::TDPHBF8PS;
break;
case X86::PTDPHF8PSV:
Opc = X86::TDPHF8PS;
break;

default:
llvm_unreachable("Unexpected Opcode");
Expand Down
31 changes: 31 additions & 0 deletions llvm/lib/Target/X86/X86InstrAMX.td
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,37 @@ let Predicates = [HasAMXFP8, In64BitMode] in {
[(int_x86_tdphf8ps timm:$src1, timm:$src2,
timm:$src3)]>;
}

let Constraints = "$src4 = $dst" in {
def PTDPBF8PSV : PseudoI<(outs TILE:$dst),
(ins GR16:$src1, GR16:$src2, GR16:$src3,
TILE:$src4, TILE:$src5, TILE:$src6),
[(set TILE:$dst,
(int_x86_tdpbf8ps_internal GR16:$src1,
GR16:$src2, GR16:$src3, TILE:$src4,
TILE:$src5, TILE:$src6))]>;
def PTDPBHF8PSV : PseudoI<(outs TILE:$dst),
(ins GR16:$src1, GR16:$src2, GR16:$src3,
TILE:$src4, TILE:$src5, TILE:$src6),
[(set TILE:$dst,
(int_x86_tdpbhf8ps_internal GR16:$src1,
GR16:$src2, GR16:$src3, TILE:$src4,
TILE:$src5, TILE:$src6))]>;
def PTDPHBF8PSV : PseudoI<(outs TILE:$dst),
(ins GR16:$src1, GR16:$src2, GR16:$src3,
TILE:$src4, TILE:$src5, TILE:$src6),
[(set TILE:$dst,
(int_x86_tdphbf8ps_internal GR16:$src1,
GR16:$src2, GR16:$src3, TILE:$src4,
TILE:$src5, TILE:$src6))]>;
def PTDPHF8PSV : PseudoI<(outs TILE:$dst),
(ins GR16:$src1, GR16:$src2, GR16:$src3,
TILE:$src4, TILE:$src5, TILE:$src6),
[(set TILE:$dst,
(int_x86_tdphf8ps_internal GR16:$src1,
GR16:$src2, GR16:$src3, TILE:$src4,
TILE:$src5, TILE:$src6))]>;
}
}
}

Expand Down
Loading
Loading