@@ -6,7 +6,7 @@ import chisel3.util._
66import kernel .alu .GEMM
77import kernel .alu .GEMMDataType
88import kernel .alu .DataWidthConfig
9-
9+ import kernel . utils . DebugLog
1010class metrixController extends Module with llamaConfig {}
1111
1212class currentSystolicGroupIdx (
@@ -21,7 +21,7 @@ class currentSystolicGroupIdx(
2121
2222 val row = Output (UInt (log2Ceil(m / nk).W ))
2323 val col = Output (UInt (log2Ceil(q / nk).W ))
24- val value = Output (Vec ( nk * nk, UInt ( config.inputWidth. W ) ))
24+ val value = Output (UInt (( nk * nk * config.inputWidth). W ))
2525}
2626
2727class MatrixSplit (
@@ -101,17 +101,14 @@ class MatrixRestore(
101101 for (blockRow <- 0 until numBlocksRow) {
102102 for (blockCol <- 0 until numBlocksCol) {
103103 val blockIndex = blockRow * numBlocksCol + blockCol
104- val block = io.inBlocks(blockIndex)
104+ val block = io.inBlocks(blockIndex).asTypeOf( Vec (nk * nk, UInt (config.inputWidth. W )))
105105
106106 // 解包当前方阵块
107- for (i <- 0 until nk) {
108- for (j <- 0 until nk) {
109- // 计算在输出向量中的位置
110- val flatIndex = (blockRow * nk + i) * p + (blockCol * nk + j)
111- // 从打包的UInt中提取对应位置的元素
112- val elementPos = (nk * nk - 1 - (i * nk + j)) * config.inputWidth
113- io.outMatrix(flatIndex) := block(elementPos + config.inputWidth - 1 , elementPos)
114- }
107+ for {
108+ i <- 0 until nk
109+ j <- 0 until nk
110+ } {
111+ io.outMatrix((blockRow * nk + i) * p + (blockCol * nk + j)) := block(i * nk + j)
115112 }
116113 }
117114 }
@@ -132,6 +129,32 @@ object MatrixRestore {
132129 }
133130}
134131
132+ class BlockMatrixRestore (
133+ val nk : Int
134+ )(
135+ implicit config : DataWidthConfig )
136+ extends Module {
137+ val io = IO (new Bundle {
138+ val inBlocks = Input (UInt ((nk * nk * config.inputWidth).W ))
139+ val outMatrix = Output (Vec (nk * nk, UInt (config.inputWidth.W )))
140+ })
141+
142+ io.outMatrix := io.inBlocks.asTypeOf(Vec (nk * nk, UInt (config.inputWidth.W )))
143+ }
144+
145+ object BlockMatrixRestore {
146+ def apply (
147+ nk : Int
148+ )(inBlocks : UInt
149+ )(
150+ implicit config : DataWidthConfig
151+ ): Vec [UInt ] = {
152+ val newBlockMatrixRestore = Module (new BlockMatrixRestore (nk))
153+ newBlockMatrixRestore.io.inBlocks := inBlocks
154+ newBlockMatrixRestore.io.outMatrix
155+ }
156+ }
157+
135158/*
136159 * matrix mul matrix
137160 * matrixA is [m, p]
@@ -149,7 +172,8 @@ class GenerationMatrixMul(
149172)(
150173 implicit config : DataWidthConfig )
151174 extends Module
152- with llamaConfig {
175+ with llamaConfig
176+ with DebugLog {
153177 // param check
154178 implicit val nk : Int = k * n
155179 require(m % nk == 0 )
@@ -159,14 +183,18 @@ class GenerationMatrixMul(
159183 val io = IO (new Bundle {
160184 val in_a = Flipped (Decoupled (Vec (m * p, UInt (config.inputWidth.W ))))
161185 val in_b = Flipped (Decoupled (Vec (p * q, UInt (config.inputWidth.W ))))
162- val result = Decoupled (Vec (m * q, UInt (config.outputWidth.W )))
163186 val current = ValidIO (new currentSystolicGroupIdx(nk, m, p, q))
164187 val reset = Input (Bool ())
165188 })
166189
167190 // reshape the input data as block => [rows, cols] [nk, nk]
168- val matrixAReshape = RegInit (MatrixSplit (m, p, nk)(io.in_a.bits))
169- val matrixBReshape = RegInit (MatrixSplit (p, q, nk)(io.in_b.bits))
191+ val matrixAReshape = RegInit (VecInit .fill(m / nk * p / nk)(0 .U ((nk * nk * config.inputWidth).W )))
192+ val matrixBReshape = RegInit (VecInit .fill(p / nk * q / nk)(0 .U ((nk * nk * config.inputWidth).W )))
193+ matrixAReshape := MatrixSplit (m, p, nk)(io.in_a.bits)
194+ matrixBReshape := MatrixSplit (p, q, nk)(io.in_b.bits)
195+
196+ // debugLog(p"matrixAReshape: ${matrixAReshape}\n", LogLevel.DEBUG)
197+ // debugLog(p"matrixBReshape: ${matrixBReshape}\n", LogLevel.DEBUG)
170198
171199 // systolic alu
172200 val gemmGroup = Module (new GEMM (nk, gemmType))
@@ -183,8 +211,6 @@ class GenerationMatrixMul(
183211 val readyReg = RegInit (true .B )
184212 io.in_a.ready := readyReg
185213 io.in_b.ready := readyReg
186- val validReg = RegInit (false .B )
187- io.result.valid := validReg
188214 val dataShapedValid = RegInit (false .B )
189215 gemmGroup.io.in_a.valid := dataShapedValid
190216 gemmGroup.io.in_b.valid := dataShapedValid
@@ -197,9 +223,6 @@ class GenerationMatrixMul(
197223 val gemmInputA = matrixAReshape(blockAIdx).asTypeOf(Vec (nk * nk, UInt (config.inputWidth.W )))
198224 val gemmInputB = matrixBReshape(blockBIdx).asTypeOf(Vec (nk * nk, UInt (config.inputWidth.W )))
199225
200- val unShapedResult = RegInit (VecInit .fill(m * q / nk / nk)(0 .U ((nk * nk * config.outputWidth).W )))
201- io.result.bits := MatrixRestore (m, q, nk)(unShapedResult)
202-
203226 gemmGroup.io.in_a.bits := gemmInputA
204227 gemmGroup.io.in_b.bits := gemmInputB
205228 gemmGroup.io.reset := false .B
@@ -229,6 +252,8 @@ class GenerationMatrixMul(
229252 when(gemmGroup.io.out.valid) {
230253 val isfinal = calTimes.inc()
231254 when(isfinal) {
255+ // 当这是最后一个值的时候,不要消费这个值
256+ gemmGroup.io.out.ready := false .B
232257 stateReg := state.collect
233258 gemmGroupReady := false .B
234259 }
@@ -242,30 +267,39 @@ class GenerationMatrixMul(
242267 // still has the last gemm block to cal
243268 when(gemmGroup.io.out.valid) {
244269 gemmGroup.io.reset := true .B
270+ gemmGroup.io.out.ready := true .B
271+
245272 // collect the result of the [rowIdx, colIdx] block
246273 val afterRowLine = gemmGroup.io.out.bits
247- unShapedResult(rowIdx.value * cols.U + colIdx.value) := afterRowLine.asTypeOf(
248- UInt ((nk * nk * config.outputWidth).W )
249- )
250274
251275 // send the current systolic group idx
252276 io.current.valid := true .B
253277 io.current.bits.row := rowIdx.value
254278 io.current.bits.col := colIdx.value
255- io.current.bits.value := afterRowLine
279+ io.current.bits.value := afterRowLine.asUInt
256280
257281 val isRowEnd = colIdx.inc()
258- stateReg := Mux (rowIdx.inc() && isRowEnd, state.done, state.cal)
282+ when(! isRowEnd) {
283+ stateReg := state.cal
284+ }.otherwise {
285+ val isAllEnd = rowIdx.inc()
286+ when(! isAllEnd) {
287+ stateReg := state.cal
288+ }.otherwise {
289+ dataShapedValid := false .B
290+ stateReg := state.done
291+ }
292+ }
259293 }
260294 }
261295
262296 is(state.done) {
263- validReg := true .B
264- when(io.result.ready) {
265- validReg := false .B
266- stateReg := state.idle
267- readyReg := true .B
268- }
297+ stateReg := state.idle
298+ readyReg := true .B
269299 }
270300 }
301+
302+ debugLog(
303+ p " stateReg: $stateReg, \t currentValid: ${io.current.valid}, \t rowIdx: ${rowIdx.value}, \t colIdx: ${colIdx.value}, \t gemmValid: ${gemmGroup.io.out.valid}\n "
304+ )
271305}
0 commit comments