Skip to content

Commit 198dae6

Browse files
committed
feat: add in existing code
This copies in parts of our existing Differentiation library code. No tests yet since those are written using XCTest will add those in a separate MR.
1 parent 4265867 commit 198dae6

File tree

5 files changed

+372
-0
lines changed

5 files changed

+372
-0
lines changed
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#if canImport(_Differentiation)
2+
3+
import _Differentiation
4+
5+
public extension Array where Element: Differentiable {
6+
// Note: a compiler bug (SR-15530: https://bugs.swift.org/browse/SR-15530)
7+
// causes the vjpUpdate functions to be associated with the wrong
8+
// base functions unless you exactly align the function signatures in the
9+
// @derivative(of:) attribute and make sure the @inlinable attribute is the
10+
// same.
11+
12+
/// This function defines a derivative for AutoDiff to use when update() is called. It's not meant to be called directly in most
13+
/// situations.
14+
///
15+
/// - Parameters:
16+
/// - index: The index to update the value at.
17+
/// - newValue: The value to write.
18+
/// - Returns: The object, plus the pullback.
19+
@inlinable
20+
@derivative(of: update(at:with:))
21+
mutating func vjpUpdate(
22+
at index: Int,
23+
with newValue: Element
24+
) -> (value: Void, pullback: (inout TangentVector) -> (Element.TangentVector)) {
25+
update(at: index, with: newValue)
26+
let forwardCount = self.count
27+
return ((), { tangentVector in
28+
// manual zero tangent initialization
29+
if tangentVector.base.count < forwardCount {
30+
tangentVector.base = .init(repeating: .zero, count: forwardCount)
31+
}
32+
let dElement = tangentVector[index]
33+
tangentVector.base[index] = .zero
34+
return dElement
35+
})
36+
}
37+
}
38+
#endif
39+
40+
public extension Array {
41+
/// A functional version of `Array.subscript.modify`.
42+
/// Differentiation does yet not support `Array.subscript.modify` because
43+
/// it is a coroutine.
44+
@inlinable
45+
#if canImport(_Differentiation)
46+
@differentiable(reverse where Element: Differentiable)
47+
#endif
48+
mutating func update(at index: Int, with newValue: Element) {
49+
self[index] = newValue
50+
}
51+
}
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#if canImport(_Differentiation)
2+
3+
import _Differentiation
4+
import Foundation
5+
6+
// -------------------------------------------------------------------------
7+
// derivatives for native functions
8+
// -------------------------------------------------------------------------
9+
10+
/// For min(): "Returns: The lesser of `x` and `y`. If `x` is equal to `y`, returns `x`."
11+
/// https://github.com/apple/swift/blob/main/stdlib/public/core/Algorithm.swift#L18
12+
@inlinable
13+
@derivative(of: min)
14+
public func minVJP<T: Comparable & Differentiable>(
15+
_ lhs: T,
16+
_ rhs: T
17+
) -> (value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)) {
18+
func pullback(_ tangentVector: T.TangentVector) -> (T.TangentVector, T.TangentVector) {
19+
if lhs <= rhs {
20+
return (tangentVector, .zero)
21+
}
22+
else {
23+
return (.zero, tangentVector)
24+
}
25+
}
26+
return (value: min(lhs, rhs), pullback: pullback)
27+
}
28+
29+
/// For max(): "Returns: The greater of `x` and `y`. If `x` is equal to `y`, returns `y`."
30+
/// https://github.com/apple/swift/blob/main/stdlib/public/core/Algorithm.swift#L52
31+
@inlinable
32+
@derivative(of: max)
33+
public func maxVJP<T: Comparable & Differentiable>(
34+
_ lhs: T,
35+
_ rhs: T
36+
) -> (value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)) {
37+
func pullback(_ tangentVector: T.TangentVector) -> (T.TangentVector, T.TangentVector) {
38+
if lhs < rhs {
39+
return (.zero, tangentVector)
40+
}
41+
else {
42+
return (tangentVector, .zero)
43+
}
44+
}
45+
return (value: max(lhs, rhs), pullback: pullback)
46+
}
47+
48+
/// To differentiate ``abs``
49+
@inlinable
50+
@derivative(of: abs)
51+
public func absVJP<T: Comparable & SignedNumeric & Differentiable>(_ value: T)
52+
-> (value: T, pullback: (T.TangentVector) -> T.TangentVector)
53+
{
54+
func pullback(_ tangentVector: T.TangentVector) -> T.TangentVector {
55+
if value < 0 {
56+
return .zero - tangentVector
57+
}
58+
else {
59+
return tangentVector
60+
}
61+
}
62+
return (value: abs(value), pullback: pullback)
63+
}
64+
65+
/// Differentiation of ``atan2``
66+
@derivative(of: atan2(_:_:))
67+
public func vjpAtan2(
68+
y: Double, x: Double
69+
) -> (value: Double, pullback: (Double) -> (Double, Double)) {
70+
(
71+
value: atan2(y, x),
72+
pullback: { ($0 * x / (x * x + y * y), -$0 * y / (x * x + y * y)) }
73+
)
74+
}
75+
76+
#endif
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
#if canImport(_Differentiation)
2+
3+
import _Differentiation
4+
5+
// copied and modified from
6+
// https://github.com/borglab/SwiftFusion/blob/main/Sources/SwiftFusion/Core/Dictionary+Differentiable.swift
7+
// and
8+
// https://bugs.swift.org/browse/TF-1193
9+
10+
/// This file makes `Dictionary` differentiable.
11+
///
12+
/// Note: This will eventually be moved into the Swift standard library. Once it is in the
13+
/// standard library, we can delete it from this repository.
14+
/// Implements the `Differentiable` requirements.
15+
extension Dictionary: Differentiable where Value: Differentiable {
16+
public typealias TangentVector = [Key: Value.TangentVector]
17+
public mutating func move(by direction: TangentVector) {
18+
for (componentKey, componentDirection) in direction {
19+
func fatalMissingComponent() -> Value {
20+
preconditionFailure("missing component \(componentKey) in moved Dictionary")
21+
}
22+
self[componentKey, default: fatalMissingComponent()].move(by: componentDirection)
23+
}
24+
}
25+
26+
public var zeroTangentVectorInitializer: () -> TangentVector {
27+
let listOfKeys = keys // capturing only what's needed, not the entire self, in order to not waste memory
28+
func initializer() -> Self.TangentVector {
29+
return listOfKeys.reduce(into: [Key: Value.TangentVector]()) { $0[$1] = Value.TangentVector.zero }
30+
}
31+
return initializer
32+
}
33+
}
34+
35+
/// Implements the `AdditiveArithmetic` requirements.
36+
extension Dictionary: AdditiveArithmetic where Value: AdditiveArithmetic {
37+
public static func + (_ lhs: Self, _ rhs: Self) -> Self {
38+
lhs.merging(rhs, uniquingKeysWith: +)
39+
}
40+
41+
public static func - (_ lhs: Self, _ rhs: Self) -> Self {
42+
lhs.merging(rhs.mapValues { .zero - $0 }, uniquingKeysWith: +)
43+
}
44+
45+
public static var zero: Self { [:] }
46+
}
47+
48+
// attempt to make builtin subscript differentiable:
49+
// https://bugs.swift.org/browse/TF-1193
50+
// https://github.com/apple/swift/pull/32614/
51+
// https://github.com/borglab/SwiftFusion/blob/main/Sources/SwiftFusion/Core/Dictionary+Differentiable.swift
52+
53+
extension Dictionary where Value: Differentiable {
54+
// get
55+
// swiftformat:disable:next typeSugar
56+
// periphery:ignore
57+
@usableFromInline
58+
@derivative(of: subscript(_:))
59+
func vjpSubscriptGet(key: Key)
60+
-> (value: Value?, pullback: (Optional<Value>.TangentVector) -> Dictionary<Key, Value>.TangentVector)
61+
{
62+
// When adding two dictionaries, nil values are equivalent to zeroes, so there is no need to manually zero-out
63+
// every key's value. Instead, it is faster to create a dictionary with the single non-zero entry.
64+
return (self[key], { tangentVector in
65+
if let value = tangentVector.value {
66+
return [key: value]
67+
}
68+
else {
69+
return .zero
70+
}
71+
})
72+
}
73+
}
74+
75+
public extension Dictionary where Value: Differentiable {
76+
// make a manual update(at: with:) since https://bugs.swift.org/browse/TF-1277 affects dictionary as well, making @derivative(of:
77+
// subscript(_:).set) useless
78+
/// manual update function replacing `subscript(_:).set` since that cannot be made differentiable (might now be possible)
79+
@differentiable(reverse)
80+
mutating func set(_ key: Key, to newValue: Value) {
81+
self[key] = newValue
82+
}
83+
84+
/// derivative of above set function. Ideally this would just be the derivative of `subscript(_:).set`
85+
@derivative(of: set)
86+
mutating func vjpUpdated(
87+
_ key: Key,
88+
to newValue: Value
89+
) -> (value: Void, pullback: (inout TangentVector) -> (Value.TangentVector)) {
90+
set(key, to: newValue)
91+
92+
let forwardCount = count
93+
let forwardKeys = keys // may be heavy to capture all of these, not sure how to do without them though
94+
95+
return ((), { tangentVector in
96+
// manual zero tangent initialization
97+
if tangentVector.count < forwardCount {
98+
tangentVector = Self.TangentVector()
99+
forwardKeys.forEach { tangentVector[$0] = .zero }
100+
}
101+
102+
if let dElement = tangentVector[key] {
103+
tangentVector[key] = .zero
104+
return dElement
105+
}
106+
else { // should this fail?
107+
tangentVector[key] = .zero
108+
return .zero
109+
}
110+
})
111+
}
112+
}
113+
114+
#endif
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#if canImport(_Differentiation)
2+
3+
import _Differentiation
4+
5+
// TODO: This will soon be upstreamed to the swift toolchain after which it will be removed from this package.
6+
7+
extension Optional where Wrapped: Differentiable {
8+
@inlinable
9+
@differentiable(reverse, wrt: self)
10+
public func differentiableMap<Result: Differentiable>(
11+
_ body: @differentiable(reverse) (Wrapped) -> Result
12+
) -> Optional<Result> {
13+
map(body)
14+
}
15+
16+
@inlinable
17+
@derivative(of: differentiableMap)
18+
internal func _vjpDifferentiableMap<Result: Differentiable>(
19+
_ body: @differentiable(reverse) (Wrapped) -> Result
20+
) -> (
21+
value: Optional<Result>,
22+
pullback: (Optional<Result>.TangentVector) -> Optional.TangentVector
23+
) {
24+
let vwpb = self.map { valueWithPullback(at: $0, of: body) }
25+
let bodyPullback = vwpb?.pullback
26+
27+
func pullback(_ vec: Optional<Result>.TangentVector) -> Optional.TangentVector {
28+
guard let value = vec.value, let bodyPullback else { return .init(.none) }
29+
return .init(.some(bodyPullback(value)))
30+
}
31+
32+
return (value: vwpb?.value, pullback: pullback)
33+
}
34+
35+
@inlinable
36+
@derivative(of: differentiableMap)
37+
internal func _jvpDifferentiableMap<Result: Differentiable>(
38+
_ body: @differentiable(reverse) (Wrapped) -> Result
39+
) -> (
40+
value: Optional<Result>,
41+
differential: (Optional.TangentVector) -> Optional<Result>.TangentVector
42+
) {
43+
let vwpb = self.map { valueWithDifferential(at: $0, of: body) }
44+
let bodyDifferential = vwpb?.differential
45+
46+
func differential(_ vec: Optional.TangentVector) -> Optional<Result>.TangentVector {
47+
guard let value = vec.value, let bodyDifferential else { return .init(.none) }
48+
return .init(bodyDifferential(value))
49+
}
50+
51+
return (value: vwpb?.value, differential: differential)
52+
}
53+
}
54+
55+
#endif
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#if canImport(_Differentiation)
2+
3+
import _Differentiation
4+
5+
public extension Sequence where
6+
Self: Collection, // we constrain to conform to collection cause otherwise we can't access any values by index
7+
Self: Differentiable,
8+
Self.TangentVector: RangeReplaceableCollection, // we constrain the tangentvector to be able to create a value and write to it
9+
Self.TangentVector.Element == Element.TangentVector,
10+
Element: Differentiable,
11+
Element: Comparable
12+
{
13+
// Match Self.Index with Self.TangentVector index so we can use them across both types.
14+
// The reason we are doing the where clause here rather than at the extension declaration
15+
// level is because of the DocC crash: https://github.com/swiftlang/swift/issues/75258
16+
/// To differentiate ``Swift/Sequence/max``
17+
@derivative(of: max)
18+
func vjpMax() -> (
19+
value: Element?,
20+
pullback: (Element?.TangentVector) -> (Self.TangentVector)
21+
) where Self.Index == Self.TangentVector.Index {
22+
let index = withoutDerivative(at: self.indices.max { self[$0] < self[$1] }) // we grab the index of the element with the max value
23+
return (
24+
value: index.map { self[$0] }, // if the index is nil, we return nil otherwise we grab the value at the index
25+
pullback: { vector in
26+
var dSelf = Self
27+
.TangentVector(
28+
repeating: .zero,
29+
count: self
30+
.count
31+
) // we create a zero tangentvector we need `RangeReplaceableCollection` conformance in order to do this
32+
if let vectorValue = vector.value,
33+
let index = index
34+
{
35+
// if an index was found and our tangentvector's value is non nil we set the value at index of our tangentvector to the
36+
// provided tangentvector value
37+
dSelf
38+
.replaceSubrange(
39+
index ..< dSelf.index(after: index),
40+
with: [vectorValue]
41+
) // we use `RangeReplaceableCollection`'s method here in order to not have to also constrain our TangentVector to
42+
// `MutableCollection`
43+
}
44+
return dSelf // return the tangentvector
45+
}
46+
)
47+
}
48+
49+
// Match Self.Index with Self.TangentVector index so we can use them across both types.
50+
// The reason we are doing the where clause here rather than at the extension declaration
51+
// level is because of the DocC crash: https://github.com/swiftlang/swift/issues/75258
52+
/// To differentiate ``Swift/Sequence/min``
53+
@derivative(of: min)
54+
func vjpMin() -> (
55+
value: Element?,
56+
pullback: (Element?.TangentVector) -> (Self.TangentVector)
57+
) where Self.Index == Self.TangentVector.Index {
58+
let index = withoutDerivative(at: self.indices.min { self[$0] < self[$1] }) // we grab the index of the element with the max value
59+
return (
60+
value: index.map { self[$0] }, // if the index is nil, we return nil otherwise we grab the value at the index
61+
pullback: { vector in
62+
var dSelf = Self.TangentVector(repeating: .zero, count: self.count) // we create a zero tangentvector
63+
if let vectorValue = vector.value,
64+
let index = index
65+
{
66+
// if an index was found and our tangentvector's value is non nil we set the value at index of our tangentvector to the
67+
// provided tangentvector value
68+
dSelf.replaceSubrange(index ..< dSelf.index(after: index), with: [vectorValue])
69+
}
70+
return dSelf // return the tangentvector
71+
}
72+
)
73+
}
74+
}
75+
76+
#endif

0 commit comments

Comments
 (0)