@@ -9,16 +9,18 @@ import kernel.alu.DataWidthConfig
99import kernel .utils .DebugLog
1010class metrixController extends Module with llamaConfig {}
1111
12+ /*
13+ * current systolic group idx
14+ * @param nk: systolic group dim
15+ * @param m: left matrix rows
16+ * @param q: right matrix rows
17+ */
1218class currentSystolicGroupIdx (
13- val nk : Int ,
14- val m : Int ,
15- val p : Int ,
16- val q : Int
1719)(
1820 implicit config : DataWidthConfig )
1921 extends Bundle
2022 with llamaConfig {
21-
23+ val nk : Int = systolicSizeGen * systolicGroupSizeGen
2224 val row = Output (UInt (log2Ceil(m / nk).W ))
2325 val col = Output (UInt (log2Ceil(q / nk).W ))
2426 val value = Output (UInt ((nk * nk * config.inputWidth).W ))
@@ -95,7 +97,7 @@ class MatrixRestore(
9597 val numBlocksRow = m / nk
9698 val numBlocksCol = p / nk
9799
98- // 初始化输出矩阵
100+ // initialize the output matrix
99101 io.outMatrix.foreach(_ := 0 .U )
100102
101103 for (blockRow <- 0 until numBlocksRow) {
@@ -163,27 +165,22 @@ object BlockMatrixRestore {
163165 * designed for QKV generation, but has a output for current systolic group idx
164166 */
165167class GenerationMatrixMul (
166- val k : Int ,
167- val n : Int ,
168- val m : Int ,
169- val p : Int ,
170- val q : Int ,
171168 val gemmType : GEMMDataType .Type
172169)(
173170 implicit config : DataWidthConfig )
174171 extends Module
175172 with llamaConfig
176173 with DebugLog {
177174 // param check
178- implicit val nk : Int = k * n
175+ implicit val nk : Int = systolicSizeGen * systolicGroupSizeGen
179176 require(m % nk == 0 )
180177 require(p % nk == 0 )
181178 require(q % nk == 0 )
182179
183180 val io = IO (new Bundle {
184181 val in_a = Flipped (Decoupled (Vec (m * p, UInt (config.inputWidth.W ))))
185182 val in_b = Flipped (Decoupled (Vec (p * q, UInt (config.inputWidth.W ))))
186- val current = ValidIO (new currentSystolicGroupIdx(nk, m, p, q) )
183+ val current = ValidIO (new currentSystolicGroupIdx)
187184 val reset = Input (Bool ())
188185 })
189186
@@ -312,34 +309,22 @@ class GenerationMatrixMul(
312309 * the k1,n1 are for q,k generation, the k2,n2 are for q,k mul.
313310 */
314311class QKMul (
315- val k1 : Int ,
316- val n1 : Int ,
317- val k2 : Int ,
318- val n2 : Int ,
319- val m : Int ,
320- val p : Int ,
321- val q : Int ,
322312 val gemmType : GEMMDataType .Type
323313)(
324314 implicit config : DataWidthConfig )
325315 extends Module
326316 with llamaConfig
327317 with DebugLog {
328318
329- val nk1 : Int = k1 * n1
330- val nk2 : Int = k2 * n2
319+ val nk1 : Int = systolicSizeGen * systolicGroupSizeGen
320+ val nk2 : Int = systolicSizeMul * systolicGroupSizeMul
331321 require(m % nk1 == 0 )
332322 require(p % nk1 == 0 )
333323 require(q % nk1 == 0 )
334324 require(m % nk2 == 0 )
335325 require(q % nk2 == 0 )
336326
337327 class QKGenerationMatrixMulWarper (
338- val k : Int ,
339- val n : Int ,
340- val m : Int ,
341- val p : Int ,
342- val q : Int ,
343328 val gemmType : GEMMDataType .Type ,
344329 val bufferSize : Int
345330 )(
@@ -351,16 +336,16 @@ class QKMul(
351336 val in_a = Flipped (Decoupled (Vec (m * p, UInt (config.inputWidth.W ))))
352337 val in_b = Flipped (Decoupled (Vec (p * q, UInt (config.inputWidth.W ))))
353338 val flush = Input (Bool ())
354- val outMatrix = Decoupled (new currentSystolicGroupIdx(nk1, m, p, q) )
339+ val outMatrix = Decoupled (new currentSystolicGroupIdx)
355340 })
356341
357- val qkGenMul = Module (new GenerationMatrixMul (k1, n1, m, p, q, gemmType))
342+ val qkGenMul = Module (new GenerationMatrixMul (gemmType))
358343 io.in_a <> qkGenMul.io.in_a
359344 io.in_b <> qkGenMul.io.in_b
360345
361346 val currentBuffer = Module (
362347 new Queue (
363- new currentSystolicGroupIdx(nk1, m, p, q) ,
348+ new currentSystolicGroupIdx,
364349 entries = bufferSize,
365350 pipe = true ,
366351 flow = false ,
@@ -388,8 +373,8 @@ class QKMul(
388373 val resetBuffer = Input (Bool ())
389374 })
390375
391- val qGen = new QKGenerationMatrixMulWarper (k1, n1, m, p, q, gemmType, bufferSizeGemm)
392- val kGen = new QKGenerationMatrixMulWarper (k2, n2, m, p, q, gemmType, bufferSizeGemm)
376+ val qGen = new QKGenerationMatrixMulWarper (gemmType, bufferSizeGemm)
377+ val kGen = new QKGenerationMatrixMulWarper (gemmType, bufferSizeGemm)
393378
394379 qGen.io.in_a <> io.inputToken
395380 qGen.io.in_b <> io.weightQ
@@ -413,3 +398,18 @@ class QKMul(
413398 }
414399
415400}
401+
402+ class GemmPool (
403+ val n : Int ,
404+ val poolSize : Int ,
405+ val gemmType : GEMMDataType .Type
406+ )(
407+ implicit config : DataWidthConfig )
408+ extends Module
409+ with llamaConfig
410+ with DebugLog {
411+ val io = IO (new Bundle {
412+ val in = Flipped (Decoupled (new currentSystolicGroupIdx))
413+ })
414+
415+ }
0 commit comments