@@ -21,7 +21,7 @@ case class FPConfig(width: Int) {
2121 fpParams.getOrElse(width, throw new IllegalArgumentException (s " Unsupported floating point width: $width" ))
2222}
2323
24- object GEMMType extends ChiselEnum {
24+ object GEMMDataType extends ChiselEnum {
2525 // UInt, FixedPoint, FloatPoint(32), FloatPoint(64)
2626 val UI, Fxp, Fp32, Fp64 = Value
2727}
@@ -75,84 +75,6 @@ class PEFp(width: Int = 32, size: Int = 4) extends Module with DebugLog {
7575 FCMAModule .io.fflags := DontCare
7676}
7777
78- // Compute A * B, where A and B are both square matrix.
79- class GEMM (val n : Int = 4 , val gemmType : GEMMType .Type )(implicit config : DataTypeConfig )
80- extends Module
81- with GEMMAccuracyConfig
82- with DebugLog {
83-
84- val io = IO (new Bundle {
85- val in_a = Flipped (Decoupled (Vec (n, Vec (n, UInt (config.inputWidth.W )))))
86- val in_b = Flipped (Decoupled (Vec (n, Vec (n, UInt (config.inputWidth.W )))))
87- val out = Decoupled (Vec (n * n, UInt (config.outputWidth.W )))
88- val reset = Input (Bool ())
89- })
90-
91- // accumulate mode
92- val accMode = IO (Input (Bool ()))
93- val accReg = RegInit (false .B )
94-
95- val dataValid = io.in_a.valid && io.in_b.valid
96-
97- val busy = RegInit (false .B )
98-
99- io.in_a.ready := ! busy
100- io.in_b.ready := ! busy
101-
102- val matrixAReg = RegInit (VecInit .fill(n)(VecInit .fill(n)(0 .U (config.inputWidth.W ))))
103- val matrixBReg = RegInit (VecInit .fill(n)(VecInit .fill(n)(0 .U (config.inputWidth.W ))))
104-
105- val sysmm = Module (new SystolicMM (n, gemmType))
106- sysmm.io.reset := false .B
107- for (i <- 0 until n) {
108- sysmm.io.in_a(i) := 0 .U
109- sysmm.io.in_b(i) := 0 .U
110- }
111-
112- when(dataValid) {
113- for (i <- 0 until n) {
114- for (j <- 0 until n) {
115- matrixAReg(i)(j) := io.in_a.bits(i)(j)
116- matrixBReg(i)(j) := io.in_b.bits(i)(j)
117- }
118- }
119- busy := true .B
120- }
121-
122- val resValid = RegInit (false .B )
123- io.out.valid := resValid
124- io.out.bits := sysmm.io.out
125-
126- val cnt = Counter (3 * n)
127- when(busy && cnt.value < (2 * n).U ) {
128- for (i <- 0 until n) {
129- val temp = cnt.value >= i.U
130- val p = Mux (temp, cnt.value - i.U , 0 .U )
131- when(temp && p < n.U ) {
132- sysmm.io.in_a(i) := matrixAReg(i)(p(log2Ceil(n) - 1 , 0 ))
133- sysmm.io.in_b(i) := matrixBReg(p(log2Ceil(n) - 1 , 0 ))(i)
134- }
135- // debugLog(p"in_a${i}: ${sysmm.io.in_a(i)} in_b${i}: ${sysmm.io.in_b(i)}\t")
136- }
137- debugLog(p " \n " )
138- cnt.inc()
139- }.elsewhen(busy && cnt.value < (3 * n - 1 ).U ) {
140- cnt.inc()
141- }
142-
143- when(cnt.value === (3 * n - 1 ).U ) {
144- resValid := true .B
145- when(io.out.ready) {
146- resValid := false .B
147- busy := false .B
148- cnt.reset()
149- sysmm.io.reset := true .B
150- }
151- }
152-
153- // debugLog(p"busy: ${busy} cnt: ${cnt.value}\n", LogLevel.DEBUG)
154- }
155-
15678trait DataTypeConfig {
15779 def inputWidth : Int
15880 def outputWidth : Int
@@ -173,7 +95,7 @@ case object Fp64Config extends DataTypeConfig {
17395 def outputWidth : Int = 64
17496}
17597
176- class SystolicMM (val n : Int = 4 , val gemmType : GEMMType .Type )(implicit config : DataTypeConfig )
98+ class SystolicMM (val n : Int = 4 , val gemmType : GEMMDataType .Type )(implicit config : DataTypeConfig )
17799 extends Module
178100 with GEMMAccuracyConfig
179101 with DebugLog {
@@ -187,16 +109,14 @@ class SystolicMM(val n: Int = 4, val gemmType: GEMMType.Type)(implicit config: D
187109
188110 val peElements = VecInit (Seq .fill(n * n) {
189111 gemmType match {
190- case GEMMType .Fxp => Module (new PEFxp ).io
191- case GEMMType .Fp32 => Module (new PEFp (config.inputWidth)).io
192- case GEMMType .Fp64 => Module (new PEFp (config.inputWidth)).io
112+ case GEMMDataType .Fxp => Module (new PEFxp ).io
113+ case GEMMDataType .Fp32 => Module (new PEFp (config.inputWidth)).io
114+ case GEMMDataType .Fp64 => Module (new PEFp (config.inputWidth)).io
193115 case _ => throw new IllegalArgumentException (" Unsupported GEMM type" )
194116 }
195117 })
196118
197- for (i <- 0 until n * n) {
198- peElements(i).reset := io.reset
199- }
119+ peElements.foreach(_.reset := io.reset)
200120
201121 val h_wires = Wire (Vec ((n - 1 ) * n, UInt (config.inputWidth.W )))
202122 val v_wires = Wire (Vec (n * (n - 1 ), UInt (config.inputWidth.W )))
@@ -237,34 +157,80 @@ class SystolicMM(val n: Int = 4, val gemmType: GEMMType.Type)(implicit config: D
237157 }
238158}
239159
240- // each ProcElem (PE) is mapped to each element in a NxN output matrix
241- class ProcElem (val bits : Int = 8 ) extends Module {
242- val io = IO (new Bundle {
243- // input from horizontal direction
244- val in_h = Input (UInt (bits.W ))
245- // input from vertical direction
246- val in_v = Input (UInt (bits.W ))
247- // output to horizontal direction
248- val out_h = Output (UInt ((bits * 2 ).W ))
249- // output to vertical direction
250- val out_v = Output (UInt ((bits * 2 ).W ))
251- // the result after N cycles once this receives the first actual data
252- val out = Output (UInt ((bits * 2 ).W ))
160+ // Compute A * B, where A and B are both square matrix.
161+ class GEMM (val n : Int = 4 , val gemmType : GEMMDataType .Type )(implicit config : DataTypeConfig )
162+ extends Module
163+ with GEMMAccuracyConfig
164+ with DebugLog {
253165
166+ val io = IO (new Bundle {
167+ val in_a = Flipped (Decoupled (Vec (n, Vec (n, UInt (config.inputWidth.W )))))
168+ val in_b = Flipped (Decoupled (Vec (n, Vec (n, UInt (config.inputWidth.W )))))
169+ val out = Decoupled (Vec (n * n, UInt (config.outputWidth.W )))
254170 val reset = Input (Bool ())
255171 })
256172
257- val res = RegInit (0 .U ((bits * 2 ).W ))
173+ // accumulate mode
174+ val accMode = IO (Input (Bool ()))
175+ val accReg = RegInit (false .B )
258176
259- when(io.reset) {
260- res := 0 .U
177+ val dataValid = io.in_a.valid && io.in_b.valid
178+
179+ val busy = RegInit (false .B )
180+
181+ io.in_a.ready := ! busy
182+ io.in_b.ready := ! busy
183+
184+ val matrixAReg = RegInit (VecInit .fill(n)(VecInit .fill(n)(0 .U (config.inputWidth.W ))))
185+ val matrixBReg = RegInit (VecInit .fill(n)(VecInit .fill(n)(0 .U (config.inputWidth.W ))))
186+
187+ val sysmm = Module (new SystolicMM (n, gemmType))
188+ sysmm.io.reset := false .B
189+ for (i <- 0 until n) {
190+ sysmm.io.in_a(i) := 0 .U
191+ sysmm.io.in_b(i) := 0 .U
261192 }
262- // this is the main computation part
263- res := res + (io.in_h * io.in_v)
264193
265- // inputs are delayed one cycle to next PEs
266- io.out_h := RegNext (io.in_h)
267- io.out_v := RegNext (io.in_v)
194+ when(dataValid) {
195+ for (i <- 0 until n) {
196+ for (j <- 0 until n) {
197+ matrixAReg(i)(j) := io.in_a.bits(i)(j)
198+ matrixBReg(i)(j) := io.in_b.bits(i)(j)
199+ }
200+ }
201+ busy := true .B
202+ }
268203
269- io.out := res
204+ val resValid = RegInit (false .B )
205+ io.out.valid := resValid
206+ io.out.bits := sysmm.io.out
207+
208+ val cnt = Counter (3 * n)
209+ when(busy && cnt.value < (2 * n).U ) {
210+ for (i <- 0 until n) {
211+ val temp = cnt.value >= i.U
212+ val p = Mux (temp, cnt.value - i.U , 0 .U )
213+ when(temp && p < n.U ) {
214+ sysmm.io.in_a(i) := matrixAReg(i)(p(log2Ceil(n) - 1 , 0 ))
215+ sysmm.io.in_b(i) := matrixBReg(p(log2Ceil(n) - 1 , 0 ))(i)
216+ }
217+ // debugLog(p"in_a${i}: ${sysmm.io.in_a(i)} in_b${i}: ${sysmm.io.in_b(i)}\t")
218+ }
219+ debugLog(p " \n " )
220+ cnt.inc()
221+ }.elsewhen(busy && cnt.value < (3 * n - 1 ).U ) {
222+ cnt.inc()
223+ }
224+
225+ when(cnt.value === (3 * n - 1 ).U ) {
226+ resValid := true .B
227+ when(io.out.ready) {
228+ resValid := false .B
229+ busy := false .B
230+ cnt.reset()
231+ sysmm.io.reset := true .B
232+ }
233+ }
234+
235+ // debugLog(p"busy: ${busy} cnt: ${cnt.value}\n", LogLevel.DEBUG)
270236}
0 commit comments