diff --git a/src/main/scala/gemmini/Arithmetic.scala b/src/main/scala/gemmini/Arithmetic.scala index 7bd8d9e8..de059dff 100644 --- a/src/main/scala/gemmini/Arithmetic.scala +++ b/src/main/scala/gemmini/Arithmetic.scala @@ -8,8 +8,8 @@ import chisel3.util._ import hardfloat._ // Bundles that represent the raw bits of custom datatypes -case class Float(expWidth: Int, sigWidth: Int) extends Bundle { - val bits = UInt((expWidth + sigWidth).W) +case class Float(expWidth: Int, sigWidth: Int, isRecoded: Boolean = false) extends Bundle { + val bits = UInt((expWidth + sigWidth + (if (isRecoded) 1 else 0)).W) val bias: Int = (1 << (expWidth-1)) - 1 } @@ -245,7 +245,7 @@ object Arithmetic { } override def reciprocal[U <: Data](u: U, options: Int = 0): Option[(DecoupledIO[UInt], DecoupledIO[U])] = u match { - case Float(expWidth, sigWidth) => + case Float(expWidth, sigWidth, false) => val input = Wire(Decoupled(UInt(0.W))) val output = Wire(Decoupled(u.cloneType)) @@ -287,7 +287,7 @@ object Arithmetic { } override def mult_with_reciprocal[U <: Data](reciprocal: U): SInt = reciprocal match { - case recip @ Float(expWidth, sigWidth) => + case recip @ Float(expWidth, sigWidth, false) => def in_to_float(x: SInt) = { val in_to_rec_fn = Module(new INToRecFN(intWidth = self.getWidth, expWidth, sigWidth)) in_to_rec_fn.io.signedIn := true.B @@ -330,8 +330,8 @@ object Arithmetic { override implicit def cast(self: Float): ArithmeticOps[Float] = new ArithmeticOps(self) { override def *(t: Float): Float = { - val t_rec = recFNFromFN(t.expWidth, t.sigWidth, t.bits) - val self_rec = recFNFromFN(self.expWidth, self.sigWidth, self.bits) + val t_rec = if (t.isRecoded) t.bits else recFNFromFN(t.expWidth, t.sigWidth, t.bits) + val self_rec = if (self.isRecoded) self.bits else recFNFromFN(self.expWidth, self.sigWidth, self.bits) val t_resizer = Module(new RecFNToRecFN(t.expWidth, t.sigWidth, self.expWidth, self.sigWidth)) t_resizer.io.in := t_rec @@ -347,16 +347,16 @@ object Arithmetic { muladder.io.a := self_rec muladder.io.b := t_rec_resized - val out = Wire(Float(self.expWidth, self.sigWidth)) - out.bits := fNFromRecFN(self.expWidth, self.sigWidth, muladder.io.out) + val out = Wire(Float(self.expWidth, self.sigWidth, self.isRecoded)) + out.bits := (if (out.isRecoded) muladder.io.out else fNFromRecFN(self.expWidth, self.sigWidth, muladder.io.out)) out } override def mac(m1: Float, m2: Float): Float = { // Recode all operands - val m1_rec = recFNFromFN(m1.expWidth, m1.sigWidth, m1.bits) - val m2_rec = recFNFromFN(m2.expWidth, m2.sigWidth, m2.bits) - val self_rec = recFNFromFN(self.expWidth, self.sigWidth, self.bits) + val m1_rec = if (m1.isRecoded) m1.bits else recFNFromFN(m1.expWidth, m1.sigWidth, m1.bits) + val m2_rec = if (m2.isRecoded) m2.bits else recFNFromFN(m2.expWidth, m2.sigWidth, m2.bits) + val self_rec = if (self.isRecoded) self.bits else recFNFromFN(self.expWidth, self.sigWidth, self.bits) // Resize m1 to self's width val m1_resizer = Module(new RecFNToRecFN(m1.expWidth, m1.sigWidth, self.expWidth, self.sigWidth)) @@ -364,7 +364,7 @@ object Arithmetic { m1_resizer.io.roundingMode := consts.round_near_even // consts.round_near_maxMag m1_resizer.io.detectTininess := consts.tininess_afterRounding val m1_rec_resized = m1_resizer.io.out - + // Resize m2 to self's width val m2_resizer = Module(new RecFNToRecFN(m2.expWidth, m2.sigWidth, self.expWidth, self.sigWidth)) m2_resizer.io.in := m2_rec @@ -384,8 +384,8 @@ object Arithmetic { muladder.io.c := self_rec // Convert result to standard format // TODO remove these intermediate recodings - val out = Wire(Float(self.expWidth, self.sigWidth)) - out.bits := fNFromRecFN(self.expWidth, self.sigWidth, muladder.io.out) + val out = Wire(Float(self.expWidth, self.sigWidth, self.isRecoded)) + out.bits := (if (out.isRecoded) muladder.io.out else fNFromRecFN(self.expWidth, self.sigWidth, muladder.io.out)) out } @@ -393,8 +393,8 @@ object Arithmetic { require(self.getWidth >= t.getWidth) // This just makes it easier to write the resizing code // Recode all operands - val t_rec = recFNFromFN(t.expWidth, t.sigWidth, t.bits) - val self_rec = recFNFromFN(self.expWidth, self.sigWidth, self.bits) + val t_rec = if (t.isRecoded) t.bits else recFNFromFN(t.expWidth, t.sigWidth, t.bits) + val self_rec = if (self.isRecoded) self.bits else recFNFromFN(self.expWidth, self.sigWidth, self.bits) // Generate 1 as a float val in_to_rec_fn = Module(new INToRecFN(1, self.expWidth, self.sigWidth)) @@ -423,8 +423,8 @@ object Arithmetic { muladder.io.b := one_rec muladder.io.c := self_rec - val result = Wire(Float(self.expWidth, self.sigWidth)) - result.bits := fNFromRecFN(self.expWidth, self.sigWidth, muladder.io.out) + val result = Wire(Float(self.expWidth, self.sigWidth, self.isRecoded)) + result.bits := (if (result.isRecoded) muladder.io.out else fNFromRecFN(self.expWidth, self.sigWidth, muladder.io.out)) result } @@ -436,7 +436,7 @@ object Arithmetic { override def >>(u: UInt): Float = { // Recode self - val self_rec = recFNFromFN(self.expWidth, self.sigWidth, self.bits) + val self_rec = if (self.isRecoded) self.bits else recFNFromFN(self.expWidth, self.sigWidth, self.bits) // Get 2^(-u) as a recoded float val shift_exp = Wire(UInt(self.expWidth.W)) @@ -455,15 +455,15 @@ object Arithmetic { muladder.io.a := self_rec muladder.io.b := shift_rec - val result = Wire(Float(self.expWidth, self.sigWidth)) - result.bits := fNFromRecFN(self.expWidth, self.sigWidth, muladder.io.out) + val result = Wire(Float(self.expWidth, self.sigWidth, self.isRecoded)) + result.bits := (if (result.isRecoded) muladder.io.out else fNFromRecFN(self.expWidth, self.sigWidth, muladder.io.out)) result } override def >(t: Float): Bool = { // Recode all operands - val t_rec = recFNFromFN(t.expWidth, t.sigWidth, t.bits) - val self_rec = recFNFromFN(self.expWidth, self.sigWidth, self.bits) + val t_rec = if (t.isRecoded) t.bits else recFNFromFN(t.expWidth, t.sigWidth, t.bits) + val self_rec = if (self.isRecoded) self.bits else recFNFromFN(self.expWidth, self.sigWidth, self.bits) // Resize t to self's width val t_resizer = Module(new RecFNToRecFN(t.expWidth, t.sigWidth, self.expWidth, self.sigWidth)) @@ -481,43 +481,49 @@ object Arithmetic { } override def withWidthOf(t: Float): Float = { - val self_rec = recFNFromFN(self.expWidth, self.sigWidth, self.bits) + val self_rec = if (self.isRecoded) self.bits else recFNFromFN(self.expWidth, self.sigWidth, self.bits) val resizer = Module(new RecFNToRecFN(self.expWidth, self.sigWidth, t.expWidth, t.sigWidth)) resizer.io.in := self_rec resizer.io.roundingMode := consts.round_near_even // consts.round_near_maxMag resizer.io.detectTininess := consts.tininess_afterRounding - val result = Wire(Float(t.expWidth, t.sigWidth)) - result.bits := fNFromRecFN(t.expWidth, t.sigWidth, resizer.io.out) + val result = Wire(Float(t.expWidth, t.sigWidth, t.isRecoded)) + result.bits := (if (result.isRecoded) resizer.io.out else fNFromRecFN(t.expWidth, t.sigWidth, resizer.io.out)) result } override def clippedToWidthOf(t: Float): Float = { // TODO check for overflow. Right now, we just assume that overflow doesn't happen - val self_rec = recFNFromFN(self.expWidth, self.sigWidth, self.bits) + val self_rec = if (self.isRecoded) self.bits else recFNFromFN(self.expWidth, self.sigWidth, self.bits) val resizer = Module(new RecFNToRecFN(self.expWidth, self.sigWidth, t.expWidth, t.sigWidth)) resizer.io.in := self_rec resizer.io.roundingMode := consts.round_near_even // consts.round_near_maxMag resizer.io.detectTininess := consts.tininess_afterRounding - val result = Wire(Float(t.expWidth, t.sigWidth)) - result.bits := fNFromRecFN(t.expWidth, t.sigWidth, resizer.io.out) + val result = Wire(Float(t.expWidth, t.sigWidth, t.isRecoded)) + result.bits := (if (result.isRecoded) resizer.io.out else fNFromRecFN(t.expWidth, t.sigWidth, resizer.io.out)) result } override def relu: Float = { - val raw = rawFloatFromFN(self.expWidth, self.sigWidth, self.bits) + val raw = if (self.isRecoded) rawFloatFromRecFN(self.expWidth, self.sigWidth, self.bits) else rawFloatFromFN(self.expWidth, self.sigWidth, self.bits) - val result = Wire(Float(self.expWidth, self.sigWidth)) + val result = Wire(Float(self.expWidth, self.sigWidth, self.isRecoded)) result.bits := Mux(!raw.isZero && raw.sign, 0.U, self.bits) result } override def zero: Float = 0.U.asTypeOf(self) - override def identity: Float = Cat(0.U(2.W), ~(0.U((self.expWidth-1).W)), 0.U((self.sigWidth-1).W)).asTypeOf(self) - override def minimum: Float = Cat(1.U, ~(0.U(self.expWidth.W)), 0.U((self.sigWidth-1).W)).asTypeOf(self) + override def identity: Float = { + require(!self.isRecoded) + Cat(0.U(2.W), ~(0.U((self.expWidth-1).W)), 0.U((self.sigWidth-1).W)).asTypeOf(self) + } + override def minimum: Float = { + require(!self.isRecoded) + Cat(1.U, ~(0.U(self.expWidth.W)), 0.U((self.sigWidth-1).W)).asTypeOf(self) + } } } diff --git a/src/main/scala/gemmini/Normalizer.scala b/src/main/scala/gemmini/Normalizer.scala index c22e9af8..5c91201f 100644 --- a/src/main/scala/gemmini/Normalizer.scala +++ b/src/main/scala/gemmini/Normalizer.scala @@ -206,7 +206,7 @@ class MulPipe[T <: Data, U <: Data](scale_t: U)(implicit ev: Arithmetic[T]) }) scale_t match { - case Float(expWidth, sigWidth) => + case Float(expWidth, sigWidth, false) => val self_rec = recFNFromFN(expWidth, sigWidth, io.ins.bits.x.asUInt) val scale_rec = recFNFromFN(expWidth, sigWidth, io.ins.bits.y.asUInt) @@ -542,7 +542,7 @@ class Normalizer[T <: Data, U <: Data](max_len: Int, num_reduce_lanes: Int, num_ val exp_divider_out = Wire(Decoupled(scale_t.cloneType)) scale_t match { - case Float(expWidth, sigWidth) => + case Float(expWidth, sigWidth, false) => exp_divider_in.bits := DontCare