Skip to content

Commit b3fbdf7

Browse files
MrSidimssys-ce-bb
authored andcommitted
Enable BFloat16 and TensorFloat32 conversions for cooperative matrices (#2213)
Previously added scalar/vector ConvertFToBF16INTEL, ConvertBF16ToFINTEL and RoundFToTF32INTEL conversions are now enabled for cooperative matrix type under SPV_INTEL_joint_matrix extension following the spec: https://github.com/intel/llvm/blob/sycl/sycl/doc/design/spirv-extensions/SPV_INTEL_joint_matrix.asciidoc Note, joint matrices are not allowed as input/output for these conversions as it is being deprecated. Signed-off-by: Sidorov, Dmitry <dmitry.sidorov@intel.com> Original commit: KhronosGroup/SPIRV-LLVM-Translator@1010efc
1 parent ea86400 commit b3fbdf7

File tree

3 files changed

+183
-2
lines changed

3 files changed

+183
-2
lines changed

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3294,10 +3294,17 @@ template <Op OC>
32943294
class SPIRVBfloat16ConversionINTELInstBase : public SPIRVUnaryInst<OC> {
32953295
protected:
32963296
SPIRVCapVec getRequiredCapability() const override {
3297+
SPIRVType *ResCompTy = this->getType();
3298+
if (ResCompTy->isTypeCooperativeMatrixKHR())
3299+
return getVec(internal::CapabilityBfloat16ConversionINTEL,
3300+
internal::CapabilityJointMatrixBF16ComponentTypeINTEL);
32973301
return getVec(internal::CapabilityBfloat16ConversionINTEL);
32983302
}
32993303

33003304
std::optional<ExtensionID> getRequiredExtension() const override {
3305+
SPIRVType *ResCompTy = this->getType();
3306+
if (ResCompTy->isTypeCooperativeMatrixKHR())
3307+
this->getModule()->addExtension(ExtensionID::SPV_INTEL_joint_matrix);
33013308
return ExtensionID::SPV_INTEL_bfloat16_conversion;
33023309
}
33033310

@@ -3326,8 +3333,25 @@ class SPIRVBfloat16ConversionINTELInstBase : public SPIRVUnaryInst<OC> {
33263333
}
33273334

33283335
auto InstName = OpCodeNameMap::map(OC);
3329-
SPIRVErrorLog &SPVErrLog = this->getModule()->getErrorLog();
3336+
auto *Module = this->getModule();
3337+
SPIRVErrorLog &SPVErrLog = Module->getErrorLog();
33303338

3339+
// Cooperative matrix type is allowed as input/output of the instruction
3340+
// if SPV_INTEL_joint_matrix is enabled
3341+
if (ResCompTy->isTypeCooperativeMatrixKHR()) {
3342+
SPVErrLog.checkError(
3343+
Module->isAllowedToUseExtension(ExtensionID::SPV_INTEL_joint_matrix),
3344+
SPIRVEC_InvalidInstruction,
3345+
InstName + "\nCan be used with "
3346+
"cooperative matrices only when SPV_INTEL_joint_matrix is "
3347+
"enabled\n");
3348+
assert(InCompTy->isTypeCooperativeMatrixKHR() &&
3349+
"Input must also be a cooperative matrix");
3350+
ResCompTy = static_cast<SPIRVTypeCooperativeMatrixKHR *>(ResCompTy)
3351+
->getCompType();
3352+
InCompTy =
3353+
static_cast<SPIRVTypeCooperativeMatrixKHR *>(InCompTy)->getCompType();
3354+
}
33313355
if (OC == internal::OpConvertFToBF16INTEL) {
33323356
SPVErrLog.checkError(
33333357
ResCompTy->isTypeInt(16), SPIRVEC_InvalidInstruction,
@@ -3679,10 +3703,17 @@ template <Op OC>
36793703
class SPIRVTensorFloat32RoundingINTELInstBase : public SPIRVUnaryInst<OC> {
36803704
protected:
36813705
SPIRVCapVec getRequiredCapability() const override {
3706+
SPIRVType *ResCompTy = this->getType();
3707+
if (ResCompTy->isTypeCooperativeMatrixKHR())
3708+
return getVec(internal::CapabilityTensorFloat32RoundingINTEL,
3709+
internal::CapabilityJointMatrixTF32ComponentTypeINTEL);
36823710
return getVec(internal::CapabilityTensorFloat32RoundingINTEL);
36833711
}
36843712

36853713
std::optional<ExtensionID> getRequiredExtension() const override {
3714+
SPIRVType *ResCompTy = this->getType();
3715+
if (ResCompTy->isTypeCooperativeMatrixKHR())
3716+
this->getModule()->addExtension(ExtensionID::SPV_INTEL_joint_matrix);
36863717
return ExtensionID::SPV_INTEL_tensor_float32_conversion;
36873718
}
36883719

@@ -3711,7 +3742,25 @@ class SPIRVTensorFloat32RoundingINTELInstBase : public SPIRVUnaryInst<OC> {
37113742
}
37123743

37133744
auto InstName = OpCodeNameMap::map(OC);
3714-
SPIRVErrorLog &SPVErrLog = this->getModule()->getErrorLog();
3745+
auto *Module = this->getModule();
3746+
SPIRVErrorLog &SPVErrLog = Module->getErrorLog();
3747+
3748+
// Cooperative matrix type is allowed as input/output of the instruction
3749+
// if SPV_INTEL_joint_matrix is enabled
3750+
if (ResCompTy->isTypeCooperativeMatrixKHR()) {
3751+
SPVErrLog.checkError(
3752+
Module->isAllowedToUseExtension(ExtensionID::SPV_INTEL_joint_matrix),
3753+
SPIRVEC_InvalidInstruction,
3754+
InstName + "\nCan be used with "
3755+
"cooperative matrices only when SPV_INTEL_joint_matrix is "
3756+
"enabled\n");
3757+
assert(InCompTy->isTypeCooperativeMatrixKHR() &&
3758+
"Input must also be a cooperative matrix");
3759+
ResCompTy = static_cast<SPIRVTypeCooperativeMatrixKHR *>(ResCompTy)
3760+
->getCompType();
3761+
InCompTy =
3762+
static_cast<SPIRVTypeCooperativeMatrixKHR *>(InCompTy)->getCompType();
3763+
}
37153764

37163765
SPVErrLog.checkError(
37173766
ResCompTy->isTypeFloat(32), SPIRVEC_InvalidInstruction,
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
; RUN: llvm-as < %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_KHR_cooperative_matrix,+SPV_INTEL_joint_matrix,+SPV_INTEL_bfloat16_conversion -o %t.spv
3+
; RUN: llvm-spirv %t.spv -to-text -o %t.spt
4+
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
5+
6+
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
7+
; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-OCL-IR
8+
9+
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc --spirv-target-env=SPV-IR
10+
; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-SPV-IR
11+
12+
; RUN: not llvm-spirv %t.bc --spirv-ext=+SPV_KHR_cooperative_matrix,+SPV_INTEL_bfloat16_conversion 2>&1 \
13+
; RUN: | FileCheck %s --check-prefix=CHECK-ERROR
14+
15+
; CHECK-ERROR: InvalidInstruction: Can't translate llvm instruction:
16+
; CHECK-ERROR-NEXT: ConvertFToBF16INTEL
17+
; CHECK-ERROR-NEXT: Can be used with cooperative matrices only when SPV_INTEL_joint_matrix is enabled
18+
19+
; CHECK-SPIRV-DAG: Capability CooperativeMatrixKHR
20+
; CHECK-SPIRV-DAG: Capability Bfloat16ConversionINTEL
21+
; CHECK-SPIRV-DAG: Capability JointMatrixBF16ComponentTypeINTEL
22+
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_bfloat16_conversion"
23+
; CHECK-SPIRV-DAG: Extension "SPV_KHR_cooperative_matrix"
24+
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_joint_matrix"
25+
; CHECK-SPIRV-DAG: TypeInt [[#ShortTy:]] 16 0
26+
; CHECK-SPIRV-DAG: TypeFloat [[#FP32Ty:]] 32
27+
; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#FP32MatTy:]] [[#FP32Ty]]
28+
; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#ShortMatTy:]] [[#ShortTy]]
29+
; CHECK-SPIRV: CompositeConstruct [[#FP32MatTy]] [[#FP32Mat:]]
30+
; CHECK-SPIRV: ConvertFToBF16INTEL [[#ShortMatTy]] [[#]] [[#FP32Mat]]
31+
; CHECK-SPIRV: CompositeConstruct [[#ShortMatTy]] [[#ShortMat:]]
32+
; CHECK-SPIRV: ConvertBF16ToFINTEL [[#FP32MatTy]] [[#]] [[#ShortMat]]
33+
34+
; CHECK-OCL-IR: %[[#FP32Matrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructf(float 0.000000e+00)
35+
; CHECK-OCL-IR: call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z32intel_convert_bfloat16_as_ushortPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_3(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %[[#FP32Matrix]])
36+
; CHECK-OCL-IR: %[[#ShortMatrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructs(i16 0)
37+
; CHECK-OCL-IR: call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z31intel_convert_as_bfloat16_floatPU3AS145__spirv_CooperativeMatrixKHR__short_3_12_12_3(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) %[[#ShortMatrix]])
38+
39+
40+
; CHECK-SPV-IR: %[[#FP32Matrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructf(float 0.000000e+00)
41+
; CHECK-SPV-IR: call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z27__spirv_ConvertFToBF16INTELPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_3(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %[[#FP32Matrix]])
42+
; CHECK-SPV-IR: %[[#ShortMatrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructs(i16 0)
43+
; CHECK-SPV-IR: call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z27__spirv_ConvertBF16ToFINTELPU3AS145__spirv_CooperativeMatrixKHR__short_3_12_12_3(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) %[[#ShortMatrix]])
44+
45+
46+
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
47+
target triple = "spir64-unknown-unknown"
48+
49+
define void @convert_f_to_bf() {
50+
entry:
51+
%0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
52+
%call = call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z27__spirv_ConvertFToBF16INTEL(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0)
53+
ret void
54+
}
55+
56+
define void @convert_bf_to_f() {
57+
entry:
58+
%0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt16(i16 0)
59+
%call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z27__spirv_ConvertBF16ToFINTEL(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) %0)
60+
ret void
61+
}
62+
63+
declare spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float noundef)
64+
65+
declare spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt16(i16 noundef)
66+
67+
declare spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z27__spirv_ConvertFToBF16INTEL(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) noundef)
68+
69+
declare spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z27__spirv_ConvertBF16ToFINTEL(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) noundef)
70+
71+
!llvm.module.flags = !{!0, !1, !2, !3, !4}
72+
!llvm.ident = !{!5}
73+
74+
!0 = !{i32 7, !"Dwarf Version", i32 4}
75+
!1 = !{i32 1, !"wchar_size", i32 4}
76+
!2 = !{i32 8, !"PIC Level", i32 2}
77+
!3 = !{i32 7, !"PIE Level", i32 2}
78+
!4 = !{i32 7, !"uwtable", i32 2}
79+
!5 = !{!"clang version 17.0.0"}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
; RUN: llvm-as < %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_KHR_cooperative_matrix,+SPV_INTEL_joint_matrix,+SPV_INTEL_tensor_float32_conversion -o %t.spv
3+
; RUN: llvm-spirv %t.spv -to-text -o %t.spt
4+
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
5+
6+
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
7+
; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM
8+
9+
; RUN: not llvm-spirv %t.bc --spirv-ext=+SPV_KHR_cooperative_matrix,+SPV_INTEL_tensor_float32_conversion 2>&1 \
10+
; RUN: | FileCheck %s --check-prefix=CHECK-ERROR
11+
12+
; CHECK-ERROR: InvalidInstruction: Can't translate llvm instruction:
13+
; CHECK-ERROR-NEXT: RoundFToTF32INTEL
14+
; CHECK-ERROR-NEXT: Can be used with cooperative matrices only when SPV_INTEL_joint_matrix is enabled
15+
16+
; CHECK-SPIRV-DAG: Capability CooperativeMatrixKHR
17+
; CHECK-SPIRV-DAG: Capability TensorFloat32RoundingINTEL
18+
; CHECK-SPIRV-DAG: Capability JointMatrixTF32ComponentTypeINTEL
19+
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_tensor_float32_conversion"
20+
; CHECK-SPIRV-DAG: Extension "SPV_KHR_cooperative_matrix"
21+
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_joint_matrix"
22+
; CHECK-SPIRV-DAG: TypeFloat [[#FP32Ty:]] 32
23+
; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#FP32MatTy:]] [[#FP32Ty]]
24+
; CHECK-SPIRV: CompositeConstruct [[#FP32MatTy]] [[#FP32Mat:]]
25+
; CHECK-SPIRV: RoundFToTF32INTEL [[#FP32MatTy]] [[#]] [[#FP32Mat]]
26+
27+
; CHECK-LLVM: %[[#Mat:]] = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructf(float 0.000000e+00)
28+
; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z25__spirv_RoundFToTF32INTELPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_3(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %[[#Mat]])
29+
30+
31+
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
32+
target triple = "spir64-unknown-unknown"
33+
34+
define void @convert_f_to_tf() {
35+
entry:
36+
%0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
37+
%call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z25__spirv_RoundFToTF32INTEL(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0)
38+
ret void
39+
}
40+
41+
declare spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float noundef)
42+
43+
declare spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z25__spirv_RoundFToTF32INTEL(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) noundef)
44+
45+
!llvm.module.flags = !{!0, !1, !2, !3, !4}
46+
!llvm.ident = !{!5}
47+
48+
!0 = !{i32 7, !"Dwarf Version", i32 4}
49+
!1 = !{i32 1, !"wchar_size", i32 4}
50+
!2 = !{i32 8, !"PIC Level", i32 2}
51+
!3 = !{i32 7, !"PIE Level", i32 2}
52+
!4 = !{i32 7, !"uwtable", i32 2}
53+
!5 = !{!"clang version 17.0.0"}

0 commit comments

Comments
 (0)