Skip to content

Commit 1502bce

Browse files
MrSidimssys-ce-bb
authored andcommitted
Add error checking for cooperative matrix use and scope parameters (#2223)
Use should be: MatrixA, MatrixB or Accumulator. Scope must be at max Invocation (others are not supported by the translator). Signed-off-by: Sidorov, Dmitry <dmitry.sidorov@intel.com> Original commit: KhronosGroup/SPIRV-LLVM-Translator@f18e64d
1 parent 7d336aa commit 1502bce

13 files changed

+271
-216
lines changed

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVType.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,4 +336,22 @@ void SPIRVTypeCooperativeMatrixKHR::decode(std::istream &I) {
336336
Decoder >> Id >> CompType >> Args;
337337
}
338338

339+
void SPIRVTypeCooperativeMatrixKHR::validate() const {
340+
SPIRVEntry::validate();
341+
SPIRVErrorLog &SPVErrLog = this->getModule()->getErrorLog();
342+
SPIRVConstant *UseConst = static_cast<SPIRVConstant *>(this->getUse());
343+
auto InstName = OpCodeNameMap::map(OC);
344+
uint64_t UseValue = UseConst->getZExtIntValue();
345+
SPVErrLog.checkError(
346+
(UseValue <= CooperativeMatrixUseMatrixAccumulatorKHR),
347+
SPIRVEC_InvalidInstruction,
348+
InstName + "\nIncorrect Use parameter, should be MatrixA, MatrixB or "
349+
"Accumulator\n");
350+
SPIRVConstant *ScopeConst = static_cast<SPIRVConstant *>(this->getScope());
351+
uint64_t ScopeValue = ScopeConst->getZExtIntValue();
352+
SPVErrLog.checkError((ScopeValue <= ScopeInvocation),
353+
SPIRVEC_InvalidInstruction,
354+
InstName + "\nUnsupported Scope parameter\n");
355+
}
356+
339357
} // namespace SPIRV

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVType.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1124,6 +1124,9 @@ class SPIRVTypeCooperativeMatrixKHR : public SPIRVType {
11241124
SPIRVType *CompType;
11251125
std::vector<SPIRVValue *> Args;
11261126

1127+
protected:
1128+
void validate() const override;
1129+
11271130
public:
11281131
const static Op OC = OpTypeCooperativeMatrixKHR;
11291132
const static SPIRVWord FixedWC = 7;

llvm-spirv/test/extensions/INTEL/SPV_INTEL_joint_matrix/bf16_conversion_instructions.ll

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,42 +31,42 @@
3131
; CHECK-SPIRV: CompositeConstruct [[#ShortMatTy]] [[#ShortMat:]]
3232
; CHECK-SPIRV: ConvertBF16ToFINTEL [[#FP32MatTy]] [[#]] [[#ShortMat]]
3333

34-
; CHECK-OCL-IR: %[[#FP32Matrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructf(float 0.000000e+00)
35-
; CHECK-OCL-IR: call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z32intel_convert_bfloat16_as_ushortPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_3(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %[[#FP32Matrix]])
36-
; CHECK-OCL-IR: %[[#ShortMatrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructs(i16 0)
37-
; CHECK-OCL-IR: call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z31intel_convert_as_bfloat16_floatPU3AS145__spirv_CooperativeMatrixKHR__short_3_12_12_3(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) %[[#ShortMatrix]])
34+
; CHECK-OCL-IR: %[[#FP32Matrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructf(float 0.000000e+00)
35+
; CHECK-OCL-IR: call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) @_Z32intel_convert_bfloat16_as_ushortPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_2(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) %[[#FP32Matrix]])
36+
; CHECK-OCL-IR: %[[#ShortMatrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructs(i16 0)
37+
; CHECK-OCL-IR: call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z31intel_convert_as_bfloat16_floatPU3AS145__spirv_CooperativeMatrixKHR__short_3_12_12_2(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) %[[#ShortMatrix]])
3838

3939

40-
; CHECK-SPV-IR: %[[#FP32Matrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructf(float 0.000000e+00)
41-
; CHECK-SPV-IR: call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z27__spirv_ConvertFToBF16INTELPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_3(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %[[#FP32Matrix]])
42-
; CHECK-SPV-IR: %[[#ShortMatrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructs(i16 0)
43-
; CHECK-SPV-IR: call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z27__spirv_ConvertBF16ToFINTELPU3AS145__spirv_CooperativeMatrixKHR__short_3_12_12_3(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) %[[#ShortMatrix]])
40+
; CHECK-SPV-IR: %[[#FP32Matrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructf(float 0.000000e+00)
41+
; CHECK-SPV-IR: call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) @_Z27__spirv_ConvertFToBF16INTELPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_2(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) %[[#FP32Matrix]])
42+
; CHECK-SPV-IR: %[[#ShortMatrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructs(i16 0)
43+
; CHECK-SPV-IR: call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z27__spirv_ConvertBF16ToFINTELPU3AS145__spirv_CooperativeMatrixKHR__short_3_12_12_2(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) %[[#ShortMatrix]])
4444

4545

4646
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
4747
target triple = "spir64-unknown-unknown"
4848

4949
define void @convert_f_to_bf() {
5050
entry:
51-
%0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
52-
%call = call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z27__spirv_ConvertFToBF16INTEL(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0)
51+
%0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
52+
%call = call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) @_Z27__spirv_ConvertFToBF16INTEL(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) %0)
5353
ret void
5454
}
5555

5656
define void @convert_bf_to_f() {
5757
entry:
58-
%0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt16(i16 0)
59-
%call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z27__spirv_ConvertBF16ToFINTEL(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) %0)
58+
%0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructInt16(i16 0)
59+
%call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z27__spirv_ConvertBF16ToFINTEL(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) %0)
6060
ret void
6161
}
6262

63-
declare spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float noundef)
63+
declare spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructFloat(float noundef)
6464

65-
declare spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt16(i16 noundef)
65+
declare spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructInt16(i16 noundef)
6666

67-
declare spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z27__spirv_ConvertFToBF16INTEL(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) noundef)
67+
declare spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) @_Z27__spirv_ConvertFToBF16INTEL(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) noundef)
6868

69-
declare spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z27__spirv_ConvertBF16ToFINTEL(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) noundef)
69+
declare spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z27__spirv_ConvertBF16ToFINTEL(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) noundef)
7070

7171
!llvm.module.flags = !{!0, !1, !2, !3, !4}
7272
!llvm.ident = !{!5}

llvm-spirv/test/extensions/INTEL/SPV_INTEL_joint_matrix/cooperative_matrix_apply.ll

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
; CHECK-SPIRV: CooperativeMatrixApplyFunctionINTEL [[#MatTy]] [[#Apply:]] [[#Ptr]] [[#Mat]]
1919
; CHECK-SPIRV: CooperativeMatrixStoreKHR [[#]] [[#Apply]]
2020

21-
; CHECK-LLVM: %[[Mat:[%0-9a-z.]+]] = call spir_func target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) @"_Z26__spirv_CompositeConstructP38class.sycl::_V1::ext::oneapi::bfloat16"
22-
; CHECK-LLVM: %[[Apply:[%0-9a-z.]+]] = call spir_func target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) @"_Z43__spirv_CooperativeMatrixApplyFunctionINTELPU3AS477class.sycl::_V1::ext::oneapi::experimental::matrix::helper::reference_wrapperPU3AS144__spirv_CooperativeMatrixKHR__short_8_16_0_0"(ptr addrspace(4) %ref.tmp.ascast.i21, target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) %[[Mat]])
23-
; CHECK-LLVM: call spir_func void @"_Z33__spirv_CooperativeMatrixStoreKHRPU3AS138class.sycl::_V1::ext::oneapi::bfloat16PU3AS144__spirv_CooperativeMatrixKHR__short_8_16_0_0liii"(ptr addrspace(1) %{{.*}}, target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) %[[Apply]], i64 32, i32 0, i32 3, i32 0)
21+
; CHECK-LLVM: %[[Mat:[%0-9a-z.]+]] = call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 8, 16, 0) @"_Z26__spirv_CompositeConstructP38class.sycl::_V1::ext::oneapi::bfloat16"
22+
; CHECK-LLVM: %[[Apply:[%0-9a-z.]+]] = call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 8, 16, 0) @"_Z43__spirv_CooperativeMatrixApplyFunctionINTELPU3AS477class.sycl::_V1::ext::oneapi::experimental::matrix::helper::reference_wrapperPU3AS144__spirv_CooperativeMatrixKHR__short_3_8_16_0"(ptr addrspace(4) %ref.tmp.ascast.i21, target("spirv.CooperativeMatrixKHR", i16, 3, 8, 16, 0) %[[Mat]])
23+
; CHECK-LLVM: call spir_func void @"_Z33__spirv_CooperativeMatrixStoreKHRPU3AS138class.sycl::_V1::ext::oneapi::bfloat16PU3AS144__spirv_CooperativeMatrixKHR__short_3_8_16_0il"(ptr addrspace(1) %{{.*}}, target("spirv.CooperativeMatrixKHR", i16, 3, 8, 16, 0) %[[Apply]], i32 0, i64 0)
2424

2525
; ModuleID = 'matrix_apply.bc'
2626
source_filename = "../llvm/sycl/test-e2e/Matrix/joint_matrix_apply_bf16.cpp"
@@ -93,14 +93,14 @@ entry:
9393
%call.i.i = call spir_func noundef zeroext i16 @__devicelib_ConvertFToBF16INTEL(ptr addrspace(4) noundef align 4 dereferenceable(4) %ref.tmp6.ascast.i)
9494
call void @llvm.lifetime.start.p0(i64 2, ptr nonnull %agg.tmp.i17)
9595
store i16 %call.i.i, ptr %agg.tmp.i17, align 2
96-
%call.i18 = call spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) @_Z26__spirv_CompositeConstruct(ptr noundef nonnull byval(%"class.sycl::_V1::ext::oneapi::bfloat16") align 2 %agg.tmp.i17)
96+
%call.i18 = call spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 8, 16, 0) @_Z26__spirv_CompositeConstruct(ptr noundef nonnull byval(%"class.sycl::_V1::ext::oneapi::bfloat16") align 2 %agg.tmp.i17)
9797
call void @llvm.lifetime.end.p0(i64 2, ptr nonnull %agg.tmp.i17)
9898
call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %ref.tmp6.i)
9999
%lambda.i = getelementptr inbounds %class.anon.0, ptr addrspace(4) %__SYCLKernel.ascast, i64 0, i32 1
100100
%ref.tmp.ascast.i21 = addrspacecast ptr %ref.tmp.i20 to ptr addrspace(4)
101101
call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %ref.tmp.i20)
102102
store ptr addrspace(4) %lambda.i, ptr %ref.tmp.i20, align 8
103-
%call.i22 = call spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) @_Z43__spirv_CooperativeMatrixApplyFunctionINTEL(ptr addrspace(4) noundef align 8 dereferenceable(8) %ref.tmp.ascast.i21, target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) noundef %call.i18)
103+
%call.i22 = call spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 8, 16, 0) @_Z43__spirv_CooperativeMatrixApplyFunctionINTEL(ptr addrspace(4) noundef align 8 dereferenceable(8) %ref.tmp.ascast.i21, target("spirv.CooperativeMatrixKHR", i16, 3, 8, 16, 0) noundef %call.i18)
104104
call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %ref.tmp.i20)
105105
%6 = load ptr addrspace(1), ptr %0, align 8
106106
%7 = load i64, ptr %__SYCLKernel, align 8
@@ -114,7 +114,7 @@ entry:
114114
%add.ptr.i43 = getelementptr inbounds %"class.sycl::_V1::ext::oneapi::bfloat16", ptr addrspace(1) %add.ptr.i.i, i64 %mul12.i
115115
%div14.i = and i64 %sub5.i, -16
116116
%add.ptr.i44 = getelementptr inbounds %"class.sycl::_V1::ext::oneapi::bfloat16", ptr addrspace(1) %add.ptr.i43, i64 %div14.i
117-
call spir_func void @_Z33__spirv_CooperativeMatrixStoreKHRPU3AS4iPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3ili(ptr addrspace(1) noundef %add.ptr.i44, target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) noundef %call.i22, i64 noundef 32, i32 noundef 0, i32 noundef 3, i32 noundef 0)
117+
call spir_func void @_Z33__spirv_CooperativeMatrixStoreKHRPU3AS4iPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3ili(ptr addrspace(1) noundef %add.ptr.i44, target("spirv.CooperativeMatrixKHR", i16, 3, 8, 16, 0) noundef %call.i22, i32 noundef 0, i64 noundef 0)
118118
call void @llvm.lifetime.end.p0(i64 64, ptr nonnull %__SYCLKernel)
119119
ret void
120120
}
@@ -126,16 +126,16 @@ declare void @llvm.lifetime.start.p0(i64 immarg, ptr nocapture)
126126
declare void @llvm.lifetime.end.p0(i64 immarg, ptr nocapture)
127127

128128
; Function Attrs: convergent nounwind
129-
declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) @_Z26__spirv_CompositeConstruct(ptr noundef byval(%"class.sycl::_V1::ext::oneapi::bfloat16") align 2) local_unnamed_addr
129+
declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 8, 16, 0) @_Z26__spirv_CompositeConstruct(ptr noundef byval(%"class.sycl::_V1::ext::oneapi::bfloat16") align 2) local_unnamed_addr
130130

131131
; Function Attrs: convergent nounwind
132132
declare dso_local spir_func zeroext i16 @__devicelib_ConvertFToBF16INTEL(ptr addrspace(4) noundef align 4 dereferenceable(4)) local_unnamed_addr
133133

134134
; Function Attrs: convergent nounwind
135-
declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) @_Z43__spirv_CooperativeMatrixApplyFunctionINTEL(ptr addrspace(4) noundef align 8 dereferenceable(8), target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) noundef) local_unnamed_addr
135+
declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 8, 16, 0) @_Z43__spirv_CooperativeMatrixApplyFunctionINTEL(ptr addrspace(4) noundef align 8 dereferenceable(8), target("spirv.CooperativeMatrixKHR", i16, 3, 8, 16, 0) noundef) local_unnamed_addr
136136

137137
; Function Attrs: convergent nounwind
138-
declare dso_local spir_func void @_Z33__spirv_CooperativeMatrixStoreKHRPU3AS4iPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3ili(ptr addrspace(1) noundef, target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) noundef, i64 noundef, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr
138+
declare dso_local spir_func void @_Z33__spirv_CooperativeMatrixStoreKHRPU3AS4iPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3ili(ptr addrspace(1) noundef, target("spirv.CooperativeMatrixKHR", i16, 3, 8, 16, 0) noundef, i32 noundef, i64 noundef) local_unnamed_addr
139139

140140
!llvm.module.flags = !{!0, !1}
141141
!opencl.spir.version = !{!2}

0 commit comments

Comments
 (0)