@@ -303,3 +303,113 @@ class GenerationMatrixMul(
303303 p " stateReg: $stateReg, \t currentValid: ${io.current.valid}, \t rowIdx: ${rowIdx.value}, \t colIdx: ${colIdx.value}, \t gemmValid: ${gemmGroup.io.out.valid}\n "
304304 )
305305}
306+
307+ /*
308+ * using two GenerationMatrixMul Modules(as QKGEN) to do q,k generation simultaneously.
309+ * using another GenerationMatrixMul Module(as QKMUL) to do q,k mul.
310+ * the output of QKMUL is the final result.
311+ * using the output of QKGEN to Stitch the final result.
312+ * the k1,n1 are for q,k generation, the k2,n2 are for q,k mul.
313+ */
314+ class QKMul (
315+ val k1 : Int ,
316+ val n1 : Int ,
317+ val k2 : Int ,
318+ val n2 : Int ,
319+ val m : Int ,
320+ val p : Int ,
321+ val q : Int ,
322+ val gemmType : GEMMDataType .Type
323+ )(
324+ implicit config : DataWidthConfig )
325+ extends Module
326+ with llamaConfig
327+ with DebugLog {
328+
329+ val nk1 : Int = k1 * n1
330+ val nk2 : Int = k2 * n2
331+ require(m % nk1 == 0 )
332+ require(p % nk1 == 0 )
333+ require(q % nk1 == 0 )
334+ require(m % nk2 == 0 )
335+ require(q % nk2 == 0 )
336+
337+ class QKGenerationMatrixMulWarper (
338+ val k : Int ,
339+ val n : Int ,
340+ val m : Int ,
341+ val p : Int ,
342+ val q : Int ,
343+ val gemmType : GEMMDataType .Type ,
344+ val bufferSize : Int
345+ )(
346+ implicit config : DataWidthConfig )
347+ extends Module
348+ with llamaConfig
349+ with DebugLog {
350+ val io = IO (new Bundle {
351+ val in_a = Flipped (Decoupled (Vec (m * p, UInt (config.inputWidth.W ))))
352+ val in_b = Flipped (Decoupled (Vec (p * q, UInt (config.inputWidth.W ))))
353+ val flush = Input (Bool ())
354+ val outMatrix = Decoupled (new currentSystolicGroupIdx(nk1, m, p, q))
355+ })
356+
357+ val qkGenMul = Module (new GenerationMatrixMul (k1, n1, m, p, q, gemmType))
358+ io.in_a <> qkGenMul.io.in_a
359+ io.in_b <> qkGenMul.io.in_b
360+
361+ val currentBuffer = Module (
362+ new Queue (
363+ new currentSystolicGroupIdx(nk1, m, p, q),
364+ entries = bufferSize,
365+ pipe = true ,
366+ flow = false ,
367+ useSyncReadMem = false ,
368+ hasFlush = true
369+ )
370+ )
371+
372+ // hasFlush must be true
373+ currentBuffer.io.flush.get := io.flush
374+
375+ // ATTENTION: we assert the size of the buffer is huge enough to hold the current systolic group output
376+ // we ignore the ready signal of the enq
377+ currentBuffer.io.enq.bits := qkGenMul.io.current.bits
378+ currentBuffer.io.enq.valid := qkGenMul.io.current.valid
379+
380+ io.outMatrix <> currentBuffer.io.deq
381+ }
382+
383+ val io = IO (new Bundle {
384+ val inputToken = Flipped (Decoupled (Vec (m * p, UInt (config.inputWidth.W ))))
385+ val weightQ = Flipped (Decoupled (Vec (p * q, UInt (config.inputWidth.W ))))
386+ val weightK = Flipped (Decoupled (Vec (p * q, UInt (config.inputWidth.W ))))
387+ val score = Decoupled (Vec (m * q, UInt (config.inputWidth.W )))
388+ val resetBuffer = Input (Bool ())
389+ })
390+
391+ val qGen = new QKGenerationMatrixMulWarper (k1, n1, m, p, q, gemmType, bufferSizeGemm)
392+ val kGen = new QKGenerationMatrixMulWarper (k2, n2, m, p, q, gemmType, bufferSizeGemm)
393+
394+ qGen.io.in_a <> io.inputToken
395+ qGen.io.in_b <> io.weightQ
396+ kGen.io.in_a <> io.inputToken
397+ kGen.io.in_b <> io.weightQ
398+
399+ qGen.io.flush := io.resetBuffer
400+ kGen.io.flush := io.resetBuffer
401+
402+ // final result idx
403+ val rowIdx = RegInit (0 .U (log2Ceil(m / nk2).W ))
404+ val colIdx = RegInit (0 .U (log2Ceil(m / nk2).W ))
405+ val resValid = RegInit (false .B )
406+ io.score.valid := resValid
407+
408+ val scoreValue = RegInit (VecInit .fill(m * q)(0 .U (config.outputWidth.W )))
409+ io.score.bits := scoreValue
410+
411+ when(resValid && io.score.ready) {
412+ resValid := false .B
413+ }
414+
415+ }
0 commit comments