Skip to content

Commit de76334

Browse files
todo: gemm-gemm flow is to finish
1 parent be851c3 commit de76334

File tree

4 files changed

+124
-4
lines changed

4 files changed

+124
-4
lines changed

src/main/scala/kernel/alu/Softmax.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,9 @@ class FixedPointExp extends Module with SoftmaxAccuracy with DebugLog {
4545

4646
class Softmax(val arraySize: Int = 4) extends Module with SoftmaxAccuracy with DebugLog {
4747
val io = IO(new Bundle {
48-
val x = Input(Valid(Vec(arraySize, UInt((I + F).W))))
49-
val soft_x = Valid(Vec(arraySize, UInt((I + F).W)))
48+
// val x = Input(Valid(Vec(arraySize, UInt((I + F).W))))
49+
val x = Flipped(Decoupled(Vec(arraySize, UInt((I + F).W))))
50+
val soft_x = Decoupled(Vec(arraySize, UInt((I + F).W)))
5051
})
5152

5253
// first find the max value of x

src/main/scala/models/llama3/common/llamaConfig.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,7 @@ trait llamaConfig {
2424
// DAC for zb, stream for heads
2525
val stream_size = 8
2626

27+
// buffer size for gemm-gemm pipeline
28+
val bufferSizeGemm = 32
29+
2730
}

src/main/scala/models/llama3/metrixController.scala

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,3 +303,113 @@ class GenerationMatrixMul(
303303
p"stateReg: $stateReg,\t currentValid: ${io.current.valid},\t rowIdx: ${rowIdx.value},\t colIdx: ${colIdx.value},\t gemmValid: ${gemmGroup.io.out.valid}\n"
304304
)
305305
}
306+
307+
/*
308+
* using two GenerationMatrixMul Modules(as QKGEN) to do q,k generation simultaneously.
309+
* using another GenerationMatrixMul Module(as QKMUL) to do q,k mul.
310+
* the output of QKMUL is the final result.
311+
* using the output of QKGEN to Stitch the final result.
312+
* the k1,n1 are for q,k generation, the k2,n2 are for q,k mul.
313+
*/
314+
class QKMul(
315+
val k1: Int,
316+
val n1: Int,
317+
val k2: Int,
318+
val n2: Int,
319+
val m: Int,
320+
val p: Int,
321+
val q: Int,
322+
val gemmType: GEMMDataType.Type
323+
)(
324+
implicit config: DataWidthConfig)
325+
extends Module
326+
with llamaConfig
327+
with DebugLog {
328+
329+
val nk1: Int = k1 * n1
330+
val nk2: Int = k2 * n2
331+
require(m % nk1 == 0)
332+
require(p % nk1 == 0)
333+
require(q % nk1 == 0)
334+
require(m % nk2 == 0)
335+
require(q % nk2 == 0)
336+
337+
class QKGenerationMatrixMulWarper(
338+
val k: Int,
339+
val n: Int,
340+
val m: Int,
341+
val p: Int,
342+
val q: Int,
343+
val gemmType: GEMMDataType.Type,
344+
val bufferSize: Int
345+
)(
346+
implicit config: DataWidthConfig)
347+
extends Module
348+
with llamaConfig
349+
with DebugLog {
350+
val io = IO(new Bundle {
351+
val in_a = Flipped(Decoupled(Vec(m * p, UInt(config.inputWidth.W))))
352+
val in_b = Flipped(Decoupled(Vec(p * q, UInt(config.inputWidth.W))))
353+
val flush = Input(Bool())
354+
val outMatrix = Decoupled(new currentSystolicGroupIdx(nk1, m, p, q))
355+
})
356+
357+
val qkGenMul = Module(new GenerationMatrixMul(k1, n1, m, p, q, gemmType))
358+
io.in_a <> qkGenMul.io.in_a
359+
io.in_b <> qkGenMul.io.in_b
360+
361+
val currentBuffer = Module(
362+
new Queue(
363+
new currentSystolicGroupIdx(nk1, m, p, q),
364+
entries = bufferSize,
365+
pipe = true,
366+
flow = false,
367+
useSyncReadMem = false,
368+
hasFlush = true
369+
)
370+
)
371+
372+
// hasFlush must be true
373+
currentBuffer.io.flush.get := io.flush
374+
375+
// ATTENTION: we assert the size of the buffer is huge enough to hold the current systolic group output
376+
// we ignore the ready signal of the enq
377+
currentBuffer.io.enq.bits := qkGenMul.io.current.bits
378+
currentBuffer.io.enq.valid := qkGenMul.io.current.valid
379+
380+
io.outMatrix <> currentBuffer.io.deq
381+
}
382+
383+
val io = IO(new Bundle {
384+
val inputToken = Flipped(Decoupled(Vec(m * p, UInt(config.inputWidth.W))))
385+
val weightQ = Flipped(Decoupled(Vec(p * q, UInt(config.inputWidth.W))))
386+
val weightK = Flipped(Decoupled(Vec(p * q, UInt(config.inputWidth.W))))
387+
val score = Decoupled(Vec(m * q, UInt(config.inputWidth.W)))
388+
val resetBuffer = Input(Bool())
389+
})
390+
391+
val qGen = new QKGenerationMatrixMulWarper(k1, n1, m, p, q, gemmType, bufferSizeGemm)
392+
val kGen = new QKGenerationMatrixMulWarper(k2, n2, m, p, q, gemmType, bufferSizeGemm)
393+
394+
qGen.io.in_a <> io.inputToken
395+
qGen.io.in_b <> io.weightQ
396+
kGen.io.in_a <> io.inputToken
397+
kGen.io.in_b <> io.weightQ
398+
399+
qGen.io.flush := io.resetBuffer
400+
kGen.io.flush := io.resetBuffer
401+
402+
// final result idx
403+
val rowIdx = RegInit(0.U(log2Ceil(m / nk2).W))
404+
val colIdx = RegInit(0.U(log2Ceil(m / nk2).W))
405+
val resValid = RegInit(false.B)
406+
io.score.valid := resValid
407+
408+
val scoreValue = RegInit(VecInit.fill(m * q)(0.U(config.outputWidth.W)))
409+
io.score.bits := scoreValue
410+
411+
when(resValid && io.score.ready) {
412+
resValid := false.B
413+
}
414+
415+
}

src/test/scala/models/llama3/metrixControllerTest.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class metrixControllerTest extends AnyFlatSpec with ChiselScalatestTester with P
4242
Array.fill(rows, cols)(
4343
numeric.fromInt(
4444
// r.nextInt(math.pow(2, config.inputWidth).toInt) - math.pow(2, config.inputWidth - 1).toInt
45-
r.nextInt(16) - 8
45+
r.nextInt(4) - 2
4646
)
4747
)
4848
case c if c == classOf[Float] =>
@@ -260,6 +260,8 @@ class metrixControllerTest extends AnyFlatSpec with ChiselScalatestTester with P
260260
val in_a = Flipped(Decoupled(Vec(m * p, UInt(config.inputWidth.W))))
261261
val in_b = Flipped(Decoupled(Vec(p * q, UInt(config.inputWidth.W))))
262262
val outMatrix = Valid(Vec(nk * nk, UInt(config.inputWidth.W)))
263+
val rowIdx = Output(UInt(config.inputWidth.W))
264+
val colIdx = Output(UInt(config.inputWidth.W))
263265
})
264266

265267
val metrixController = Module(new GenerationMatrixMul(k, n, m, p, q, gemmType))
@@ -270,6 +272,8 @@ class metrixControllerTest extends AnyFlatSpec with ChiselScalatestTester with P
270272
matrixRestore.io.inBlocks := metrixController.io.current.bits.value
271273
io.outMatrix.bits := matrixRestore.io.outMatrix
272274
io.outMatrix.valid := metrixController.io.current.valid
275+
io.rowIdx := metrixController.io.current.bits.row
276+
io.colIdx := metrixController.io.current.bits.col
273277
}
274278

275279
private def testMetrixController[T: Numeric: ClassTag](
@@ -328,6 +332,8 @@ class metrixControllerTest extends AnyFlatSpec with ChiselScalatestTester with P
328332
}
329333
// println(s"emptyRes: ${emptyRes.mkString(", ")}")
330334
// assert(emptyRes.sameElements(finalMatrix))
335+
println(s"rowIdx: ${dut.io.rowIdx.peekInt()}")
336+
println(s"colIdx: ${dut.io.colIdx.peekInt()}")
331337
printmat(emptyRes, nk, nk)
332338
}
333339
dut.clock.step()
@@ -350,7 +356,7 @@ class metrixControllerTest extends AnyFlatSpec with ChiselScalatestTester with P
350356

351357
"GenerationMatrixMul" should "correctly multiply matrices" in {
352358
implicit val config: DataWidthConfig = FxpConfig
353-
test(new MetrixControllerWarper(k = 1, n = 2, m = 4, p = 4, q = 4, GEMMDataType.Fxp))
359+
test(new MetrixControllerWarper(k = 1, n = 2, m = 4, p = 6, q = 8, GEMMDataType.Fxp))
354360
.withAnnotations(Seq(VerilatorBackendAnnotation))(testMetrixController[Int])
355361
}
356362
}

0 commit comments

Comments
 (0)