Skip to content

Commit 7b370c2

Browse files
add gemm using prefix pqm
1 parent 16777a0 commit 7b370c2

File tree

1 file changed

+248
-32
lines changed

1 file changed

+248
-32
lines changed

src/main/scala/models/llama3/metrixController.scala

Lines changed: 248 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,52 +4,268 @@ import common.llamaConfig
44
import chisel3._
55
import chisel3.util._
66
import kernel.alu.GEMM
7-
import kernel.utils.ForwardingMemory
7+
import kernel.alu.GEMMDataType
8+
import kernel.alu.DataWidthConfig
9+
810
class metrixController extends Module with llamaConfig {}
911

12+
class currentSystolicGroupIdx(
13+
val nk: Int,
14+
val m: Int,
15+
val p: Int,
16+
val q: Int
17+
)(
18+
implicit config: DataWidthConfig)
19+
extends Bundle
20+
with llamaConfig {
21+
22+
val row = Output(UInt(log2Ceil(m / nk).W))
23+
val col = Output(UInt(log2Ceil(q / nk).W))
24+
val value = Output(Vec(nk * nk, UInt(config.inputWidth.W)))
25+
}
26+
27+
class MatrixSplit(
28+
val m: Int,
29+
val p: Int,
30+
val nk: Int
31+
)(
32+
implicit config: DataWidthConfig)
33+
extends Module {
34+
require(m % nk == 0 && p % nk == 0, "m and p must be divisible by nk")
35+
36+
val io = IO(new Bundle {
37+
val inMatrix = Input(Vec(m * p, UInt(config.inputWidth.W)))
38+
val outBlocks = Output(Vec(m / nk * p / nk, UInt((nk * nk * config.inputWidth).W)))
39+
})
40+
41+
val numBlocksRow = m / nk
42+
val numBlocksCol = p / nk
43+
44+
for (blockRow <- 0 until numBlocksRow) {
45+
for (blockCol <- 0 until numBlocksCol) {
46+
val blockIndex = blockRow * numBlocksCol + blockCol
47+
48+
// 收集当前方阵块的所有元素
49+
val blockElements = for {
50+
i <- 0 until nk
51+
j <- 0 until nk
52+
} yield {
53+
// 计算在一维输入向量中的索引
54+
val flatIndex = (blockRow * nk + i) * p + (blockCol * nk + j)
55+
io.inMatrix(flatIndex)
56+
}
57+
58+
// 将方阵元素连接成一个UInt
59+
io.outBlocks(blockIndex) := VecInit(blockElements).asUInt
60+
}
61+
}
62+
}
63+
64+
object MatrixSplit {
65+
def apply(
66+
m: Int,
67+
p: Int,
68+
nk: Int
69+
)(inMatrix: Vec[UInt]
70+
)(
71+
implicit config: DataWidthConfig
72+
): Vec[UInt] = {
73+
val newMatrixSplit = Module(new MatrixSplit(m, p, nk))
74+
newMatrixSplit.io.inMatrix := inMatrix
75+
newMatrixSplit.io.outBlocks
76+
}
77+
}
78+
79+
class MatrixRestore(
80+
val m: Int,
81+
val p: Int,
82+
val nk: Int
83+
)(
84+
implicit config: DataWidthConfig)
85+
extends Module {
86+
require(m % nk == 0 && p % nk == 0, "m and p must be divisible by nk")
87+
88+
val io = IO(new Bundle {
89+
// 输入是打包的方阵序列
90+
val inBlocks = Input(Vec(m / nk * p / nk, UInt((nk * nk * config.inputWidth).W)))
91+
// 输出是一维向量表示的矩阵
92+
val outMatrix = Output(Vec(m * p, UInt(config.inputWidth.W)))
93+
})
94+
95+
val numBlocksRow = m / nk
96+
val numBlocksCol = p / nk
97+
98+
// 初始化输出矩阵
99+
io.outMatrix.foreach(_ := 0.U)
100+
101+
for (blockRow <- 0 until numBlocksRow) {
102+
for (blockCol <- 0 until numBlocksCol) {
103+
val blockIndex = blockRow * numBlocksCol + blockCol
104+
val block = io.inBlocks(blockIndex)
105+
106+
// 解包当前方阵块
107+
for (i <- 0 until nk) {
108+
for (j <- 0 until nk) {
109+
// 计算在输出向量中的位置
110+
val flatIndex = (blockRow * nk + i) * p + (blockCol * nk + j)
111+
// 从打包的UInt中提取对应位置的元素
112+
val elementPos = (nk * nk - 1 - (i * nk + j)) * config.inputWidth
113+
io.outMatrix(flatIndex) := block(elementPos + config.inputWidth - 1, elementPos)
114+
}
115+
}
116+
}
117+
}
118+
}
119+
120+
object MatrixRestore {
121+
def apply(
122+
m: Int,
123+
p: Int,
124+
nk: Int
125+
)(inBlocks: Vec[UInt]
126+
)(
127+
implicit config: DataWidthConfig
128+
): Vec[UInt] = {
129+
val newMatrixRestore = Module(new MatrixRestore(m, p, nk))
130+
newMatrixRestore.io.inBlocks := inBlocks
131+
newMatrixRestore.io.outMatrix
132+
}
133+
}
134+
10135
/*
11136
* matrix mul matrix
12-
* matrixA is [inputN, dim]
13-
* matrixB is [dim, head_dim]
14-
* matrixC is [inputN, head_dim]
137+
* matrixA is [m, p]
138+
* matrixB is [p, q]
139+
* use k^2 systolic-groups with dim as n to do the matrix mul
140+
* designed for QKV generation, but has a output for current systolic group idx
15141
*/
16-
class QKVGenerationMul extends Module with llamaConfig {
142+
class GenerationMatrixMul(
143+
val k: Int,
144+
val n: Int,
145+
val m: Int,
146+
val p: Int,
147+
val q: Int,
148+
val gemmType: GEMMDataType.Type
149+
)(
150+
implicit config: DataWidthConfig)
151+
extends Module
152+
with llamaConfig {
153+
// param check
154+
implicit val nk: Int = k * n
155+
require(m % nk == 0)
156+
require(p % nk == 0)
157+
require(q % nk == 0)
158+
17159
val io = IO(new Bundle {
18-
val matrixAPart = Input(Vec(minN, Vec(dim, UInt(bits.W))))
160+
val in_a = Flipped(Decoupled(Vec(m * p, UInt(config.inputWidth.W))))
161+
val in_b = Flipped(Decoupled(Vec(p * q, UInt(config.inputWidth.W))))
162+
val result = Decoupled(Vec(m * q, UInt(config.outputWidth.W)))
163+
val current = ValidIO(new currentSystolicGroupIdx(nk, m, p, q))
164+
val reset = Input(Bool())
19165
})
20-
}
21166

22-
// class SystolicGroup extends Module with llamaConfig {
23-
// val io = IO(new Bundle {
24-
// val matrixAVec = Flipped(Decoupled(Vec(systolicGroupSize, Vec(systolicSize, Vec(systolicSize, UInt(bits.W))))))
25-
// val matrixBVec = Flipped(Decoupled(Vec(systolicGroupSize, Vec(systolicSize, Vec(systolicSize, UInt(bits.W))))))
26-
// val matrixCVec = Decoupled(Vec(systolicGroupSize, Vec(systolicSize * systolicSize, UInt(bits.W))))
27-
// })
167+
// reshape the input data as block => [rows, cols] [nk, nk]
168+
val matrixAReshape = RegInit(MatrixSplit(m, p, nk)(io.in_a.bits))
169+
val matrixBReshape = RegInit(MatrixSplit(p, q, nk)(io.in_b.bits))
170+
171+
// systolic alu
172+
val gemmGroup = Module(new GEMM(nk, gemmType))
173+
174+
// systolic group idx
175+
val rows = m / nk
176+
val cols = q / nk
177+
val middle = p / nk
178+
val rowIdx = Counter(rows)
179+
val colIdx = Counter(cols)
180+
val calTimes = Counter(middle)
181+
182+
val dataValid = io.in_a.valid && io.in_b.valid
183+
val readyReg = RegInit(true.B)
184+
io.in_a.ready := readyReg
185+
io.in_b.ready := readyReg
186+
val validReg = RegInit(false.B)
187+
io.result.valid := validReg
188+
val dataShapedValid = RegInit(false.B)
189+
gemmGroup.io.in_a.valid := dataShapedValid
190+
gemmGroup.io.in_b.valid := dataShapedValid
191+
192+
io.current.valid := false.B
193+
io.current.bits := DontCare
194+
195+
val blockAIdx = rowIdx.value * middle.U + calTimes.value
196+
val blockBIdx = calTimes.value * cols.U + colIdx.value
197+
val gemmInputA = matrixAReshape(blockAIdx).asTypeOf(Vec(nk * nk, UInt(config.inputWidth.W)))
198+
val gemmInputB = matrixBReshape(blockBIdx).asTypeOf(Vec(nk * nk, UInt(config.inputWidth.W)))
199+
200+
val unShapedResult = RegInit(VecInit.fill(m * q / nk / nk)(0.U((nk * nk * config.outputWidth).W)))
201+
io.result.bits := MatrixRestore(m, q, nk)(unShapedResult)
202+
203+
gemmGroup.io.in_a.bits := gemmInputA
204+
gemmGroup.io.in_b.bits := gemmInputB
205+
gemmGroup.io.reset := false.B
28206

29-
// val gemmRow = for (i <- 0 until systolicGroupSize) yield Module(new GEMM(16))
207+
val gemmGroupReady = RegInit(false.B)
208+
gemmGroup.io.out.ready := gemmGroupReady
30209

31-
// val matrixAValid = io.matrixAVec.valid
32-
// val matrixBValid = io.matrixBVec.valid
210+
object state extends ChiselEnum {
211+
val idle, cal, collect, done = Value
212+
}
213+
val stateReg = RegInit(state.idle)
33214

34-
// io.matrixAVec.ready := gemmRow.map(_.InputA.ready).reduce(_ && _)
35-
// io.matrixBVec.ready := gemmRow.map(_.InputB.ready).reduce(_ && _)
215+
switch(stateReg) {
216+
is(state.idle) {
217+
when(dataValid) {
218+
dataShapedValid := true.B
219+
stateReg := state.cal
220+
readyReg := false.B
221+
}
222+
}
36223

37-
// val matrixCValid = gemmRow.map(_.OutputPipe.valid).reduce(_ && _)
224+
is(state.cal) {
225+
// acc mode
226+
gemmGroup.io.reset := false.B
227+
// when a gemm block is done, io.current will send data
228+
gemmGroup.io.out.ready := true.B
229+
when(gemmGroup.io.out.valid) {
230+
val isfinal = calTimes.inc()
231+
when(isfinal) {
232+
stateReg := state.collect
233+
gemmGroupReady := false.B
234+
}
235+
}
236+
}
38237

39-
// for (i <- 0 until systolicGroupSize) {
40-
// gemmRow(i).InputA.bits := io.matrixAVec.bits(i)
41-
// gemmRow(i).InputA.valid := matrixAValid
42-
// gemmRow(i).InputB.bits := io.matrixBVec.bits(i)
43-
// gemmRow(i).InputB.valid := matrixBValid
44-
// gemmRow(i).accMode := false.B
45-
// io.matrixCVec.bits(i) := gemmRow(i).OutputPipe.bits
46-
// gemmRow(i).OutputPipe.ready := io.matrixCVec.ready
47-
// }
238+
is(state.collect) {
239+
// collect mode
240+
gemmGroup.io.reset := false.B
241+
gemmGroupReady := true.B
242+
// still has the last gemm block to cal
243+
when(gemmGroup.io.out.valid) {
244+
gemmGroup.io.reset := true.B
245+
// collect the result of the [rowIdx, colIdx] block
246+
val afterRowLine = gemmGroup.io.out.bits
247+
unShapedResult(rowIdx.value * cols.U + colIdx.value) := afterRowLine.asTypeOf(
248+
UInt((nk * nk * config.outputWidth).W)
249+
)
48250

49-
// io.matrixCVec.valid := matrixCValid
251+
// send the current systolic group idx
252+
io.current.valid := true.B
253+
io.current.bits.row := rowIdx.value
254+
io.current.bits.col := colIdx.value
255+
io.current.bits.value := afterRowLine
50256

51-
// }
257+
val isRowEnd = colIdx.inc()
258+
stateReg := Mux(rowIdx.inc() && isRowEnd, state.done, state.cal)
259+
}
260+
}
52261

53-
class GEMMController(val x: Int, val y: Int) extends Module with llamaConfig {
54-
assert(x % (systolicGroupSize * systolicSize) == 0 && y % systolicSize == 0)
262+
is(state.done) {
263+
validReg := true.B
264+
when(io.result.ready) {
265+
validReg := false.B
266+
stateReg := state.idle
267+
readyReg := true.B
268+
}
269+
}
270+
}
55271
}

0 commit comments

Comments
 (0)