Skip to content

Commit

Permalink
Enable BFloat16 and TensorFloat32 conversions for cooperative matrices (
Browse files Browse the repository at this point in the history
#2213)

Previously added scalar/vector ConvertFToBF16INTEL, ConvertBF16ToFINTEL
and RoundFToTF32INTEL conversions are now enabled for cooperative matrix
type under SPV_INTEL_joint_matrix extension following the spec:
https://github.com/intel/llvm/blob/sycl/sycl/doc/design/spirv-extensions/SPV_INTEL_joint_matrix.asciidoc

Note, joint matrices are not allowed as input/output for these
conversions as it is being deprecated.

Signed-off-by: Sidorov, Dmitry <dmitry.sidorov@intel.com>
  • Loading branch information
MrSidims authored Nov 16, 2023
1 parent ae2d38b commit 1010efc
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 2 deletions.
53 changes: 51 additions & 2 deletions lib/SPIRV/libSPIRV/SPIRVInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -3294,10 +3294,17 @@ template <Op OC>
class SPIRVBfloat16ConversionINTELInstBase : public SPIRVUnaryInst<OC> {
protected:
SPIRVCapVec getRequiredCapability() const override {
SPIRVType *ResCompTy = this->getType();
if (ResCompTy->isTypeCooperativeMatrixKHR())
return getVec(internal::CapabilityBfloat16ConversionINTEL,
internal::CapabilityJointMatrixBF16ComponentTypeINTEL);
return getVec(internal::CapabilityBfloat16ConversionINTEL);
}

std::optional<ExtensionID> getRequiredExtension() const override {
SPIRVType *ResCompTy = this->getType();
if (ResCompTy->isTypeCooperativeMatrixKHR())
this->getModule()->addExtension(ExtensionID::SPV_INTEL_joint_matrix);
return ExtensionID::SPV_INTEL_bfloat16_conversion;
}

Expand Down Expand Up @@ -3326,8 +3333,25 @@ class SPIRVBfloat16ConversionINTELInstBase : public SPIRVUnaryInst<OC> {
}

auto InstName = OpCodeNameMap::map(OC);
SPIRVErrorLog &SPVErrLog = this->getModule()->getErrorLog();
auto *Module = this->getModule();
SPIRVErrorLog &SPVErrLog = Module->getErrorLog();

// Cooperative matrix type is allowed as input/output of the instruction
// if SPV_INTEL_joint_matrix is enabled
if (ResCompTy->isTypeCooperativeMatrixKHR()) {
SPVErrLog.checkError(
Module->isAllowedToUseExtension(ExtensionID::SPV_INTEL_joint_matrix),
SPIRVEC_InvalidInstruction,
InstName + "\nCan be used with "
"cooperative matrices only when SPV_INTEL_joint_matrix is "
"enabled\n");
assert(InCompTy->isTypeCooperativeMatrixKHR() &&
"Input must also be a cooperative matrix");
ResCompTy = static_cast<SPIRVTypeCooperativeMatrixKHR *>(ResCompTy)
->getCompType();
InCompTy =
static_cast<SPIRVTypeCooperativeMatrixKHR *>(InCompTy)->getCompType();
}
if (OC == internal::OpConvertFToBF16INTEL) {
SPVErrLog.checkError(
ResCompTy->isTypeInt(16), SPIRVEC_InvalidInstruction,
Expand Down Expand Up @@ -3679,10 +3703,17 @@ template <Op OC>
class SPIRVTensorFloat32RoundingINTELInstBase : public SPIRVUnaryInst<OC> {
protected:
SPIRVCapVec getRequiredCapability() const override {
SPIRVType *ResCompTy = this->getType();
if (ResCompTy->isTypeCooperativeMatrixKHR())
return getVec(internal::CapabilityTensorFloat32RoundingINTEL,
internal::CapabilityJointMatrixTF32ComponentTypeINTEL);
return getVec(internal::CapabilityTensorFloat32RoundingINTEL);
}

std::optional<ExtensionID> getRequiredExtension() const override {
SPIRVType *ResCompTy = this->getType();
if (ResCompTy->isTypeCooperativeMatrixKHR())
this->getModule()->addExtension(ExtensionID::SPV_INTEL_joint_matrix);
return ExtensionID::SPV_INTEL_tensor_float32_conversion;
}

Expand Down Expand Up @@ -3711,7 +3742,25 @@ class SPIRVTensorFloat32RoundingINTELInstBase : public SPIRVUnaryInst<OC> {
}

auto InstName = OpCodeNameMap::map(OC);
SPIRVErrorLog &SPVErrLog = this->getModule()->getErrorLog();
auto *Module = this->getModule();
SPIRVErrorLog &SPVErrLog = Module->getErrorLog();

// Cooperative matrix type is allowed as input/output of the instruction
// if SPV_INTEL_joint_matrix is enabled
if (ResCompTy->isTypeCooperativeMatrixKHR()) {
SPVErrLog.checkError(
Module->isAllowedToUseExtension(ExtensionID::SPV_INTEL_joint_matrix),
SPIRVEC_InvalidInstruction,
InstName + "\nCan be used with "
"cooperative matrices only when SPV_INTEL_joint_matrix is "
"enabled\n");
assert(InCompTy->isTypeCooperativeMatrixKHR() &&
"Input must also be a cooperative matrix");
ResCompTy = static_cast<SPIRVTypeCooperativeMatrixKHR *>(ResCompTy)
->getCompType();
InCompTy =
static_cast<SPIRVTypeCooperativeMatrixKHR *>(InCompTy)->getCompType();
}

SPVErrLog.checkError(
ResCompTy->isTypeFloat(32), SPIRVEC_InvalidInstruction,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
; RUN: llvm-as < %s -o %t.bc
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_KHR_cooperative_matrix,+SPV_INTEL_joint_matrix,+SPV_INTEL_bfloat16_conversion -o %t.spv
; RUN: llvm-spirv %t.spv -to-text -o %t.spt
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV

; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-OCL-IR

; RUN: llvm-spirv -r %t.spv -o %t.rev.bc --spirv-target-env=SPV-IR
; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-SPV-IR

; RUN: not llvm-spirv %t.bc --spirv-ext=+SPV_KHR_cooperative_matrix,+SPV_INTEL_bfloat16_conversion 2>&1 \
; RUN: | FileCheck %s --check-prefix=CHECK-ERROR

; CHECK-ERROR: InvalidInstruction: Can't translate llvm instruction:
; CHECK-ERROR-NEXT: ConvertFToBF16INTEL
; CHECK-ERROR-NEXT: Can be used with cooperative matrices only when SPV_INTEL_joint_matrix is enabled

; CHECK-SPIRV-DAG: Capability CooperativeMatrixKHR
; CHECK-SPIRV-DAG: Capability Bfloat16ConversionINTEL
; CHECK-SPIRV-DAG: Capability JointMatrixBF16ComponentTypeINTEL
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_bfloat16_conversion"
; CHECK-SPIRV-DAG: Extension "SPV_KHR_cooperative_matrix"
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_joint_matrix"
; CHECK-SPIRV-DAG: TypeInt [[#ShortTy:]] 16 0
; CHECK-SPIRV-DAG: TypeFloat [[#FP32Ty:]] 32
; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#FP32MatTy:]] [[#FP32Ty]]
; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#ShortMatTy:]] [[#ShortTy]]
; CHECK-SPIRV: CompositeConstruct [[#FP32MatTy]] [[#FP32Mat:]]
; CHECK-SPIRV: ConvertFToBF16INTEL [[#ShortMatTy]] [[#]] [[#FP32Mat]]
; CHECK-SPIRV: CompositeConstruct [[#ShortMatTy]] [[#ShortMat:]]
; CHECK-SPIRV: ConvertBF16ToFINTEL [[#FP32MatTy]] [[#]] [[#ShortMat]]

; CHECK-OCL-IR: %[[#FP32Matrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructf(float 0.000000e+00)
; CHECK-OCL-IR: call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z32intel_convert_bfloat16_as_ushortPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_3(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %[[#FP32Matrix]])
; CHECK-OCL-IR: %[[#ShortMatrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructs(i16 0)
; CHECK-OCL-IR: call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z31intel_convert_as_bfloat16_floatPU3AS145__spirv_CooperativeMatrixKHR__short_3_12_12_3(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) %[[#ShortMatrix]])


; CHECK-SPV-IR: %[[#FP32Matrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructf(float 0.000000e+00)
; CHECK-SPV-IR: call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z27__spirv_ConvertFToBF16INTELPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_3(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %[[#FP32Matrix]])
; CHECK-SPV-IR: %[[#ShortMatrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructs(i16 0)
; CHECK-SPV-IR: call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z27__spirv_ConvertBF16ToFINTELPU3AS145__spirv_CooperativeMatrixKHR__short_3_12_12_3(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) %[[#ShortMatrix]])


target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "spir64-unknown-unknown"

define void @convert_f_to_bf() {
entry:
%0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
%call = call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z27__spirv_ConvertFToBF16INTEL(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0)
ret void
}

define void @convert_bf_to_f() {
entry:
%0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt16(i16 0)
%call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z27__spirv_ConvertBF16ToFINTEL(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) %0)
ret void
}

declare spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float noundef)

declare spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt16(i16 noundef)

declare spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z27__spirv_ConvertFToBF16INTEL(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) noundef)

declare spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z27__spirv_ConvertBF16ToFINTEL(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) noundef)

!llvm.module.flags = !{!0, !1, !2, !3, !4}
!llvm.ident = !{!5}

!0 = !{i32 7, !"Dwarf Version", i32 4}
!1 = !{i32 1, !"wchar_size", i32 4}
!2 = !{i32 8, !"PIC Level", i32 2}
!3 = !{i32 7, !"PIE Level", i32 2}
!4 = !{i32 7, !"uwtable", i32 2}
!5 = !{!"clang version 17.0.0"}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
; RUN: llvm-as < %s -o %t.bc
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_KHR_cooperative_matrix,+SPV_INTEL_joint_matrix,+SPV_INTEL_tensor_float32_conversion -o %t.spv
; RUN: llvm-spirv %t.spv -to-text -o %t.spt
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV

; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM

; RUN: not llvm-spirv %t.bc --spirv-ext=+SPV_KHR_cooperative_matrix,+SPV_INTEL_tensor_float32_conversion 2>&1 \
; RUN: | FileCheck %s --check-prefix=CHECK-ERROR

; CHECK-ERROR: InvalidInstruction: Can't translate llvm instruction:
; CHECK-ERROR-NEXT: RoundFToTF32INTEL
; CHECK-ERROR-NEXT: Can be used with cooperative matrices only when SPV_INTEL_joint_matrix is enabled

; CHECK-SPIRV-DAG: Capability CooperativeMatrixKHR
; CHECK-SPIRV-DAG: Capability TensorFloat32RoundingINTEL
; CHECK-SPIRV-DAG: Capability JointMatrixTF32ComponentTypeINTEL
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_tensor_float32_conversion"
; CHECK-SPIRV-DAG: Extension "SPV_KHR_cooperative_matrix"
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_joint_matrix"
; CHECK-SPIRV-DAG: TypeFloat [[#FP32Ty:]] 32
; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#FP32MatTy:]] [[#FP32Ty]]
; CHECK-SPIRV: CompositeConstruct [[#FP32MatTy]] [[#FP32Mat:]]
; CHECK-SPIRV: RoundFToTF32INTEL [[#FP32MatTy]] [[#]] [[#FP32Mat]]

; CHECK-LLVM: %[[#Mat:]] = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructf(float 0.000000e+00)
; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z25__spirv_RoundFToTF32INTELPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_3(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %[[#Mat]])


target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "spir64-unknown-unknown"

define void @convert_f_to_tf() {
entry:
%0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
%call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z25__spirv_RoundFToTF32INTEL(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0)
ret void
}

declare spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float noundef)

declare spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z25__spirv_RoundFToTF32INTEL(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) noundef)

!llvm.module.flags = !{!0, !1, !2, !3, !4}
!llvm.ident = !{!5}

!0 = !{i32 7, !"Dwarf Version", i32 4}
!1 = !{i32 1, !"wchar_size", i32 4}
!2 = !{i32 8, !"PIC Level", i32 2}
!3 = !{i32 7, !"PIE Level", i32 2}
!4 = !{i32 7, !"uwtable", i32 2}
!5 = !{!"clang version 17.0.0"}

0 comments on commit 1010efc

Please sign in to comment.