Skip to content

Commit e082434

Browse files
add openXiangshan Fudian into lib and use it in GEMM
1 parent 52b129d commit e082434

File tree

14 files changed

+381
-159
lines changed

14 files changed

+381
-159
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,7 @@ project/plugins/project/
351351
*.war
352352
*.ear
353353

354+
!lib/fudian.jar
354355
!lib/fixedpoint.jar
355356
# virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml
356357
hs_err_pid*

.gitmodules

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
[submodule "depencies/fputil"]
2-
path = depencies/fputil
3-
url = git@github.com:CodingPlatelets/fp-division-pipelined.git
4-
[submodule "depencies/hardfloat"]
5-
path = depencies/hardfloat
1+
[submodule "dependencies/hardfloat"]
2+
path = dependencies/hardfloat
63
url = git@github.com:CodingPlatelets/berkeley-hardfloat.git
4+
[submodule "dependencies/fputil"]
5+
path = dependencies/fputil
6+
url = git@github.com:CodingPlatelets/fp-division-pipelined.git

build.sbt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ lazy val root = (project in file("."))
3131
commonChiselSettings
3232
)
3333

34-
lazy val fputil = Project("fputil", file("depencies/fputil/src"))
34+
lazy val fputil = Project("fputil", file("dependencies/fputil/src"))
3535
.settings(
3636
name := "fputil",
3737
commonChiselSettings
@@ -41,7 +41,7 @@ lazy val fputil = Project("fputil", file("depencies/fputil/src"))
4141
Compile / resourceDirectory := baseDirectory.value / "main" / "resources"
4242
)
4343

44-
lazy val hardfloat = Project("hardfloat", file("depencies/hardfloat/hardfloat/src"))
44+
lazy val hardfloat = Project("hardfloat", file("dependencies/hardfloat/hardfloat/src"))
4545
.settings(
4646
name := "hardfloat",
4747
commonChiselSettings

build.sc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ trait HasChisel extends SbtModule {
3838
}
3939

4040
object fputil extends HasChisel {
41-
override def millSourcePath = os.pwd / "depencies" / "fputil"
41+
override def millSourcePath = os.pwd / "dependencies" / "fputil"
4242
}
4343

4444
trait transformer_MMModule extends ScalaModule {

depencies/fputil

Lines changed: 0 additions & 1 deletion
This file was deleted.

depencies/hardfloat

Lines changed: 0 additions & 1 deletion
This file was deleted.

dependencies/fputil

Submodule fputil added at 787313e

dependencies/hardfloat

Submodule hardfloat added at 26eb654

lib/fudian.jar

486 KB
Binary file not shown.

src/main/scala/kernel/alu/FPU.scala

Lines changed: 146 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,53 +3,164 @@ package kernel.alu
33
import chisel3._
44
import chisel3.util._
55
import hardfloat._
6-
object FPU {
7-
def equivRecFN(expWidth: Int, sigWidth: Int, a: UInt, b: UInt) = {
8-
val top4A = a(expWidth + sigWidth, expWidth + sigWidth - 3)
9-
val top4B = b(expWidth + sigWidth, expWidth + sigWidth - 3)
10-
Mux(
11-
(top4A(2, 0) === 0.U) || (top4A(2, 0) === 7.U),
12-
(top4A === top4B) && (a(sigWidth - 2, 0) === b(sigWidth - 2, 0)),
13-
Mux((top4A(2, 0) === 6.U), (top4A === top4B), (a === b))
6+
7+
case class FType(exp: Int, sig: Int) {
8+
def ieeeWidth = exp + sig
9+
def recodedWidth = ieeeWidth + 1
10+
11+
def ieeeQNaN = ((BigInt(1) << (ieeeWidth - 1)) - (BigInt(1) << (sig - 2))).U(ieeeWidth.W)
12+
def qNaN = ((BigInt(7) << (exp + sig - 3)) + (BigInt(1) << (sig - 2))).U(recodedWidth.W)
13+
def isNaN(x: UInt) = x(sig + exp - 1, sig + exp - 3).andR
14+
def isSNaN(x: UInt) = isNaN(x) && !x(sig - 2)
15+
16+
def classify(x: UInt) = {
17+
val sign = x(sig + exp)
18+
val code = x(exp + sig - 1, exp + sig - 3)
19+
val codeHi = code(2, 1)
20+
val isSpecial = codeHi === 3.U
21+
22+
val isHighSubnormalIn = x(exp + sig - 3, sig - 1) < 2.U
23+
val isSubnormal = code === 1.U || codeHi === 1.U && isHighSubnormalIn
24+
val isNormal = codeHi === 1.U && !isHighSubnormalIn || codeHi === 2.U
25+
val isZero = code === 0.U
26+
val isInf = isSpecial && !code(0)
27+
val isNaN = code.andR
28+
val isSNaN = isNaN && !x(sig - 2)
29+
val isQNaN = isNaN && x(sig - 2)
30+
31+
Cat(
32+
isQNaN,
33+
isSNaN,
34+
isInf && !sign,
35+
isNormal && !sign,
36+
isSubnormal && !sign,
37+
isZero && !sign,
38+
isZero && sign,
39+
isSubnormal && sign,
40+
isNormal && sign,
41+
isInf && sign
1442
)
1543
}
44+
45+
// convert between formats, ignoring rounding, range, NaN
46+
def unsafeConvert(x: UInt, to: FType) = if (this == to) x
47+
else {
48+
val sign = x(sig + exp)
49+
val fractIn = x(sig - 2, 0)
50+
val expIn = x(sig + exp - 1, sig - 1)
51+
val fractOut = fractIn << to.sig >> sig
52+
val expOut = {
53+
val expCode = expIn(exp, exp - 2)
54+
val commonCase = (expIn + (1 << to.exp).U) - (1 << exp).U
55+
Mux(expCode === 0.U || expCode >= 6.U, Cat(expCode, commonCase(to.exp - 3, 0)), commonCase(to.exp, 0))
56+
}
57+
Cat(sign, expOut, fractOut)
58+
}
59+
60+
private def ieeeBundle = {
61+
val expWidth = exp
62+
class IEEEBundle extends Bundle {
63+
val sign = Bool()
64+
val exp = UInt(expWidth.W)
65+
val sig = UInt((ieeeWidth - expWidth - 1).W)
66+
}
67+
new IEEEBundle
68+
}
69+
70+
def unpackIEEE(x: UInt) = x.asTypeOf(ieeeBundle)
71+
72+
def recode(x: UInt) = hardfloat.recFNFromFN(exp, sig, x)
73+
def ieee(x: UInt) = hardfloat.fNFromRecFN(exp, sig, x)
1674
}
1775

18-
class ValExec_MulRecFN(expWidth: Int, sigWidth: Int) extends Module {
76+
case class FPUParams(
77+
minFLen: Int = 32,
78+
fLen: Int = 64,
79+
divSqrt: Boolean = true,
80+
sfmaLatency: Int = 3,
81+
dfmaLatency: Int = 4,
82+
fpmuLatency: Int = 2,
83+
ifpuLatency: Int = 2)
84+
85+
object FPConstants {
86+
val RM_SZ = 3
87+
val FLAGS_SZ = 5
88+
}
89+
90+
class FPInput(implicit p: FPUParams) extends Bundle {
91+
val rm = Bits(FPConstants.RM_SZ.W)
92+
val fmaCmd = Bits(2.W)
93+
val typ = Bits(2.W)
94+
val fmt = Bits(2.W)
95+
val in1 = Bits((p.fLen + 1).W)
96+
val in2 = Bits((p.fLen + 1).W)
97+
val in3 = Bits((p.fLen + 1).W)
98+
}
99+
100+
class FPResult(implicit p: FPUParams) {
101+
val data = Bits((p.fLen + 1).W)
102+
val exc = Bits(FPConstants.FLAGS_SZ.W)
103+
}
104+
105+
class MulAddRecFNPipe(latency: Int, expWidth: Int, sigWidth: Int) extends Module {
106+
override def desiredName = s"MulAddRecFNPipe_l${latency}_e${expWidth}_s${sigWidth}"
107+
require(latency <= 2)
108+
19109
val io = IO(new Bundle {
20-
val a = Input(Bits((expWidth + sigWidth).W))
21-
val b = Input(Bits((expWidth + sigWidth).W))
110+
val validin = Input(Bool())
111+
val op = Input(Bits(2.W))
112+
val a = Input(Bits((expWidth + sigWidth + 1).W))
113+
val b = Input(Bits((expWidth + sigWidth + 1).W))
114+
val c = Input(Bits((expWidth + sigWidth + 1).W))
22115
val roundingMode = Input(UInt(3.W))
23116
val detectTininess = Input(UInt(1.W))
117+
val out = Output(Bits((expWidth + sigWidth + 1).W))
118+
val exceptionFlags = Output(Bits(5.W))
119+
val validout = Output(Bool())
120+
})
24121

25-
val expected = new Bundle {
26-
val out = Input(Bits((expWidth + sigWidth).W))
27-
val exceptionFlags = Input(Bits(5.W))
28-
val recOut = Output(Bits((expWidth + sigWidth + 1).W))
29-
}
122+
//------------------------------------------------------------------------
123+
//------------------------------------------------------------------------
30124

31-
val actual = new Bundle {
32-
val out = Output(Bits((expWidth + sigWidth + 1).W))
33-
val exceptionFlags = Output(Bits(5.W))
34-
}
125+
val mulAddRecFNToRaw_preMul = Module(new hardfloat.MulAddRecFNToRaw_preMul(expWidth, sigWidth))
126+
val mulAddRecFNToRaw_postMul = Module(new hardfloat.MulAddRecFNToRaw_postMul(expWidth, sigWidth))
35127

36-
val check = Output(Bool())
37-
val pass = Output(Bool())
38-
})
128+
mulAddRecFNToRaw_preMul.io.op := io.op
129+
mulAddRecFNToRaw_preMul.io.a := io.a
130+
mulAddRecFNToRaw_preMul.io.b := io.b
131+
mulAddRecFNToRaw_preMul.io.c := io.c
132+
133+
val mulAddResult =
134+
(mulAddRecFNToRaw_preMul.io.mulAddA *
135+
mulAddRecFNToRaw_preMul.io.mulAddB) +&
136+
mulAddRecFNToRaw_preMul.io.mulAddC
137+
138+
val valid_stage0 = Wire(Bool())
139+
val roundingMode_stage0 = Wire(UInt(3.W))
140+
val detectTininess_stage0 = Wire(UInt(1.W))
141+
142+
val postmul_regs = if (latency > 0) 1 else 0
143+
mulAddRecFNToRaw_postMul.io.fromPreMul := Pipe(io.validin, mulAddRecFNToRaw_preMul.io.toPostMul, postmul_regs).bits
144+
mulAddRecFNToRaw_postMul.io.mulAddResult := Pipe(io.validin, mulAddResult, postmul_regs).bits
145+
mulAddRecFNToRaw_postMul.io.roundingMode := Pipe(io.validin, io.roundingMode, postmul_regs).bits
146+
roundingMode_stage0 := Pipe(io.validin, io.roundingMode, postmul_regs).bits
147+
detectTininess_stage0 := Pipe(io.validin, io.detectTininess, postmul_regs).bits
148+
valid_stage0 := Pipe(io.validin, false.B, postmul_regs).valid
149+
150+
//------------------------------------------------------------------------
151+
//------------------------------------------------------------------------
39152

40-
val mulRecFN = Module(new MulRecFN(expWidth, sigWidth))
41-
mulRecFN.io.a := recFNFromFN(expWidth, sigWidth, io.a)
42-
mulRecFN.io.b := recFNFromFN(expWidth, sigWidth, io.b)
43-
mulRecFN.io.roundingMode := io.roundingMode
44-
mulRecFN.io.detectTininess := io.detectTininess
153+
val roundRawFNToRecFN = Module(new hardfloat.RoundRawFNToRecFN(expWidth, sigWidth, 0))
45154

46-
io.expected.recOut := recFNFromFN(expWidth, sigWidth, io.expected.out)
155+
val round_regs = if (latency == 2) 1 else 0
156+
roundRawFNToRecFN.io.invalidExc := Pipe(valid_stage0, mulAddRecFNToRaw_postMul.io.invalidExc, round_regs).bits
157+
roundRawFNToRecFN.io.in := Pipe(valid_stage0, mulAddRecFNToRaw_postMul.io.rawOut, round_regs).bits
158+
roundRawFNToRecFN.io.roundingMode := Pipe(valid_stage0, roundingMode_stage0, round_regs).bits
159+
roundRawFNToRecFN.io.detectTininess := Pipe(valid_stage0, detectTininess_stage0, round_regs).bits
160+
io.validout := Pipe(valid_stage0, false.B, round_regs).valid
47161

48-
io.actual.out := mulRecFN.io.out
49-
io.actual.exceptionFlags := mulRecFN.io.exceptionFlags
162+
roundRawFNToRecFN.io.infiniteExc := false.B
50163

51-
io.check := true.B
52-
io.pass :=
53-
FPU.equivRecFN(expWidth, sigWidth, io.actual.out, io.expected.recOut) &&
54-
(io.actual.exceptionFlags === io.expected.exceptionFlags)
164+
io.out := roundRawFNToRecFN.io.out
165+
io.exceptionFlags := roundRawFNToRecFN.io.exceptionFlags
55166
}

0 commit comments

Comments
 (0)