Skip to content

Commit

Permalink
incorporate sirius recoding skips in the gemmini mesh
Browse files Browse the repository at this point in the history
  • Loading branch information
vikramjain236 authored and richardyrh committed Nov 4, 2024
1 parent 0328277 commit 2916cde
Show file tree
Hide file tree
Showing 10 changed files with 66 additions and 36 deletions.
15 changes: 14 additions & 1 deletion src/main/scala/gemmini/AccumulatorScale.scala
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ object AccumulatorScale {
}

def iexp[T <: Data](q: T, qln2: T, qln2_inv: T, qb: T, qc: T)(implicit ev: Arithmetic[T]): T = {
/*
import ev._
val zero = q.zero
Expand All @@ -406,6 +407,18 @@ object AccumulatorScale {
val q_poly_iexp = qc.mac(qp_iexp + qb, qp_iexp + qb).withWidthOf(q)
// we dont want a rounding shift
// TODO: z overflow
(q_poly_iexp.asUInt.do_>>(z_iexp_saturated.asUInt)).asTypeOf(q)
(q_poly_iexp.asUInt.do_>>(z_iexp_saturated.asUInt)).asTypeOf(q) */
import ev._

val zero = q.zero
val one = q.identity
def neg(x: T) = zero-x

val q_sign = Mux(q.zero > q, neg(one), one)
val q_abs = Mux(q.zero > q, neg(q), q)
val q_clipped = Mux(q_abs > neg(qb), neg(qb), q_abs)
val q_poly = qc.mac(q_clipped + qb, q_clipped + qb).withWidthOf(q)
val q_erf = (q_sign * q_poly).withWidthOf(q)
(q * (q_erf + qc)).withWidthOf(q)
}}

6 changes: 6 additions & 0 deletions src/main/scala/gemmini/Configs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@ object GemminiConfigs {
val defaultConfig = GemminiArrayConfig[SInt, Float, Float](
// Datatypes
inputType = SInt(8.W),
weightType = SInt(8.W),
accType = SInt(32.W),

spatialArrayInputType = SInt(8.W),
spatialArrayWeightType = SInt(8.W),
spatialArrayOutputType = SInt(20.W),

// Spatial array size options
Expand Down Expand Up @@ -165,7 +168,10 @@ object GemminiConfigs {

val dummyConfig = GemminiArrayConfig[DummySInt, Float, Float](
inputType = DummySInt(8),
weightType = DummySInt(8),
accType = DummySInt(32),
spatialArrayInputType = DummySInt(8),
spatialArrayWeightType = DummySInt(8),
spatialArrayOutputType = DummySInt(20),
tileRows = defaultConfig.tileRows,
tileColumns = defaultConfig.tileColumns,
Expand Down
14 changes: 9 additions & 5 deletions src/main/scala/gemmini/ConfigsFP.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,13 @@ object GemminiFPConfigs {
use_dedicated_tl_port = false,
use_shared_ext_mem = false,
inputType = Float(8, 24),
spatialArrayOutputType = Float(8, 24),
weightType = Float(8, 24),
accType = Float(8, 24),

spatialArrayInputType = Float(8, 24),
spatialArrayWeightType = Float(8, 24),
spatialArrayOutputType = Float(8, 24),

mvin_scale_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")),
mvin_scale_acc_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")),
mvin_scale_shared = false,
Expand Down Expand Up @@ -81,28 +85,28 @@ object GemminiFPConfigs {
)

//FP32 Single Precision Configuration
val FP32DefaultConfig = defaultFPConfig.copy(inputType = Float(8, 24), spatialArrayOutputType = Float(8, 24), accType = Float(8, 24),
val FP32DefaultConfig = defaultFPConfig.copy(inputType = Float(8, 24), weightType = Float(8, 24), accType = Float(8, 24), spatialArrayInputType = Float(8, 24), spatialArrayWeightType = Float(8, 24), spatialArrayOutputType = Float(8, 24),
tile_latency = 2,
mvin_scale_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")),
mvin_scale_acc_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")),
)

//FP16 Half Precision Configuration
val FP16DefaultConfig = defaultFPConfig.copy(inputType = Float(5, 11), spatialArrayOutputType = Float(5, 11), accType = Float(8, 24),
val FP16DefaultConfig = defaultFPConfig.copy(inputType = Float(5, 11), weightType = Float(5, 11), accType = Float(8, 24), spatialArrayInputType = Float(5, 11), spatialArrayWeightType = Float(5, 11), spatialArrayOutputType = Float(5, 11),
tile_latency = 2,
mvin_scale_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(5, 11), -1, identity = "1.0", c_str="((x) * (scale))")),
mvin_scale_acc_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(5, 11), -1, identity = "1.0", c_str="((x) * (scale))")),
)

//Bfloat16 Brain-half Precision Configuration
val BF16DefaultConfig = defaultFPConfig.copy(inputType = Float(8, 8), spatialArrayOutputType = Float(8, 8), accType = Float(8, 24),
val BF16DefaultConfig = defaultFPConfig.copy(inputType = Float(8, 8), weightType = Float(8, 8), accType = Float(8, 24), spatialArrayInputType = Float(8, 8), spatialArrayWeightType = Float(8, 8), spatialArrayOutputType = Float(8, 8),
tile_latency = 2,
mvin_scale_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")),
mvin_scale_acc_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")),
)

//Bfloat16 Brain-half Precision Configuration 8x8 array
val BF16Default8Config = defaultFPConfig.copy(inputType = Float(8, 8), spatialArrayOutputType = Float(8, 8), accType = Float(8, 24),
val BF16Default8Config = defaultFPConfig.copy(inputType = Float(8, 8), weightType = Float(8, 8), accType = Float(8, 24), spatialArrayInputType = Float(8, 8), spatialArrayWeightType = Float(8, 8), spatialArrayOutputType = Float(8, 8),
meshRows = 8, meshColumns = 8,
tile_latency = 2,
mvin_scale_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")),
Expand Down
5 changes: 4 additions & 1 deletion src/main/scala/gemmini/DSEConfigs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,11 @@ object DSEBaseConfig {
dma_buswidth = 128, // TODO get this from SystemBusKey
aligned_to = 16,
inputType = SInt(8.W),
spatialArrayOutputType = SInt(19.W),
weightType = SInt(8.W),
accType = SInt(32.W),
spatialArrayInputType = SInt(8.W),
spatialArrayWeightType = SInt(8.W),
spatialArrayOutputType = SInt(19.W),
mvin_scale_args = None,
mvin_scale_acc_args = None,
mvin_scale_shared = false,
Expand Down
14 changes: 7 additions & 7 deletions src/main/scala/gemmini/ExecuteController.scala
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
val cntl = mesh_cntl_signals_q.io.deq.bits

// Instantiate the actual mesh
val mesh = Module(new MeshWithDelays(inputType, spatialArrayOutputType, accType, mesh_tag, dataflow, tree_reduction, tile_latency, mesh_output_delay,
val mesh = Module(new MeshWithDelays(spatialArrayInputType, spatialArrayWeightType, spatialArrayOutputType, accType, mesh_tag, dataflow, tree_reduction, tile_latency, mesh_output_delay,
tileRows, tileColumns, meshRows, meshColumns, shifter_banks, shifter_banks))

mesh.io.a.valid := false.B
Expand Down Expand Up @@ -833,9 +833,9 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
val dataB_unpadded = MuxCase(readData(cntl.b_bank), Seq(cntl.accumulate_zeros -> 0.U, cntl.b_read_from_acc -> accReadData(cntl.b_bank_acc)))
val dataD_unpadded = MuxCase(readData(cntl.d_bank), Seq(cntl.preload_zeros -> 0.U, cntl.d_read_from_acc -> accReadData(cntl.d_bank_acc)))

val dataA = VecInit(dataA_unpadded.asTypeOf(Vec(block_size, inputType)).zipWithIndex.map { case (d, i) => Mux(i.U < cntl.a_unpadded_cols, d, inputType.zero)})
val dataB = VecInit(dataB_unpadded.asTypeOf(Vec(block_size, inputType)).zipWithIndex.map { case (d, i) => Mux(i.U < cntl.b_unpadded_cols, d, inputType.zero)})
val dataD = VecInit(dataD_unpadded.asTypeOf(Vec(block_size, inputType)).zipWithIndex.map { case (d, i) => Mux(i.U < cntl.d_unpadded_cols, d, inputType.zero)})
val dataA = VecInit(dataA_unpadded.asTypeOf(Vec(block_size, inputType)).zipWithIndex.map { case (d, i) => Mux(i.U < cntl.a_unpadded_cols, d, inputType.zero)}.map(d => d.asTypeOf(inputType).withWidthOf(spatialArrayInputType)))
val dataB = VecInit(dataB_unpadded.asTypeOf(Vec(block_size, inputType)).zipWithIndex.map { case (d, i) => Mux(i.U < cntl.b_unpadded_cols, d, inputType.zero)}.map(d => d.asTypeOf(inputType).withWidthOf(spatialArrayWeightType)))
val dataD = VecInit(dataD_unpadded.asTypeOf(Vec(block_size, inputType)).zipWithIndex.map { case (d, i) => Mux(i.U < cntl.d_unpadded_cols, d, inputType.zero)}.map(d => d.asTypeOf(inputType).withWidthOf(spatialArrayWeightType)))

// Pop responses off the scratchpad io ports
when (mesh_cntl_signals_q.io.deq.fire) {
Expand Down Expand Up @@ -879,9 +879,9 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
mesh.io.b.valid := cntl.b_fire && dataB_valid && n_a_r_not_v && n_d_r_not_v
mesh.io.d.valid := cntl.d_fire && dataD_valid && n_a_r_not_v && n_b_r_not_v

mesh.io.a.bits := dataA.asTypeOf(Vec(meshRows, Vec(tileRows, inputType)))
mesh.io.b.bits := dataB.asTypeOf(Vec(meshColumns, Vec(tileColumns, inputType)))
mesh.io.d.bits := dataD.asTypeOf(Vec(meshColumns, Vec(tileColumns, inputType)))
mesh.io.a.bits := dataA.asTypeOf(Vec(meshRows, Vec(tileRows, spatialArrayInputType)))
mesh.io.b.bits := dataB.asTypeOf(Vec(meshColumns, Vec(tileColumns, spatialArrayWeightType)))
mesh.io.d.bits := dataD.asTypeOf(Vec(meshColumns, Vec(tileColumns, spatialArrayWeightType)))

mesh.io.req.valid := mesh_cntl_signals_q.io.deq.fire && (cntl.a_fire || cntl.b_fire || cntl.d_fire)

Expand Down
6 changes: 5 additions & 1 deletion src/main/scala/gemmini/GemminiConfigs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data](
opcodes: OpcodeSet = OpcodeSet.custom3,

inputType: T,
spatialArrayOutputType: T,
weightType: T,
accType: T,
spatialArrayInputType: T,
spatialArrayWeightType: T,
spatialArrayOutputType: T,

dataflow: Dataflow.Value = Dataflow.BOTH,

Expand Down Expand Up @@ -98,6 +101,7 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data](

headerFileName: String = "gemmini_params.h"
) {
require(inputType.getWidth == weightType.getWidth)
val sp_width = meshColumns * tileColumns * inputType.getWidth
val sp_bank_entries = sp_capacity match {
case CapacityInKilobytes(kb) => kb * 1024 * 8 / (sp_banks * sp_width)
Expand Down
8 changes: 4 additions & 4 deletions src/main/scala/gemmini/Mesh.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ import chisel3.experimental._
* @param meshRows
* @param meshColumns
*/
class Mesh[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T,
class Mesh[T <: Data : Arithmetic](inputType: T, weightType: T, outputType: T, accType: T,
df: Dataflow.Value, tree_reduction: Boolean, tile_latency: Int,
max_simultaneous_matmuls: Int, output_delay: Int,
val tileRows: Int, val tileColumns: Int,
val meshRows: Int, val meshColumns: Int) extends Module {
val io = IO(new Bundle {
val in_a = Input(Vec(meshRows, Vec(tileRows, inputType)))
val in_b = Input(Vec(meshColumns, Vec(tileColumns, inputType)))
val in_d = Input(Vec(meshColumns, Vec(tileColumns, inputType)))
val in_b = Input(Vec(meshColumns, Vec(tileColumns, weightType)))
val in_d = Input(Vec(meshColumns, Vec(tileColumns, weightType))) // TODO should this be weightType, inputType, or something like max(inputType, weightType)?
val in_control = Input(Vec(meshColumns, Vec(tileColumns, new PEControl(accType))))
val in_id = Input(Vec(meshColumns, Vec(tileColumns, UInt(log2Up(max_simultaneous_matmuls).W)))) // The unique id of this particular matmul
val in_last = Input(Vec(meshColumns, Vec(tileColumns, Bool())))
Expand All @@ -36,7 +36,7 @@ class Mesh[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T,
})

// mesh(r)(c) => Tile at row r, column c
val mesh: Seq[Seq[Tile[T]]] = Seq.fill(meshRows, meshColumns)(Module(new Tile(inputType, outputType, accType, df, tree_reduction, max_simultaneous_matmuls, tileRows, tileColumns)))
val mesh: Seq[Seq[Tile[T]]] = Seq.fill(meshRows, meshColumns)(Module(new Tile(inputType, weightType, outputType, accType, df, tree_reduction, max_simultaneous_matmuls, tileRows, tileColumns)))
val meshT = mesh.transpose

def pipe[T <: Data](valid: Bool, t: T, latency: Int): T = {
Expand Down
12 changes: 6 additions & 6 deletions src/main/scala/gemmini/MeshWithDelays.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,16 @@ class MeshWithDelaysResp[T <: Data: Arithmetic, TagT <: TagQueueTag with Data](o
// TODO make all inputs go straight into registers to help with physical design

class MeshWithDelays[T <: Data: Arithmetic, U <: TagQueueTag with Data]
(inputType: T, val outputType: T, accType: T,
(val inputType: T, val weightType: T, val outputType: T, accType: T,
tagType: U, df: Dataflow.Value, tree_reduction: Boolean, tile_latency: Int, output_delay: Int,
tileRows: Int, tileColumns: Int, meshRows: Int, meshColumns: Int,
leftBanks: Int, upBanks: Int, outBanks: Int = 1, n_simultaneous_matmuls: Int = -1)
extends Module {

val A_TYPE = Vec(meshRows, Vec(tileRows, inputType))
val B_TYPE = Vec(meshColumns, Vec(tileColumns, inputType))
val C_TYPE = Vec(meshColumns, Vec(tileColumns, outputType))
val D_TYPE = Vec(meshColumns, Vec(tileColumns, inputType))
val B_TYPE = Vec(meshColumns, Vec(tileColumns, weightType)) // TODO should this be weightType, inputType, or something like max(inputType, weightType)?
val C_TYPE = Vec(meshColumns, Vec(tileColumns, outputType))
val D_TYPE = Vec(meshColumns, Vec(tileColumns, weightType)) // TODO should this be weightType, inputType, or something like max(inputType, weightType)?
val S_TYPE = Vec(meshColumns, Vec(tileColumns, new PEControl(accType)))

assert(meshRows*tileRows == meshColumns*tileColumns)
Expand Down Expand Up @@ -67,7 +67,7 @@ class MeshWithDelays[T <: Data: Arithmetic, U <: TagQueueTag with Data]
val tags_in_progress = Output(Vec(tagqlen, tagType))
})

def shifted[T <: Data](x: Vec[Vec[T]], banks: Int, reverse: Boolean = false) = {
def shifted[T <: Data](x: Vec[Vec[T]], banks: Int, reverse: Boolean = false): Seq[Vec[T]] = {
assert(x.size % banks == 0, "cannot bank without clean divisors")

val banked_len = x.size / banks
Expand Down Expand Up @@ -164,7 +164,7 @@ class MeshWithDelays[T <: Data: Arithmetic, U <: TagQueueTag with Data]
val transposer_out = VecInit(transposer.io.outCol.bits.grouped(tileRows).map(t => VecInit(t)).toSeq)

// Wire up mesh's IO to this module's IO
val mesh = Module(new Mesh(inputType, outputType, accType, df, tree_reduction, tile_latency, max_simultaneous_matmuls, output_delay, tileRows, tileColumns, meshRows, meshColumns))
val mesh = Module(new Mesh(inputType, weightType, outputType, accType, df, tree_reduction, tile_latency, max_simultaneous_matmuls, output_delay, tileRows, tileColumns, meshRows, meshColumns))

// TODO wire only to *_buf here, instead of io.*.bits
val a_shifter_in = WireInit(Mux(a_is_from_transposer, transposer_out.asTypeOf(A_TYPE), a_buf))
Expand Down
18 changes: 9 additions & 9 deletions src/main/scala/gemmini/PE.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ class PEControl[T <: Data : Arithmetic](accType: T) extends Bundle {

}

class MacUnit[T <: Data](inputType: T, cType: T, dType: T) (implicit ev: Arithmetic[T]) extends Module {
class MacUnit[T <: Data](inputType: T, weightType: T, cType: T, dType: T) (implicit ev: Arithmetic[T]) extends Module {
import ev._
val io = IO(new Bundle {
val in_a = Input(inputType)
val in_b = Input(inputType)
val in_b = Input(weightType)
val in_c = Input(cType)
val out_d = Output(dType)
})
Expand All @@ -28,7 +28,7 @@ class MacUnit[T <: Data](inputType: T, cType: T, dType: T) (implicit ev: Arithme
* A PE implementing a MAC operation. Configured as fully combinational when integrated into a Mesh.
* @param width Data width of operands
*/
class PE[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Value, max_simultaneous_matmuls: Int)
class PE[T <: Data](inputType: T, weightType: T, outputType: T, accType: T, df: Dataflow.Value, max_simultaneous_matmuls: Int)
(implicit ev: Arithmetic[T]) extends Module { // Debugging variables
import ev._

Expand Down Expand Up @@ -61,7 +61,7 @@ class PE[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Value,
// elaboration/synthesis tools often fail to consolidate and de-duplicate
// MAC units. To force mac circuitry to be re-used, we create a "mac_unit"
// module here which just performs a single MAC operation
val mac_unit = Module(new MacUnit(inputType,
val mac_unit = Module(new MacUnit(inputType, weightType,
if (df == Dataflow.WS) outputType else accType, outputType))

val a = io.in_a
Expand Down Expand Up @@ -103,28 +103,28 @@ class PE[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Value,
when(prop === PROPAGATE) {
io.out_c := (c1 >> shift_offset).clippedToWidthOf(outputType)
io.out_b := b
mac_unit.io.in_b := b.asTypeOf(inputType)
mac_unit.io.in_b := b.asTypeOf(weightType)
mac_unit.io.in_c := c2
c2 := mac_unit.io.out_d
c1 := d.withWidthOf(cType)
}.otherwise {
io.out_c := (c2 >> shift_offset).clippedToWidthOf(outputType)
io.out_b := b
mac_unit.io.in_b := b.asTypeOf(inputType)
mac_unit.io.in_b := b.asTypeOf(weightType)
mac_unit.io.in_c := c1
c1 := mac_unit.io.out_d
c2 := d.withWidthOf(cType)
}
}.elsewhen ((df == Dataflow.WS).B || ((df == Dataflow.BOTH).B && dataflow === WEIGHT_STATIONARY)) {
when(prop === PROPAGATE) {
io.out_c := c1
mac_unit.io.in_b := c2.asTypeOf(inputType)
mac_unit.io.in_b := c2.asTypeOf(weightType)
mac_unit.io.in_c := b
io.out_b := mac_unit.io.out_d
c1 := d
}.otherwise {
io.out_c := c2
mac_unit.io.in_b := c1.asTypeOf(inputType)
mac_unit.io.in_b := c1.asTypeOf(weightType)
mac_unit.io.in_c := b
io.out_b := mac_unit.io.out_d
c2 := d
Expand All @@ -134,7 +134,7 @@ class PE[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Value,
//assert(false.B, "unknown dataflow")
io.out_c := DontCare
io.out_b := DontCare
mac_unit.io.in_b := b.asTypeOf(inputType)
mac_unit.io.in_b := b.asTypeOf(weightType)
mac_unit.io.in_c := c2
}

Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/gemmini/Tile.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import Util._
* @param rows Number of PEs on each row
* @param columns Number of PEs on each column
*/
class Tile[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Value, tree_reduction: Boolean, max_simultaneous_matmuls: Int, val rows: Int, val columns: Int)(implicit ev: Arithmetic[T]) extends Module {
class Tile[T <: Data](inputType: T, weightType: T, outputType: T, accType: T, df: Dataflow.Value, tree_reduction: Boolean, max_simultaneous_matmuls: Int, val rows: Int, val columns: Int)(implicit ev: Arithmetic[T]) extends Module {
val io = IO(new Bundle {
val in_a = Input(Vec(rows, inputType))
val in_b = Input(Vec(columns, outputType)) // This is the output of the tile next to it
Expand All @@ -39,7 +39,7 @@ class Tile[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Valu

import ev._

val tile = Seq.fill(rows, columns)(Module(new PE(inputType, outputType, accType, df, max_simultaneous_matmuls)))
val tile = Seq.fill(rows, columns)(Module(new PE(inputType, weightType, outputType, accType, df, max_simultaneous_matmuls)))
val tileT = tile.transpose

// TODO: abstract hori/vert broadcast, all these connections look the same
Expand Down

0 comments on commit 2916cde

Please sign in to comment.