Skip to content

Commit 7b20d1d

Browse files
committed
Refactor. Fix FMA for subnormal products
1 parent c06476a commit 7b20d1d

File tree

3 files changed

+22
-14
lines changed

3 files changed

+22
-14
lines changed

Makefile

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ P32toI32 \
5151
P32toI64 \
5252
P64toI32 \
5353
P64toI64 \
54+
P16toUI32 \
55+
P16toUI64 \
56+
P32toUI32 \
57+
P32toUI64 \
58+
P64toUI32 \
59+
P64toUI64 \
5460
I32toP16 \
5561
I64toP16 \
5662
I32toP32 \

src/main/scala/PositExtractor.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ class PositExtractor(val totalBits: Int, val es: Int) extends Module with HasHar
2626
val extractedExponent =
2727
if (es > 0) expFrac(totalBits - 1, totalBits - es)
2828
else 0.U
29-
val frac = expFrac << es
29+
val frac = (expFrac << es)(totalBits - 1, totalBits - maxFractionBits)
3030

3131
io.out.sign := sign
3232
io.out.isZero := isZero(io.in)
3333
io.out.isNaR := isNaR(io.in)
3434
io.out.exponent := ((regime << es) | extractedExponent).asSInt
35-
io.out.fraction := Cat(1.U, frac(totalBits - 1, totalBits - maxFractionBits))
35+
io.out.fraction := Cat(1.U, frac)
3636
}

src/main/scala/PositFMA.scala

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class PositFMA(val totalBits: Int, val es: Int) extends Module with HasHardPosit
3131
val productSign = num1.sign ^ num2.sign ^ io.negate
3232
val addendSign = num3.sign ^ io.negate ^ io.sub
3333

34-
val productExponent = num1.exponent + num2.exponent
34+
val productExponent = num1.exponent +& num2.exponent
3535
val productFraction =
3636
WireInit(UInt(maxMultiplierFractionBits.W), num1.fraction * num2.fraction)
3737

@@ -40,13 +40,14 @@ class PositFMA(val totalBits: Int, val es: Int) extends Module with HasHardPosit
4040
val normProductExponent = productExponent + Mux(prodOverflow, 1.S, 0.S)
4141
val prodStickyBit = Mux(prodOverflow, productFraction(0), false.B)
4242

43-
val addendFraction = (num3.fraction << maxFractionBits).asUInt
43+
val addendIsZero = num3.isZero
44+
val addendFraction = Mux(!addendIsZero, (num3.fraction << maxFractionBits).asUInt, 0.U)
4445
val addendExponent = num3.exponent
4546

4647
val isAddendLargerThanProduct =
47-
(addendExponent > normProductExponent) |
48-
(addendExponent === normProductExponent &&
49-
(addendFraction > normProductFraction))
48+
~addendIsZero &
49+
((addendExponent > normProductExponent) |
50+
(addendExponent === normProductExponent && (addendFraction > normProductFraction)))
5051

5152
val largeExp = Mux(isAddendLargerThanProduct, addendExponent, normProductExponent)
5253
val largeFrac = Mux(isAddendLargerThanProduct, addendFraction, normProductFraction)
@@ -57,28 +58,29 @@ class PositFMA(val totalBits: Int, val es: Int) extends Module with HasHardPosit
5758
val smallSign = Mux(isAddendLargerThanProduct, productSign, addendSign)
5859

5960
val expDiff = (largeExp - smallExp).asUInt()
61+
val ShftInBound = expDiff < maxMultiplierFractionBits.U
6062
val shiftedSmallFrac =
6163
Mux(expDiff < maxMultiplierFractionBits.U, smallFrac >> expDiff, 0.U)
6264
val smallFracStickyBit = (smallFrac & ((1.U << expDiff) - 1.U)).orR()
6365

6466
val isAddition = ~(largeSign ^ smallSign)
65-
val signedSmallerFraction =
67+
val signedSmallerFrac =
6668
Mux(isAddition, shiftedSmallFrac, ~shiftedSmallFrac + 1.U)
6769
val fmaFraction =
68-
WireInit(UInt(maxMultiplierFractionBits.W), largeFrac +& signedSmallerFraction)
70+
WireInit(UInt(maxMultiplierFractionBits.W), largeFrac +& signedSmallerFrac)
6971

70-
val sumOverflow = fmaFraction(maxMultiplierFractionBits - 1)
72+
val fmaOverflow = isAddition & fmaFraction(maxMultiplierFractionBits - 1)
7173
val adjFmaFraction =
72-
Mux(isAddition, fmaFraction >> sumOverflow.asUInt(), fmaFraction(maxMultiplierFractionBits - 2, 0))
73-
val adjFmaExponent = largeExp + Mux(isAddition & sumOverflow, 1.S, 0.S)
74-
val sumStickyBit = Mux(isAddition & sumOverflow, fmaFraction(0), false.B)
74+
Mux(fmaOverflow, fmaFraction(maxMultiplierFractionBits - 1, 1), fmaFraction(maxMultiplierFractionBits - 2, 0))
75+
val adjFmaExponent = largeExp + Mux(fmaOverflow, 1.S, 0.S)
76+
val sumStickyBit = Mux(fmaOverflow, fmaFraction(0), false.B)
7577

7678
val normalizationFactor = MuxCase(0.S, Array.range(0, maxMultiplierFractionBits - 2).map(index => {
7779
(adjFmaFraction(maxMultiplierFractionBits - 2, maxMultiplierFractionBits - index - 2) === 1.U) -> index.S
7880
}))
7981

8082
val normFmaExponent = adjFmaExponent - normalizationFactor
81-
val normFmaFraction = adjFmaFraction << normalizationFactor.asUInt()
83+
val normFmaFraction = (adjFmaFraction << normalizationFactor.asUInt())(maxMultiplierFractionBits - 1, 0)
8284

8385
val result = Wire(new unpackedPosit(totalBits, es))
8486
result.isNaR := num1.isNaR || num2.isNaR || num3.isNaR

0 commit comments

Comments
 (0)