Skip to content

Commit 9c82c14

Browse files
need continues design for DAC25
1 parent 58ab56a commit 9c82c14

File tree

9 files changed

+113
-11
lines changed

9 files changed

+113
-11
lines changed

Makefile

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@ test:
88
verilog:
99
mkdir -p $(CHISEL_BUILD_DIR)
1010
#mill -i chiselVitisTemplate.runMain --mainClass vitisrtlkernel.VitisRTLKernelVerilog -td $(CHISEL_BUILD_DIR)
11-
sbt run
11+
# sbt run
12+
sbt "runMain vitisrtlkernel.VitisRTLKernelVerilog"
13+
14+
feature:
15+
sbt -J-Xmx50G "runMain kernel.NewFeatureTest"
1216

1317
help:
1418
mill -i __.runMain --mainClass vitisrtlkernel.VitisRTLKernelVerilog --help

build.sbt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,12 @@ lazy val root = (project in file("."))
2525
.dependsOn(fputil)
2626
.settings(
2727
name := "transformer_MM",
28+
// fork := true,
29+
// javaOptions += "-Xmx50G",
2830
commonChiselSettings
2931
)
3032

3133
lazy val fputil = (project in file("fputil/src/main/scala")).settings(
3234
name := "fputil",
3335
commonChiselSettings
34-
)
36+
)

src/main/scala/kernel/alu/Gemm.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ import chisel3.util._
55
import kernel.utils.DebugLog
66

77
trait GEMMAccuracyConfig {
8-
val I: Int = 8
9-
val F: Int = 24
8+
val I: Int = 4
9+
val F: Int = 12
1010
}
1111

1212
class PEFxp extends Module with GEMMAccuracyConfig with DebugLog {
@@ -40,6 +40,10 @@ class GEMM(val n: Int = 4) extends Module with GEMMAccuracyConfig with DebugLog
4040
val InputB = IO(Flipped(Decoupled(Vec(n, Vec(n, UInt((I + F).W))))))
4141
val OutputPipe = IO(Decoupled(Vec(n * n, UInt((2 * (I + F)).W))))
4242

43+
// accumulate mode
44+
val accMode = IO(Input(Bool()))
45+
val accReg = RegInit(false.B)
46+
4347
val dataValid = InputA.valid && InputB.valid
4448

4549
val busy = RegInit(false.B)

src/main/scala/kernel/alu/Softmax.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@ import kernel.utils.PipeValue
88
import kernel.utils.common
99

1010
trait SoftmaxAccuracy {
11-
val I: Int = 8
12-
val F: Int = 16
11+
val I: Int = 4
12+
val F: Int = 12
13+
14+
val maxDivNum = 16
1315
}
1416

1517
class FixedPointExp extends Module with SoftmaxAccuracy with DebugLog {
@@ -43,6 +45,7 @@ class FixedPointExp extends Module with SoftmaxAccuracy with DebugLog {
4345
io.exp_x.valid := ShiftRegister(io.x.valid, expDelay)
4446
}
4547

48+
// TODO: the number of divModule should be limited, max is 16
4649
class Softmax(val arraySize: Int = 4) extends Module with SoftmaxAccuracy with DebugLog {
4750
val io = IO(new Bundle {
4851
val x = Input(Valid(Vec(arraySize, UInt((I + F).W))))
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
package models.llama3
2+
3+
class attentionLayer {
4+
5+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package models.llama3.common
2+
3+
trait llamaConfig {
4+
val dim = 4096
5+
val n_layers = 32
6+
val n_heads: Int = 32
7+
8+
// head_dim is the dimension of each head
9+
val head_dim: Int = dim / n_heads
10+
val maxN: Int = 8 * 1024
11+
val minN: Int = 65
12+
13+
// fixed-point accuracy
14+
val fx_int = 4
15+
val fx_frac = 12
16+
17+
// UInt width
18+
val bits = 16
19+
20+
// systolic array size
21+
val systolicSize = 64
22+
val systolicGroupSize = 1
23+
24+
// DAC for zb, stream for heads
25+
val stream_size = 8
26+
27+
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package models.llama3
2+
3+
import common.llamaConfig
4+
import chisel3._
5+
import chisel3.util._
6+
import kernel.alu.GEMM
7+
import kernel.utils.ForwardingMemory
8+
class metrixController extends Module with llamaConfig {}
9+
10+
/*
11+
* matrix mul matrix
12+
* matrixA is [inputN, dim]
13+
* matrixB is [dim, head_dim]
14+
* matrixC is [inputN, head_dim]
15+
*/
16+
class QKVGenerationMul extends Module with llamaConfig {
17+
val io = IO(new Bundle {
18+
val matrixAPart = Input(Vec(minN, Vec(dim, UInt(bits.W))))
19+
})
20+
}
21+
22+
class SystolicGroup extends Module with llamaConfig {
23+
val io = IO(new Bundle {
24+
val matrixAVec = Flipped(Decoupled(Vec(systolicGroupSize, Vec(systolicSize, Vec(systolicSize, UInt(bits.W))))))
25+
val matrixBVec = Flipped(Decoupled(Vec(systolicGroupSize, Vec(systolicSize, Vec(systolicSize, UInt(bits.W))))))
26+
val matrixCVec = Decoupled(Vec(systolicGroupSize, Vec(systolicSize * systolicSize, UInt(bits.W))))
27+
})
28+
29+
val gemmRow = for (i <- 0 until systolicGroupSize) yield Module(new GEMM(64))
30+
31+
val matrixAValid = io.matrixAVec.valid
32+
val matrixBValid = io.matrixBVec.valid
33+
34+
val matrixCValid = gemmRow.map(_.OutputPipe.valid).reduce(_ && _)
35+
36+
for (i <- 0 until systolicGroupSize) {
37+
gemmRow(i).InputA.bits := io.matrixAVec.bits(i)
38+
gemmRow(i).InputA.valid := matrixAValid
39+
gemmRow(i).InputB.bits := io.matrixBVec.bits(i)
40+
gemmRow(i).InputB.valid := matrixBValid
41+
gemmRow(i).accMode := false.B
42+
io.matrixCVec.bits(i) := gemmRow(i).OutputPipe.bits
43+
gemmRow(i).OutputPipe.ready := io.matrixCVec.ready
44+
}
45+
46+
io.matrixCVec.valid := matrixCValid
47+
48+
}
49+
50+
class GEMMController(val x: Int, val y: Int) extends Module with llamaConfig {
51+
assert(x % (systolicGroupSize * systolicSize) == 0 && y % systolicSize == 0)
52+
}

src/test/scala/kernel/alu/AverageModuleSpec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ class NormalizedModuleSpec extends AnyFlatSpec with ChiselScalatestTester {
158158
behavior.of("NormalizedModule")
159159

160160
it should "calculate the standard deviation correctly" in {
161-
test(new NormalizedModule(WII = 8, WIF = 8, WOI = 8, WOF = 16, ArraySize = 4))
161+
test(new NormalizedModule(WII = 4, WIF = 12, WOI = 4, WOF = 12, ArraySize = 4))
162162
.withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut =>
163163
// 测试数据
164164
val testData = Seq(

src/test/scala/kernel/alu/SoftmaxTest.scala

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ class SoftmaxTest extends AnyFlatSpec with SoftmaxAccuracy with ChiselScalatestT
1313
// val bit = 64
1414
// val dimV = 32
1515
// val depth = 128
16+
val FF = 24
1617
val annos = Seq(VerilatorBackendAnnotation)
17-
val pow2 = scala.math.pow(2, F).toFloat
18+
val pow2 = scala.math.pow(2, FF).toFloat
1819
behavior.of("tester on exp function in chisel")
1920
it should "exp in fixedpoint" in {
2021
test(new FixedPointExp)
@@ -24,7 +25,7 @@ class SoftmaxTest extends AnyFlatSpec with SoftmaxAccuracy with ChiselScalatestT
2425
dut.reset.poke(false.B)
2526
dut.clock.step()
2627
//generate a range number from -10.5 to 0.0 step 0.5
27-
val range = BigDecimal(-9.0) to BigDecimal(0.0) by BigDecimal(0.5)
28+
val range = BigDecimal(-7.0) to BigDecimal(0.0) by BigDecimal(0.5)
2829

2930
// val writer = new PrintWriter(new File("test_results.csv"))
3031

@@ -71,13 +72,13 @@ class SoftmaxTest extends AnyFlatSpec with SoftmaxAccuracy with ChiselScalatestT
7172
BigInt((d * (1L << fracBits)).round)
7273
}
7374

74-
val arraySize = 4
75+
val arraySize = 4096
7576
it should "pass softmax test" in {
7677
test(new Softmax(arraySize))
7778
.withAnnotations(annos) { dut =>
7879
val rseed = 4
7980
val rnd = new scala.util.Random(rseed)
80-
val testQ = Seq.tabulate(arraySize)(_ => rnd.nextFloat() * 10)
81+
val testQ = Seq.tabulate(arraySize)(_ => rnd.nextFloat())
8182

8283
val maxInQ = testQ.max
8384
val expInQ = testQ.map(x => scala.math.exp(x - maxInQ))
@@ -86,6 +87,7 @@ class SoftmaxTest extends AnyFlatSpec with SoftmaxAccuracy with ChiselScalatestT
8687

8788
println(s"testQ: ${testQ}")
8889
println(s"resultSoftmax: ${resultSoftmax}")
90+
println(s"maxInResult: ${resultSoftmax.max}")
8991

9092
dut.reset.poke(true.B)
9193
dut.clock.step()
@@ -101,8 +103,10 @@ class SoftmaxTest extends AnyFlatSpec with SoftmaxAccuracy with ChiselScalatestT
101103
dut.clock.step()
102104
dut.io.x.valid.poke(false.B)
103105
}.fork {
106+
var clk = 0
104107
while (!dut.io.soft_x.valid.peekBoolean()) {
105108
dut.clock.step()
109+
clk += 1
106110
}
107111
for (i <- 0 until arraySize) {
108112
val computedValue = dut.io.soft_x.bits(i).peekInt().toFloat / pow2
@@ -112,6 +116,7 @@ class SoftmaxTest extends AnyFlatSpec with SoftmaxAccuracy with ChiselScalatestT
112116
)
113117
// assert(relativeError < 5)
114118
}
119+
println(s"clk: $clk")
115120
dut.clock.step()
116121
dut.io.soft_x.valid.expect(false.B)
117122
}.join()

0 commit comments

Comments
 (0)