Skip to content

Commit 1f07c13

Browse files
committed
feature: Make OrderedDictionary.Values conform to Differentiable
1 parent ace9ad1 commit 1f07c13

File tree

4 files changed

+136
-10
lines changed

4 files changed

+136
-10
lines changed

Sources/OrderedCollectionsDifferentiable/OrderedDictionary+Differentiable.swift

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@ import _Differentiation
55
extension OrderedDictionary: @retroactive Differentiable where Value: Differentiable {
66
public typealias TangentVector = OrderedDictionary<Key, Value.TangentVector>
77

8-
public mutating func move(by direction: TangentVector) {
9-
for (componentKey, componentDirection) in direction {
8+
public mutating func move(by offset: TangentVector) {
9+
for (key, tangentValue) in offset {
1010
func fatalMissingComponent() -> Value {
11-
preconditionFailure("missing component \(componentKey) in moved OrderedDictionary")
11+
preconditionFailure("missing entry for key \(key) in moved OrderedDictionary")
1212
}
13-
self[componentKey, default: fatalMissingComponent()].move(by: componentDirection)
13+
self[key, default: fatalMissingComponent()].move(by: tangentValue)
1414
}
1515
}
1616
}
@@ -33,9 +33,10 @@ extension OrderedDictionary where Value: Differentiable {
3333
/// differentiable
3434
@inlinable
3535
@derivative(of: subscript(_:))
36-
func _vjpSubscript(key: Key)
37-
-> (value: Value?, pullback: (Optional<Value>.TangentVector) -> OrderedDictionary<Key, Value>.TangentVector)
38-
{
36+
func _vjpSubscript(key: Key) -> (
37+
value: Value?,
38+
pullback: (Optional<Value>.TangentVector) -> OrderedDictionary<Key, Value>.TangentVector
39+
) {
3940
let keys = self.keys
4041
// When adding two dictionaries, nil values are equivalent to zeroes, so there is no need to manually zero-out
4142
// every key's value. Instead, it is faster to create a dictionary with the single non-zero entry.
@@ -58,8 +59,26 @@ extension OrderedDictionary where Value: Differentiable {
5859
}
5960
)
6061
}
62+
63+
@derivative(of: values)
64+
@inlinable
65+
@inline(__always)
66+
public func _vjpValues() -> (value: Values, pullback: (Values.TangentVector) -> OrderedDictionary<Key, Value>.TangentVector) {
67+
let keys = self.keys
68+
return (
69+
value: self.values,
70+
pullback: { v in
71+
var dict = OrderedDictionary<Key, Value>.TangentVector()
72+
dict.reserveCapacity(keys.count)
73+
for (key, tangentValue) in zip(keys, v.base) {
74+
dict[key] = tangentValue
75+
}
76+
return dict
77+
}
78+
)
79+
}
6180
}
6281

63-
// TODO: make `OrderedDictionary.Values` and `OrderedDictionary.Elements` differentiable
82+
// TODO: make `OrderedDictionary.Elements` differentiable
6483

6584
#endif
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#if canImport(_Differentiation)
2+
3+
import _Differentiation
4+
5+
extension OrderedDictionary.Values: @retroactive Differentiable where Value: Differentiable {
6+
public typealias TangentVector = Array<Value.TangentVector>.TangentVector
7+
8+
public mutating func move(by offset: OrderedDictionary<Key, Value>.Values.TangentVector) {
9+
for (i, j) in zip(self.indices, offset.base.indices) {
10+
self[i].move(by: offset.base[j])
11+
}
12+
}
13+
}
14+
15+
extension OrderedDictionary.Values where Value: Differentiable {
16+
/// Defines a derivative for `OrderedDictionary`s subscript getter enabling calls like `var value = dictionary[key]` to be
17+
/// differentiable
18+
@inlinable
19+
@derivative(of: subscript(_:))
20+
func _vjpSubscript(position: Int) -> (
21+
value: Value,
22+
pullback: (Value.TangentVector) -> OrderedDictionary<Key, Value>.Values.TangentVector
23+
) {
24+
let count = self.count
25+
return (
26+
value: self[position],
27+
pullback: { tangentVector in
28+
var vector = Array<Value.TangentVector>.TangentVector(.init(repeating: .zero, count: count))
29+
vector.base[position] = tangentVector
30+
return vector
31+
}
32+
)
33+
}
34+
35+
@derivative(of: elements)
36+
@inlinable
37+
func _vjpElements() -> (
38+
value: Array<Value>,
39+
pullback: (Array<Value>.TangentVector) -> OrderedDictionary.Values.TangentVector
40+
) {
41+
(
42+
value: self.elements,
43+
pullback: { v in v }
44+
)
45+
}
46+
}
47+
48+
#endif

Tests/OrderedCollectionsDifferentiableTests/OrderedDictionary+DifferentiableTests.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import OrderedCollectionsDifferentiable
44
import Testing
55

6-
@Suite("Dictionary+Differentiation")
7-
struct DictionaryDifferentiationTests {
6+
@Suite("OrderedDictionary+Differentiation")
7+
struct OrderedDictionaryDifferentiationTests {
88
@Test
99
func testSubscriptGet() throws {
1010
let dictionary: OrderedDictionary<String, Double> = ["a": 3, "b": 7]
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#if canImport(_Differentiation)
2+
3+
import OrderedCollectionsDifferentiable
4+
import Testing
5+
6+
@Suite("OrderedDictionary.Values+Differentiable")
7+
struct OrderedDictionaryValuesTests {
8+
@Test
9+
func dictionaryValuesMemberTest() {
10+
@differentiable(reverse)
11+
func values(dict: OrderedDictionary<String, Double>) -> OrderedDictionary<String, Double>.Values {
12+
dict.values
13+
}
14+
15+
let dict: OrderedDictionary<String, Double> = ["a": 1.0, "b": 2.0]
16+
let vwpb = valueWithPullback(at: dict, of: values)
17+
18+
#expect(vwpb.value == dict.values)
19+
let pullback = vwpb.pullback([0.0, 1.0])
20+
21+
#expect(pullback == ["a": 0.0, "b": 1.0])
22+
}
23+
24+
@Test
25+
func dictionaryValuesMemberElementsTest() {
26+
@differentiable(reverse)
27+
func elements(dict: OrderedDictionary<String, Double>) -> [Double] {
28+
dict.values.elements
29+
}
30+
31+
let dict: OrderedDictionary<String, Double> = ["a": 1.0, "b": 2.0]
32+
let vwpb = valueWithPullback(at: dict, of: elements)
33+
34+
#expect(vwpb.value == [1.0, 2.0])
35+
let pullback = vwpb.pullback([0.0, 1.0])
36+
37+
#expect(pullback == ["a": 0.0, "b": 1.0])
38+
}
39+
40+
@Test
41+
func dictionaryValuesMemberSubscriptTest() {
42+
@differentiable(reverse)
43+
func values(dict: OrderedDictionary<String, Double>, index: Int) -> Double {
44+
dict.values[index]
45+
}
46+
47+
let dict: OrderedDictionary<String, Double> = ["a": 1.0, "b": 2.0]
48+
let vwpb = valueWithPullback(at: dict, of: { dict in
49+
values(dict: dict, index: 1)
50+
})
51+
52+
#expect(vwpb.value == 2.0)
53+
let pullback = vwpb.pullback(1.0)
54+
55+
#expect(pullback == ["a": 0.0, "b": 1.0])
56+
}
57+
}
58+
59+
#endif

0 commit comments

Comments
 (0)