Skip to content

Commit 6e46ebe

Browse files
committed
chore: apply formatting
1 parent 5196dde commit 6e46ebe

File tree

3 files changed

+35
-21
lines changed

3 files changed

+35
-21
lines changed

Package.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ let package = Package(
1515
),
1616
],
1717
dependencies: [
18-
.package(url: "https://github.com/apple/swift-collections.git", from: "1.1.4")
18+
.package(url: "https://github.com/apple/swift-collections.git", from: "1.1.4"),
1919
],
2020
targets: [
2121
.target(
@@ -28,13 +28,13 @@ let package = Package(
2828
.target(
2929
name: "OrderedCollectionsDifferentiable",
3030
dependencies: [
31-
.product(name: "OrderedCollections", package: "swift-collections")
31+
.product(name: "OrderedCollections", package: "swift-collections"),
3232
]
3333
),
3434
.testTarget(
3535
name: "OrderedCollectionsDifferentiableTests",
3636
dependencies: [
37-
"OrderedCollectionsDifferentiable"
37+
"OrderedCollectionsDifferentiable",
3838
]
3939
),
4040
]

Sources/OrderedCollectionsDifferentiable/OrderedDictionary+Differentiable.swift

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import _Differentiation
44

55
extension OrderedDictionary: @retroactive Differentiable where Value: Differentiable {
66
public typealias TangentVector = OrderedDictionary<Key, Value.TangentVector>
7-
7+
88
public mutating func move(by direction: TangentVector) {
99
for (componentKey, componentDirection) in direction {
1010
func fatalMissingComponent() -> Value {
@@ -29,7 +29,8 @@ extension OrderedDictionary: @retroactive AdditiveArithmetic where Value: Additi
2929
}
3030

3131
extension OrderedDictionary where Value: Differentiable {
32-
/// Defines a derivative for `OrderedDictionary`s subscript getter enabling calls like `var value = dictionary[key]` to be differentiable
32+
/// Defines a derivative for `OrderedDictionary`s subscript getter enabling calls like `var value = dictionary[key]` to be
33+
/// differentiable
3334
@inlinable
3435
@derivative(of: subscript(_:))
3536
func _vjpSubscript(key: Key)
@@ -38,7 +39,8 @@ extension OrderedDictionary where Value: Differentiable {
3839
let keys = self.keys
3940
// When adding two dictionaries, nil values are equivalent to zeroes, so there is no need to manually zero-out
4041
// every key's value. Instead, it is faster to create a dictionary with the single non-zero entry.
41-
// for ordered dictionaries however we can't because the keys will be added in reverse order so the tangentvector's key order will be different from the original
42+
// for ordered dictionaries however we can't because the keys will be added in reverse order so the tangentvector's key order will
43+
// be different from the original
4244
return (
4345
value: self[key],
4446
pullback: { tangentVector in
@@ -61,4 +63,3 @@ extension OrderedDictionary where Value: Differentiable {
6163
// TODO: make `OrderedDictionary.Values` and `OrderedDictionary.Elements` differentiable
6264

6365
#endif
64-

Tests/OrderedCollectionsDifferentiableTests/OrderedDictionary+DifferentiableTests.swift

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -43,23 +43,28 @@ struct DictionaryDifferentiationTests {
4343
#expect(vwg.value == 140.0)
4444
#expect(vwg.gradient == ["s1": 1.0, "s2": 2.0, "s3": 3.0])
4545
}
46-
47-
48-
46+
4947
@Test
5048
func testOrderedDictionaryInoutWriteMethod() {
5149
@differentiable(reverse)
52-
func combineByReplacingDictionaryValues(of mainDict: inout OrderedDictionary<String, Double>, with otherDict: OrderedDictionary<String, Double>) {
50+
func combineByReplacingDictionaryValues(
51+
of mainDict: inout OrderedDictionary<String, Double>,
52+
with otherDict: OrderedDictionary<String, Double>
53+
) {
5354
for key in withoutDerivative(at: otherDict.keys) {
54-
// note that we cannot use #require here as this function cannot throw (due to current compiler constraints wrt differentiation)
55+
// note that we cannot use #require here as this function cannot throw (due to current compiler constraints wrt
56+
// differentiation)
5557
// swift-format-ignore: NeverForceUnwrap
5658
let otherValue = otherDict[key]!
5759
mainDict.update(at: key, with: otherValue)
5860
}
5961
}
6062

6163
@differentiable(reverse)
62-
func inoutWrapper(dictionary: OrderedDictionary<String, Double>, otherDictionary: OrderedDictionary<String, Double>) -> OrderedDictionary<String, Double> {
64+
func inoutWrapper(
65+
dictionary: OrderedDictionary<String, Double>,
66+
otherDictionary: OrderedDictionary<String, Double>
67+
) -> OrderedDictionary<String, Double> {
6368
// we wrap the `combineByReplacingDictionaryValues`
6469
var mainCopy = dictionary
6570
combineByReplacingDictionaryValues(of: &mainCopy, with: otherDictionary)
@@ -68,7 +73,7 @@ struct DictionaryDifferentiationTests {
6873

6974
let vwpb = valueWithPullback(
7075
at: ["s1": 10.0, "s2": 20.0, "s3": 30.0],
71-
["s1": 2.0], //, "s2": nil, "s3": nil],
76+
["s1": 2.0], // , "s2": nil, "s3": nil],
7277
of: inoutWrapper
7378
)
7479

@@ -82,9 +87,13 @@ struct DictionaryDifferentiationTests {
8287
@Test
8388
func testInoutWriteAndSumValues() {
8489
@differentiable(reverse)
85-
func combineByReplacingDictionaryValues(of mainDict: inout OrderedDictionary<String, Double>, with otherDict: OrderedDictionary<String, Double>) {
90+
func combineByReplacingDictionaryValues(
91+
of mainDict: inout OrderedDictionary<String, Double>,
92+
with otherDict: OrderedDictionary<String, Double>
93+
) {
8694
for key in withoutDerivative(at: otherDict.keys) {
87-
// note that we cannot use #require here as this function cannot throw (due to current compiler constraints wrt differentiation)
95+
// note that we cannot use #require here as this function cannot throw (due to current compiler constraints wrt
96+
// differentiation)
8897
// swift-format-ignore: NeverForceUnwrap
8998
let otherValue = otherDict[key]!
9099
mainDict.update(at: key, with: otherValue)
@@ -93,25 +102,29 @@ struct DictionaryDifferentiationTests {
93102

94103
@differentiable(reverse)
95104
func sumValues(of dictionary: OrderedDictionary<String, Double>) -> Double {
96-
var sum: Double = 0.0
105+
var sum = 0.0
97106
for key in withoutDerivative(at: dictionary.keys) {
98-
// note that we cannot use #require here as this function cannot throw (due to current compiler constraints wrt differentiation)
107+
// note that we cannot use #require here as this function cannot throw (due to current compiler constraints wrt
108+
// differentiation)
99109
// swift-format-ignore: NeverForceUnwrap
100110
sum += dictionary[key]!
101111
}
102112
return sum
103113
}
104-
@differentiable(reverse,wrt: dictionary)
114+
@differentiable(reverse, wrt: dictionary)
105115

106-
func inoutWrapperAndSum(dictionary: OrderedDictionary<String, Double>, otherDictionary: OrderedDictionary<String, Double>) -> Double {
116+
func inoutWrapperAndSum(
117+
dictionary: OrderedDictionary<String, Double>,
118+
otherDictionary: OrderedDictionary<String, Double>
119+
) -> Double {
107120
var mainCopy = dictionary
108121
combineByReplacingDictionaryValues(of: &mainCopy, with: otherDictionary)
109122
return sumValues(of: mainCopy)
110123
}
111124

112125
let vwg = valueWithGradient(
113126
at: ["s1": 10.0, "s2": 20.0, "s3": 30.0],
114-
["s1": 2.0], //, "s2": nil, "s3": nil],
127+
["s1": 2.0], // , "s2": nil, "s3": nil],
115128
of: inoutWrapperAndSum
116129
)
117130

0 commit comments

Comments
 (0)