Skip to content

Commit

Permalink
[Backport to 14] add initial support for CooperativeMatrixConstructCh…
Browse files Browse the repository at this point in the history
  • Loading branch information
VyacheslavLevytskyy authored and MrSidims committed Mar 22, 2024
1 parent 004c1fc commit e845832
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 7 deletions.
1 change: 1 addition & 0 deletions lib/SPIRV/libSPIRV/SPIRVInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -3417,6 +3417,7 @@ class SPIRVCooperativeMatrixCheckedInstructionsINTELInstBase
SPIRV##x##INTEL;
_SPIRV_OP(CooperativeMatrixLoadChecked, true, 9, true, 7)
_SPIRV_OP(CooperativeMatrixStoreChecked, false, 8, true, 8)
_SPIRV_OP(CooperativeMatrixConstructChecked, true, 8)
#undef _SPIRV_OP

class SPIRVCooperativeMatrixKHRInstBase : public SPIRVInstTemplateBase {
Expand Down
2 changes: 2 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ _SPIRV_OP_INTERNAL(CooperativeMatrixLoadCheckedINTEL,
internal::OpCooperativeMatrixLoadCheckedINTEL)
_SPIRV_OP_INTERNAL(CooperativeMatrixStoreCheckedINTEL,
internal::OpCooperativeMatrixStoreCheckedINTEL)
_SPIRV_OP_INTERNAL(CooperativeMatrixConstructCheckedINTEL,
internal::OpCooperativeMatrixConstructCheckedINTEL)
_SPIRV_OP_INTERNAL(ComplexFMulINTEL, internal::ComplexFMulINTEL)
_SPIRV_OP_INTERNAL(ComplexFDivINTEL, internal::ComplexFDivINTEL)
_SPIRV_OP_INTERNAL(MaskedGatherINTEL, internal::OpMaskedGatherINTEL)
Expand Down
2 changes: 2 additions & 0 deletions lib/SPIRV/libSPIRV/spirv_internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ enum InternalOp {
IOpArithmeticFenceINTEL = 6145,
IOpCooperativeMatrixLoadCheckedINTEL = 6193,
IOpCooperativeMatrixStoreCheckedINTEL = 6194,
IOpCooperativeMatrixConstructCheckedINTEL = 6195,
IOpJointMatrixWorkItemLengthINTEL = 6410,
IOpComplexFMulINTEL = 6415,
IOpComplexFDivINTEL = 6416,
Expand Down Expand Up @@ -190,6 +191,7 @@ _SPIRV_OP(Op, CooperativeMatrixPrefetchINTEL)
_SPIRV_OP(Capability, CooperativeMatrixCheckedInstructionsINTEL)
_SPIRV_OP(Op, CooperativeMatrixLoadCheckedINTEL)
_SPIRV_OP(Op, CooperativeMatrixStoreCheckedINTEL)
_SPIRV_OP(Op, CooperativeMatrixConstructCheckedINTEL)

_SPIRV_OP(Capability, HWThreadQueryINTEL)
_SPIRV_OP(BuiltIn, SubDeviceIDINTEL)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
; RUN: llvm-spirv %t.spv -to-text -o %t.spt
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV

; RUN: llvm-spirv -r %t.spv -o - | llvm-dis -o %t.rev.ll
; RUN: llvm-spirv -r %t.spv -o - | llvm-dis -o %t.rev.ll
; RUN: FileCheck < %t.rev.ll %s --check-prefix=CHECK-LLVM

; CHECK-SPIRV-DAG: Capability CooperativeMatrixKHR
Expand All @@ -24,14 +24,14 @@
; CHECK-SPIRV: CooperativeMatrixLoadCheckedINTEL [[#MatTy1]] [[#Load1:]]
; TODO: Pass Matrix Type Id instead of Matrix Id to CooperativeMatrixLengthKHR.
; CHECK-SPIRV: CooperativeMatrixLengthKHR [[#Int32Ty]] [[#]] [[#Load1]]
; CHECK-SPIRV: CompositeConstruct [[#MatTy2]]
; CHECK-SPIRV: CooperativeMatrixConstructCheckedINTEL [[#MatTy2]]
; CHECK-SPIRV: CooperativeMatrixLoadCheckedINTEL [[#MatTy3]]
; CHECK-SPIRV: CooperativeMatrixMulAddKHR [[#MatTy2]]
; CHECK-SPIRV: CooperativeMatrixStoreCheckedINTEL

; CHECK-LLVM: call spir_func %spirv.CooperativeMatrixKHR._char_0_12_48_3 addrspace(1)* @_Z41__spirv_CooperativeMatrixLoadCheckedINTELPU3AS4siiillli(i16 addrspace(4)* %{{.*}}, i32 0, i32 0, i32 0, i64 12, i64 48, i64 %_arg_1, i32 1)
; CHECK-LLVM: call spir_func %spirv.CooperativeMatrixKHR._char_0_12_48_3 addrspace(1)* @_Z41__spirv_CooperativeMatrixLoadCheckedINTELPU3AS4siiillli(i16 addrspace(4)* %{{.*}}, i32 0, i32 0, i32 0, i64 12, i64 48, i64 %{{.*}}, i32 1)
; CHECK-LLVM: call spir_func i32 @_Z34__spirv_CooperativeMatrixLengthKHRPU3AS144__spirv_CooperativeMatrixKHR__char_0_12_48_3(%spirv.CooperativeMatrixKHR._char_0_12_48_3 addrspace(1)*
; CHECK-LLVM: call spir_func %spirv.CooperativeMatrixKHR._int_3_12_12_3 addrspace(1)* @_Z26__spirv_CompositeConstructi(i32 42)
; CHECK-LLVM: call spir_func %spirv.CooperativeMatrixKHR._int_3_12_12_3 addrspace(1)* @_Z46__spirv_CooperativeMatrixConstructCheckedINTELiilli(i32 4, i32 4, i64 12, i64 12, i32 %{{.*}})
; CHECK-LLVM: call spir_func %spirv.CooperativeMatrixKHR._char_2_48_12_3 addrspace(1)* @_Z41__spirv_CooperativeMatrixLoadCheckedINTELPU3AS4ciiilll(i8 addrspace(4)* %{{.*}}, i32 0, i32 0, i32 0, i64 48, i64 12, i64 1)
; CHECK-LLVM: call spir_func %spirv.CooperativeMatrixKHR._int_3_12_12_3 addrspace(1)* @_Z34__spirv_CooperativeMatrixMulAddKHRPU3AS144__spirv_CooperativeMatrixKHR__char_0_12_48_3PU3AS144__spirv_CooperativeMatrixKHR__char_2_48_12_3PU3AS143__spirv_CooperativeMatrixKHR__int_3_12_12_3i(%spirv.CooperativeMatrixKHR._char_0_12_48_3 addrspace(1)* %{{.*}}, %spirv.CooperativeMatrixKHR._char_2_48_12_3 addrspace(1)* %{{.*}}, %spirv.CooperativeMatrixKHR._int_3_12_12_3 addrspace(1)*
; CHECK-LLVM: call spir_func void @_Z42__spirv_CooperativeMatrixStoreCheckedINTELPU3AS4siiPU3AS143__spirv_CooperativeMatrixKHR__int_3_12_12_3illli(i16 addrspace(4)* %{{.*}}, %spirv.CooperativeMatrixKHR._int_3_12_12_3 addrspace(1)*
Expand All @@ -52,7 +52,7 @@ $_ZTSZ4mainE11matrix_test = comdat any
@__spirv_BuiltInLocalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32

; Function Attrs: convergent norecurse
define weak_odr dso_local spir_kernel void @_ZTSZ4mainE11matrix_test(i16 addrspace(1)* %_arg_, i64 %_arg_1, i8 addrspace(1)* %_arg_3, i8 addrspace(1)* %_arg_5) local_unnamed_addr #0 comdat !kernel_arg_buffer_location !5 !intel_reqd_sub_group_size !6 {
define weak_odr dso_local spir_kernel void @_ZTSZ4mainE11matrix_test(i16 addrspace(1)* %_arg_, i64 %_arg_1, i8 addrspace(1)* %_arg_3, i8 addrspace(1)* %_arg_5, i32 noundef %_arg_6) local_unnamed_addr #0 comdat !kernel_arg_buffer_location !5 !intel_reqd_sub_group_size !6 {
entry:
%0 = load <3 x i64>, <3 x i64> addrspace(4)* addrspacecast (<3 x i64> addrspace(1)* @__spirv_BuiltInGlobalInvocationId to <3 x i64> addrspace(4)*), align 32, !noalias !7
%1 = extractelement <3 x i64> %0, i64 1
Expand All @@ -79,7 +79,7 @@ entry:
%add.ptr16.i55 = getelementptr inbounds i8, i8 addrspace(1)* %_arg_5, i64 %sub5.i
%len = tail call spir_func noundef i32 @_Z34__spirv_CooperativeMatrixLengthKHR(%spirv.CooperativeMatrixKHR._char_0_12_48_3 addrspace(4)* %call8.i)

%C.0.i = call spir_func %spirv.CooperativeMatrixKHR._int_3_12_12_3 addrspace(4)* @_Z26__spirv_CompositeConstruct(i32 42) #1
%C.0.i = call spir_func %spirv.CooperativeMatrixKHR._int_3_12_12_3 addrspace(4)* @_Z46__spirv_CooperativeMatrixConstructCheckedINTEL(i32 noundef 4, i32 noundef 4, i64 noundef 12, i64 noundef 12, i32 noundef %_arg_6) #1
%add.ptr12.i54 = getelementptr inbounds i8, i8 addrspace(1)* %add.ptr11.i53, i64 0
%add.ptr12.i = addrspacecast i8 addrspace(1)* %add.ptr12.i54 to i8 addrspace(4)*
%call13.i = tail call spir_func %spirv.CooperativeMatrixKHR._char_2_48_12_3 addrspace(4)* @_Z41__spirv_CooperativeMatrixLoadCheckedINTEL_2(i8 addrspace(4)* %add.ptr12.i, i32 noundef 0, i32 noundef 0, i32 noundef 0, i64 noundef 48, i64 noundef 12, i64 noundef 1) #3
Expand All @@ -98,7 +98,10 @@ entry:
}

; Function Attrs: convergent
declare dso_local spir_func noundef %spirv.CooperativeMatrixKHR._int_3_12_12_3 addrspace(4)* @_Z26__spirv_CompositeConstruct(i32 noundef) local_unnamed_addr #2
declare dso_local spir_func %spirv.CooperativeMatrixKHR._int_3_12_12_3 addrspace(4)* @_Z26__spirv_CompositeConstruct(i32 noundef) local_unnamed_addr #2

; Function Attrs: convergent
declare dso_local spir_func noundef %spirv.CooperativeMatrixKHR._int_3_12_12_3 addrspace(4)* @_Z46__spirv_CooperativeMatrixConstructCheckedINTEL(i32 noundef, i32 noundef, i64 noundef, i64 noundef, i32 noundef) local_unnamed_addr #2

; Function Attrs: convergent
declare dso_local spir_func noundef i32 @_Z34__spirv_CooperativeMatrixLengthKHR(%spirv.CooperativeMatrixKHR._char_0_12_48_3 addrspace(4)* noundef)
Expand Down

0 comments on commit e845832

Please sign in to comment.