Skip to content

Commit a7a1d2f

Browse files
authored
Split JointMatrixMadINTEL instruction into 4 (#1833)
JointMatrixMadINTEL will stand for signed/signed Matrix type JointMatrixSUMadINTEL will stand for signed/signed Matrix type JointMatrixUSMadINTEL will stand for unsigned/signed Matrix type JointMatrixUUMadINTEL will stand for unsigned/unsigned Matrix type Spec update: intel/llvm#8175 Signed-off-by: Dmitry Sidorov dmitry.sidorov@intel.com
1 parent f0366a6 commit a7a1d2f

File tree

5 files changed

+36
-9
lines changed

5 files changed

+36
-9
lines changed

lib/SPIRV/SPIRVWriter.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -608,14 +608,9 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(SPIRVType *ET, unsigned AddrSpc) {
608608
return TranslatedTy;
609609
}
610610

611-
// Representation in LLVM IR before the translator is a pointer array wrapped
612-
// in a structure:
613-
// %struct.__spirv_JointMatrixINTEL = type { [R x [C x [L x [S x type]]]]* }
614-
// where R = Rows, C = Columnts, L = Layout + 1, S = Scope + 1
615-
// this '+1' for the Layout and Scope is required because both of them can
616-
// be '0', but array size can not be '0'.
617-
// The result should look like SPIR-V friendly LLVM IR:
618-
// %spirv.JointMatrixINTEL._char_2_2_0_3
611+
// Representation in LLVM IR before the translator is a pointer to an opaque
612+
// structure:
613+
// %spirv.JointMatrixINTEL._%element_type%_%rows%_%cols%_%scope%_%use%
619614
// Here we check the structure name yet again. Another option would be to
620615
// check SPIR-V friendly function calls (by their name) and obtain return
621616
// or their parameter types, assuming, that the appropriate types are Matrix

lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3324,6 +3324,9 @@ class SPIRVJointMatrixINTELInst : public SPIRVJointMatrixINTELInstBase {
33243324
_SPIRV_OP(JointMatrixLoad, true, 6, true)
33253325
_SPIRV_OP(JointMatrixStore, false, 5, true)
33263326
_SPIRV_OP(JointMatrixMad, true, 7)
3327+
_SPIRV_OP(JointMatrixSUMad, true, 7)
3328+
_SPIRV_OP(JointMatrixUSMad, true, 7)
3329+
_SPIRV_OP(JointMatrixUUMad, true, 7)
33273330
_SPIRV_OP(JointMatrixWorkItemLength, true, 4)
33283331
#undef _SPIRV_OP
33293332

lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ _SPIRV_OP_INTERNAL(TypeJointMatrixINTEL, internal::OpTypeJointMatrixINTEL)
99
_SPIRV_OP_INTERNAL(JointMatrixLoadINTEL, internal::OpJointMatrixLoadINTEL)
1010
_SPIRV_OP_INTERNAL(JointMatrixStoreINTEL, internal::OpJointMatrixStoreINTEL)
1111
_SPIRV_OP_INTERNAL(JointMatrixMadINTEL, internal::OpJointMatrixMadINTEL)
12+
_SPIRV_OP_INTERNAL(JointMatrixSUMadINTEL, internal::OpJointMatrixSUMadINTEL)
13+
_SPIRV_OP_INTERNAL(JointMatrixUSMadINTEL, internal::OpJointMatrixUSMadINTEL)
14+
_SPIRV_OP_INTERNAL(JointMatrixUUMadINTEL, internal::OpJointMatrixUUMadINTEL)
1215
_SPIRV_OP_INTERNAL(JointMatrixWorkItemLengthINTEL,
1316
internal::OpJointMatrixWorkItemLengthINTEL)
1417
_SPIRV_OP_INTERNAL(ComplexFMulINTEL, internal::ComplexFMulINTEL)

lib/SPIRV/libSPIRV/spirv_internal.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ enum InternalOp {
4242
IOpJointMatrixLoadINTEL = 6120,
4343
IOpJointMatrixStoreINTEL = 6121,
4444
IOpJointMatrixMadINTEL = 6122,
45+
IOpJointMatrixSUMadINTEL = 6128,
46+
IOpJointMatrixUSMadINTEL = 6129,
47+
IOpJointMatrixUUMadINTEL = 6130,
4548
IOpArithmeticFenceINTEL = 6145,
4649
IOpJointMatrixWorkItemLengthINTEL = 6410,
4750
IOpComplexFMulINTEL = 6415,
@@ -105,6 +108,9 @@ _SPIRV_OP(Op, TypeJointMatrixINTEL)
105108
_SPIRV_OP(Op, JointMatrixLoadINTEL)
106109
_SPIRV_OP(Op, JointMatrixStoreINTEL)
107110
_SPIRV_OP(Op, JointMatrixMadINTEL)
111+
_SPIRV_OP(Op, JointMatrixSUMadINTEL)
112+
_SPIRV_OP(Op, JointMatrixUSMadINTEL)
113+
_SPIRV_OP(Op, JointMatrixUUMadINTEL)
108114
_SPIRV_OP(Op, JointMatrixWorkItemLengthINTEL)
109115
_SPIRV_OP(Capability, HWThreadQueryINTEL)
110116
_SPIRV_OP(BuiltIn, SubDeviceIDINTEL)

test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix.ll

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@
3434
; CHECK-SPIRV: JointMatrixLoadINTEL [[#ATy]] [[#A:]] [[#Aptr:]] [[#Stride]] [[#Zero]] [[#Three]] [[#Zero]]
3535
; CHECK-SPIRV: JointMatrixLoadINTEL [[#BTy]] [[#B:]] [[#Bptr:]] [[#Stride]] [[#Zero]] [[#Three]] [[#Zero]]
3636
; CHECK-SPIRV: JointMatrixMadINTEL [[#CTy]] [[#CMad]] [[#A]] [[#B]] [[#C]] [[#Three]]
37+
; CHECK-SPIRV: JointMatrixSUMadINTEL [[#CTy]] [[#UnusedMad1:]] [[#A]] [[#B]] [[#C]] [[#Three]]
38+
; CHECK-SPIRV: JointMatrixUSMadINTEL [[#CTy]] [[#UnusedMad2:]] [[#A]] [[#B]] [[#C]] [[#Three]]
39+
; CHECK-SPIRV: JointMatrixUUMadINTEL [[#CTy]] [[#UnusedMad3:]] [[#A]] [[#B]] [[#C]] [[#Three]]
40+
3741
; CHECK-SPIRV: JointMatrixStoreINTEL [[#Cptr:]] [[#C]] [[#Stride]] [[#Zero]] [[#Three]] [[#Zero]]
3842
; CHECK-SPIRV: CompositeConstruct [[#CTy]] [[#Cnew:]] [[#FortyTwo]]
3943
; CHECK-SPIRV: Store [[#PtrToZero:]] [[#Zero]]
@@ -49,7 +53,11 @@
4953
; CHECK-LLVM: [[C:%.*]] = phi %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [ [[CLoaded]], %entry ], [ [[CMad:%.*]], %for.body.i ]
5054
; CHECK-LLVM: [[A:%.*]] = call spir_func %spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(1)* @_Z79__spirv_JointMatrixLoadINTEL_RPU3AS141__spirv_JointMatrixINTEL__char_2_16_0_3_0PU3AS4cliii(i8 addrspace(4)* [[APtr:%.*]], i64 [[Stride]], i32 0, i32 3, i32 0)
5155
; CHECK-LLVM: [[B:%.*]] = call spir_func %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(1)* @_Z77__spirv_JointMatrixLoadINTEL_RPU3AS139__spirv_JointMatrixINTEL__char_16_2_3_3PU3AS4cliii(i8 addrspace(4)* [[BPtr:%.*]], i64 [[Stride]], i32 0, i32 3, i32 0)
52-
; CHECK-LLVM: [[CMad:%.*]] = call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* @_Z27__spirv_JointMatrixMadINTELPU3AS141__spirv_JointMatrixINTEL__char_2_16_0_3_0PU3AS139__spirv_JointMatrixINTEL__char_16_2_3_3PU3AS139__spirv_JointMatrixINTEL__short_2_2_0_3i(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(1)* [[A]], %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(1)* [[B]], %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [[C]], i32 3)
56+
; CHECK-LLVM: [[CMad1:%.*]] = call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* @_Z27__spirv_JointMatrixMadINTELPU3AS141__spirv_JointMatrixINTEL__char_2_16_0_3_0PU3AS139__spirv_JointMatrixINTEL__char_16_2_3_3PU3AS139__spirv_JointMatrixINTEL__short_2_2_0_3i(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(1)* [[A]], %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(1)* [[B]], %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [[C]], i32 3)
57+
; CHECK-LLVM: [[CMad2:%.*]] = call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* @_Z29__spirv_JointMatrixSUMadINTELPU3AS141__spirv_JointMatrixINTEL__char_2_16_0_3_0PU3AS139__spirv_JointMatrixINTEL__char_16_2_3_3PU3AS139__spirv_JointMatrixINTEL__short_2_2_0_3i(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(1)* [[A]], %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(1)* [[B]], %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [[C]], i32 3)
58+
; CHECK-LLVM: [[CMad3:%.*]] = call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* @_Z29__spirv_JointMatrixUSMadINTELPU3AS141__spirv_JointMatrixINTEL__char_2_16_0_3_0PU3AS139__spirv_JointMatrixINTEL__char_16_2_3_3PU3AS139__spirv_JointMatrixINTEL__short_2_2_0_3i(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(1)* [[A]], %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(1)* [[B]], %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [[C]], i32 3)
59+
; CHECK-LLVM: [[CMad4:%.*]] = call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* @_Z29__spirv_JointMatrixUUMadINTELPU3AS141__spirv_JointMatrixINTEL__char_2_16_0_3_0PU3AS139__spirv_JointMatrixINTEL__char_16_2_3_3PU3AS139__spirv_JointMatrixINTEL__short_2_2_0_3i(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(1)* [[A]], %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(1)* [[B]], %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [[C]], i32 3)
60+
5361
; CHECK-LLVM: call spir_func void @_Z29__spirv_JointMatrixStoreINTELPU3AS4sPU3AS139__spirv_JointMatrixINTEL__short_2_2_0_3liii(i16 addrspace(4)* [[CPtr]], %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [[C]], i64 [[Stride]], i32 0, i32 3, i32 0)
5462
; CHECK-LLVM: call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* @_Z26__spirv_CompositeConstructi(i32 42)
5563
; CHECK-LLVM: store i32 0, i32 addrspace(4)* [[StoredZero:%.*]], align 4
@@ -115,6 +123,9 @@ for.body.i: ; preds = %for.cond.i
115123
%add.ptr17.i = addrspacecast i8 addrspace(1)* %add.ptr17.i56 to i8 addrspace(4)*
116124
%call18.i = tail call spir_func %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)* @_Z28__spirv_JointMatrixLoadINTELIaLm16ELm2ELN5__spv12MatrixLayoutE3ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT_XT0_EXT1_EXT2_EXT3_EEEPS5_mS1_S3_i(i8 addrspace(4)* %add.ptr17.i, i64 %_arg_1, i32 0, i32 3, i32 0) #3
117125
%call19.i = tail call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* @_Z27__spirv_JointMatrixMadINTELIasLm2ELm16ELm2ELN5__spv12MatrixLayoutE0ELS1_3ELS1_0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT6_EXT7_EEEPNS4_IT_XT1_EXT2_EXT4_EXT7_EEEPNS4_IS8_XT2_EXT3_EXT5_EXT7_EEES7_S3_(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(4)* %call13.i, %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)* %call18.i, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* %C.0.i, i32 3) #3
126+
%call20.i = tail call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* @_Z29__spirv_JointMatrixSUMadINTELIasLm2ELm16ELm2ELN5__spv12MatrixLayoutE0ELS1_3ELS1_0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT6_EXT7_EEEPNS4_IT_XT1_EXT2_EXT4_EXT7_EEEPNS4_IS8_XT2_EXT3_EXT5_EXT7_EEES7_S3_(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(4)* %call13.i, %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)* %call18.i, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* %C.0.i, i32 3) #3
127+
%call21.i = tail call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* @_Z29__spirv_JointMatrixUSMadINTELIasLm2ELm16ELm2ELN5__spv12MatrixLayoutE0ELS1_3ELS1_0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT6_EXT7_EEEPNS4_IT_XT1_EXT2_EXT4_EXT7_EEEPNS4_IS8_XT2_EXT3_EXT5_EXT7_EEES7_S3_(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(4)* %call13.i, %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)* %call18.i, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* %C.0.i, i32 3) #3
128+
%call22.i = tail call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* @_Z29__spirv_JointMatrixUUMadINTELIasLm2ELm16ELm2ELN5__spv12MatrixLayoutE0ELS1_3ELS1_0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT6_EXT7_EEEPNS4_IT_XT1_EXT2_EXT4_EXT7_EEEPNS4_IS8_XT2_EXT3_EXT5_EXT7_EEES7_S3_(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(4)* %call13.i, %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)* %call18.i, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* %C.0.i, i32 3) #3
118129
%add.i = add nuw nsw i32 %k.0.i, 16
119130
br label %for.cond.i, !llvm.loop !19
120131

@@ -142,6 +153,15 @@ declare dso_local spir_func %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)*
142153
; Function Attrs: convergent
143154
declare dso_local spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* @_Z27__spirv_JointMatrixMadINTELIasLm2ELm16ELm2ELN5__spv12MatrixLayoutE0ELS1_3ELS1_0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT6_EXT7_EEEPNS4_IT_XT1_EXT2_EXT4_EXT7_EEEPNS4_IS8_XT2_EXT3_EXT5_EXT7_EEES7_S3_(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(4)*, %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)*, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)*, i32) local_unnamed_addr #1
144155

156+
; Function Attrs: convergent
157+
declare dso_local spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* @_Z29__spirv_JointMatrixSUMadINTELIasLm2ELm16ELm2ELN5__spv12MatrixLayoutE0ELS1_3ELS1_0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT6_EXT7_EEEPNS4_IT_XT1_EXT2_EXT4_EXT7_EEEPNS4_IS8_XT2_EXT3_EXT5_EXT7_EEES7_S3_(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(4)*, %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)*, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)*, i32) local_unnamed_addr #1
158+
159+
; Function Attrs: convergent
160+
declare dso_local spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* @_Z29__spirv_JointMatrixUSMadINTELIasLm2ELm16ELm2ELN5__spv12MatrixLayoutE0ELS1_3ELS1_0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT6_EXT7_EEEPNS4_IT_XT1_EXT2_EXT4_EXT7_EEEPNS4_IS8_XT2_EXT3_EXT5_EXT7_EEES7_S3_(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(4)*, %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)*, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)*, i32) local_unnamed_addr #1
161+
162+
; Function Attrs: convergent
163+
declare dso_local spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* @_Z29__spirv_JointMatrixUUMadINTELIasLm2ELm16ELm2ELN5__spv12MatrixLayoutE0ELS1_3ELS1_0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT6_EXT7_EEEPNS4_IT_XT1_EXT2_EXT4_EXT7_EEEPNS4_IS8_XT2_EXT3_EXT5_EXT7_EEES7_S3_(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(4)*, %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)*, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)*, i32) local_unnamed_addr #1
164+
145165
; Function Attrs: convergent
146166
declare dso_local spir_func void @_Z29__spirv_JointMatrixStoreINTELIsLm2ELm2ELN5__spv12MatrixLayoutE0ELNS0_5Scope4FlagE3EEvPT_PNS0_24__spirv_JointMatrixINTELIS4_XT0_EXT1_EXT2_EXT3_EEEmS1_S3_i(i16 addrspace(4)*, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)*, i64, i32, i32, i32) local_unnamed_addr #1
147167

0 commit comments

Comments
 (0)