Skip to content

Commit 8caa407

Browse files
committed
use isApproximatelyEqual to compare
1 parent bb000dc commit 8caa407

File tree

1 file changed

+74
-69
lines changed

1 file changed

+74
-69
lines changed

Tests/RealModuleDifferentiableTests/RealModuleDifferentiableTests.swift

Lines changed: 74 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -4,224 +4,229 @@ import Testing
44

55
@Suite
66
struct TestRegisteredDerivatives {
7-
// These tests are more about the derivatives being correctly registered than the correct values. Should probably use a form of
8-
// `.isApproximatelyEqual(to:)` for the results in the future but that doesn't combine too well with a SIMD vector comparison.
7+
// These tests are more about the derivatives being correctly registered than the correct values. We're only checking the first value for simds since we run the computation on the same values and only checking the first result makes using `.isApproximatelyEqual(to:)` a lot easier.
98
@Test
109
func
1110
testExp()
1211
{
1312
let vwpb = valueWithPullback(at: 2.0, of: Float.exp)
14-
#expect(vwpb.value == 7.38905609893065)
15-
#expect(vwpb.pullback(1) == 7.38905609893065)
13+
#expect(vwpb.value.isApproximatelyEqual(to: 7.38905609893065))
14+
#expect(vwpb.pullback(1).isApproximatelyEqual(to: 7.38905609893065))
1615
}
1716

1817
@Test
1918
func testExpMinusOne() {
2019
let vwpb = valueWithPullback(at: 2.0, of: Double.expMinusOne(_:))
21-
#expect(vwpb.value == 6.38905609893065)
22-
#expect(vwpb.pullback(1) == 7.38905609893065)
20+
#expect(vwpb.value.isApproximatelyEqual(to: 6.38905609893065))
21+
#expect(vwpb.pullback(1).isApproximatelyEqual(to: 7.38905609893065))
2322
}
2423

2524
@Test
2625
func testCosh() {
2726
let vwpb = valueWithPullback(at: SIMD2<Float>(repeating: 2.0), of: SIMD2<Float>.cosh)
28-
#expect(vwpb.value == .init(repeating: 3.7621956))
29-
#expect(vwpb.pullback(.one) == .init(repeating: 3.6268604))
27+
#expect(vwpb.value[0].isApproximatelyEqual(to: 3.7621956))
28+
#expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 3.6268604))
3029
}
3130

3231
@Test
3332
func testSinh() {
3433
let vwpb = valueWithPullback(at: SIMD4<Float>(repeating: 2.0), of: SIMD4<Float>.sinh)
35-
#expect(vwpb.value == .init(repeating: 3.6268604))
36-
#expect(vwpb.pullback(.one) == .init(repeating: 3.7621956))
34+
#expect(vwpb.value[0].isApproximatelyEqual(to: 3.6268604))
35+
#expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 3.7621956))
3736
}
3837

3938
@Test
4039
func testTanh() {
4140
let vwpb = valueWithPullback(at: SIMD8<Float>(repeating: 2.0), of: SIMD8<Float>.tanh)
42-
#expect(vwpb.value == .init(repeating: 0.9640276))
43-
#expect(vwpb.pullback(.one) == .init(repeating: 0.07065083))
41+
#expect(vwpb.value[0].isApproximatelyEqual(to: 0.9640276))
42+
#expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 0.07065083))
4443
}
4544

4645
@Test
4746
func testCos() {
4847
let vwpb = valueWithPullback(at: SIMD16<Float>(repeating: .pi / 2), of: SIMD16<Float>.cos)
49-
#expect(vwpb.value == .init(repeating: 7.54979E-08))
50-
#expect(vwpb.pullback(.one) == .init(repeating: -1))
48+
#expect(vwpb.value[0].isApproximatelyEqual(to: 0, absoluteTolerance: .ulpOfOne))
49+
#expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: -1))
5150
}
5251

5352
@Test
5453
func testSin() {
5554
let vwpb = valueWithPullback(at: SIMD32<Float>(repeating: .pi / 2), of: SIMD32<Float>.sin)
56-
#expect(vwpb.value == .init(repeating: 1))
57-
#expect(vwpb.pullback(.one) == .init(repeating: 7.54979E-08))
55+
#expect(vwpb.value[0].isApproximatelyEqual(to: 1))
56+
#expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 0, absoluteTolerance: .ulpOfOne))
5857
}
5958

6059
@Test
6160
func testTan() {
6261
let vwpb = valueWithPullback(at: SIMD64<Float>(repeating: .pi / 4), of: SIMD64<Float>.tan)
63-
#expect(vwpb.value == .init(repeating: 0.99999994))
64-
#expect(vwpb.pullback(.one) == .init(repeating: 1.9999998))
62+
#expect(vwpb.value[0].isApproximatelyEqual(to: 1))
63+
#expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 2))
6564
}
6665

6766
@Test
6867
func testLog() {
6968
let vwpb = valueWithPullback(at: SIMD2<Double>(repeating: 2), of: SIMD2<Double>.log(_:))
70-
#expect(vwpb.value == SIMD2<Double>(repeating: 0.6931471805599453))
71-
#expect(vwpb.pullback(SIMD2<Double>.one) == .init(repeating: 0.5))
69+
#expect(vwpb.value[0].isApproximatelyEqual(to: 0.6931471805599453))
70+
#expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 0.5))
7271
}
7372

7473
@Test
7574
func testLogOnePlus() {
7675
let vwpb = valueWithPullback(at: SIMD4<Double>(repeating: 3), of: SIMD4<Double>.log(onePlus:))
77-
#expect(vwpb.value == .init(repeating: 1.3862943611198906))
78-
#expect(vwpb.pullback(.one) == .init(repeating: 0.25))
76+
#expect(vwpb.value[0].isApproximatelyEqual(to: 1.3862943611198906))
77+
#expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 0.25))
7978
}
8079

8180
@Test
8281
func testAcosh() {
8382
let vwpb = valueWithPullback(at: SIMD8<Double>(repeating: 2), of: SIMD8<Double>.acosh)
84-
#expect(vwpb.value == .init(repeating: 1.3169578969248166))
85-
#expect(vwpb.pullback(.one) == .init(repeating: 1 / .sqrt(3)))
83+
#expect(vwpb.value[0].isApproximatelyEqual(to: 1.3169578969248166))
84+
#expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 1 / .sqrt(3)))
8685
}
8786

8887
@Test
8988
func testAsinh() {
9089
let vwpb = valueWithPullback(at: SIMD16<Double>(repeating: 2), of: SIMD16<Double>.asinh)
91-
#expect(vwpb.value == .init(repeating: 1.4436354751788103))
92-
#expect(vwpb.pullback(.one) == .init(repeating: 1 / .sqrt(5)))
90+
#expect(vwpb.value[0].isApproximatelyEqual(to: 1.4436354751788103))
91+
#expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 1 / .sqrt(5)))
9392
}
9493

9594
@Test
9695
func testAtanh() {
9796
let vwpb = valueWithPullback(at: SIMD32<Double>(repeating: 0.5), of: SIMD32<Double>.atanh)
98-
#expect(vwpb.value == .init(repeating: 0.5493061443340549))
99-
#expect(vwpb.pullback(.one) == .init(repeating: 4 / 3))
97+
#expect(vwpb.value[0].isApproximatelyEqual(to: 0.5493061443340549))
98+
#expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 4 / 3))
10099
}
101100

102101
@Test
103-
func testaCos() {
104-
let vwpb = valueWithPullback(at: SIMD64<Double>(repeating: 0.5), of: SIMD64<Double>.acos)
105-
#expect(vwpb.value == .init(repeating: 1.0471975511965976))
106-
#expect(vwpb.pullback(.one) == .init(repeating: -1.1547005383792517))
102+
func testAcos() {
103+
let vwpb = valueWithPullback(at: SIMD64<Double>(repeating: 1 / .sqrt(2)), of: SIMD64<Double>.acos)
104+
#expect(vwpb.value[0].isApproximatelyEqual(to: .pi / 4))
105+
#expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: -.sqrt(2)))
107106
}
108107

109108
@Test
110-
func testaSin() {
111-
let vwpb = valueWithPullback(at: 0.5, of: Float.asin)
112-
#expect(vwpb.value == 0.5235988)
113-
#expect(vwpb.pullback(1) == 1.1547005383792517)
109+
func testAsin() {
110+
let vwpb = valueWithPullback(at: 1 / .sqrt(2), of: Float.asin)
111+
#expect(vwpb.value.isApproximatelyEqual(to: .pi / 4))
112+
#expect(vwpb.pullback(1).isApproximatelyEqual(to: .sqrt(2)))
114113
}
115114

116115
@Test
117-
func testaTan() {
116+
func testAtan() {
118117
let vwpb = valueWithPullback(at: 0.5, of: Double.atan)
119-
#expect(vwpb.value == 0.46364760900080615)
120-
#expect(vwpb.pullback(1) == 0.8)
118+
#expect(vwpb.value.isApproximatelyEqual(to: 0.46364760900080615))
119+
#expect(vwpb.pullback(1).isApproximatelyEqual(to: 0.8))
121120
}
122121

123122
@Test
124123
func testPow() {
125124
let vwpb = valueWithPullback(at: SIMD2<Float>(repeating: 0.5), SIMD2<Float>(repeating: 2), of: SIMD2<Float>.pow(_:_:))
126-
#expect(vwpb.value == .init(repeating: 0.25))
127-
#expect(vwpb.pullback(.one) == (.init(repeating: 1.0), .init(repeating: -0.1732868)))
125+
#expect(vwpb.value[0].isApproximatelyEqual(to: 0.25))
126+
let gradient = vwpb.pullback(.one)
127+
#expect(gradient.0[0].isApproximatelyEqual(to: 1.0))
128+
#expect(gradient.1[0].isApproximatelyEqual(to: -0.1732868))
128129
}
129130

130131
@Test
131132
func testPowInt() {
132133
let vwpb = valueWithPullback(at: SIMD4<Float>(repeating: 0.5), of: { x in SIMD4<Float>.pow(x, 2) })
133-
#expect(vwpb.value == .init(repeating: 0.25))
134-
#expect(vwpb.pullback(.one) == .init(repeating: 1.0))
134+
#expect(vwpb.value[0].isApproximatelyEqual(to: 0.25))
135+
#expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 1))
135136
}
136137

137138
@Test
138139
func testSqrt() {
139140
let vwpb = valueWithPullback(at: SIMD8<Float>(repeating: 4), of: SIMD8<Float>.sqrt)
140-
#expect(vwpb.value == .init(repeating: 2))
141-
#expect(vwpb.pullback(.one) == .init(repeating: 0.25))
141+
#expect(vwpb.value[0].isApproximatelyEqual(to: 2))
142+
#expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 0.25))
142143
}
143144

144145
@Test
145146
func testRoot() {
146147
let vwpb = valueWithPullback(at: SIMD16<Float>(repeating: 16), of: { x in SIMD16<Float>.root(x, 4) })
147-
#expect(vwpb.value == .init(repeating: 2))
148-
#expect(vwpb.pullback(.one) == .init(repeating: 1 / 32))
148+
#expect(vwpb.value[0].isApproximatelyEqual(to: 2))
149+
#expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 1 / 32))
149150
}
150151

151152
@Test
152153
func testAtan2() {
153154
let vwpb = valueWithPullback(at: SIMD32<Float>(repeating: 1), SIMD32<Float>(repeating: 0), of: SIMD32<Float>.atan2)
154-
#expect(vwpb.value == .init(repeating: 1.5707964)) // .pi / 2
155-
#expect(vwpb.pullback(.one) == (.init(repeating: 0), .init(repeating: -1)))
155+
#expect(vwpb.value[0].isApproximatelyEqual(to: .pi / 2))
156+
let gradient = vwpb.pullback(.one)
157+
#expect(gradient.0[0].isApproximatelyEqual(to: 0, absoluteTolerance: .ulpOfOne))
158+
#expect(gradient.1[0].isApproximatelyEqual(to: -1))
156159
}
157160

158161
@Test
159162
func testErf() {
160163
let vwpb = valueWithPullback(at: SIMD64<Float>(repeating: 0.5), of: SIMD64<Float>.erf)
161-
#expect(vwpb.value == .init(repeating: 0.5204999))
162-
#expect(vwpb.pullback(.one) == .init(repeating: 0.87878263))
164+
#expect(vwpb.value[0].isApproximatelyEqual(to: 0.5204998778))
165+
#expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 0.8787825789354449))
163166
}
164167

165168
@Test
166169
func testErfc() {
167170
let vwpb = valueWithPullback(at: SIMD2<Double>(repeating: 0.5), of: SIMD2<Double>.erfc)
168-
#expect(vwpb.value == .init(repeating: 0.4795001221869535))
169-
#expect(vwpb.pullback(.one) == .init(repeating: -0.8787825789354449))
171+
#expect(vwpb.value[0].isApproximatelyEqual(to: 1 - 0.5204998778))
172+
#expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: -0.8787825789354449))
170173
}
171174

172175
@Test
173176
func testExp2() {
174177
let vwpb = valueWithPullback(at: SIMD4<Double>(repeating: 2), of: SIMD4<Double>.exp2)
175-
#expect(vwpb.value == .init(repeating: 4))
176-
#expect(vwpb.pullback(.one) == .init(repeating: 4 * .log(2)))
178+
#expect(vwpb.value[0].isApproximatelyEqual(to: 4))
179+
#expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 4 * .log(2)))
177180
}
178181

179182
@Test
180183
func testExp10() {
181184
let vwpb = valueWithPullback(at: SIMD8<Double>(repeating: 2), of: SIMD8<Double>.exp10)
182-
#expect(vwpb.value == .init(repeating: 100))
183-
#expect(vwpb.pullback(.one) == .init(repeating: 100 * .log(10)))
185+
#expect(vwpb.value[0].isApproximatelyEqual(to: 100))
186+
#expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 100 * .log(10)))
184187
}
185188

186189
@Test
187190
func testHypot() {
188191
let vwpb = valueWithPullback(at: SIMD16<Double>(repeating: 3), SIMD16<Double>(repeating: 4), of: SIMD16<Double>.hypot)
189-
#expect(vwpb.value == .init(repeating: 5))
190-
#expect(vwpb.pullback(.one) == (.init(repeating: 3 / 5), .init(repeating: 4 / 5)))
192+
#expect(vwpb.value[0].isApproximatelyEqual(to: 5))
193+
let gradient = vwpb.pullback(.one)
194+
#expect(gradient.0[0].isApproximatelyEqual(to: 3 / 5))
195+
#expect(gradient.1[0].isApproximatelyEqual(to: 4 / 5))
191196
}
192197

193198
@Test(.disabled("derivative not implemented"))
194199
func testGamma() {
195200
let vwpb = valueWithPullback(at: SIMD32<Double>(repeating: 2), of: SIMD32<Double>.gamma)
196-
#expect(vwpb.value == .init(repeating: 1))
197-
#expect(vwpb.pullback(.one) == .init(repeating: 0))
201+
#expect(vwpb.value[0].isApproximatelyEqual(to: 1))
202+
#expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 0, absoluteTolerance: .ulpOfOne))
198203
}
199204

200205
@Test
201206
func testLog2() {
202207
let vwpb = valueWithPullback(at: SIMD64<Double>(repeating: 2), of: SIMD64<Double>.log2)
203-
#expect(vwpb.value == .init(repeating: 1))
204-
#expect(vwpb.pullback(.one) == .init(repeating: 1 / .log(4)))
208+
#expect(vwpb.value[0].isApproximatelyEqual(to: 1))
209+
#expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 1 / .log(4)))
205210
}
206211

207212
@Test
208213
func testLog10() {
209214
let vwpb = valueWithPullback(at: 2.0, of: Float.log10)
210-
#expect(vwpb.value == 0.30103) // .log(2) / .log(10)
211-
#expect(vwpb.pullback(1) == 1 / .log(100))
215+
#expect(vwpb.value.isApproximatelyEqual(to: .log(2) / .log(10)))
216+
#expect(vwpb.pullback(1).isApproximatelyEqual(to: 1 / .log(100)))
212217
}
213218

214219
@Test(.disabled("derivative not implemented"))
215220
func testLogGamma() {
216221
let vwpb = valueWithPullback(at: 2, of: Double.logGamma)
217-
#expect(vwpb.value == 0)
218-
#expect(vwpb.pullback(1) == 0)
222+
#expect(vwpb.value.isApproximatelyEqual(to: 0))
223+
#expect(vwpb.pullback(1).isApproximatelyEqual(to: 0, absoluteTolerance: .ulpOfOne))
219224
}
220225

221226
@Test
222227
func testAbs() {
223228
let vwpb = valueWithPullback(at: SIMD2<Float>(repeating: -2), of: SIMD2<Float>.abs)
224-
#expect(vwpb.value == .init(repeating: 2))
225-
#expect(vwpb.pullback(.one) == .init(repeating: -1))
229+
#expect(vwpb.value[0].isApproximatelyEqual(to: 2))
230+
#expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: -1))
226231
}
227232
}

0 commit comments

Comments
 (0)