@@ -13,8 +13,9 @@ class SoftmaxTest extends AnyFlatSpec with SoftmaxAccuracy with ChiselScalatestT
1313// val bit = 64
1414// val dimV = 32
1515// val depth = 128
16+ val FF = 24
1617 val annos = Seq (VerilatorBackendAnnotation )
17- val pow2 = scala.math.pow(2 , F ).toFloat
18+ val pow2 = scala.math.pow(2 , FF ).toFloat
1819 behavior.of(" tester on exp function in chisel" )
1920 it should " exp in fixedpoint" in {
2021 test(new FixedPointExp )
@@ -24,7 +25,7 @@ class SoftmaxTest extends AnyFlatSpec with SoftmaxAccuracy with ChiselScalatestT
2425 dut.reset.poke(false .B )
2526 dut.clock.step()
2627 // generate a range number from -10.5 to 0.0 step 0.5
27- val range = BigDecimal (- 9 .0 ) to BigDecimal (0.0 ) by BigDecimal (0.5 )
28+ val range = BigDecimal (- 7 .0 ) to BigDecimal (0.0 ) by BigDecimal (0.5 )
2829
2930 // val writer = new PrintWriter(new File("test_results.csv"))
3031
@@ -71,13 +72,13 @@ class SoftmaxTest extends AnyFlatSpec with SoftmaxAccuracy with ChiselScalatestT
7172 BigInt ((d * (1L << fracBits)).round)
7273 }
7374
74- val arraySize = 4
75+ val arraySize = 4096
7576 it should " pass softmax test" in {
7677 test(new Softmax (arraySize))
7778 .withAnnotations(annos) { dut =>
7879 val rseed = 4
7980 val rnd = new scala.util.Random (rseed)
80- val testQ = Seq .tabulate(arraySize)(_ => rnd.nextFloat() * 10 )
81+ val testQ = Seq .tabulate(arraySize)(_ => rnd.nextFloat())
8182
8283 val maxInQ = testQ.max
8384 val expInQ = testQ.map(x => scala.math.exp(x - maxInQ))
@@ -86,6 +87,7 @@ class SoftmaxTest extends AnyFlatSpec with SoftmaxAccuracy with ChiselScalatestT
8687
8788 println(s " testQ: ${testQ}" )
8889 println(s " resultSoftmax: ${resultSoftmax}" )
90+ println(s " maxInResult: ${resultSoftmax.max}" )
8991
9092 dut.reset.poke(true .B )
9193 dut.clock.step()
@@ -101,8 +103,10 @@ class SoftmaxTest extends AnyFlatSpec with SoftmaxAccuracy with ChiselScalatestT
101103 dut.clock.step()
102104 dut.io.x.valid.poke(false .B )
103105 }.fork {
106+ var clk = 0
104107 while (! dut.io.soft_x.valid.peekBoolean()) {
105108 dut.clock.step()
109+ clk += 1
106110 }
107111 for (i <- 0 until arraySize) {
108112 val computedValue = dut.io.soft_x.bits(i).peekInt().toFloat / pow2
@@ -112,6 +116,7 @@ class SoftmaxTest extends AnyFlatSpec with SoftmaxAccuracy with ChiselScalatestT
112116 )
113117 // assert(relativeError < 5)
114118 }
119+ println(s " clk: $clk" )
115120 dut.clock.step()
116121 dut.io.soft_x.valid.expect(false .B )
117122 }.join()
0 commit comments