Skip to content

Commit ee711c5

Browse files
fix gemm precision
1 parent 01dd47b commit ee711c5

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

src/main/scala/kernel/alu/Gemm.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ import chisel3.util._
55
import kernel.utils.DebugLog
66

77
trait GEMMAccuracyConfig {
8-
val I: Int = 4
9-
val F: Int = 12
8+
val I: Int = 8
9+
val F: Int = 24
1010
}
1111

1212
class PEFxp extends Module with GEMMAccuracyConfig with DebugLog {

src/test/scala/kernel/alu/GEMMTest.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ class GEMMTest extends AnyFlatSpec with ChiselScalatestTester {
1212
for (c <- b.transpose) yield r.zip(c).map(Function.tupled(_ * _)).reduceLeft(_ + _)
1313
}
1414
}
15+
16+
val precision = 0.001f
1517
// n * n
1618
def matInit(n: Int): Array[Array[Float]] = {
1719
val rseed = System.currentTimeMillis().toInt
@@ -84,7 +86,7 @@ class GEMMTest extends AnyFlatSpec with ChiselScalatestTester {
8486
val out = checkresult()
8587
var invalidcnt = 0
8688
for (i <- out.zip(matrixYArray(resC).flatten.toList)) {
87-
if (math.abs(i._1 - i._2) > 0.0001) {
89+
if (math.abs(i._1 - i._2) > precision) {
8890
println("Error: " + i._1 + " " + i._2)
8991
invalidcnt += 1
9092
}
@@ -143,7 +145,7 @@ class GEMMTest extends AnyFlatSpec with ChiselScalatestTester {
143145

144146
var invalidcnt = 0
145147
for (i <- output.zip(y.flatten.toList)) {
146-
if (math.abs(i._1 - i._2) > 0.0001) {
148+
if (math.abs(i._1 - i._2) > precision) {
147149
println("Error: " + i._1 + " " + i._2)
148150
invalidcnt += 1
149151
}

0 commit comments

Comments
 (0)