Skip to content

Commit

Permalink
add recoded arithmetic from sirius
Browse files Browse the repository at this point in the history
  • Loading branch information
richardyrh committed Aug 6, 2024
1 parent 38d1020 commit 0328277
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 35 deletions.
72 changes: 39 additions & 33 deletions src/main/scala/gemmini/Arithmetic.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import chisel3.util._
import hardfloat._

// Bundles that represent the raw bits of custom datatypes
case class Float(expWidth: Int, sigWidth: Int) extends Bundle {
val bits = UInt((expWidth + sigWidth).W)
case class Float(expWidth: Int, sigWidth: Int, isRecoded: Boolean = false) extends Bundle {
val bits = UInt((expWidth + sigWidth + (if (isRecoded) 1 else 0)).W)

val bias: Int = (1 << (expWidth-1)) - 1
}
Expand Down Expand Up @@ -245,7 +245,7 @@ object Arithmetic {
}

override def reciprocal[U <: Data](u: U, options: Int = 0): Option[(DecoupledIO[UInt], DecoupledIO[U])] = u match {
case Float(expWidth, sigWidth) =>
case Float(expWidth, sigWidth, false) =>
val input = Wire(Decoupled(UInt(0.W)))
val output = Wire(Decoupled(u.cloneType))

Expand Down Expand Up @@ -287,7 +287,7 @@ object Arithmetic {
}

override def mult_with_reciprocal[U <: Data](reciprocal: U): SInt = reciprocal match {
case recip @ Float(expWidth, sigWidth) =>
case recip @ Float(expWidth, sigWidth, false) =>
def in_to_float(x: SInt) = {
val in_to_rec_fn = Module(new INToRecFN(intWidth = self.getWidth, expWidth, sigWidth))
in_to_rec_fn.io.signedIn := true.B
Expand Down Expand Up @@ -330,8 +330,8 @@ object Arithmetic {

override implicit def cast(self: Float): ArithmeticOps[Float] = new ArithmeticOps(self) {
override def *(t: Float): Float = {
val t_rec = recFNFromFN(t.expWidth, t.sigWidth, t.bits)
val self_rec = recFNFromFN(self.expWidth, self.sigWidth, self.bits)
val t_rec = if (t.isRecoded) t.bits else recFNFromFN(t.expWidth, t.sigWidth, t.bits)
val self_rec = if (self.isRecoded) self.bits else recFNFromFN(self.expWidth, self.sigWidth, self.bits)

val t_resizer = Module(new RecFNToRecFN(t.expWidth, t.sigWidth, self.expWidth, self.sigWidth))
t_resizer.io.in := t_rec
Expand All @@ -347,24 +347,24 @@ object Arithmetic {
muladder.io.a := self_rec
muladder.io.b := t_rec_resized

val out = Wire(Float(self.expWidth, self.sigWidth))
out.bits := fNFromRecFN(self.expWidth, self.sigWidth, muladder.io.out)
val out = Wire(Float(self.expWidth, self.sigWidth, self.isRecoded))
out.bits := (if (out.isRecoded) muladder.io.out else fNFromRecFN(self.expWidth, self.sigWidth, muladder.io.out))
out
}

override def mac(m1: Float, m2: Float): Float = {
// Recode all operands
val m1_rec = recFNFromFN(m1.expWidth, m1.sigWidth, m1.bits)
val m2_rec = recFNFromFN(m2.expWidth, m2.sigWidth, m2.bits)
val self_rec = recFNFromFN(self.expWidth, self.sigWidth, self.bits)
val m1_rec = if (m1.isRecoded) m1.bits else recFNFromFN(m1.expWidth, m1.sigWidth, m1.bits)
val m2_rec = if (m2.isRecoded) m2.bits else recFNFromFN(m2.expWidth, m2.sigWidth, m2.bits)
val self_rec = if (self.isRecoded) self.bits else recFNFromFN(self.expWidth, self.sigWidth, self.bits)

// Resize m1 to self's width
val m1_resizer = Module(new RecFNToRecFN(m1.expWidth, m1.sigWidth, self.expWidth, self.sigWidth))
m1_resizer.io.in := m1_rec
m1_resizer.io.roundingMode := consts.round_near_even // consts.round_near_maxMag
m1_resizer.io.detectTininess := consts.tininess_afterRounding
val m1_rec_resized = m1_resizer.io.out

// Resize m2 to self's width
val m2_resizer = Module(new RecFNToRecFN(m2.expWidth, m2.sigWidth, self.expWidth, self.sigWidth))
m2_resizer.io.in := m2_rec
Expand All @@ -384,17 +384,17 @@ object Arithmetic {
muladder.io.c := self_rec

// Convert result to standard format // TODO remove these intermediate recodings
val out = Wire(Float(self.expWidth, self.sigWidth))
out.bits := fNFromRecFN(self.expWidth, self.sigWidth, muladder.io.out)
val out = Wire(Float(self.expWidth, self.sigWidth, self.isRecoded))
out.bits := (if (out.isRecoded) muladder.io.out else fNFromRecFN(self.expWidth, self.sigWidth, muladder.io.out))
out
}

override def +(t: Float): Float = {
require(self.getWidth >= t.getWidth) // This just makes it easier to write the resizing code

// Recode all operands
val t_rec = recFNFromFN(t.expWidth, t.sigWidth, t.bits)
val self_rec = recFNFromFN(self.expWidth, self.sigWidth, self.bits)
val t_rec = if (t.isRecoded) t.bits else recFNFromFN(t.expWidth, t.sigWidth, t.bits)
val self_rec = if (self.isRecoded) self.bits else recFNFromFN(self.expWidth, self.sigWidth, self.bits)

// Generate 1 as a float
val in_to_rec_fn = Module(new INToRecFN(1, self.expWidth, self.sigWidth))
Expand Down Expand Up @@ -423,8 +423,8 @@ object Arithmetic {
muladder.io.b := one_rec
muladder.io.c := self_rec

val result = Wire(Float(self.expWidth, self.sigWidth))
result.bits := fNFromRecFN(self.expWidth, self.sigWidth, muladder.io.out)
val result = Wire(Float(self.expWidth, self.sigWidth, self.isRecoded))
result.bits := (if (result.isRecoded) muladder.io.out else fNFromRecFN(self.expWidth, self.sigWidth, muladder.io.out))
result
}

Expand All @@ -436,7 +436,7 @@ object Arithmetic {

override def >>(u: UInt): Float = {
// Recode self
val self_rec = recFNFromFN(self.expWidth, self.sigWidth, self.bits)
val self_rec = if (self.isRecoded) self.bits else recFNFromFN(self.expWidth, self.sigWidth, self.bits)

// Get 2^(-u) as a recoded float
val shift_exp = Wire(UInt(self.expWidth.W))
Expand All @@ -455,15 +455,15 @@ object Arithmetic {
muladder.io.a := self_rec
muladder.io.b := shift_rec

val result = Wire(Float(self.expWidth, self.sigWidth))
result.bits := fNFromRecFN(self.expWidth, self.sigWidth, muladder.io.out)
val result = Wire(Float(self.expWidth, self.sigWidth, self.isRecoded))
result.bits := (if (result.isRecoded) muladder.io.out else fNFromRecFN(self.expWidth, self.sigWidth, muladder.io.out))
result
}

override def >(t: Float): Bool = {
// Recode all operands
val t_rec = recFNFromFN(t.expWidth, t.sigWidth, t.bits)
val self_rec = recFNFromFN(self.expWidth, self.sigWidth, self.bits)
val t_rec = if (t.isRecoded) t.bits else recFNFromFN(t.expWidth, t.sigWidth, t.bits)
val self_rec = if (self.isRecoded) self.bits else recFNFromFN(self.expWidth, self.sigWidth, self.bits)

// Resize t to self's width
val t_resizer = Module(new RecFNToRecFN(t.expWidth, t.sigWidth, self.expWidth, self.sigWidth))
Expand All @@ -481,43 +481,49 @@ object Arithmetic {
}

override def withWidthOf(t: Float): Float = {
val self_rec = recFNFromFN(self.expWidth, self.sigWidth, self.bits)
val self_rec = if (self.isRecoded) self.bits else recFNFromFN(self.expWidth, self.sigWidth, self.bits)

val resizer = Module(new RecFNToRecFN(self.expWidth, self.sigWidth, t.expWidth, t.sigWidth))
resizer.io.in := self_rec
resizer.io.roundingMode := consts.round_near_even // consts.round_near_maxMag
resizer.io.detectTininess := consts.tininess_afterRounding

val result = Wire(Float(t.expWidth, t.sigWidth))
result.bits := fNFromRecFN(t.expWidth, t.sigWidth, resizer.io.out)
val result = Wire(Float(t.expWidth, t.sigWidth, t.isRecoded))
result.bits := (if (result.isRecoded) resizer.io.out else fNFromRecFN(t.expWidth, t.sigWidth, resizer.io.out))
result
}

override def clippedToWidthOf(t: Float): Float = {
// TODO check for overflow. Right now, we just assume that overflow doesn't happen
val self_rec = recFNFromFN(self.expWidth, self.sigWidth, self.bits)
val self_rec = if (self.isRecoded) self.bits else recFNFromFN(self.expWidth, self.sigWidth, self.bits)

val resizer = Module(new RecFNToRecFN(self.expWidth, self.sigWidth, t.expWidth, t.sigWidth))
resizer.io.in := self_rec
resizer.io.roundingMode := consts.round_near_even // consts.round_near_maxMag
resizer.io.detectTininess := consts.tininess_afterRounding

val result = Wire(Float(t.expWidth, t.sigWidth))
result.bits := fNFromRecFN(t.expWidth, t.sigWidth, resizer.io.out)
val result = Wire(Float(t.expWidth, t.sigWidth, t.isRecoded))
result.bits := (if (result.isRecoded) resizer.io.out else fNFromRecFN(t.expWidth, t.sigWidth, resizer.io.out))
result
}

override def relu: Float = {
val raw = rawFloatFromFN(self.expWidth, self.sigWidth, self.bits)
val raw = if (self.isRecoded) rawFloatFromRecFN(self.expWidth, self.sigWidth, self.bits) else rawFloatFromFN(self.expWidth, self.sigWidth, self.bits)

val result = Wire(Float(self.expWidth, self.sigWidth))
val result = Wire(Float(self.expWidth, self.sigWidth, self.isRecoded))
result.bits := Mux(!raw.isZero && raw.sign, 0.U, self.bits)
result
}

override def zero: Float = 0.U.asTypeOf(self)
override def identity: Float = Cat(0.U(2.W), ~(0.U((self.expWidth-1).W)), 0.U((self.sigWidth-1).W)).asTypeOf(self)
override def minimum: Float = Cat(1.U, ~(0.U(self.expWidth.W)), 0.U((self.sigWidth-1).W)).asTypeOf(self)
override def identity: Float = {
require(!self.isRecoded)
Cat(0.U(2.W), ~(0.U((self.expWidth-1).W)), 0.U((self.sigWidth-1).W)).asTypeOf(self)
}
override def minimum: Float = {
require(!self.isRecoded)
Cat(1.U, ~(0.U(self.expWidth.W)), 0.U((self.sigWidth-1).W)).asTypeOf(self)
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/gemmini/Normalizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ class MulPipe[T <: Data, U <: Data](scale_t: U)(implicit ev: Arithmetic[T])
})

scale_t match {
case Float(expWidth, sigWidth) =>
case Float(expWidth, sigWidth, false) =>
val self_rec = recFNFromFN(expWidth, sigWidth, io.ins.bits.x.asUInt)
val scale_rec = recFNFromFN(expWidth, sigWidth, io.ins.bits.y.asUInt)

Expand Down Expand Up @@ -542,7 +542,7 @@ class Normalizer[T <: Data, U <: Data](max_len: Int, num_reduce_lanes: Int, num_
val exp_divider_out = Wire(Decoupled(scale_t.cloneType))

scale_t match {
case Float(expWidth, sigWidth) =>
case Float(expWidth, sigWidth, false) =>

exp_divider_in.bits := DontCare

Expand Down

0 comments on commit 0328277

Please sign in to comment.