@@ -4,52 +4,268 @@ import common.llamaConfig
44import chisel3 ._
55import chisel3 .util ._
66import kernel .alu .GEMM
7- import kernel .utils .ForwardingMemory
7+ import kernel .alu .GEMMDataType
8+ import kernel .alu .DataWidthConfig
9+
810class metrixController extends Module with llamaConfig {}
911
12+ class currentSystolicGroupIdx (
13+ val nk : Int ,
14+ val m : Int ,
15+ val p : Int ,
16+ val q : Int
17+ )(
18+ implicit config : DataWidthConfig )
19+ extends Bundle
20+ with llamaConfig {
21+
22+ val row = Output (UInt (log2Ceil(m / nk).W ))
23+ val col = Output (UInt (log2Ceil(q / nk).W ))
24+ val value = Output (Vec (nk * nk, UInt (config.inputWidth.W )))
25+ }
26+
27+ class MatrixSplit (
28+ val m : Int ,
29+ val p : Int ,
30+ val nk : Int
31+ )(
32+ implicit config : DataWidthConfig )
33+ extends Module {
34+ require(m % nk == 0 && p % nk == 0 , " m and p must be divisible by nk" )
35+
36+ val io = IO (new Bundle {
37+ val inMatrix = Input (Vec (m * p, UInt (config.inputWidth.W )))
38+ val outBlocks = Output (Vec (m / nk * p / nk, UInt ((nk * nk * config.inputWidth).W )))
39+ })
40+
41+ val numBlocksRow = m / nk
42+ val numBlocksCol = p / nk
43+
44+ for (blockRow <- 0 until numBlocksRow) {
45+ for (blockCol <- 0 until numBlocksCol) {
46+ val blockIndex = blockRow * numBlocksCol + blockCol
47+
48+ // 收集当前方阵块的所有元素
49+ val blockElements = for {
50+ i <- 0 until nk
51+ j <- 0 until nk
52+ } yield {
53+ // 计算在一维输入向量中的索引
54+ val flatIndex = (blockRow * nk + i) * p + (blockCol * nk + j)
55+ io.inMatrix(flatIndex)
56+ }
57+
58+ // 将方阵元素连接成一个UInt
59+ io.outBlocks(blockIndex) := VecInit (blockElements).asUInt
60+ }
61+ }
62+ }
63+
64+ object MatrixSplit {
65+ def apply (
66+ m : Int ,
67+ p : Int ,
68+ nk : Int
69+ )(inMatrix : Vec [UInt ]
70+ )(
71+ implicit config : DataWidthConfig
72+ ): Vec [UInt ] = {
73+ val newMatrixSplit = Module (new MatrixSplit (m, p, nk))
74+ newMatrixSplit.io.inMatrix := inMatrix
75+ newMatrixSplit.io.outBlocks
76+ }
77+ }
78+
79+ class MatrixRestore (
80+ val m : Int ,
81+ val p : Int ,
82+ val nk : Int
83+ )(
84+ implicit config : DataWidthConfig )
85+ extends Module {
86+ require(m % nk == 0 && p % nk == 0 , " m and p must be divisible by nk" )
87+
88+ val io = IO (new Bundle {
89+ // 输入是打包的方阵序列
90+ val inBlocks = Input (Vec (m / nk * p / nk, UInt ((nk * nk * config.inputWidth).W )))
91+ // 输出是一维向量表示的矩阵
92+ val outMatrix = Output (Vec (m * p, UInt (config.inputWidth.W )))
93+ })
94+
95+ val numBlocksRow = m / nk
96+ val numBlocksCol = p / nk
97+
98+ // 初始化输出矩阵
99+ io.outMatrix.foreach(_ := 0 .U )
100+
101+ for (blockRow <- 0 until numBlocksRow) {
102+ for (blockCol <- 0 until numBlocksCol) {
103+ val blockIndex = blockRow * numBlocksCol + blockCol
104+ val block = io.inBlocks(blockIndex)
105+
106+ // 解包当前方阵块
107+ for (i <- 0 until nk) {
108+ for (j <- 0 until nk) {
109+ // 计算在输出向量中的位置
110+ val flatIndex = (blockRow * nk + i) * p + (blockCol * nk + j)
111+ // 从打包的UInt中提取对应位置的元素
112+ val elementPos = (nk * nk - 1 - (i * nk + j)) * config.inputWidth
113+ io.outMatrix(flatIndex) := block(elementPos + config.inputWidth - 1 , elementPos)
114+ }
115+ }
116+ }
117+ }
118+ }
119+
120+ object MatrixRestore {
121+ def apply (
122+ m : Int ,
123+ p : Int ,
124+ nk : Int
125+ )(inBlocks : Vec [UInt ]
126+ )(
127+ implicit config : DataWidthConfig
128+ ): Vec [UInt ] = {
129+ val newMatrixRestore = Module (new MatrixRestore (m, p, nk))
130+ newMatrixRestore.io.inBlocks := inBlocks
131+ newMatrixRestore.io.outMatrix
132+ }
133+ }
134+
10135/*
11136 * matrix mul matrix
12- * matrixA is [inputN, dim]
13- * matrixB is [dim, head_dim]
14- * matrixC is [inputN, head_dim]
137+ * matrixA is [m, p]
138+ * matrixB is [p, q]
139+ * use k^2 systolic-groups with dim as n to do the matrix mul
140+ * designed for QKV generation, but has a output for current systolic group idx
15141 */
16- class QKVGenerationMul extends Module with llamaConfig {
142+ class GenerationMatrixMul (
143+ val k : Int ,
144+ val n : Int ,
145+ val m : Int ,
146+ val p : Int ,
147+ val q : Int ,
148+ val gemmType : GEMMDataType .Type
149+ )(
150+ implicit config : DataWidthConfig )
151+ extends Module
152+ with llamaConfig {
153+ // param check
154+ implicit val nk : Int = k * n
155+ require(m % nk == 0 )
156+ require(p % nk == 0 )
157+ require(q % nk == 0 )
158+
17159 val io = IO (new Bundle {
18- val matrixAPart = Input (Vec (minN, Vec (dim, UInt (bits.W ))))
160+ val in_a = Flipped (Decoupled (Vec (m * p, UInt (config.inputWidth.W ))))
161+ val in_b = Flipped (Decoupled (Vec (p * q, UInt (config.inputWidth.W ))))
162+ val result = Decoupled (Vec (m * q, UInt (config.outputWidth.W )))
163+ val current = ValidIO (new currentSystolicGroupIdx(nk, m, p, q))
164+ val reset = Input (Bool ())
19165 })
20- }
21166
22- // class SystolicGroup extends Module with llamaConfig {
23- // val io = IO(new Bundle {
24- // val matrixAVec = Flipped(Decoupled(Vec(systolicGroupSize, Vec(systolicSize, Vec(systolicSize, UInt(bits.W))))))
25- // val matrixBVec = Flipped(Decoupled(Vec(systolicGroupSize, Vec(systolicSize, Vec(systolicSize, UInt(bits.W))))))
26- // val matrixCVec = Decoupled(Vec(systolicGroupSize, Vec(systolicSize * systolicSize, UInt(bits.W))))
27- // })
167+ // reshape the input data as block => [rows, cols] [nk, nk]
168+ val matrixAReshape = RegInit (MatrixSplit (m, p, nk)(io.in_a.bits))
169+ val matrixBReshape = RegInit (MatrixSplit (p, q, nk)(io.in_b.bits))
170+
171+ // systolic alu
172+ val gemmGroup = Module (new GEMM (nk, gemmType))
173+
174+ // systolic group idx
175+ val rows = m / nk
176+ val cols = q / nk
177+ val middle = p / nk
178+ val rowIdx = Counter (rows)
179+ val colIdx = Counter (cols)
180+ val calTimes = Counter (middle)
181+
182+ val dataValid = io.in_a.valid && io.in_b.valid
183+ val readyReg = RegInit (true .B )
184+ io.in_a.ready := readyReg
185+ io.in_b.ready := readyReg
186+ val validReg = RegInit (false .B )
187+ io.result.valid := validReg
188+ val dataShapedValid = RegInit (false .B )
189+ gemmGroup.io.in_a.valid := dataShapedValid
190+ gemmGroup.io.in_b.valid := dataShapedValid
191+
192+ io.current.valid := false .B
193+ io.current.bits := DontCare
194+
195+ val blockAIdx = rowIdx.value * middle.U + calTimes.value
196+ val blockBIdx = calTimes.value * cols.U + colIdx.value
197+ val gemmInputA = matrixAReshape(blockAIdx).asTypeOf(Vec (nk * nk, UInt (config.inputWidth.W )))
198+ val gemmInputB = matrixBReshape(blockBIdx).asTypeOf(Vec (nk * nk, UInt (config.inputWidth.W )))
199+
200+ val unShapedResult = RegInit (VecInit .fill(m * q / nk / nk)(0 .U ((nk * nk * config.outputWidth).W )))
201+ io.result.bits := MatrixRestore (m, q, nk)(unShapedResult)
202+
203+ gemmGroup.io.in_a.bits := gemmInputA
204+ gemmGroup.io.in_b.bits := gemmInputB
205+ gemmGroup.io.reset := false .B
28206
29- // val gemmRow = for (i <- 0 until systolicGroupSize) yield Module(new GEMM(16))
207+ val gemmGroupReady = RegInit (false .B )
208+ gemmGroup.io.out.ready := gemmGroupReady
30209
31- // val matrixAValid = io.matrixAVec.valid
32- // val matrixBValid = io.matrixBVec.valid
210+ object state extends ChiselEnum {
211+ val idle, cal, collect, done = Value
212+ }
213+ val stateReg = RegInit (state.idle)
33214
34- // io.matrixAVec.ready := gemmRow.map(_.InputA.ready).reduce(_ && _)
35- // io.matrixBVec.ready := gemmRow.map(_.InputB.ready).reduce(_ && _)
215+ switch(stateReg) {
216+ is(state.idle) {
217+ when(dataValid) {
218+ dataShapedValid := true .B
219+ stateReg := state.cal
220+ readyReg := false .B
221+ }
222+ }
36223
37- // val matrixCValid = gemmRow.map(_.OutputPipe.valid).reduce(_ && _)
224+ is(state.cal) {
225+ // acc mode
226+ gemmGroup.io.reset := false .B
227+ // when a gemm block is done, io.current will send data
228+ gemmGroup.io.out.ready := true .B
229+ when(gemmGroup.io.out.valid) {
230+ val isfinal = calTimes.inc()
231+ when(isfinal) {
232+ stateReg := state.collect
233+ gemmGroupReady := false .B
234+ }
235+ }
236+ }
38237
39- // for (i <- 0 until systolicGroupSize) {
40- // gemmRow(i).InputA.bits := io.matrixAVec.bits(i)
41- // gemmRow(i).InputA.valid := matrixAValid
42- // gemmRow(i).InputB.bits := io.matrixBVec.bits(i)
43- // gemmRow(i).InputB.valid := matrixBValid
44- // gemmRow(i).accMode := false.B
45- // io.matrixCVec.bits(i) := gemmRow(i).OutputPipe.bits
46- // gemmRow(i).OutputPipe.ready := io.matrixCVec.ready
47- // }
238+ is(state.collect) {
239+ // collect mode
240+ gemmGroup.io.reset := false .B
241+ gemmGroupReady := true .B
242+ // still has the last gemm block to cal
243+ when(gemmGroup.io.out.valid) {
244+ gemmGroup.io.reset := true .B
245+ // collect the result of the [rowIdx, colIdx] block
246+ val afterRowLine = gemmGroup.io.out.bits
247+ unShapedResult(rowIdx.value * cols.U + colIdx.value) := afterRowLine.asTypeOf(
248+ UInt ((nk * nk * config.outputWidth).W )
249+ )
48250
49- // io.matrixCVec.valid := matrixCValid
251+ // send the current systolic group idx
252+ io.current.valid := true .B
253+ io.current.bits.row := rowIdx.value
254+ io.current.bits.col := colIdx.value
255+ io.current.bits.value := afterRowLine
50256
51- // }
257+ val isRowEnd = colIdx.inc()
258+ stateReg := Mux (rowIdx.inc() && isRowEnd, state.done, state.cal)
259+ }
260+ }
52261
53- class GEMMController (val x : Int , val y : Int ) extends Module with llamaConfig {
54- assert(x % (systolicGroupSize * systolicSize) == 0 && y % systolicSize == 0 )
262+ is(state.done) {
263+ validReg := true .B
264+ when(io.result.ready) {
265+ validReg := false .B
266+ stateReg := state.idle
267+ readyReg := true .B
268+ }
269+ }
270+ }
55271}
0 commit comments