@@ -8,18 +8,16 @@ import java.io._
88import kernel .configs .SdpmmConfigs
99import os .write
1010
11- class SoftmaxTest extends AnyFlatSpec with ChiselScalatestTester {
11+ class SoftmaxTest extends AnyFlatSpec with SoftmaxAccuracy with ChiselScalatestTester {
1212
1313// val bit = 64
1414// val dimV = 32
1515// val depth = 128
1616 val annos = Seq (VerilatorBackendAnnotation )
17- val wholeWidth : Int = SdpmmConfigs .bit + SdpmmConfigs .fixedPoint
18- val fractionalWidth : Int = SdpmmConfigs .fixedPoint
19- val pow2 = scala.math.pow(2 , fractionalWidth).toFloat
17+ val pow2 = scala.math.pow(2 , F ).toFloat
2018 behavior.of(" tester on exp function in chisel" )
2119 it should " exp in fixedpoint" in {
22- test(new FixedPointExp (wholeWidth, fractionalWidth) )
20+ test(new FixedPointExp )
2321 .withAnnotations(annos) { dut =>
2422 dut.reset.poke(true .B )
2523 dut.clock.step()
@@ -37,7 +35,7 @@ class SoftmaxTest extends AnyFlatSpec with ChiselScalatestTester {
3735 fork {
3836 for (value <- range) {
3937 dut.io.x.valid.poke(true .B )
40- dut.io.x.bits.poke(value.F (fractionalWidth .BP ).asSInt)
38+ dut.io.x.bits.poke(value.F (F .BP ).asSInt)
4139 dut.clock.step()
4240 }
4341
@@ -50,9 +48,9 @@ class SoftmaxTest extends AnyFlatSpec with ChiselScalatestTester {
5048 val computedValue = dut.io.exp_x.bits.peekInt().toFloat / pow2
5149 val relativeError = ((actualValue - computedValue) / actualValue).abs * 100
5250
53- println(
54- s " actualValue is $actualValue, \t computedValue is $computedValue, \t relativeError is $relativeError"
55- )
51+ // println(
52+ // s"actualValue is $actualValue,\t computedValue is $computedValue,\t relativeError is $relativeError"
53+ // )
5654 assert(relativeError < 5 )
5755
5856 dut.clock.step()
@@ -63,76 +61,61 @@ class SoftmaxTest extends AnyFlatSpec with ChiselScalatestTester {
6361 // writer.close()
6462 }
6563 }
64+ def doubleToFixedPoint (d : Float , intBits : Int , fracBits : Int ): BigInt = {
65+ // 检查数值范围
66+ val maxVal = Math .pow(2 , intBits - 1 ) - Math .pow(2 , - fracBits)
67+ val minVal = - Math .pow(2 , intBits - 1 )
68+ require(d <= maxVal && d >= minVal, s " Value $d out of range [ $minVal, $maxVal] " )
69+
70+ // 转换为定点数表示
71+ BigInt ((d * (1L << fracBits)).round)
72+ }
73+
74+ val arraySize = 4
75+ it should " pass softmax test" in {
76+ test(new Softmax (arraySize))
77+ .withAnnotations(annos) { dut =>
78+ val rseed = 4
79+ val rnd = new scala.util.Random (rseed)
80+ val testQ = Seq .tabulate(arraySize)(_ => rnd.nextFloat() * 10 )
81+
82+ val maxInQ = testQ.max
83+ val expInQ = testQ.map(x => scala.math.exp(x - maxInQ))
84+ val sumExpInQ = expInQ.sum
85+ val resultSoftmax = expInQ.map(_ / sumExpInQ)
86+
87+ println(s " testQ: ${testQ}" )
88+ println(s " resultSoftmax: ${resultSoftmax}" )
6689
67- // it should "softmax in chisel3" in {
68- // test(new Softmax)
69- // .withAnnotations(annos) { dut =>
70- // val numOfMask = SdpmmConfigs.numOfMask
71- // val testQ = Seq.tabulate(SdpmmConfigs.dim)(x => scala.util.Random.nextInt(10) + 1)
72- // val inputTimes = 1
73-
74- // val pow2 = scala.math.pow(2, SdpmmConfigs.bit - 1)
75-
76- // val mask = for (i <- 0 until inputTimes) yield {
77- // Seq.fill(2 * numOfMask)(scala.util.Random.nextInt(SdpmmConfigs.L)).distinct.take(numOfMask)
78- // }
79- // var resultSoftmax = Seq.tabulate(SdpmmConfigs.dim) { i =>
80- // val exp = scala.math.exp(testQ(i))
81- // exp / (testQ.map(x => scala.math.exp(x)).sum)
82- // }
83- // println(testQ)
84- // println(mask)
85- // val writer = new PrintWriter(new File("softmax_test_results.csv"))
86-
87- // writer.write("Input Value,Computed softmax,Actual softmax,Relative Error (%)\n")
88-
89- // dut.reset.poke(true.B)
90- // dut.clock.step()
91- // dut.reset.poke(false.B)
92- // dut.clock.step()
93- // fork {
94- // var cnt = 0
95- // while (cnt < inputTimes) {
96- // if (dut.InputPipe.ready.peekBoolean()) {
97- // dut.InputPipe.valid.poke(true.B)
98- // for (i <- 0 until numOfMask) {
99- // dut.InputPipe.bits.mask(i).poke(mask(cnt)(i).U)
100- // }
101-
102- // for (i <- 0 until SdpmmConfigs.dim) {
103- // dut.InputPipe.bits.value(i).poke(testQ(i).U)
104- // }
105- // cnt = cnt + 1
106- // } else {
107- // dut.InputPipe.valid.poke(false.B)
108- // }
109- // dut.clock.step()
110- // }
111-
112- // dut.InputPipe.valid.poke(false.B)
113- // }.fork {
114- // var cntR = 0
115- // while (cntR < inputTimes) {
116- // if (dut.OutputPipe.valid.peekBoolean()) {
117- // for (i <- 0 until numOfMask) {
118- // dut.OutputPipe.bits.mask(i).expect(mask(cntR)(i))
119- // }
120- // for (i <- 0 until SdpmmConfigs.dim) {
121- // val com = dut.OutputPipe.bits.value(i).peek().litValue.toFloat / pow2
122- // val act = resultSoftmax(i)
123- // val relativeError = ((act - com) / act).abs * 100
124- // writer.write(f"${testQ(i)}%.2f,$com%.5f,$act%.5f,$relativeError%.2f\n")
125- // }
126- // dut.OutputPipe.ready.poke(true.B)
127- // cntR = cntR + 1
128- // } else {
129- // dut.OutputPipe.ready.poke(false.B)
130- // }
131- // dut.clock.step()
132- // }
133- // dut.OutputPipe.ready.poke(false.B)
134- // }.join()
135- // writer.close()
136- // }
137- // }
90+ dut.reset.poke(true .B )
91+ dut.clock.step()
92+ dut.reset.poke(false .B )
93+ dut.clock.step()
94+
95+ fork {
96+ dut.io.x.valid.poke(true .B )
97+ for (i <- 0 until arraySize) {
98+ println(s " testQ( $i): ${doubleToFixedPoint(testQ(i), I , F )}" )
99+ dut.io.x.bits(i).poke(doubleToFixedPoint(testQ(i), I , F ).U )
100+ }
101+ dut.clock.step()
102+ dut.io.x.valid.poke(false .B )
103+ }.fork {
104+ while (! dut.io.soft_x.valid.peekBoolean()) {
105+ dut.clock.step()
106+ }
107+ for (i <- 0 until arraySize) {
108+ val computedValue = dut.io.soft_x.bits(i).peekInt().toFloat / pow2
109+ val relativeError = ((resultSoftmax(i) - computedValue) / resultSoftmax(i)).abs * 100
110+ println(
111+ s " actualValue is ${resultSoftmax(i)}, \t computedValue is $computedValue, \t relativeError is $relativeError"
112+ )
113+ // assert(relativeError < 5)
114+ }
115+ dut.clock.step()
116+ dut.io.soft_x.valid.expect(false .B )
117+ }.join()
118+
119+ }
120+ }
138121}
0 commit comments