Skip to content

Commit f375f01

Browse files
shrink params
1 parent de76334 commit f375f01

File tree

3 files changed

+134
-35
lines changed

3 files changed

+134
-35
lines changed

src/main/scala/kernel/alu/Softmax.scala

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import chisel3.util._
55
import kernel.utils.DebugLog
66
import fixedpoint._
77
import kernel.utils.PipeValue
8+
import math.pow
89
import kernel.utils.common
910

1011
trait SoftmaxAccuracy {
@@ -76,4 +77,96 @@ class Softmax(val arraySize: Int = 4) extends Module with SoftmaxAccuracy with D
7677
}
7778

7879
io.soft_x <> Pipe(softTmp.map(_.valid).reduce(_ & _), VecInit(softTmp.map(_.bits)), 0)
79-
}
80+
}
81+
82+
// from Sanger, need test
83+
class ExpUnitFixPoint(width: Int, point: Int, lut_bits: Int, append_bits: Int) extends Module {
84+
val v_width = width + append_bits
85+
val v_point = point + append_bits
86+
val fpType = FixedPoint(width.W, point.BP)
87+
val vType = FixedPoint(v_width.W, v_point.BP)
88+
val io = IO(new Bundle {
89+
val in_value = Input(fpType)
90+
val out_exp = Output(fpType)
91+
})
92+
93+
val x = Wire(UInt(width.W))
94+
val y = Wire(UInt(v_width.W))
95+
val z1 = Wire(vType)
96+
val z2 = Wire(vType)
97+
98+
val s = Reg(fpType)
99+
100+
val u = Wire(UInt((width - point).W))
101+
val v = Wire(vType)
102+
103+
val testers =
104+
Range.BigDecimal(0.0, 1.0, pow(2.0, -point)).map((a) => pow(2.0, a.toDouble) - a)
105+
val d_value =
106+
(testers.reduce((a, b) => if (a > b) a else b) +
107+
testers.reduce((a, b) => if (a < b) a else b)) / 2.0
108+
109+
val d_fixed = FixedPoint.fromBigDecimal(d_value, v_width.W, v_point.BP)
110+
val d_wire = Wire(vType)
111+
if (lut_bits == 0)
112+
d_wire := d_fixed
113+
else {
114+
val lut_in = Range(0, 1 << lut_bits)
115+
val lut_out =
116+
lut_in
117+
.map((x) => x / pow(2.0, lut_bits))
118+
.map((x) => {
119+
val r = Range
120+
.BigDecimal(x, x + pow(2.0, -lut_bits), pow(2.0, -lut_bits))
121+
.map((y) => pow(2.0, y.toDouble) - y)
122+
(r.reduce((a, b) => if (a > b) a else b) +
123+
r.reduce((a, b) => if (a < b) a else b)) / 2.0
124+
})
125+
.map((x) =>
126+
FixedPoint
127+
.fromBigDecimal(x, v_width.W, v_point.BP)
128+
)
129+
// val lut_mem = Mem(lut_in.length, vType)
130+
// for (i <- 0 until lut_out.length)
131+
// lut_mem(i.U) := lut_out(i)
132+
133+
val v_bits = Wire(UInt(lut_bits.W))
134+
v_bits := v.asUInt(v_point - 1, v_point - lut_bits)
135+
136+
var w = when(v_bits === lut_in(0).U) {
137+
d_wire := lut_out(0)
138+
}
139+
for (i <- 1 until lut_in.size)
140+
w = w.elsewhen(v_bits === lut_in(i).U) {
141+
d_wire := lut_out(i)
142+
}
143+
w.otherwise {
144+
d_wire := DontCare
145+
}
146+
// d_wire := lut_mem(v_bits)
147+
}
148+
// println(d_fixed)
149+
150+
x := io.in_value.asUInt
151+
y := (x << append_bits) + (x << (append_bits - 1)) - (x << (append_bits - 4));
152+
153+
u := y(v_width - 1, v_point)
154+
v := Cat(0.U((v_width - v_point).W), y(v_point - 1, 0))
155+
.asFixedPoint(v_point.BP)
156+
157+
z1 := v + d_wire
158+
z2 := z1 << u;
159+
160+
// printf(
161+
// "x:%b y:%b u:%b v:%b d:%b z1:%b z2:%b\n",
162+
// x,
163+
// y,
164+
// u,
165+
// v.asUInt(),
166+
// d_wire.asUInt(),
167+
// z1.asUInt(),
168+
// z2.asUInt()
169+
// )
170+
171+
io.out_exp := z2
172+
}

src/main/scala/models/llama3/common/llamaConfig.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ trait llamaConfig {
55
val n_layers = 32
66
val n_heads: Int = 32
77

8+
val m = 16
9+
val p = 8
10+
val q = 24
11+
812
// head_dim is the dimension of each head
913
val head_dim: Int = dim / n_heads
1014
val maxN: Int = 8 * 1024
@@ -18,8 +22,10 @@ trait llamaConfig {
1822
val bits = 16
1923

2024
// systolic array size
21-
val systolicSize = 16
22-
val systolicGroupSize = 1
25+
val systolicSizeGen = 4
26+
val systolicGroupSizeGen = 1
27+
val systolicSizeMul = 4
28+
val systolicGroupSizeMul = 1
2329

2430
// DAC for zb, stream for heads
2531
val stream_size = 8

src/main/scala/models/llama3/metrixController.scala

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,18 @@ import kernel.alu.DataWidthConfig
99
import kernel.utils.DebugLog
1010
class metrixController extends Module with llamaConfig {}
1111

12+
/*
13+
* current systolic group idx
14+
* @param nk: systolic group dim
15+
* @param m: left matrix rows
16+
* @param q: right matrix rows
17+
*/
1218
class currentSystolicGroupIdx(
13-
val nk: Int,
14-
val m: Int,
15-
val p: Int,
16-
val q: Int
1719
)(
1820
implicit config: DataWidthConfig)
1921
extends Bundle
2022
with llamaConfig {
21-
23+
val nk: Int = systolicSizeGen * systolicGroupSizeGen
2224
val row = Output(UInt(log2Ceil(m / nk).W))
2325
val col = Output(UInt(log2Ceil(q / nk).W))
2426
val value = Output(UInt((nk * nk * config.inputWidth).W))
@@ -95,7 +97,7 @@ class MatrixRestore(
9597
val numBlocksRow = m / nk
9698
val numBlocksCol = p / nk
9799

98-
// 初始化输出矩阵
100+
// initialize the output matrix
99101
io.outMatrix.foreach(_ := 0.U)
100102

101103
for (blockRow <- 0 until numBlocksRow) {
@@ -163,27 +165,22 @@ object BlockMatrixRestore {
163165
* designed for QKV generation, but has a output for current systolic group idx
164166
*/
165167
class GenerationMatrixMul(
166-
val k: Int,
167-
val n: Int,
168-
val m: Int,
169-
val p: Int,
170-
val q: Int,
171168
val gemmType: GEMMDataType.Type
172169
)(
173170
implicit config: DataWidthConfig)
174171
extends Module
175172
with llamaConfig
176173
with DebugLog {
177174
// param check
178-
implicit val nk: Int = k * n
175+
implicit val nk: Int = systolicSizeGen * systolicGroupSizeGen
179176
require(m % nk == 0)
180177
require(p % nk == 0)
181178
require(q % nk == 0)
182179

183180
val io = IO(new Bundle {
184181
val in_a = Flipped(Decoupled(Vec(m * p, UInt(config.inputWidth.W))))
185182
val in_b = Flipped(Decoupled(Vec(p * q, UInt(config.inputWidth.W))))
186-
val current = ValidIO(new currentSystolicGroupIdx(nk, m, p, q))
183+
val current = ValidIO(new currentSystolicGroupIdx)
187184
val reset = Input(Bool())
188185
})
189186

@@ -312,34 +309,22 @@ class GenerationMatrixMul(
312309
* the k1,n1 are for q,k generation, the k2,n2 are for q,k mul.
313310
*/
314311
class QKMul(
315-
val k1: Int,
316-
val n1: Int,
317-
val k2: Int,
318-
val n2: Int,
319-
val m: Int,
320-
val p: Int,
321-
val q: Int,
322312
val gemmType: GEMMDataType.Type
323313
)(
324314
implicit config: DataWidthConfig)
325315
extends Module
326316
with llamaConfig
327317
with DebugLog {
328318

329-
val nk1: Int = k1 * n1
330-
val nk2: Int = k2 * n2
319+
val nk1: Int = systolicSizeGen * systolicGroupSizeGen
320+
val nk2: Int = systolicSizeMul * systolicGroupSizeMul
331321
require(m % nk1 == 0)
332322
require(p % nk1 == 0)
333323
require(q % nk1 == 0)
334324
require(m % nk2 == 0)
335325
require(q % nk2 == 0)
336326

337327
class QKGenerationMatrixMulWarper(
338-
val k: Int,
339-
val n: Int,
340-
val m: Int,
341-
val p: Int,
342-
val q: Int,
343328
val gemmType: GEMMDataType.Type,
344329
val bufferSize: Int
345330
)(
@@ -351,16 +336,16 @@ class QKMul(
351336
val in_a = Flipped(Decoupled(Vec(m * p, UInt(config.inputWidth.W))))
352337
val in_b = Flipped(Decoupled(Vec(p * q, UInt(config.inputWidth.W))))
353338
val flush = Input(Bool())
354-
val outMatrix = Decoupled(new currentSystolicGroupIdx(nk1, m, p, q))
339+
val outMatrix = Decoupled(new currentSystolicGroupIdx)
355340
})
356341

357-
val qkGenMul = Module(new GenerationMatrixMul(k1, n1, m, p, q, gemmType))
342+
val qkGenMul = Module(new GenerationMatrixMul(gemmType))
358343
io.in_a <> qkGenMul.io.in_a
359344
io.in_b <> qkGenMul.io.in_b
360345

361346
val currentBuffer = Module(
362347
new Queue(
363-
new currentSystolicGroupIdx(nk1, m, p, q),
348+
new currentSystolicGroupIdx,
364349
entries = bufferSize,
365350
pipe = true,
366351
flow = false,
@@ -388,8 +373,8 @@ class QKMul(
388373
val resetBuffer = Input(Bool())
389374
})
390375

391-
val qGen = new QKGenerationMatrixMulWarper(k1, n1, m, p, q, gemmType, bufferSizeGemm)
392-
val kGen = new QKGenerationMatrixMulWarper(k2, n2, m, p, q, gemmType, bufferSizeGemm)
376+
val qGen = new QKGenerationMatrixMulWarper(gemmType, bufferSizeGemm)
377+
val kGen = new QKGenerationMatrixMulWarper(gemmType, bufferSizeGemm)
393378

394379
qGen.io.in_a <> io.inputToken
395380
qGen.io.in_b <> io.weightQ
@@ -413,3 +398,18 @@ class QKMul(
413398
}
414399

415400
}
401+
402+
class GemmPool(
403+
val n: Int,
404+
val poolSize: Int,
405+
val gemmType: GEMMDataType.Type
406+
)(
407+
implicit config: DataWidthConfig)
408+
extends Module
409+
with llamaConfig
410+
with DebugLog {
411+
val io = IO(new Bundle {
412+
val in = Flipped(Decoupled(new currentSystolicGroupIdx))
413+
})
414+
415+
}

0 commit comments

Comments
 (0)