Skip to content

Commit be851c3

Browse files
done mertrixController with test
todo: after cal mul
1 parent 7b370c2 commit be851c3

File tree

4 files changed

+481
-88
lines changed

4 files changed

+481
-88
lines changed

src/main/scala/kernel/alu/Gemm.scala

Lines changed: 43 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import hardfloat._
99

1010
trait GEMMAccuracyConfig {
1111
val I: Int = 8
12-
val F: Int = 24
12+
val F: Int = 0
1313
}
1414

1515
case class FPConfig(width: Int) {
@@ -26,45 +26,65 @@ object GEMMDataType extends ChiselEnum {
2626
val UI, Fxp, Fp32, Fp64 = Value
2727
}
2828

29-
class PEFxp extends Module with GEMMAccuracyConfig with DebugLog {
29+
trait DataWidthConfig {
30+
def inputWidth: Int
31+
def outputWidth: Int
32+
}
33+
34+
case object FxpConfig extends DataWidthConfig with GEMMAccuracyConfig {
35+
def inputWidth: Int = I + F
36+
def outputWidth: Int = I + F
37+
}
38+
39+
case object Fp32Config extends DataWidthConfig {
40+
def inputWidth: Int = 32
41+
def outputWidth: Int = 32
42+
}
43+
44+
case object Fp64Config extends DataWidthConfig {
45+
def inputWidth: Int = 64
46+
def outputWidth: Int = 64
47+
}
48+
49+
class PEFxp(implicit config: DataWidthConfig) extends Module with GEMMAccuracyConfig with DebugLog {
3050
val io = IO(new Bundle {
31-
val in_h = Input(UInt((I + F).W))
32-
val in_v = Input(UInt((I + F).W))
33-
val out_h = Output(UInt((I + F).W))
34-
val out_v = Output(UInt((I + F).W))
35-
val out = Output(UInt((2 * (I + F)).W))
51+
val in_h = Input(UInt(config.inputWidth.W))
52+
val in_v = Input(UInt(config.inputWidth.W))
53+
val out_h = Output(UInt(config.inputWidth.W))
54+
val out_v = Output(UInt(config.inputWidth.W))
55+
val out = Output(UInt(config.outputWidth.W))
3656
val reset = Input(Bool())
3757
})
3858

39-
val res = RegInit(0.U((2 * (I + F)).W))
59+
val res = RegInit(0.U(config.outputWidth.W))
4060

4161
when(io.reset) {
4262
res := 0.U
4363
}.otherwise {
4464
val tmp = FxpMulPure(io.in_h, io.in_v)(I, F, I, F)
45-
res := FxpAddPure(res, tmp)(I * 2, F * 2, I * 2, F * 2)
65+
res := FxpAddPure(res, tmp)(I, F, I, F)
4666
}
4767

4868
io.out_h := RegNext(io.in_h)
4969
io.out_v := RegNext(io.in_v)
5070
io.out := res
5171
}
5272

53-
class PEFp(width: Int = 32, size: Int = 4) extends Module with DebugLog {
73+
class PEFp(implicit config: DataWidthConfig) extends Module with DebugLog {
5474
val io = IO(new Bundle {
55-
val in_h = Input(UInt(width.W))
56-
val in_v = Input(UInt(width.W))
57-
val out_h = Output(UInt(width.W))
58-
val out_v = Output(UInt(width.W))
59-
val out = Output(UInt(width.W))
75+
val in_h = Input(UInt(config.inputWidth.W))
76+
val in_v = Input(UInt(config.inputWidth.W))
77+
val out_h = Output(UInt(config.inputWidth.W))
78+
val out_v = Output(UInt(config.inputWidth.W))
79+
val out = Output(UInt(config.outputWidth.W))
6080
val reset = Input(Bool())
6181
})
6282

6383
io.out_h := RegNext(io.in_h)
6484
io.out_v := RegNext(io.in_v)
6585

66-
val res = RegInit(0.U(width.W))
67-
val fpConfig = FPConfig(width)
86+
val res = RegInit(0.U(config.inputWidth.W))
87+
val fpConfig = FPConfig(config.inputWidth)
6888
val FCMAModule = Module(new fudian.FCMA(fpConfig.expWidth, fpConfig.sigWidth))
6989
FCMAModule.io.a := io.in_h
7090
FCMAModule.io.b := io.in_v
@@ -75,26 +95,6 @@ class PEFp(width: Int = 32, size: Int = 4) extends Module with DebugLog {
7595
FCMAModule.io.fflags := DontCare
7696
}
7797

78-
trait DataWidthConfig {
79-
def inputWidth: Int
80-
def outputWidth: Int
81-
}
82-
83-
case object FxpConfig extends DataWidthConfig with GEMMAccuracyConfig {
84-
def inputWidth: Int = I + F
85-
def outputWidth: Int = 2 * (I + F)
86-
}
87-
88-
case object Fp32Config extends DataWidthConfig {
89-
def inputWidth: Int = 32
90-
def outputWidth: Int = 32
91-
}
92-
93-
case object Fp64Config extends DataWidthConfig {
94-
def inputWidth: Int = 64
95-
def outputWidth: Int = 64
96-
}
97-
9898
class SystolicMM(val n: Int = 4, val gemmType: GEMMDataType.Type)(implicit config: DataWidthConfig)
9999
extends Module
100100
with GEMMAccuracyConfig
@@ -110,8 +110,8 @@ class SystolicMM(val n: Int = 4, val gemmType: GEMMDataType.Type)(implicit confi
110110
val peElements = VecInit(Seq.fill(n * n) {
111111
gemmType match {
112112
case GEMMDataType.Fxp => Module(new PEFxp).io
113-
case GEMMDataType.Fp32 => Module(new PEFp(config.inputWidth)).io
114-
case GEMMDataType.Fp64 => Module(new PEFp(config.inputWidth)).io
113+
case GEMMDataType.Fp32 => Module(new PEFp).io
114+
case GEMMDataType.Fp64 => Module(new PEFp).io
115115
case _ => throw new IllegalArgumentException("Unsupported GEMM type")
116116
}
117117
})
@@ -210,7 +210,7 @@ class GEMM(val n: Int = 4, val gemmType: GEMMDataType.Type)(implicit config: Dat
210210
sysmm.io.in_a(i) := matrixAReg(i)(p(log2Ceil(n) - 1, 0))
211211
sysmm.io.in_b(i) := matrixBReg(p(log2Ceil(n) - 1, 0))(i)
212212
}
213-
// debugLog(p"in_a${i}: ${sysmm.io.in_a(i)} in_b${i}: ${sysmm.io.in_b(i)}\t")
213+
debugLog(p"in_a${i}: ${sysmm.io.in_a(i)} in_b${i}: ${sysmm.io.in_b(i)}\t")
214214
}
215215
debugLog(p"\n")
216216
cnt.inc()
@@ -220,13 +220,14 @@ class GEMM(val n: Int = 4, val gemmType: GEMMDataType.Type)(implicit config: Dat
220220

221221
when(cnt.value === (3 * n - 1).U) {
222222
resValid := true.B
223-
when(io.out.ready) {
223+
// debugLog(p"res: ${sysmm.io.out}\n", LogLevel.DEBUG)
224+
when(resValid && io.out.ready) {
224225
resValid := false.B
225226
busy := false.B
226227
cnt.reset()
227228
sysmm.io.reset := io.reset
228229
}
229230
}
230231

231-
// debugLog(p"busy: ${busy} cnt: ${cnt.value}\n", LogLevel.DEBUG)
232+
// debugLog(p"busy: ${busy} resValid: ${resValid} cnt: ${cnt.value}\n", LogLevel.DEBUG)
232233
}

src/main/scala/models/llama3/metrixController.scala

Lines changed: 65 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import chisel3.util._
66
import kernel.alu.GEMM
77
import kernel.alu.GEMMDataType
88
import kernel.alu.DataWidthConfig
9-
9+
import kernel.utils.DebugLog
1010
class metrixController extends Module with llamaConfig {}
1111

1212
class currentSystolicGroupIdx(
@@ -21,7 +21,7 @@ class currentSystolicGroupIdx(
2121

2222
val row = Output(UInt(log2Ceil(m / nk).W))
2323
val col = Output(UInt(log2Ceil(q / nk).W))
24-
val value = Output(Vec(nk * nk, UInt(config.inputWidth.W)))
24+
val value = Output(UInt((nk * nk * config.inputWidth).W))
2525
}
2626

2727
class MatrixSplit(
@@ -101,17 +101,14 @@ class MatrixRestore(
101101
for (blockRow <- 0 until numBlocksRow) {
102102
for (blockCol <- 0 until numBlocksCol) {
103103
val blockIndex = blockRow * numBlocksCol + blockCol
104-
val block = io.inBlocks(blockIndex)
104+
val block = io.inBlocks(blockIndex).asTypeOf(Vec(nk * nk, UInt(config.inputWidth.W)))
105105

106106
// 解包当前方阵块
107-
for (i <- 0 until nk) {
108-
for (j <- 0 until nk) {
109-
// 计算在输出向量中的位置
110-
val flatIndex = (blockRow * nk + i) * p + (blockCol * nk + j)
111-
// 从打包的UInt中提取对应位置的元素
112-
val elementPos = (nk * nk - 1 - (i * nk + j)) * config.inputWidth
113-
io.outMatrix(flatIndex) := block(elementPos + config.inputWidth - 1, elementPos)
114-
}
107+
for {
108+
i <- 0 until nk
109+
j <- 0 until nk
110+
} {
111+
io.outMatrix((blockRow * nk + i) * p + (blockCol * nk + j)) := block(i * nk + j)
115112
}
116113
}
117114
}
@@ -132,6 +129,32 @@ object MatrixRestore {
132129
}
133130
}
134131

132+
class BlockMatrixRestore(
133+
val nk: Int
134+
)(
135+
implicit config: DataWidthConfig)
136+
extends Module {
137+
val io = IO(new Bundle {
138+
val inBlocks = Input(UInt((nk * nk * config.inputWidth).W))
139+
val outMatrix = Output(Vec(nk * nk, UInt(config.inputWidth.W)))
140+
})
141+
142+
io.outMatrix := io.inBlocks.asTypeOf(Vec(nk * nk, UInt(config.inputWidth.W)))
143+
}
144+
145+
object BlockMatrixRestore {
146+
def apply(
147+
nk: Int
148+
)(inBlocks: UInt
149+
)(
150+
implicit config: DataWidthConfig
151+
): Vec[UInt] = {
152+
val newBlockMatrixRestore = Module(new BlockMatrixRestore(nk))
153+
newBlockMatrixRestore.io.inBlocks := inBlocks
154+
newBlockMatrixRestore.io.outMatrix
155+
}
156+
}
157+
135158
/*
136159
* matrix mul matrix
137160
* matrixA is [m, p]
@@ -149,7 +172,8 @@ class GenerationMatrixMul(
149172
)(
150173
implicit config: DataWidthConfig)
151174
extends Module
152-
with llamaConfig {
175+
with llamaConfig
176+
with DebugLog {
153177
// param check
154178
implicit val nk: Int = k * n
155179
require(m % nk == 0)
@@ -159,14 +183,18 @@ class GenerationMatrixMul(
159183
val io = IO(new Bundle {
160184
val in_a = Flipped(Decoupled(Vec(m * p, UInt(config.inputWidth.W))))
161185
val in_b = Flipped(Decoupled(Vec(p * q, UInt(config.inputWidth.W))))
162-
val result = Decoupled(Vec(m * q, UInt(config.outputWidth.W)))
163186
val current = ValidIO(new currentSystolicGroupIdx(nk, m, p, q))
164187
val reset = Input(Bool())
165188
})
166189

167190
// reshape the input data as block => [rows, cols] [nk, nk]
168-
val matrixAReshape = RegInit(MatrixSplit(m, p, nk)(io.in_a.bits))
169-
val matrixBReshape = RegInit(MatrixSplit(p, q, nk)(io.in_b.bits))
191+
val matrixAReshape = RegInit(VecInit.fill(m / nk * p / nk)(0.U((nk * nk * config.inputWidth).W)))
192+
val matrixBReshape = RegInit(VecInit.fill(p / nk * q / nk)(0.U((nk * nk * config.inputWidth).W)))
193+
matrixAReshape := MatrixSplit(m, p, nk)(io.in_a.bits)
194+
matrixBReshape := MatrixSplit(p, q, nk)(io.in_b.bits)
195+
196+
// debugLog(p"matrixAReshape: ${matrixAReshape}\n", LogLevel.DEBUG)
197+
// debugLog(p"matrixBReshape: ${matrixBReshape}\n", LogLevel.DEBUG)
170198

171199
// systolic alu
172200
val gemmGroup = Module(new GEMM(nk, gemmType))
@@ -183,8 +211,6 @@ class GenerationMatrixMul(
183211
val readyReg = RegInit(true.B)
184212
io.in_a.ready := readyReg
185213
io.in_b.ready := readyReg
186-
val validReg = RegInit(false.B)
187-
io.result.valid := validReg
188214
val dataShapedValid = RegInit(false.B)
189215
gemmGroup.io.in_a.valid := dataShapedValid
190216
gemmGroup.io.in_b.valid := dataShapedValid
@@ -197,9 +223,6 @@ class GenerationMatrixMul(
197223
val gemmInputA = matrixAReshape(blockAIdx).asTypeOf(Vec(nk * nk, UInt(config.inputWidth.W)))
198224
val gemmInputB = matrixBReshape(blockBIdx).asTypeOf(Vec(nk * nk, UInt(config.inputWidth.W)))
199225

200-
val unShapedResult = RegInit(VecInit.fill(m * q / nk / nk)(0.U((nk * nk * config.outputWidth).W)))
201-
io.result.bits := MatrixRestore(m, q, nk)(unShapedResult)
202-
203226
gemmGroup.io.in_a.bits := gemmInputA
204227
gemmGroup.io.in_b.bits := gemmInputB
205228
gemmGroup.io.reset := false.B
@@ -229,6 +252,8 @@ class GenerationMatrixMul(
229252
when(gemmGroup.io.out.valid) {
230253
val isfinal = calTimes.inc()
231254
when(isfinal) {
255+
// 当这是最后一个值的时候,不要消费这个值
256+
gemmGroup.io.out.ready := false.B
232257
stateReg := state.collect
233258
gemmGroupReady := false.B
234259
}
@@ -242,30 +267,39 @@ class GenerationMatrixMul(
242267
// still has the last gemm block to cal
243268
when(gemmGroup.io.out.valid) {
244269
gemmGroup.io.reset := true.B
270+
gemmGroup.io.out.ready := true.B
271+
245272
// collect the result of the [rowIdx, colIdx] block
246273
val afterRowLine = gemmGroup.io.out.bits
247-
unShapedResult(rowIdx.value * cols.U + colIdx.value) := afterRowLine.asTypeOf(
248-
UInt((nk * nk * config.outputWidth).W)
249-
)
250274

251275
// send the current systolic group idx
252276
io.current.valid := true.B
253277
io.current.bits.row := rowIdx.value
254278
io.current.bits.col := colIdx.value
255-
io.current.bits.value := afterRowLine
279+
io.current.bits.value := afterRowLine.asUInt
256280

257281
val isRowEnd = colIdx.inc()
258-
stateReg := Mux(rowIdx.inc() && isRowEnd, state.done, state.cal)
282+
when(!isRowEnd) {
283+
stateReg := state.cal
284+
}.otherwise {
285+
val isAllEnd = rowIdx.inc()
286+
when(!isAllEnd) {
287+
stateReg := state.cal
288+
}.otherwise {
289+
dataShapedValid := false.B
290+
stateReg := state.done
291+
}
292+
}
259293
}
260294
}
261295

262296
is(state.done) {
263-
validReg := true.B
264-
when(io.result.ready) {
265-
validReg := false.B
266-
stateReg := state.idle
267-
readyReg := true.B
268-
}
297+
stateReg := state.idle
298+
readyReg := true.B
269299
}
270300
}
301+
302+
debugLog(
303+
p"stateReg: $stateReg,\t currentValid: ${io.current.valid},\t rowIdx: ${rowIdx.value},\t colIdx: ${colIdx.value},\t gemmValid: ${gemmGroup.io.out.valid}\n"
304+
)
271305
}

src/test/scala/kernel/alu/GEMMTest.scala

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,10 @@ class GEMMTest extends AnyFlatSpec with ChiselScalatestTester with ParallelTestE
8282
}.fork {
8383
var resC = 0
8484
while (resC < arraySize) {
85+
dut.io.reset.poke(false.B)
8586
if (dut.io.out.valid.peekBoolean()) {
8687
dut.io.out.ready.poke(true.B)
88+
dut.io.reset.poke(true.B)
8789
val out = checkresult()
8890
var invalidcnt = 0
8991
for (i <- out.zip(matrixYArray(resC).flatten.toList)) {
@@ -342,27 +344,27 @@ class GEMMTest extends AnyFlatSpec with ChiselScalatestTester with ParallelTestE
342344
// test(new PEFp(32, 4)).withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation))(testPEFp)
343345
// }
344346

345-
"SystolicMM basic test on Verilator" should "pass" in {
346-
implicit val fxpConfig: DataWidthConfig = FxpConfig
347-
test(new SystolicMM(4, GEMMDataType.Fxp))
348-
.withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation))(testSystolicMM)
349-
}
347+
// "SystolicMM basic test on Verilator" should "pass" in {
348+
// implicit val fxpConfig: DataWidthConfig = FxpConfig
349+
// test(new SystolicMM(4, GEMMDataType.Fxp))
350+
// .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation))(testSystolicMM)
351+
// }
350352

351-
"SystolicMMFp basic test on Verilator" should "pass" in {
352-
implicit val fxpConfig: DataWidthConfig = Fp32Config
353-
test(new SystolicMM(4, GEMMDataType.Fp32))
354-
.withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation))(testSystolicMMFp)
355-
}
353+
// "SystolicMMFp basic test on Verilator" should "pass" in {
354+
// implicit val fxpConfig: DataWidthConfig = Fp32Config
355+
// test(new SystolicMM(4, GEMMDataType.Fp32))
356+
// .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation))(testSystolicMMFp)
357+
// }
356358

357359
"GeMM basic test on Verilator" should "pass" in {
358360
implicit val fxpConfig: DataWidthConfig = FxpConfig
359361
test(new GEMM(6, GEMMDataType.Fxp))
360362
.withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation))(testGEMM)
361363
}
362364

363-
"GeMMFp basic test on Verilator" should "pass" in {
364-
implicit val fxpConfig: DataWidthConfig = Fp32Config
365-
test(new GEMM(6, GEMMDataType.Fp32))
366-
.withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation))(testGEMMFp)
367-
}
365+
// "GeMMFp basic test on Verilator" should "pass" in {
366+
// implicit val fxpConfig: DataWidthConfig = Fp32Config
367+
// test(new GEMM(6, GEMMDataType.Fp32))
368+
// .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation))(testGEMMFp)
369+
// }
368370
}

0 commit comments

Comments
 (0)