Skip to content

Commit 1ffb211

Browse files
committed
feat: add Differentiable conformance to InlineArray
1 parent 6160c0c commit 1ffb211

File tree

4 files changed

+255
-3
lines changed

4 files changed

+255
-3
lines changed

.github/workflows/pull_request.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
name: Test Swift ${{ matrix.swift }} Ubuntu Latest
1414
strategy:
1515
matrix:
16-
swift: ["6.0.3", "6.1", "6.2"]
16+
swift: ["6.2"]
1717
runs-on: ubuntu-latest
1818
container: swift:${{ matrix.swift }}
1919
steps:
@@ -24,7 +24,7 @@ jobs:
2424
name: Test Swift ${{ matrix.swift }} macOS
2525
strategy:
2626
matrix:
27-
swift: ["6.0.3", "6.1", "6.2"]
27+
swift: ["6.2"]
2828
runs-on: macos-15
2929
steps:
3030
- uses: actions/checkout@v4

Package.swift

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1-
// swift-tools-version: 6.0
1+
// swift-tools-version: 6.2
22
// The swift-tools-version declares the minimum version of Swift required to build this package.
33

44
import PackageDescription
55

66
let package = Package(
77
name: "swift-differentiation",
8+
platforms: [
9+
// we only support the latest versions of OSes as `@available(...)` is not yet supported for differentiation.
10+
.macOS("26"),
11+
.iOS("26"),
12+
],
813
products: [
914
.library(
1015
name: "Differentiation",
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
#if canImport(_Differentiation)
2+
3+
import _Differentiation
4+
5+
extension InlineArray: @retroactive Differentiable where Element: Differentiable {
6+
public typealias TangentVector = InlineArray<count, Element.TangentVector>
7+
8+
@inlinable
9+
public mutating func move(by offset: TangentVector) {
10+
for i in self.indices {
11+
self[i].move(by: offset[i])
12+
}
13+
}
14+
15+
@derivative(of: init)
16+
@_alwaysEmitIntoClient
17+
public static func _vjpInit(repeating value: Element) -> (value: Self, pullback: (TangentVector) -> Element.TangentVector) {
18+
(
19+
value: Self(repeating: value),
20+
pullback: { v in
21+
var result: Element.TangentVector = .zero
22+
for i in v.indices {
23+
result += v[i]
24+
}
25+
return result
26+
}
27+
)
28+
}
29+
30+
@inlinable
31+
public func read(_ i: Index) -> Element {
32+
self[i]
33+
}
34+
35+
@derivative(of: read)
36+
@inlinable
37+
public func _vjpRead(_ i: Index) -> (value: Element, pullback: (Element.TangentVector) -> TangentVector) {
38+
(
39+
value: self[i],
40+
pullback: { v in
41+
var array = InlineArray<count, Element.TangentVector>(repeating: .zero)
42+
array[i] = v
43+
return array
44+
}
45+
)
46+
}
47+
48+
@inlinable
49+
public mutating func update(at i: Index, with value: Element) {
50+
self[i] = value
51+
}
52+
53+
@derivative(of: update)
54+
@inlinable
55+
public mutating func _vjpUpdate(
56+
at i: Index,
57+
with value: Element
58+
) -> (value: Void, pullback: (inout TangentVector) -> Element.TangentVector) {
59+
self[i] = value
60+
return (
61+
value: (),
62+
pullback: { (v: inout TangentVector) in
63+
let dElement = v[i]
64+
v[i] = Element.TangentVector.zero
65+
return dElement
66+
}
67+
)
68+
}
69+
}
70+
71+
extension InlineArray: @retroactive AdditiveArithmetic where Element: AdditiveArithmetic {
72+
@inlinable
73+
public static var zero: InlineArray<count, Element> {
74+
.init(repeating: .zero)
75+
}
76+
77+
@inlinable
78+
public static func + (lhs: InlineArray<count, Element>, rhs: InlineArray<count, Element>) -> InlineArray<count, Element> {
79+
InlineArray<count, Element> { lhs[$0] + rhs[$0] }
80+
}
81+
82+
@inlinable
83+
public static func - (lhs: InlineArray<count, Element>, rhs: InlineArray<count, Element>) -> InlineArray<count, Element> {
84+
InlineArray<count, Element> { lhs[$0] - rhs[$0] }
85+
}
86+
}
87+
88+
extension InlineArray where Element: Differentiable & AdditiveArithmetic {
89+
@derivative(of: +)
90+
@inlinable
91+
public static func _vjpAdd(
92+
lhs: InlineArray<count, Element>,
93+
rhs: InlineArray<count, Element>
94+
) -> (
95+
value: InlineArray<count, Element>,
96+
pullback: (InlineArray<count, Element.TangentVector>) -> (
97+
InlineArray<count, Element.TangentVector>,
98+
InlineArray<count, Element.TangentVector>
99+
)
100+
) {
101+
(
102+
value: lhs + rhs,
103+
pullback: { v in
104+
(v, v)
105+
}
106+
)
107+
}
108+
109+
@derivative(of: -)
110+
@inlinable
111+
public static func _vjpSubtract(
112+
lhs: InlineArray<count, Element>,
113+
rhs: InlineArray<count, Element>
114+
) -> (
115+
value: InlineArray<count, Element>,
116+
pullback: (InlineArray<count, Element.TangentVector>) -> (
117+
InlineArray<count, Element.TangentVector>,
118+
InlineArray<count, Element.TangentVector>
119+
)
120+
) {
121+
(
122+
value: lhs - rhs,
123+
pullback: { v in
124+
(v, .zero - v)
125+
}
126+
)
127+
}
128+
}
129+
130+
// Temporary conformance to `Equatable` as this will eventually land in the stdlib
131+
extension InlineArray: @retroactive Equatable where Element: Equatable {
132+
@inlinable
133+
public static func == (lhs: InlineArray<count, Element>, rhs: InlineArray<count, Element>) -> Bool {
134+
for i in lhs.indices {
135+
if lhs[i] != rhs[i] { return false }
136+
}
137+
return true
138+
}
139+
}
140+
141+
#endif
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import Differentiation
2+
import Testing
3+
4+
#if canImport(_Differentiation)
5+
6+
@Suite
7+
struct InlineArrayTests {
8+
// Test that the zero additive arithmetic gives an array of zeros
9+
@Test("zero produces repeating zero elements")
10+
func zeroProducesZeros() {
11+
let z = InlineArray<2, Double>.zero
12+
#expect(z[0] == 0.0)
13+
#expect(z[1] == 0.0)
14+
}
15+
16+
// Test + and - work elementwise
17+
@Test("additive arithmetic + and − are elementwise")
18+
func additiveArithmeticAddSubtract() {
19+
let a = InlineArray<2, Double>(repeating: 1.5) // [1.5, 1.5]
20+
let b: InlineArray<2, Double> = [2.0, 3.0]
21+
let sum = a + b
22+
#expect(sum[0] == 3.5)
23+
#expect(sum[1] == 4.5)
24+
let diff = b - a
25+
#expect(diff[0] == 0.5)
26+
#expect(diff[1] == 1.5)
27+
}
28+
29+
// Test differentiable init(repeating:)
30+
@Test("vjp of init(repeating:) aggregates tangent inputs correctly")
31+
func testVJPInitRepeating() {
32+
// For differentiable init(repeating:), the pullback should sum all elements of the tangent vector
33+
let repeated = InlineArray<2, Double>(repeating: 4.0)
34+
// forward run
35+
// Now test pullback: apply VJP
36+
// The API for using VJP: call `valueWithPullback` or similar
37+
let (value, pullback) = valueWithPullback(at: 4.0, of: { value in InlineArray<2, Double>(repeating: value) })
38+
// value should equal what init(repeating:) produces
39+
#expect(value == repeated)
40+
41+
// construct some tangent vector
42+
let tv: InlineArray<2, Double> = [10.0, 20.0]
43+
// apply pullback
44+
let back = pullback(tv)
45+
// Should equal sum of elements, i.e. 10 + 20 == 30, as Double’s tangent
46+
#expect(back == 30.0)
47+
}
48+
49+
@Test("vjp of read is correct")
50+
func testVJPRead() {
51+
let arr: InlineArray<2, Double> = [5.0, 7.0]
52+
let index = 1
53+
let (value, pullback) = valueWithPullback(at: arr, of: { value in value.read(index) })
54+
#expect(value == 7.0)
55+
// Tangent vector for output
56+
let outTangent = 3.0
57+
let backVec = pullback(outTangent) // this returns a T2
58+
// It should have zero except at that index where it's outTangent
59+
#expect(backVec[0] == 0.0)
60+
#expect(backVec[1] == 3.0)
61+
}
62+
63+
@Test("vjp of update mutating works")
64+
func testVJPUpdate() {
65+
let arr: InlineArray<2, Double> = [1.0, 2.0]
66+
let index = 0
67+
let newValue = 100.0
68+
69+
// Apply the derivative via VJP of update
70+
// Because update is mutating, the pullback signature is a bit different
71+
// Use the manual _vjpUpdate
72+
let (value, pullback) = valueWithPullback(
73+
at: arr, newValue,
74+
of: { arr, newValue in
75+
var arr = arr
76+
arr.update(at: index, with: newValue)
77+
return arr
78+
}
79+
)
80+
// After update, arr[0] should be newValue
81+
#expect(value[0] == 100.0)
82+
#expect(value[1] == 2.0)
83+
84+
// Suppose we have a tangent vector v for the whole array
85+
let tangent: InlineArray<2, Double> = [10.0, 20.0]
86+
// Pullback should take and zero out the tangent component at `index`, returning the old tangent at that index
87+
let result = pullback(tangent)
88+
#expect(result.1 == 10.0)
89+
// After pullback, tangent[0] should be zero, tangent[1] remains 20
90+
#expect(result.0[0] == 0.0)
91+
#expect(result.0[1] == 20.0)
92+
}
93+
94+
// You could test move(by:) on the tangent vector space
95+
@Test("move(by:) translates elements correctly")
96+
func testMoveBy() {
97+
var arr: InlineArray<2, Double> = [1.0, 2.0]
98+
let offset: InlineArray<2, Double> = [1.0, 2.0]
99+
arr.move(by: offset)
100+
// After move, arr should be [1+1, 2+2] == [2,4]
101+
#expect(arr[0] == 2.0)
102+
#expect(arr[1] == 4.0)
103+
}
104+
}
105+
106+
#endif

0 commit comments

Comments
 (0)