@@ -32,27 +32,83 @@ func compareLogits(t *testing.T, name string, want []float32, got []token) {
32
32
}
33
33
}
34
34
35
- func TestTemperatureAndSoftmax (t * testing.T ) {
36
- input := []float32 {1 , 4 , - 2 , 0 }
35
+ func TestTemperature (t * testing.T ) {
36
+ input := []float32 {1.0 , 4.0 , - 2.0 , 0. 0 }
37
37
got := temperature (toTokens (input ), 0.5 )
38
+ want := []float32 {2.0 , 8.0 , - 4.0 , 0.0 }
39
+ compareLogits (t , "temperature(0.5)" , want , got )
38
40
39
- // Check probabilities sum to 1
40
- var sum float32
41
- for _ , token := range got {
42
- sum += token . value
43
- }
44
- if math . Abs ( float64 ( sum - 1.0 )) > 1e-6 {
45
- t . Errorf ( "probabilities don't sum to 1: got %f" , sum )
46
- }
41
+ got = temperature ( toTokens ( input ), 1.0 )
42
+ want = [] float32 { 1.0 , 4.0 , - 2.0 , 0.0 }
43
+ compareLogits ( t , "temperature(1)" , want , got )
44
+
45
+ got = temperature ( toTokens ( input ), 0.0 )
46
+ want = [] float32 { 1e7 , 4e7 , - 2e7 , 0.0 }
47
+ compareLogits ( t , "temperature(0)" , want , got )
48
+ }
47
49
48
- got = temperature (toTokens (input ), 1 )
49
- // Check probabilities sum to 1
50
- sum = 0.0
51
- for _ , token := range got {
52
- sum += token .value
50
+ func TestSoftmax (t * testing.T ) {
51
+ tests := []struct {
52
+ name string
53
+ input []float32
54
+ expected []float32
55
+ }{
56
+ {
57
+ name : "correctness softmax" ,
58
+ input : []float32 {1 , - 2 , 3 , 0 },
59
+ expected : []float32 {0.113550 , 0.005653 , 0.839024 , 0.041773 },
60
+ },
61
+ {
62
+ name : "normal distribution" ,
63
+ input : []float32 {0.026986899 , 0.043722924 , 0.036774673 , 0.27755088 , 0.0046718004 , 0.08582123 , 0.20409796 , 0.00412893 , 0.15720603 , 0.045046154 , 0.0030491839 , 0.01681367 },
64
+ },
65
+ {
66
+ name : "single value" ,
67
+ input : []float32 {1.0 },
68
+ },
69
+ {
70
+ name : "identical values" ,
71
+ input : []float32 {0.9 , 0.9 , 0.9 },
72
+ },
73
+ {
74
+ name : "large values" ,
75
+ input : []float32 {1000.0 , 2000.0 , 3000.0 },
76
+ },
77
+ {
78
+ name : "small values" ,
79
+ input : []float32 {1e-6 , 2e-6 , 3e-6 },
80
+ },
81
+ {
82
+ name : "negative values" ,
83
+ input : []float32 {- 1.0 , - 2.0 , - 3.0 },
84
+ },
85
+ {
86
+ name : "mixed values" ,
87
+ input : []float32 {- 100.0 , 0.0 , 100.0 },
88
+ },
53
89
}
54
- if math .Abs (float64 (sum - 1.0 )) > 1e-6 {
55
- t .Errorf ("probabilities don't sum to 1: got %f" , sum )
90
+
91
+ for _ , tt := range tests {
92
+ t .Run (tt .name , func (t * testing.T ) {
93
+ got := softmax (toTokens (tt .input ))
94
+
95
+ if tt .expected != nil {
96
+ compareLogits (t , tt .name , tt .expected , got )
97
+ return
98
+ }
99
+
100
+ // Check probabilities sum to 1
101
+ var sum float32
102
+ for _ , token := range got {
103
+ sum += token .value
104
+ if token .value < 0 || token .value > 1 {
105
+ t .Errorf ("probability out of range [0,1]: got %f" , token .value )
106
+ }
107
+ }
108
+ if math .Abs (float64 (sum - 1.0 )) > 1e-6 {
109
+ t .Errorf ("probabilities don't sum to 1: got %f" , sum )
110
+ }
111
+ })
56
112
}
57
113
}
58
114
@@ -97,7 +153,7 @@ func TestTopP(t *testing.T) {
97
153
tokens := toTokens (input )
98
154
99
155
// First apply temperature and softmax to get probabilities
100
- tokens = temperature (tokens , 1 )
156
+ tokens = softmax (tokens )
101
157
tokens = topK (tokens , 20 )
102
158
103
159
// Then apply topP
@@ -115,7 +171,7 @@ func TestMinP(t *testing.T) {
115
171
tokens := toTokens (input )
116
172
117
173
// First apply temperature and softmax
118
- tokens = temperature (tokens , 1 )
174
+ tokens = softmax (tokens )
119
175
120
176
// Then apply minP
121
177
got := minP (tokens , 0.2 )
@@ -163,6 +219,14 @@ func BenchmarkTransforms(b *testing.B) {
163
219
}
164
220
})
165
221
222
+ b .Run ("Softmax" , func (b * testing.B ) {
223
+ b .ResetTimer ()
224
+ for b .Loop () {
225
+ copy (tokensCopy , tokens )
226
+ softmax (tokensCopy )
227
+ }
228
+ })
229
+
166
230
b .Run ("TopK" , func (b * testing.B ) {
167
231
b .ResetTimer ()
168
232
for b .Loop () {
0 commit comments