Skip to content

Commit 123050b

Browse files
add gemmtest with paralleltest
1 parent 71443ab commit 123050b

File tree

3 files changed

+156
-116
lines changed

3 files changed

+156
-116
lines changed

build.sbt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ lazy val commonChiselSettings = Seq(
99
libraryDependencies ++= Seq(
1010
"org.chipsalliance" %% "chisel" % chiselVersion,
1111
"edu.berkeley.cs" %% "chiseltest" % "6.0.0",
12+
"org.scala-lang.modules" %% "scala-parallel-collections" % "1.0.4"
1213
),
1314
resolvers += "huaweiyun".at("https://repo.huaweicloud.com/repository/maven/"),
1415
scalacOptions ++= Seq(
@@ -49,4 +50,4 @@ lazy val hardfloat = Project("hardfloat", file("dependencies/hardfloat/hardfloat
4950
.settings(
5051
Compile / scalaSource := baseDirectory.value / "main" / "scala",
5152
Compile / resourceDirectory := baseDirectory.value / "main" / "resources"
52-
)
53+
)

src/main/scala/kernel/alu/Gemm.scala

Lines changed: 74 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ case class FPConfig(width: Int) {
2121
fpParams.getOrElse(width, throw new IllegalArgumentException(s"Unsupported floating point width: $width"))
2222
}
2323

24-
object GEMMType extends ChiselEnum {
24+
object GEMMDataType extends ChiselEnum {
2525
// UInt, FixedPoint, FloatPoint(32), FloatPoint(64)
2626
val UI, Fxp, Fp32, Fp64 = Value
2727
}
@@ -75,84 +75,6 @@ class PEFp(width: Int = 32, size: Int = 4) extends Module with DebugLog {
7575
FCMAModule.io.fflags := DontCare
7676
}
7777

78-
// Compute A * B, where A and B are both square matrix.
79-
class GEMM(val n: Int = 4, val gemmType: GEMMType.Type)(implicit config: DataTypeConfig)
80-
extends Module
81-
with GEMMAccuracyConfig
82-
with DebugLog {
83-
84-
val io = IO(new Bundle {
85-
val in_a = Flipped(Decoupled(Vec(n, Vec(n, UInt(config.inputWidth.W)))))
86-
val in_b = Flipped(Decoupled(Vec(n, Vec(n, UInt(config.inputWidth.W)))))
87-
val out = Decoupled(Vec(n * n, UInt(config.outputWidth.W)))
88-
val reset = Input(Bool())
89-
})
90-
91-
// accumulate mode
92-
val accMode = IO(Input(Bool()))
93-
val accReg = RegInit(false.B)
94-
95-
val dataValid = io.in_a.valid && io.in_b.valid
96-
97-
val busy = RegInit(false.B)
98-
99-
io.in_a.ready := !busy
100-
io.in_b.ready := !busy
101-
102-
val matrixAReg = RegInit(VecInit.fill(n)(VecInit.fill(n)(0.U(config.inputWidth.W))))
103-
val matrixBReg = RegInit(VecInit.fill(n)(VecInit.fill(n)(0.U(config.inputWidth.W))))
104-
105-
val sysmm = Module(new SystolicMM(n, gemmType))
106-
sysmm.io.reset := false.B
107-
for (i <- 0 until n) {
108-
sysmm.io.in_a(i) := 0.U
109-
sysmm.io.in_b(i) := 0.U
110-
}
111-
112-
when(dataValid) {
113-
for (i <- 0 until n) {
114-
for (j <- 0 until n) {
115-
matrixAReg(i)(j) := io.in_a.bits(i)(j)
116-
matrixBReg(i)(j) := io.in_b.bits(i)(j)
117-
}
118-
}
119-
busy := true.B
120-
}
121-
122-
val resValid = RegInit(false.B)
123-
io.out.valid := resValid
124-
io.out.bits := sysmm.io.out
125-
126-
val cnt = Counter(3 * n)
127-
when(busy && cnt.value < (2 * n).U) {
128-
for (i <- 0 until n) {
129-
val temp = cnt.value >= i.U
130-
val p = Mux(temp, cnt.value - i.U, 0.U)
131-
when(temp && p < n.U) {
132-
sysmm.io.in_a(i) := matrixAReg(i)(p(log2Ceil(n) - 1, 0))
133-
sysmm.io.in_b(i) := matrixBReg(p(log2Ceil(n) - 1, 0))(i)
134-
}
135-
// debugLog(p"in_a${i}: ${sysmm.io.in_a(i)} in_b${i}: ${sysmm.io.in_b(i)}\t")
136-
}
137-
debugLog(p"\n")
138-
cnt.inc()
139-
}.elsewhen(busy && cnt.value < (3 * n - 1).U) {
140-
cnt.inc()
141-
}
142-
143-
when(cnt.value === (3 * n - 1).U) {
144-
resValid := true.B
145-
when(io.out.ready) {
146-
resValid := false.B
147-
busy := false.B
148-
cnt.reset()
149-
sysmm.io.reset := true.B
150-
}
151-
}
152-
153-
// debugLog(p"busy: ${busy} cnt: ${cnt.value}\n", LogLevel.DEBUG)
154-
}
155-
15678
trait DataTypeConfig {
15779
def inputWidth: Int
15880
def outputWidth: Int
@@ -173,7 +95,7 @@ case object Fp64Config extends DataTypeConfig {
17395
def outputWidth: Int = 64
17496
}
17597

176-
class SystolicMM(val n: Int = 4, val gemmType: GEMMType.Type)(implicit config: DataTypeConfig)
98+
class SystolicMM(val n: Int = 4, val gemmType: GEMMDataType.Type)(implicit config: DataTypeConfig)
17799
extends Module
178100
with GEMMAccuracyConfig
179101
with DebugLog {
@@ -187,16 +109,14 @@ class SystolicMM(val n: Int = 4, val gemmType: GEMMType.Type)(implicit config: D
187109

188110
val peElements = VecInit(Seq.fill(n * n) {
189111
gemmType match {
190-
case GEMMType.Fxp => Module(new PEFxp).io
191-
case GEMMType.Fp32 => Module(new PEFp(config.inputWidth)).io
192-
case GEMMType.Fp64 => Module(new PEFp(config.inputWidth)).io
112+
case GEMMDataType.Fxp => Module(new PEFxp).io
113+
case GEMMDataType.Fp32 => Module(new PEFp(config.inputWidth)).io
114+
case GEMMDataType.Fp64 => Module(new PEFp(config.inputWidth)).io
193115
case _ => throw new IllegalArgumentException("Unsupported GEMM type")
194116
}
195117
})
196118

197-
for (i <- 0 until n * n) {
198-
peElements(i).reset := io.reset
199-
}
119+
peElements.foreach(_.reset := io.reset)
200120

201121
val h_wires = Wire(Vec((n - 1) * n, UInt(config.inputWidth.W)))
202122
val v_wires = Wire(Vec(n * (n - 1), UInt(config.inputWidth.W)))
@@ -237,34 +157,80 @@ class SystolicMM(val n: Int = 4, val gemmType: GEMMType.Type)(implicit config: D
237157
}
238158
}
239159

240-
// each ProcElem (PE) is mapped to each element in a NxN output matrix
241-
class ProcElem(val bits: Int = 8) extends Module {
242-
val io = IO(new Bundle {
243-
// input from horizontal direction
244-
val in_h = Input(UInt(bits.W))
245-
// input from vertical direction
246-
val in_v = Input(UInt(bits.W))
247-
// output to horizontal direction
248-
val out_h = Output(UInt((bits * 2).W))
249-
// output to vertical direction
250-
val out_v = Output(UInt((bits * 2).W))
251-
// the result after N cycles once this receives the first actual data
252-
val out = Output(UInt((bits * 2).W))
160+
// Compute A * B, where A and B are both square matrix.
161+
class GEMM(val n: Int = 4, val gemmType: GEMMDataType.Type)(implicit config: DataTypeConfig)
162+
extends Module
163+
with GEMMAccuracyConfig
164+
with DebugLog {
253165

166+
val io = IO(new Bundle {
167+
val in_a = Flipped(Decoupled(Vec(n, Vec(n, UInt(config.inputWidth.W)))))
168+
val in_b = Flipped(Decoupled(Vec(n, Vec(n, UInt(config.inputWidth.W)))))
169+
val out = Decoupled(Vec(n * n, UInt(config.outputWidth.W)))
254170
val reset = Input(Bool())
255171
})
256172

257-
val res = RegInit(0.U((bits * 2).W))
173+
// accumulate mode
174+
val accMode = IO(Input(Bool()))
175+
val accReg = RegInit(false.B)
258176

259-
when(io.reset) {
260-
res := 0.U
177+
val dataValid = io.in_a.valid && io.in_b.valid
178+
179+
val busy = RegInit(false.B)
180+
181+
io.in_a.ready := !busy
182+
io.in_b.ready := !busy
183+
184+
val matrixAReg = RegInit(VecInit.fill(n)(VecInit.fill(n)(0.U(config.inputWidth.W))))
185+
val matrixBReg = RegInit(VecInit.fill(n)(VecInit.fill(n)(0.U(config.inputWidth.W))))
186+
187+
val sysmm = Module(new SystolicMM(n, gemmType))
188+
sysmm.io.reset := false.B
189+
for (i <- 0 until n) {
190+
sysmm.io.in_a(i) := 0.U
191+
sysmm.io.in_b(i) := 0.U
261192
}
262-
// this is the main computation part
263-
res := res + (io.in_h * io.in_v)
264193

265-
// inputs are delayed one cycle to next PEs
266-
io.out_h := RegNext(io.in_h)
267-
io.out_v := RegNext(io.in_v)
194+
when(dataValid) {
195+
for (i <- 0 until n) {
196+
for (j <- 0 until n) {
197+
matrixAReg(i)(j) := io.in_a.bits(i)(j)
198+
matrixBReg(i)(j) := io.in_b.bits(i)(j)
199+
}
200+
}
201+
busy := true.B
202+
}
268203

269-
io.out := res
204+
val resValid = RegInit(false.B)
205+
io.out.valid := resValid
206+
io.out.bits := sysmm.io.out
207+
208+
val cnt = Counter(3 * n)
209+
when(busy && cnt.value < (2 * n).U) {
210+
for (i <- 0 until n) {
211+
val temp = cnt.value >= i.U
212+
val p = Mux(temp, cnt.value - i.U, 0.U)
213+
when(temp && p < n.U) {
214+
sysmm.io.in_a(i) := matrixAReg(i)(p(log2Ceil(n) - 1, 0))
215+
sysmm.io.in_b(i) := matrixBReg(p(log2Ceil(n) - 1, 0))(i)
216+
}
217+
// debugLog(p"in_a${i}: ${sysmm.io.in_a(i)} in_b${i}: ${sysmm.io.in_b(i)}\t")
218+
}
219+
debugLog(p"\n")
220+
cnt.inc()
221+
}.elsewhen(busy && cnt.value < (3 * n - 1).U) {
222+
cnt.inc()
223+
}
224+
225+
when(cnt.value === (3 * n - 1).U) {
226+
resValid := true.B
227+
when(io.out.ready) {
228+
resValid := false.B
229+
busy := false.B
230+
cnt.reset()
231+
sysmm.io.reset := true.B
232+
}
233+
}
234+
235+
// debugLog(p"busy: ${busy} cnt: ${cnt.value}\n", LogLevel.DEBUG)
270236
}

src/test/scala/kernel/alu/GEMMTest.scala

Lines changed: 80 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ import chisel3._
44
import chiseltest._
55
import org.scalatest.Tag
66
import org.scalatest.flatspec.AnyFlatSpec
7+
import org.scalatest.ParallelTestExecution
78

8-
class GEMMTest extends AnyFlatSpec with ChiselScalatestTester {
9+
class GEMMTest extends AnyFlatSpec with ChiselScalatestTester with ParallelTestExecution {
910

1011
def mmul(a: Array[Array[Float]], b: Array[Array[Float]]): Array[Array[Float]] = {
1112
for (r <- a) yield {
@@ -104,6 +105,70 @@ class GEMMTest extends AnyFlatSpec with ChiselScalatestTester {
104105

105106
}
106107

108+
private def testGEMMFp(dut: GEMM) = {
109+
val n = dut.n
110+
111+
val arraySize = 10
112+
val matrixAArray = Array.tabulate(arraySize)(_ => matInit(n))
113+
val matrixBArray = Array.tabulate(arraySize)(_ => matInit(n))
114+
val matrixYArray = matrixAArray.zip(matrixBArray).map {
115+
case (a, b) => mmul(a, b)
116+
}
117+
118+
def checkresult(): List[Float] = {
119+
val ret = for (j <- 0 until n * n) yield {
120+
val out = java.lang.Float.intBitsToFloat(dut.io.out.bits(j).peekInt().toInt)
121+
print(f"${out}%.4f ")
122+
out.toFloat // litValue returns BigInt
123+
}
124+
println()
125+
ret.toList
126+
}
127+
128+
fork {
129+
var c = 0;
130+
while (c < arraySize) {
131+
if (dut.io.in_a.ready.peekBoolean() && dut.io.in_b.ready.peekBoolean()) {
132+
dut.io.in_a.valid.poke(true.B)
133+
dut.io.in_b.valid.poke(true.B)
134+
for (i <- 0 until n) {
135+
for (j <- 0 until n) {
136+
dut.io.in_a.bits(i)(j).poke(BigInt(java.lang.Float.floatToRawIntBits(matrixAArray(c)(i)(j)).toBinaryString, 2).U)
137+
dut.io.in_b.bits(i)(j).poke(BigInt(java.lang.Float.floatToRawIntBits(matrixBArray(c)(i)(j)).toBinaryString, 2).U)
138+
}
139+
}
140+
c += 1
141+
} else {
142+
dut.io.in_a.valid.poke(false.B)
143+
dut.io.in_b.valid.poke(false.B)
144+
}
145+
dut.clock.step()
146+
}
147+
}.fork {
148+
var resC = 0
149+
while (resC < arraySize) {
150+
if (dut.io.out.valid.peekBoolean()) {
151+
dut.io.out.ready.poke(true.B)
152+
val out = checkresult()
153+
var invalidcnt = 0
154+
for (i <- out.zip(matrixYArray(resC).flatten.toList)) {
155+
if (math.abs(i._1 - i._2) > precision) {
156+
println("Error: " + i._1 + " " + i._2)
157+
invalidcnt += 1
158+
}
159+
}
160+
if (invalidcnt == 0) println("GEMM Verification passed!")
161+
assert(invalidcnt == 0)
162+
resC += 1
163+
} else {
164+
dut.io.out.ready.poke(false.B)
165+
}
166+
dut.clock.step()
167+
}
168+
169+
}.join()
170+
171+
}
107172
private def testSystolicMM(dut: SystolicMM): Unit = {
108173
val n = dut.n
109174
val a = matInit(n)
@@ -161,7 +226,6 @@ class GEMMTest extends AnyFlatSpec with ChiselScalatestTester {
161226
val a = matInit(n)
162227
val b = matInit(n)
163228
val y = mmul(a, b)
164-
println(s"type: ${dut.gemmType}")
165229
printmat(a)
166230
printmat(b)
167231
printmat(y)
@@ -276,15 +340,24 @@ class GEMMTest extends AnyFlatSpec with ChiselScalatestTester {
276340

277341
"SystolicMM basic test on Verilator" should "pass" in {
278342
implicit val fxpConfig: DataTypeConfig = FxpConfig
279-
test(new SystolicMM(4, GEMMType.Fxp)).withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation))(testSystolicMM)
343+
test(new SystolicMM(4, GEMMDataType.Fxp))
344+
.withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation))(testSystolicMM)
280345
}
281346

282347
"SystolicMMFp basic test on Verilator" should "pass" in {
283348
implicit val fxpConfig: DataTypeConfig = Fp32Config
284-
test(new SystolicMM(4, GEMMType.Fp32))
349+
test(new SystolicMM(4, GEMMDataType.Fp32))
285350
.withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation))(testSystolicMMFp)
286351
}
287-
// "GeMM basic test on Verilator" should "pass" in {
288-
// test(new GEMM(6)).withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation))(testGEMM)
289-
// }
352+
"GeMM basic test on Verilator" should "pass" in {
353+
implicit val fxpConfig: DataTypeConfig = FxpConfig
354+
test(new GEMM(6, GEMMDataType.Fxp))
355+
.withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation))(testGEMM)
356+
}
357+
358+
"GeMMFp basic test on Verilator" should "pass" in {
359+
implicit val fxpConfig: DataTypeConfig = Fp32Config
360+
test(new GEMM(6, GEMMDataType.Fp32))
361+
.withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation))(testGEMMFp)
362+
}
290363
}

0 commit comments

Comments
 (0)