Skip to content

Commit 9a5e216

Browse files
author
Jaap Wijnen
committed
test: convert to swift testing
1 parent d14f396 commit 9a5e216

File tree

4 files changed

+160
-6
lines changed

4 files changed

+160
-6
lines changed

Tests/OrderedCollectionsDifferentiableTests/OrderedCollectionsDifferentiableTests.swift

Lines changed: 0 additions & 6 deletions
This file was deleted.
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
#if canImport(_Differentiation)
2+
3+
import OrderedCollectionsDifferentiable
4+
import Testing
5+
6+
@Suite("Dictionary+Differentiation")
7+
struct DictionaryDifferentiationTests {
8+
@Test
9+
func testSubscriptGet() throws {
10+
let dictionary: OrderedDictionary<String, Double> = ["a": 3, "b": 7]
11+
12+
let aMultiplier: Double = 13
13+
let bMultiplier: Double = 17
14+
15+
func readFromDictionary(d: OrderedDictionary<String, Double>) -> Double {
16+
// note that we cannot use #require here as this function cannot throw (due to current compiler constraints wrt differentiation)
17+
// swift-format-ignore: NeverForceUnwrap
18+
let a = d["a"]! * aMultiplier
19+
let b = d["b"]! * bMultiplier
20+
return a + b
21+
}
22+
23+
let vwg = valueWithGradient(at: dictionary, of: readFromDictionary)
24+
25+
#expect(vwg.value == 3 * aMultiplier + 7 * bMultiplier)
26+
#expect(vwg.gradient == ["a": aMultiplier, "b": bMultiplier])
27+
}
28+
29+
@Test
30+
func testOrderedDictionaryReadAndCombineValues() {
31+
@differentiable(reverse)
32+
func testFunction(newValues: OrderedDictionary<String, Double>) -> Double {
33+
// note that we cannot use #require here as this function cannot throw (due to current compiler constraints wrt differentiation)
34+
// swift-format-ignore: NeverForceUnwrap
35+
1.0 * newValues["s1"]! + 2.0 * newValues["s2"]! + 3.0 * newValues["s3"]!
36+
}
37+
38+
let vwg = valueWithGradient(
39+
at: ["s1": 10.0, "s2": 20.0, "s3": 30.0],
40+
of: testFunction
41+
)
42+
43+
#expect(vwg.value == 140.0)
44+
#expect(vwg.gradient == ["s1": 1.0, "s2": 2.0, "s3": 3.0])
45+
}
46+
47+
48+
49+
@Test
50+
func testOrderedDictionaryInoutWriteMethod() {
51+
@differentiable(reverse)
52+
func combineByReplacingDictionaryValues(of mainDict: inout OrderedDictionary<String, Double>, with otherDict: OrderedDictionary<String, Double>) {
53+
for key in withoutDerivative(at: otherDict.keys) {
54+
// note that we cannot use #require here as this function cannot throw (due to current compiler constraints wrt differentiation)
55+
// swift-format-ignore: NeverForceUnwrap
56+
let otherValue = otherDict[key]!
57+
mainDict.update(at: key, with: otherValue)
58+
}
59+
}
60+
61+
@differentiable(reverse)
62+
func inoutWrapper(dictionary: OrderedDictionary<String, Double>, otherDictionary: OrderedDictionary<String, Double>) -> OrderedDictionary<String, Double> {
63+
// we wrap the `combineByReplacingDictionaryValues`
64+
var mainCopy = dictionary
65+
combineByReplacingDictionaryValues(of: &mainCopy, with: otherDictionary)
66+
return mainCopy
67+
}
68+
69+
let vwpb = valueWithPullback(
70+
at: ["s1": 10.0, "s2": 20.0, "s3": 30.0],
71+
["s1": 2.0], //, "s2": nil, "s3": nil],
72+
of: inoutWrapper
73+
)
74+
75+
#expect(vwpb.value == ["s1": 2.0, "s2": 20.0, "s3": 30.0])
76+
// we need to provide a full tangentvector to the pullback hence the keys with zero entries.
77+
#expect(vwpb.pullback(["s1": 1.0, "s2": 0.0, "s3": 0.0]) == (["s1": 0.0, "s2": 0.0, "s3": 0.0], ["s1": 1.0]))
78+
#expect(vwpb.pullback(["s1": 0.0, "s2": 1.0, "s3": 0.0]) == (["s1": 0.0, "s2": 1.0, "s3": 0.0], ["s1": 0.0]))
79+
#expect(vwpb.pullback(["s1": 0.0, "s2": 0.0, "s3": 1.0]) == (["s1": 0.0, "s2": 0.0, "s3": 1.0], ["s1": 0.0]))
80+
}
81+
82+
@Test
83+
func testInoutWriteAndSumValues() {
84+
@differentiable(reverse)
85+
func combineByReplacingDictionaryValues(of mainDict: inout OrderedDictionary<String, Double>, with otherDict: OrderedDictionary<String, Double>) {
86+
for key in withoutDerivative(at: otherDict.keys) {
87+
// note that we cannot use #require here as this function cannot throw (due to current compiler constraints wrt differentiation)
88+
// swift-format-ignore: NeverForceUnwrap
89+
let otherValue = otherDict[key]!
90+
mainDict.update(at: key, with: otherValue)
91+
}
92+
}
93+
94+
@differentiable(reverse)
95+
func sumValues(of dictionary: OrderedDictionary<String, Double>) -> Double {
96+
var sum: Double = 0.0
97+
for key in withoutDerivative(at: dictionary.keys) {
98+
// note that we cannot use #require here as this function cannot throw (due to current compiler constraints wrt differentiation)
99+
// swift-format-ignore: NeverForceUnwrap
100+
sum += dictionary[key]!
101+
}
102+
return sum
103+
}
104+
@differentiable(reverse,wrt: dictionary)
105+
106+
func inoutWrapperAndSum(dictionary: OrderedDictionary<String, Double>, otherDictionary: OrderedDictionary<String, Double>) -> Double {
107+
var mainCopy = dictionary
108+
combineByReplacingDictionaryValues(of: &mainCopy, with: otherDictionary)
109+
return sumValues(of: mainCopy)
110+
}
111+
112+
let vwg = valueWithGradient(
113+
at: ["s1": 10.0, "s2": 20.0, "s3": 30.0],
114+
["s1": 2.0], //, "s2": nil, "s3": nil],
115+
of: inoutWrapperAndSum
116+
)
117+
118+
#expect(vwg.value == 52.0)
119+
#expect(vwg.gradient == (["s1": 0.0, "s2": 1.0, "s3": 1.0], ["s1": 1.0]))
120+
}
121+
}
122+
123+
#endif
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#if canImport(_Differentiation)
2+
3+
import OrderedCollectionsDifferentiable
4+
import Testing
5+
6+
@Suite("OrderedDictionary+Update")
7+
struct OrderedDictionaryUpdateTests {
8+
@Test
9+
func testUpdateWithValue() throws {
10+
let dictionary: OrderedDictionary<String, Double> = ["a": 1, "b": 1]
11+
12+
let aMultiplier: Double = 13
13+
let bMultiplier: Double = 17
14+
15+
func writeAndReadFromDictionary(dict: OrderedDictionary<String, Double>, newA: Double, newB: Double) -> Double {
16+
var dict = dict
17+
dict.update(at: "a", with: newA)
18+
dict.update(at: "b", with: newB)
19+
20+
// note that we cannot use #require here as this function cannot throw (due to current compiler constraints wrt differentiation)
21+
// swift-format-ignore: NeverForceUnwrap
22+
let a = dict["a"]! * aMultiplier
23+
let b = dict["b"]! * bMultiplier
24+
return a + b
25+
}
26+
27+
let newA: Double = 3
28+
let newB: Double = 7
29+
30+
let valAndGrad = valueWithGradient(at: dictionary, newA, newB, of: writeAndReadFromDictionary)
31+
print(valAndGrad.gradient)
32+
#expect(valAndGrad.value == newA * aMultiplier + newB * bMultiplier)
33+
#expect(valAndGrad.gradient == (["a": 0, "b": 0], aMultiplier, bMultiplier))
34+
}
35+
}
36+
37+
#endif

Tests/OrderedCollectionsDifferentiableTests/OrderedDictionaryDifferentiableTests.swift

Whitespace-only changes.

0 commit comments

Comments
 (0)