Skip to content

Commit 7210317

Browse files
remove unused FPU and add new GEMM input io
1 parent 4603e3d commit 7210317

File tree

4 files changed

+24
-204
lines changed

4 files changed

+24
-204
lines changed

src/main/scala/kernel/alu/FPU.scala

Lines changed: 0 additions & 166 deletions
This file was deleted.

src/main/scala/kernel/alu/Gemm.scala

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -75,27 +75,27 @@ class PEFp(width: Int = 32, size: Int = 4) extends Module with DebugLog {
7575
FCMAModule.io.fflags := DontCare
7676
}
7777

78-
trait DataTypeConfig {
78+
trait DataWidthConfig {
7979
def inputWidth: Int
8080
def outputWidth: Int
8181
}
8282

83-
case object FxpConfig extends DataTypeConfig with GEMMAccuracyConfig {
83+
case object FxpConfig extends DataWidthConfig with GEMMAccuracyConfig {
8484
def inputWidth: Int = I + F
8585
def outputWidth: Int = 2 * (I + F)
8686
}
8787

88-
case object Fp32Config extends DataTypeConfig {
88+
case object Fp32Config extends DataWidthConfig {
8989
def inputWidth: Int = 32
9090
def outputWidth: Int = 32
9191
}
9292

93-
case object Fp64Config extends DataTypeConfig {
93+
case object Fp64Config extends DataWidthConfig {
9494
def inputWidth: Int = 64
9595
def outputWidth: Int = 64
9696
}
9797

98-
class SystolicMM(val n: Int = 4, val gemmType: GEMMDataType.Type)(implicit config: DataTypeConfig)
98+
class SystolicMM(val n: Int = 4, val gemmType: GEMMDataType.Type)(implicit config: DataWidthConfig)
9999
extends Module
100100
with GEMMAccuracyConfig
101101
with DebugLog {
@@ -112,7 +112,7 @@ class SystolicMM(val n: Int = 4, val gemmType: GEMMDataType.Type)(implicit confi
112112
case GEMMDataType.Fxp => Module(new PEFxp).io
113113
case GEMMDataType.Fp32 => Module(new PEFp(config.inputWidth)).io
114114
case GEMMDataType.Fp64 => Module(new PEFp(config.inputWidth)).io
115-
case _ => throw new IllegalArgumentException("Unsupported GEMM type")
115+
case _ => throw new IllegalArgumentException("Unsupported GEMM type")
116116
}
117117
})
118118

@@ -158,14 +158,14 @@ class SystolicMM(val n: Int = 4, val gemmType: GEMMDataType.Type)(implicit confi
158158
}
159159

160160
// Compute A * B, where A and B are both square matrix.
161-
class GEMM(val n: Int = 4, val gemmType: GEMMDataType.Type)(implicit config: DataTypeConfig)
161+
class GEMM(val n: Int = 4, val gemmType: GEMMDataType.Type)(implicit config: DataWidthConfig)
162162
extends Module
163163
with GEMMAccuracyConfig
164164
with DebugLog {
165165

166166
val io = IO(new Bundle {
167-
val in_a = Flipped(Decoupled(Vec(n, Vec(n, UInt(config.inputWidth.W)))))
168-
val in_b = Flipped(Decoupled(Vec(n, Vec(n, UInt(config.inputWidth.W)))))
167+
val in_a = Flipped(Decoupled(Vec(n * n, UInt(config.inputWidth.W))))
168+
val in_b = Flipped(Decoupled(Vec(n * n, UInt(config.inputWidth.W))))
169169
val out = Decoupled(Vec(n * n, UInt(config.outputWidth.W)))
170170
val reset = Input(Bool())
171171
})
@@ -194,8 +194,8 @@ class GEMM(val n: Int = 4, val gemmType: GEMMDataType.Type)(implicit config: Dat
194194
when(dataValid) {
195195
for (i <- 0 until n) {
196196
for (j <- 0 until n) {
197-
matrixAReg(i)(j) := io.in_a.bits(i)(j)
198-
matrixBReg(i)(j) := io.in_b.bits(i)(j)
197+
matrixAReg(i)(j) := io.in_a.bits(i * n + j)
198+
matrixBReg(i)(j) := io.in_b.bits(i * n + j)
199199
}
200200
}
201201
busy := true.B

src/test/scala/kernel/alu/FPUTest.scala

Lines changed: 0 additions & 19 deletions
This file was deleted.

src/test/scala/kernel/alu/GEMMTest.scala

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ class GEMMTest extends AnyFlatSpec with ChiselScalatestTester with ParallelTestE
6868
dut.io.in_b.valid.poke(true.B)
6969
for (i <- 0 until n) {
7070
for (j <- 0 until n) {
71-
dut.io.in_a.bits(i)(j).poke(doubleToFixedPoint(matrixAArray(c)(i)(j), dut.I, dut.F))
72-
dut.io.in_b.bits(i)(j).poke(doubleToFixedPoint(matrixBArray(c)(i)(j), dut.I, dut.F))
71+
dut.io.in_a.bits(i * n + j).poke(doubleToFixedPoint(matrixAArray(c)(i)(j), dut.I, dut.F))
72+
dut.io.in_b.bits(i * n + j).poke(doubleToFixedPoint(matrixBArray(c)(i)(j), dut.I, dut.F))
7373
}
7474
}
7575
c += 1
@@ -133,8 +133,12 @@ class GEMMTest extends AnyFlatSpec with ChiselScalatestTester with ParallelTestE
133133
dut.io.in_b.valid.poke(true.B)
134134
for (i <- 0 until n) {
135135
for (j <- 0 until n) {
136-
dut.io.in_a.bits(i)(j).poke(BigInt(java.lang.Float.floatToRawIntBits(matrixAArray(c)(i)(j)).toBinaryString, 2).U)
137-
dut.io.in_b.bits(i)(j).poke(BigInt(java.lang.Float.floatToRawIntBits(matrixBArray(c)(i)(j)).toBinaryString, 2).U)
136+
dut.io.in_a
137+
.bits(i * n + j)
138+
.poke(BigInt(java.lang.Float.floatToRawIntBits(matrixAArray(c)(i)(j)).toBinaryString, 2).U)
139+
dut.io.in_b
140+
.bits(i * n + j)
141+
.poke(BigInt(java.lang.Float.floatToRawIntBits(matrixBArray(c)(i)(j)).toBinaryString, 2).U)
138142
}
139143
}
140144
c += 1
@@ -339,24 +343,25 @@ class GEMMTest extends AnyFlatSpec with ChiselScalatestTester with ParallelTestE
339343
// }
340344

341345
"SystolicMM basic test on Verilator" should "pass" in {
342-
implicit val fxpConfig: DataTypeConfig = FxpConfig
346+
implicit val fxpConfig: DataWidthConfig = FxpConfig
343347
test(new SystolicMM(4, GEMMDataType.Fxp))
344348
.withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation))(testSystolicMM)
345349
}
346350

347351
"SystolicMMFp basic test on Verilator" should "pass" in {
348-
implicit val fxpConfig: DataTypeConfig = Fp32Config
352+
implicit val fxpConfig: DataWidthConfig = Fp32Config
349353
test(new SystolicMM(4, GEMMDataType.Fp32))
350354
.withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation))(testSystolicMMFp)
351355
}
356+
352357
"GeMM basic test on Verilator" should "pass" in {
353-
implicit val fxpConfig: DataTypeConfig = FxpConfig
358+
implicit val fxpConfig: DataWidthConfig = FxpConfig
354359
test(new GEMM(6, GEMMDataType.Fxp))
355360
.withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation))(testGEMM)
356361
}
357362

358363
"GeMMFp basic test on Verilator" should "pass" in {
359-
implicit val fxpConfig: DataTypeConfig = Fp32Config
364+
implicit val fxpConfig: DataWidthConfig = Fp32Config
360365
test(new GEMM(6, GEMMDataType.Fp32))
361366
.withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation))(testGEMMFp)
362367
}

0 commit comments

Comments
 (0)