-
Notifications
You must be signed in to change notification settings - Fork 180
Description
Hi, I'm working on generating a SPIRV module with the joint matrix extension using TVM, and running it with the OpenCL runtime. While I got the simplest case (m8n8k16 matmul, fp16) working, I'm getting a weird result if I extend the simplest test case to m8n16k16 and let a single subgroup do two independent JointMatrixMadINTEL.
The SPIRV module in question is attached below and its text representation is posted at https://gist.github.com/masahi/9e219607dd80939397b1cffe91a08f4e
As you can see in the text disassembly, there are two JointMatrixMadINTEL:
...
7 JointMatrixMadINTEL 41 184 175 180 43 40
7 JointMatrixMadINTEL 41 185 175 183 44 40
...
The first one is supposed to compute the left half of the 8x16 output matrix, while the second one does the right half. Here is my current situation:
- The left half of the output matrix is computed correctly
- The right half is not, it is actually filled with the value of
Cmatrix inA x B + C.
Since the same A matrix is used for both multiply-add and the first one is computing the correct output, the only possible explanation of this result is that the B matrix in the second multiply-add is all zero. But I'm pretty sure that the B cannot be all-zero in my test case.
I've been trying to figure out what's going on, and one thing I noticed is that, in the generated ISA (dumped using https://github.com/intel/opencl-intercept-layer), there are three calls to dpas like the ones below, while the SPIRV module only has two JointMatrixMadINTEL. Is this expected, and if so, why three?
...
dpas.8x8 (8|M0) r109:f null:f r109:hf r53.0:hf {Atomic}
dpas.8x8 (8|M0) r109:f null:f r77:hf r53.0:hf {Atomic}
dpas.8x8 (8|M0) r117:f null:f r101:hf r53.0:hf {$1}
...