@@ -75,27 +75,27 @@ class PEFp(width: Int = 32, size: Int = 4) extends Module with DebugLog {
7575 FCMAModule .io.fflags := DontCare
7676}
7777
78- trait DataTypeConfig {
78+ trait DataWidthConfig {
7979 def inputWidth : Int
8080 def outputWidth : Int
8181}
8282
83- case object FxpConfig extends DataTypeConfig with GEMMAccuracyConfig {
83+ case object FxpConfig extends DataWidthConfig with GEMMAccuracyConfig {
8484 def inputWidth : Int = I + F
8585 def outputWidth : Int = 2 * (I + F )
8686}
8787
88- case object Fp32Config extends DataTypeConfig {
88+ case object Fp32Config extends DataWidthConfig {
8989 def inputWidth : Int = 32
9090 def outputWidth : Int = 32
9191}
9292
93- case object Fp64Config extends DataTypeConfig {
93+ case object Fp64Config extends DataWidthConfig {
9494 def inputWidth : Int = 64
9595 def outputWidth : Int = 64
9696}
9797
98- class SystolicMM (val n : Int = 4 , val gemmType : GEMMDataType .Type )(implicit config : DataTypeConfig )
98+ class SystolicMM (val n : Int = 4 , val gemmType : GEMMDataType .Type )(implicit config : DataWidthConfig )
9999 extends Module
100100 with GEMMAccuracyConfig
101101 with DebugLog {
@@ -112,7 +112,7 @@ class SystolicMM(val n: Int = 4, val gemmType: GEMMDataType.Type)(implicit confi
112112 case GEMMDataType .Fxp => Module (new PEFxp ).io
113113 case GEMMDataType .Fp32 => Module (new PEFp (config.inputWidth)).io
114114 case GEMMDataType .Fp64 => Module (new PEFp (config.inputWidth)).io
115- case _ => throw new IllegalArgumentException (" Unsupported GEMM type" )
115+ case _ => throw new IllegalArgumentException (" Unsupported GEMM type" )
116116 }
117117 })
118118
@@ -158,14 +158,14 @@ class SystolicMM(val n: Int = 4, val gemmType: GEMMDataType.Type)(implicit confi
158158}
159159
160160// Compute A * B, where A and B are both square matrix.
161- class GEMM (val n : Int = 4 , val gemmType : GEMMDataType .Type )(implicit config : DataTypeConfig )
161+ class GEMM (val n : Int = 4 , val gemmType : GEMMDataType .Type )(implicit config : DataWidthConfig )
162162 extends Module
163163 with GEMMAccuracyConfig
164164 with DebugLog {
165165
166166 val io = IO (new Bundle {
167- val in_a = Flipped (Decoupled (Vec (n, Vec ( n, UInt (config.inputWidth.W ) ))))
168- val in_b = Flipped (Decoupled (Vec (n, Vec ( n, UInt (config.inputWidth.W ) ))))
167+ val in_a = Flipped (Decoupled (Vec (n * n, UInt (config.inputWidth.W ))))
168+ val in_b = Flipped (Decoupled (Vec (n * n, UInt (config.inputWidth.W ))))
169169 val out = Decoupled (Vec (n * n, UInt (config.outputWidth.W )))
170170 val reset = Input (Bool ())
171171 })
@@ -194,8 +194,8 @@ class GEMM(val n: Int = 4, val gemmType: GEMMDataType.Type)(implicit config: Dat
194194 when(dataValid) {
195195 for (i <- 0 until n) {
196196 for (j <- 0 until n) {
197- matrixAReg(i)(j) := io.in_a.bits(i)( j)
198- matrixBReg(i)(j) := io.in_b.bits(i)( j)
197+ matrixAReg(i)(j) := io.in_a.bits(i * n + j)
198+ matrixBReg(i)(j) := io.in_b.bits(i * n + j)
199199 }
200200 }
201201 busy := true .B
0 commit comments