Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions Sources/Differentiation/Array+Update.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#if canImport(_Differentiation)

import _Differentiation

#endif

extension Array {
/// A Differentiable alternative to `Array.subscript.modify`.
/// Differentiation does not yet support `Array.subscript.modify` because it is a coroutine.
#if canImport(_Differentiation)
@differentiable(reverse where Element: Differentiable)
#endif
@inlinable
public mutating func update(at index: Int, with newValue: Element) {
self[index] = newValue
}
}

#if canImport(_Differentiation)

extension Array where Element: Differentiable {
/// This function defines a derivative for AutoDiff to use when update() is called. It's not meant to be called directly in most
/// situations.
///
/// - Parameters:
/// - index: The index to update the value at.
/// - newValue: The value to write.
/// - Returns: The object, plus the pullback.
@derivative(of: update(at:with:))
@inlinable
public mutating func _vjpUpdate(
at index: Int,
with newValue: Element
) -> (value: Void, pullback: (inout TangentVector) -> (Element.TangentVector)) {
update(at: index, with: newValue)
let forwardCount = self.count
return ((), { tangentVector in
// manual zero tangent initialization
if tangentVector.base.count < forwardCount {
tangentVector.base = .init(repeating: .zero, count: forwardCount)
}
let dElement = tangentVector[index]
tangentVector.base[index] = .zero
return dElement
})
}
}

#endif

45 changes: 45 additions & 0 deletions Sources/Differentiation/ArrayDifferentiableView+Collection.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#if canImport(_Differentiation)

import _Differentiation

extension Array.DifferentiableView:
@retroactive Sequence,
@retroactive Collection,
@retroactive RangeReplaceableCollection,
@retroactive RandomAccessCollection,
@retroactive BidirectionalCollection,
@retroactive MutableCollection
where Element: Differentiable
{
public typealias Element = Array.Element
public typealias Index = Array.Index
public typealias SubSequence = Array.SubSequence

@inlinable
public subscript(position: Index) -> Element {
_read { yield base[position] }
set(newValue) { base[position] = newValue }
}

@inlinable
public subscript(bounds: Range<Index>) -> SubSequence {
get { base[bounds] }
set(newValue) { base[bounds] = newValue }
}

@inlinable
public var startIndex: Index { base.startIndex }

@inlinable
public var endIndex: Index { base.endIndex }

@inlinable
public init() { self.init(Array<Element>()) }

@inlinable
public mutating func replaceSubrange<C>(_ subrange: Range<Self.Index>, with newElements: C) where C : Collection, Self.Element == C.Element {
base.replaceSubrange(subrange, with: newElements)
}
}

#endif
60 changes: 60 additions & 0 deletions Sources/Differentiation/DerivativesOfNativeFunctions.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#if canImport(_Differentiation)

import _Differentiation

/// For min(): "Returns: The lesser of `x` and `y`. If `x` is equal to `y`, returns `x`."
/// https://github.com/apple/swift/blob/main/stdlib/public/core/Algorithm.swift#L18
@inlinable
@derivative(of: min(_:_:))
public func _vjpMin<T: Comparable & Differentiable>(
_ lhs: T,
_ rhs: T
) -> (value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)) {
func pullback(_ tangentVector: T.TangentVector) -> (T.TangentVector, T.TangentVector) {
if lhs <= rhs {
return (tangentVector, .zero)
}
else {
return (.zero, tangentVector)
}
}
return (value: min(lhs, rhs), pullback: pullback)
}

/// For max(): "Returns: The greater of `x` and `y`. If `x` is equal to `y`, returns `y`."
/// https://github.com/apple/swift/blob/main/stdlib/public/core/Algorithm.swift#L52
@inlinable
@derivative(of: max(_:_:))
public func _vjpMax<T: Comparable & Differentiable>(
_ lhs: T,
_ rhs: T
) -> (value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)) {
func pullback(_ tangentVector: T.TangentVector) -> (T.TangentVector, T.TangentVector) {
if lhs < rhs {
return (.zero, tangentVector)
}
else {
return (tangentVector, .zero)
}
}
return (value: max(lhs, rhs), pullback: pullback)
}

/// To differentiate ``abs``
@inlinable
@derivative(of: abs(_:))
public func _vjpAbs<T: Comparable & SignedNumeric & Differentiable>(_ value: T)
-> (value: T, pullback: (T.TangentVector) -> T.TangentVector)
{
func pullback(_ tangentVector: T.TangentVector) -> T.TangentVector {
if value < 0 {
return .zero - tangentVector
}
else {
return tangentVector
}
}
return (value: abs(value), pullback: pullback)
}

#endif
62 changes: 62 additions & 0 deletions Sources/Differentiation/Dictionary+Differentiation.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#if canImport(_Differentiation)

import _Differentiation

// copied and modified from
// https://github.com/borglab/SwiftFusion/blob/main/Sources/SwiftFusion/Core/Dictionary+Differentiable.swift
// and
// https://bugs.swift.org/browse/TF-1193

extension Dictionary: Differentiable where Value: Differentiable {
public typealias TangentVector = [Key: Value.TangentVector]
public mutating func move(by direction: TangentVector) {
for (componentKey, componentDirection) in direction {
func fatalMissingComponent() -> Value {
preconditionFailure("missing component \(componentKey) in moved Dictionary")
}
self[componentKey, default: fatalMissingComponent()].move(by: componentDirection)
}
}

public var zeroTangentVectorInitializer: () -> TangentVector {
let listOfKeys = keys // capturing only what's needed, not the entire self, in order to not waste memory
func initializer() -> Self.TangentVector {
return listOfKeys.reduce(into: [Key: Value.TangentVector]()) { $0[$1] = Value.TangentVector.zero }
}
return initializer
}
}

/// Implements the `AdditiveArithmetic` requirements.
extension Dictionary: AdditiveArithmetic where Value: AdditiveArithmetic {
public static func + (_ lhs: Self, _ rhs: Self) -> Self {
lhs.merging(rhs, uniquingKeysWith: +)
}

public static func - (_ lhs: Self, _ rhs: Self) -> Self {
lhs.merging(rhs.mapValues { .zero - $0 }, uniquingKeysWith: +)
}

public static var zero: Self { [:] }
}

extension Dictionary where Value: Differentiable {
/// Defines a derivative for `Dictionary`s subscript getter enabling calls like `var value = dictionary[key]` to be differentiable
@inlinable
@derivative(of: subscript(_:))
public func _vjpSubscript(key: Key)
-> (value: Value?, pullback: (Optional<Value>.TangentVector) -> Dictionary<Key, Value>.TangentVector)
{
// When adding two dictionaries, nil values are equivalent to zeroes, so there is no need to manually zero-out
// every key's value. Instead, it is faster to create a dictionary with the single non-zero entry.
return (self[key], { tangentVector in
if let value = tangentVector.value {
return [key: value]
}
else {
return .zero
}
})
}
}
#endif
62 changes: 62 additions & 0 deletions Sources/Differentiation/Dictionary+Update.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#if canImport(_Differentiation)

import _Differentiation

#endif

extension Dictionary {
/// A Differentiable alternative to `Dictionary.subscript.modify`
/// Differentiation does not yet support `Dictionary.subscript.modify` because it is a coroutine.
#if canImport(_Differentiation)
@differentiable(reverse where Value: Differentiable)
#endif
@inlinable
public mutating func update(at key: Key, with newValue: Value) {
self[key] = newValue
}
}

#if canImport(_Differentiation)

extension Dictionary where Value: Differentiable {
/// This function defines a derivative for AutoDiff to use when update() is called. It's not meant to be called directly in most
/// situations.
///
/// - Parameters:
/// - key: The key to update the value at.
/// - newValue: The value to write.
/// - Returns: The object, plus the pullback.
@derivative(of: update(at:with:))
@inlinable
public mutating func _vjpUpdate(
at key: Key,
with newValue: Value // TODO: this should be optional?
) -> (value: Void, pullback: (inout TangentVector) -> (Value.TangentVector)) {
update(at: key, with: newValue)

let forwardCount = count
let forwardKeys = keys // may be heavy to capture all of these, not sure how to do without them though

return ((), { tangentVector in
// manual zero tangent initialization
// TODO: Should we consider missing keys as a complete tangentvector with zero values for those keys?
if tangentVector.count < forwardCount { // TODO: is this the correct check keys could still differ
tangentVector = Self.TangentVector() // TODO: should we be replacing this or merging
for key in forwardKeys {
tangentVector[key] = .zero
}
}

if let dElement = tangentVector[key] {
tangentVector[key] = .zero
return dElement
}
else { // should this fail?
tangentVector[key] = .zero
return .zero
}
})
}
}

#endif
17 changes: 17 additions & 0 deletions Sources/Differentiation/Foundation+VJPs.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#if canImport(_Differentiation)

import _Differentiation
import Foundation

/// Differentiation of ``atan2``
@derivative(of: atan2(_:_:))
public func _vjpAtan2(
y: Double, x: Double
) -> (value: Double, pullback: (Double) -> (Double, Double)) {
(
value: atan2(y, x),
pullback: { ($0 * x / (x * x + y * y), -$0 * y / (x * x + y * y)) }
)
}

#endif
53 changes: 53 additions & 0 deletions Sources/Differentiation/Optional+DifferentiableMap.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#if canImport(_Differentiation)

import _Differentiation

extension Optional where Wrapped: Differentiable {
@inlinable
@differentiable(reverse, wrt: self)
public func differentiableMap<Result: Differentiable>(
_ body: @differentiable(reverse) (Wrapped) -> Result
) -> Optional<Result> {
map(body)
}

@inlinable
@derivative(of: differentiableMap)
internal func _vjpDifferentiableMap<Result: Differentiable>(
_ body: @differentiable(reverse) (Wrapped) -> Result
) -> (
value: Optional<Result>,
pullback: (Optional<Result>.TangentVector) -> Optional.TangentVector
) {
let vwpb = self.map { valueWithPullback(at: $0, of: body) }
let bodyPullback = vwpb?.pullback

func pullback(_ vec: Optional<Result>.TangentVector) -> Optional.TangentVector {
guard let value = vec.value, let bodyPullback else { return .init(.none) }
return .init(.some(bodyPullback(value)))
}

return (value: vwpb?.value, pullback: pullback)
}

@inlinable
@derivative(of: differentiableMap)
internal func _jvpDifferentiableMap<Result: Differentiable>(
_ body: @differentiable(reverse) (Wrapped) -> Result
) -> (
value: Optional<Result>,
differential: (Optional.TangentVector) -> Optional<Result>.TangentVector
) {
let vwpb = self.map { valueWithDifferential(at: $0, of: body) }
let bodyDifferential = vwpb?.differential

func differential(_ vec: Optional.TangentVector) -> Optional<Result>.TangentVector {
guard let value = vec.value, let bodyDifferential else { return .init(.none) }
return .init(bodyDifferential(value))
}

return (value: vwpb?.value, differential: differential)
}
}

#endif
Loading