Skip to content

Commit

Permalink
Add error checking for cooperative matrix use and scope parameters (#…
Browse files Browse the repository at this point in the history
…2223)

Use should be: MatrixA, MatrixB or Accumulator.
Scope must be at max Invocation (others are not supported
by the translator).

Signed-off-by: Sidorov, Dmitry <dmitry.sidorov@intel.com>
  • Loading branch information
MrSidims authored Nov 23, 2023
1 parent 0888551 commit f18e64d
Show file tree
Hide file tree
Showing 13 changed files with 271 additions and 216 deletions.
18 changes: 18 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,4 +336,22 @@ void SPIRVTypeCooperativeMatrixKHR::decode(std::istream &I) {
Decoder >> Id >> CompType >> Args;
}

void SPIRVTypeCooperativeMatrixKHR::validate() const {
SPIRVEntry::validate();
SPIRVErrorLog &SPVErrLog = this->getModule()->getErrorLog();
SPIRVConstant *UseConst = static_cast<SPIRVConstant *>(this->getUse());
auto InstName = OpCodeNameMap::map(OC);
uint64_t UseValue = UseConst->getZExtIntValue();
SPVErrLog.checkError(
(UseValue <= CooperativeMatrixUseMatrixAccumulatorKHR),
SPIRVEC_InvalidInstruction,
InstName + "\nIncorrect Use parameter, should be MatrixA, MatrixB or "
"Accumulator\n");
SPIRVConstant *ScopeConst = static_cast<SPIRVConstant *>(this->getScope());
uint64_t ScopeValue = ScopeConst->getZExtIntValue();
SPVErrLog.checkError((ScopeValue <= ScopeInvocation),
SPIRVEC_InvalidInstruction,
InstName + "\nUnsupported Scope parameter\n");
}

} // namespace SPIRV
3 changes: 3 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVType.h
Original file line number Diff line number Diff line change
Expand Up @@ -1124,6 +1124,9 @@ class SPIRVTypeCooperativeMatrixKHR : public SPIRVType {
SPIRVType *CompType;
std::vector<SPIRVValue *> Args;

protected:
void validate() const override;

public:
const static Op OC = OpTypeCooperativeMatrixKHR;
const static SPIRVWord FixedWC = 7;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,42 +31,42 @@
; CHECK-SPIRV: CompositeConstruct [[#ShortMatTy]] [[#ShortMat:]]
; CHECK-SPIRV: ConvertBF16ToFINTEL [[#FP32MatTy]] [[#]] [[#ShortMat]]

; CHECK-OCL-IR: %[[#FP32Matrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructf(float 0.000000e+00)
; CHECK-OCL-IR: call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z32intel_convert_bfloat16_as_ushortPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_3(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %[[#FP32Matrix]])
; CHECK-OCL-IR: %[[#ShortMatrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructs(i16 0)
; CHECK-OCL-IR: call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z31intel_convert_as_bfloat16_floatPU3AS145__spirv_CooperativeMatrixKHR__short_3_12_12_3(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) %[[#ShortMatrix]])
; CHECK-OCL-IR: %[[#FP32Matrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructf(float 0.000000e+00)
; CHECK-OCL-IR: call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) @_Z32intel_convert_bfloat16_as_ushortPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_2(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) %[[#FP32Matrix]])
; CHECK-OCL-IR: %[[#ShortMatrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructs(i16 0)
; CHECK-OCL-IR: call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z31intel_convert_as_bfloat16_floatPU3AS145__spirv_CooperativeMatrixKHR__short_3_12_12_2(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) %[[#ShortMatrix]])


; CHECK-SPV-IR: %[[#FP32Matrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructf(float 0.000000e+00)
; CHECK-SPV-IR: call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z27__spirv_ConvertFToBF16INTELPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_3(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %[[#FP32Matrix]])
; CHECK-SPV-IR: %[[#ShortMatrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructs(i16 0)
; CHECK-SPV-IR: call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z27__spirv_ConvertBF16ToFINTELPU3AS145__spirv_CooperativeMatrixKHR__short_3_12_12_3(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) %[[#ShortMatrix]])
; CHECK-SPV-IR: %[[#FP32Matrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructf(float 0.000000e+00)
; CHECK-SPV-IR: call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) @_Z27__spirv_ConvertFToBF16INTELPU3AS145__spirv_CooperativeMatrixKHR__float_3_12_12_2(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) %[[#FP32Matrix]])
; CHECK-SPV-IR: %[[#ShortMatrix:]] = call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructs(i16 0)
; CHECK-SPV-IR: call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z27__spirv_ConvertBF16ToFINTELPU3AS145__spirv_CooperativeMatrixKHR__short_3_12_12_2(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) %[[#ShortMatrix]])


target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "spir64-unknown-unknown"

define void @convert_f_to_bf() {
entry:
%0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
%call = call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z27__spirv_ConvertFToBF16INTEL(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0)
%0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
%call = call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) @_Z27__spirv_ConvertFToBF16INTEL(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) %0)
ret void
}

define void @convert_bf_to_f() {
entry:
%0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt16(i16 0)
%call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z27__spirv_ConvertBF16ToFINTEL(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) %0)
%0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructInt16(i16 0)
%call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z27__spirv_ConvertBF16ToFINTEL(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) %0)
ret void
}

declare spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float noundef)
declare spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructFloat(float noundef)

declare spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt16(i16 noundef)
declare spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructInt16(i16 noundef)

declare spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) @_Z27__spirv_ConvertFToBF16INTEL(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) noundef)
declare spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) @_Z27__spirv_ConvertFToBF16INTEL(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) noundef)

declare spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z27__spirv_ConvertBF16ToFINTEL(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 3) noundef)
declare spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z27__spirv_ConvertBF16ToFINTEL(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 12, 2) noundef)

!llvm.module.flags = !{!0, !1, !2, !3, !4}
!llvm.ident = !{!5}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
; CHECK-SPIRV: CooperativeMatrixApplyFunctionINTEL [[#MatTy]] [[#Apply:]] [[#Ptr]] [[#Mat]]
; CHECK-SPIRV: CooperativeMatrixStoreKHR [[#]] [[#Apply]]

; CHECK-LLVM: %[[Mat:[%0-9a-z.]+]] = call spir_func target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) @"_Z26__spirv_CompositeConstructP38class.sycl::_V1::ext::oneapi::bfloat16"
; CHECK-LLVM: %[[Apply:[%0-9a-z.]+]] = call spir_func target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) @"_Z43__spirv_CooperativeMatrixApplyFunctionINTELPU3AS477class.sycl::_V1::ext::oneapi::experimental::matrix::helper::reference_wrapperPU3AS144__spirv_CooperativeMatrixKHR__short_8_16_0_0"(ptr addrspace(4) %ref.tmp.ascast.i21, target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) %[[Mat]])
; CHECK-LLVM: call spir_func void @"_Z33__spirv_CooperativeMatrixStoreKHRPU3AS138class.sycl::_V1::ext::oneapi::bfloat16PU3AS144__spirv_CooperativeMatrixKHR__short_8_16_0_0liii"(ptr addrspace(1) %{{.*}}, target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) %[[Apply]], i64 32, i32 0, i32 3, i32 0)
; CHECK-LLVM: %[[Mat:[%0-9a-z.]+]] = call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 8, 16, 0) @"_Z26__spirv_CompositeConstructP38class.sycl::_V1::ext::oneapi::bfloat16"
; CHECK-LLVM: %[[Apply:[%0-9a-z.]+]] = call spir_func target("spirv.CooperativeMatrixKHR", i16, 3, 8, 16, 0) @"_Z43__spirv_CooperativeMatrixApplyFunctionINTELPU3AS477class.sycl::_V1::ext::oneapi::experimental::matrix::helper::reference_wrapperPU3AS144__spirv_CooperativeMatrixKHR__short_3_8_16_0"(ptr addrspace(4) %ref.tmp.ascast.i21, target("spirv.CooperativeMatrixKHR", i16, 3, 8, 16, 0) %[[Mat]])
; CHECK-LLVM: call spir_func void @"_Z33__spirv_CooperativeMatrixStoreKHRPU3AS138class.sycl::_V1::ext::oneapi::bfloat16PU3AS144__spirv_CooperativeMatrixKHR__short_3_8_16_0il"(ptr addrspace(1) %{{.*}}, target("spirv.CooperativeMatrixKHR", i16, 3, 8, 16, 0) %[[Apply]], i32 0, i64 0)

; ModuleID = 'matrix_apply.bc'
source_filename = "../llvm/sycl/test-e2e/Matrix/joint_matrix_apply_bf16.cpp"
Expand Down Expand Up @@ -93,14 +93,14 @@ entry:
%call.i.i = call spir_func noundef zeroext i16 @__devicelib_ConvertFToBF16INTEL(ptr addrspace(4) noundef align 4 dereferenceable(4) %ref.tmp6.ascast.i)
call void @llvm.lifetime.start.p0(i64 2, ptr nonnull %agg.tmp.i17)
store i16 %call.i.i, ptr %agg.tmp.i17, align 2
%call.i18 = call spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) @_Z26__spirv_CompositeConstruct(ptr noundef nonnull byval(%"class.sycl::_V1::ext::oneapi::bfloat16") align 2 %agg.tmp.i17)
%call.i18 = call spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 8, 16, 0) @_Z26__spirv_CompositeConstruct(ptr noundef nonnull byval(%"class.sycl::_V1::ext::oneapi::bfloat16") align 2 %agg.tmp.i17)
call void @llvm.lifetime.end.p0(i64 2, ptr nonnull %agg.tmp.i17)
call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %ref.tmp6.i)
%lambda.i = getelementptr inbounds %class.anon.0, ptr addrspace(4) %__SYCLKernel.ascast, i64 0, i32 1
%ref.tmp.ascast.i21 = addrspacecast ptr %ref.tmp.i20 to ptr addrspace(4)
call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %ref.tmp.i20)
store ptr addrspace(4) %lambda.i, ptr %ref.tmp.i20, align 8
%call.i22 = call spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) @_Z43__spirv_CooperativeMatrixApplyFunctionINTEL(ptr addrspace(4) noundef align 8 dereferenceable(8) %ref.tmp.ascast.i21, target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) noundef %call.i18)
%call.i22 = call spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 8, 16, 0) @_Z43__spirv_CooperativeMatrixApplyFunctionINTEL(ptr addrspace(4) noundef align 8 dereferenceable(8) %ref.tmp.ascast.i21, target("spirv.CooperativeMatrixKHR", i16, 3, 8, 16, 0) noundef %call.i18)
call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %ref.tmp.i20)
%6 = load ptr addrspace(1), ptr %0, align 8
%7 = load i64, ptr %__SYCLKernel, align 8
Expand All @@ -114,7 +114,7 @@ entry:
%add.ptr.i43 = getelementptr inbounds %"class.sycl::_V1::ext::oneapi::bfloat16", ptr addrspace(1) %add.ptr.i.i, i64 %mul12.i
%div14.i = and i64 %sub5.i, -16
%add.ptr.i44 = getelementptr inbounds %"class.sycl::_V1::ext::oneapi::bfloat16", ptr addrspace(1) %add.ptr.i43, i64 %div14.i
call spir_func void @_Z33__spirv_CooperativeMatrixStoreKHRPU3AS4iPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3ili(ptr addrspace(1) noundef %add.ptr.i44, target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) noundef %call.i22, i64 noundef 32, i32 noundef 0, i32 noundef 3, i32 noundef 0)
call spir_func void @_Z33__spirv_CooperativeMatrixStoreKHRPU3AS4iPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3ili(ptr addrspace(1) noundef %add.ptr.i44, target("spirv.CooperativeMatrixKHR", i16, 3, 8, 16, 0) noundef %call.i22, i32 noundef 0, i64 noundef 0)
call void @llvm.lifetime.end.p0(i64 64, ptr nonnull %__SYCLKernel)
ret void
}
Expand All @@ -126,16 +126,16 @@ declare void @llvm.lifetime.start.p0(i64 immarg, ptr nocapture)
declare void @llvm.lifetime.end.p0(i64 immarg, ptr nocapture)

; Function Attrs: convergent nounwind
declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) @_Z26__spirv_CompositeConstruct(ptr noundef byval(%"class.sycl::_V1::ext::oneapi::bfloat16") align 2) local_unnamed_addr
declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 8, 16, 0) @_Z26__spirv_CompositeConstruct(ptr noundef byval(%"class.sycl::_V1::ext::oneapi::bfloat16") align 2) local_unnamed_addr

; Function Attrs: convergent nounwind
declare dso_local spir_func zeroext i16 @__devicelib_ConvertFToBF16INTEL(ptr addrspace(4) noundef align 4 dereferenceable(4)) local_unnamed_addr

; Function Attrs: convergent nounwind
declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) @_Z43__spirv_CooperativeMatrixApplyFunctionINTEL(ptr addrspace(4) noundef align 8 dereferenceable(8), target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) noundef) local_unnamed_addr
declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 8, 16, 0) @_Z43__spirv_CooperativeMatrixApplyFunctionINTEL(ptr addrspace(4) noundef align 8 dereferenceable(8), target("spirv.CooperativeMatrixKHR", i16, 3, 8, 16, 0) noundef) local_unnamed_addr

; Function Attrs: convergent nounwind
declare dso_local spir_func void @_Z33__spirv_CooperativeMatrixStoreKHRPU3AS4iPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3ili(ptr addrspace(1) noundef, target("spirv.CooperativeMatrixKHR", i16, 8, 16, 0, 0) noundef, i64 noundef, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr
declare dso_local spir_func void @_Z33__spirv_CooperativeMatrixStoreKHRPU3AS4iPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_3ili(ptr addrspace(1) noundef, target("spirv.CooperativeMatrixKHR", i16, 3, 8, 16, 0) noundef, i32 noundef, i64 noundef) local_unnamed_addr

!llvm.module.flags = !{!0, !1}
!opencl.spir.version = !{!2}
Expand Down
Loading

0 comments on commit f18e64d

Please sign in to comment.