Skip to content

Commit eff5d52

Browse files
committed
ci: apply lint
There's some autodiff related bugs in swift-format unfortunately. We might have to use NickLockwood's version since that doesn't have these.
1 parent 26a834b commit eff5d52

11 files changed

+92
-94
lines changed

Package.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ let package = Package(
99
.library(
1010
name: "Differentiation",
1111
targets: ["Differentiation"]
12-
),
12+
)
1313
],
1414
targets: [
1515
.target(name: "Differentiation"),

Sources/Differentiation/Array+Update.swift

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,19 @@ extension Array where Element: Differentiable {
3434
) -> (value: Void, pullback: (inout TangentVector) -> (Element.TangentVector)) {
3535
update(at: index, with: newValue)
3636
let forwardCount = self.count
37-
return ((), { tangentVector in
38-
// manual zero tangent initialization
39-
if tangentVector.base.count < forwardCount {
40-
tangentVector.base = .init(repeating: .zero, count: forwardCount)
37+
return (
38+
(),
39+
{ tangentVector in
40+
// manual zero tangent initialization
41+
if tangentVector.base.count < forwardCount {
42+
tangentVector.base = .init(repeating: .zero, count: forwardCount)
43+
}
44+
let dElement = tangentVector[index]
45+
tangentVector.base[index] = .zero
46+
return dElement
4147
}
42-
let dElement = tangentVector[index]
43-
tangentVector.base[index] = .zero
44-
return dElement
45-
})
48+
)
4649
}
4750
}
4851

4952
#endif
50-

Sources/Differentiation/ArrayDifferentiableView+Collection.swift

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,35 +9,35 @@ extension Array.DifferentiableView:
99
@retroactive RandomAccessCollection,
1010
@retroactive BidirectionalCollection,
1111
@retroactive MutableCollection
12-
where Element: Differentiable
13-
{
12+
where Element: Differentiable {
1413
public typealias Element = Array.Element
1514
public typealias Index = Array.Index
1615
public typealias SubSequence = Array.SubSequence
17-
16+
1817
@inlinable
1918
public subscript(position: Index) -> Element {
2019
_read { yield base[position] }
2120
set(newValue) { base[position] = newValue }
2221
}
23-
22+
2423
@inlinable
2524
public subscript(bounds: Range<Index>) -> SubSequence {
2625
get { base[bounds] }
2726
set(newValue) { base[bounds] = newValue }
2827
}
29-
28+
3029
@inlinable
3130
public var startIndex: Index { base.startIndex }
32-
31+
3332
@inlinable
3433
public var endIndex: Index { base.endIndex }
35-
34+
3635
@inlinable
3736
public init() { self.init(Array<Element>()) }
38-
37+
3938
@inlinable
40-
public mutating func replaceSubrange<C>(_ subrange: Range<Self.Index>, with newElements: C) where C : Collection, Self.Element == C.Element {
39+
public mutating func replaceSubrange<C>(_ subrange: Range<Self.Index>, with newElements: C)
40+
where C: Collection, Self.Element == C.Element {
4141
base.replaceSubrange(subrange, with: newElements)
4242
}
4343
}

Sources/Differentiation/DerivativesOfNativeFunctions.swift

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,10 @@ public func _vjpMin<T: Comparable & Differentiable>(
1111
_ rhs: T
1212
) -> (value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)) {
1313
func pullback(_ tangentVector: T.TangentVector) -> (T.TangentVector, T.TangentVector) {
14-
if lhs <= rhs {
15-
return (tangentVector, .zero)
16-
}
17-
else {
14+
guard lhs <= rhs else {
1815
return (.zero, tangentVector)
1916
}
17+
return (tangentVector, .zero)
2018
}
2119
return (value: min(lhs, rhs), pullback: pullback)
2220
}
@@ -30,12 +28,10 @@ public func _vjpMax<T: Comparable & Differentiable>(
3028
_ rhs: T
3129
) -> (value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)) {
3230
func pullback(_ tangentVector: T.TangentVector) -> (T.TangentVector, T.TangentVector) {
33-
if lhs < rhs {
34-
return (.zero, tangentVector)
35-
}
36-
else {
31+
guard lhs < rhs else {
3732
return (tangentVector, .zero)
3833
}
34+
return (.zero, tangentVector)
3935
}
4036
return (value: max(lhs, rhs), pullback: pullback)
4137
}
@@ -47,12 +43,10 @@ public func _vjpAbs<T: Comparable & SignedNumeric & Differentiable>(_ value: T)
4743
-> (value: T, pullback: (T.TangentVector) -> T.TangentVector)
4844
{
4945
func pullback(_ tangentVector: T.TangentVector) -> T.TangentVector {
50-
if value < 0 {
51-
return .zero - tangentVector
52-
}
53-
else {
46+
guard value < 0 else {
5447
return tangentVector
5548
}
49+
return .zero - tangentVector
5650
}
5751
return (value: abs(value), pullback: pullback)
5852
}

Sources/Differentiation/Dictionary+Differentiation.swift

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ extension Dictionary: @retroactive Differentiable where Value: Differentiable {
1919
}
2020

2121
public var zeroTangentVectorInitializer: () -> TangentVector {
22-
let listOfKeys = keys // capturing only what's needed, not the entire self, in order to not waste memory
22+
let listOfKeys = keys // capturing only what's needed, not the entire self, in order to not waste memory
2323
func initializer() -> Self.TangentVector {
2424
return listOfKeys.reduce(into: [Key: Value.TangentVector]()) { $0[$1] = Value.TangentVector.zero }
2525
}
@@ -49,14 +49,15 @@ extension Dictionary where Value: Differentiable {
4949
{
5050
// When adding two dictionaries, nil values are equivalent to zeroes, so there is no need to manually zero-out
5151
// every key's value. Instead, it is faster to create a dictionary with the single non-zero entry.
52-
return (self[key], { tangentVector in
53-
if let value = tangentVector.value {
52+
(
53+
self[key],
54+
{ tangentVector in
55+
guard let value = tangentVector.value else {
56+
return .zero
57+
}
5458
return [key: value]
5559
}
56-
else {
57-
return .zero
58-
}
59-
})
60+
)
6061
}
6162
}
6263
#endif

Sources/Differentiation/Dictionary+Update.swift

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,32 +30,33 @@ extension Dictionary where Value: Differentiable {
3030
@inlinable
3131
public mutating func _vjpUpdate(
3232
at key: Key,
33-
with newValue: Value // TODO: this should be optional?
33+
with newValue: Value // TODO: this should be optional?
3434
) -> (value: Void, pullback: (inout TangentVector) -> (Value.TangentVector)) {
3535
update(at: key, with: newValue)
3636

3737
let forwardCount = count
38-
let forwardKeys = keys // may be heavy to capture all of these, not sure how to do without them though
39-
40-
return ((), { tangentVector in
41-
// manual zero tangent initialization
42-
// TODO: Should we consider missing keys as a complete tangentvector with zero values for those keys?
43-
if tangentVector.count < forwardCount { // TODO: is this the correct check keys could still differ
44-
tangentVector = Self.TangentVector() // TODO: should we be replacing this or merging
45-
for key in forwardKeys {
46-
tangentVector[key] = .zero
38+
let forwardKeys = keys // may be heavy to capture all of these, not sure how to do without them though
39+
40+
return (
41+
(),
42+
{ tangentVector in
43+
// manual zero tangent initialization
44+
// TODO: Should we consider missing keys as a complete tangentvector with zero values for those keys?
45+
if tangentVector.count < forwardCount { // TODO: is this the correct check keys could still differ
46+
tangentVector = Self.TangentVector() // TODO: should we be replacing this or merging
47+
for key in forwardKeys {
48+
tangentVector[key] = .zero
49+
}
4750
}
48-
}
4951

50-
if let dElement = tangentVector[key] {
52+
guard let dElement = tangentVector[key] else { // should this fail?
53+
tangentVector[key] = .zero
54+
return .zero
55+
}
5156
tangentVector[key] = .zero
5257
return dElement
5358
}
54-
else { // should this fail?
55-
tangentVector[key] = .zero
56-
return .zero
57-
}
58-
})
59+
)
5960
}
6061
}
6162

Sources/Differentiation/Sequence+MinMaxVJPs.swift

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
import _Differentiation
44

5-
extension Sequence where
6-
Self: Collection, // we constrain to conform to collection cause otherwise we can't access any values by index
5+
extension Sequence
6+
where
7+
Self: Collection, // we constrain to conform to collection cause otherwise we can't access any values by index
78
Self: Differentiable,
8-
Self.TangentVector: RangeReplaceableCollection, // we constrain the tangentvector to be able to create a value and write to it
9+
Self.TangentVector: RangeReplaceableCollection, // we constrain the tangentvector to be able to create a value and write to it
910
Self.TangentVector.Element == Element.TangentVector,
1011
Element: Differentiable,
1112
Element: Comparable
@@ -19,29 +20,30 @@ extension Sequence where
1920
value: Element?,
2021
pullback: (Element?.TangentVector) -> (Self.TangentVector)
2122
) where Self.Index == Self.TangentVector.Index {
22-
let index = withoutDerivative(at: self.indices.max { self[$0] < self[$1] }) // we grab the index of the element with the max value
23+
let index = withoutDerivative(at: self.indices.max { self[$0] < self[$1] }) // we grab the index of the element with the max value
2324
return (
24-
value: index.map { self[$0] }, // if the index is nil, we return nil otherwise we grab the value at the index
25+
value: index.map { self[$0] }, // if the index is nil, we return nil otherwise we grab the value at the index
2526
pullback: { vector in
26-
var dSelf = Self
27+
var dSelf =
28+
Self
2729
.TangentVector(
2830
repeating: .zero,
2931
count: self
3032
.count
31-
) // we create a zero tangentvector we need `RangeReplaceableCollection` conformance in order to do this
33+
) // we create a zero tangentvector we need `RangeReplaceableCollection` conformance in order to do this
3234
if let vectorValue = vector.value,
33-
let index = index
35+
let index = index
3436
{
3537
// if an index was found and our tangentvector's value is non nil we set the value at index of our tangentvector to the
3638
// provided tangentvector value
3739
dSelf
3840
.replaceSubrange(
3941
index ..< dSelf.index(after: index),
4042
with: [vectorValue]
41-
) // we use `RangeReplaceableCollection`'s method here in order to not have to also constrain our TangentVector to
43+
) // we use `RangeReplaceableCollection`'s method here in order to not have to also constrain our TangentVector to
4244
// `MutableCollection`
4345
}
44-
return dSelf // return the tangentvector
46+
return dSelf // return the tangentvector
4547
}
4648
)
4749
}
@@ -55,19 +57,19 @@ extension Sequence where
5557
value: Element?,
5658
pullback: (Element?.TangentVector) -> (Self.TangentVector)
5759
) where Self.Index == Self.TangentVector.Index {
58-
let index = withoutDerivative(at: self.indices.min { self[$0] < self[$1] }) // we grab the index of the element with the max value
60+
let index = withoutDerivative(at: self.indices.min { self[$0] < self[$1] }) // we grab the index of the element with the max value
5961
return (
60-
value: index.map { self[$0] }, // if the index is nil, we return nil otherwise we grab the value at the index
62+
value: index.map { self[$0] }, // if the index is nil, we return nil otherwise we grab the value at the index
6163
pullback: { vector in
62-
var dSelf = Self.TangentVector(repeating: .zero, count: self.count) // we create a zero tangentvector
64+
var dSelf = Self.TangentVector(repeating: .zero, count: self.count) // we create a zero tangentvector
6365
if let vectorValue = vector.value,
64-
let index = index
66+
let index = index
6567
{
6668
// if an index was found and our tangentvector's value is non nil we set the value at index of our tangentvector to the
6769
// provided tangentvector value
6870
dSelf.replaceSubrange(index ..< dSelf.index(after: index), with: [vectorValue])
6971
}
70-
return dSelf // return the tangentvector
72+
return dSelf // return the tangentvector
7173
}
7274
)
7375
}

Tests/DifferentiationTests/DerivativesOfNativeFunctionsTests.swift

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ struct DerivativesOfNativeFunctionsTests {
1111
// the type of a top-level min() function passed into gradient(at:of:).
1212
@differentiable(reverse)
1313
func minContainer(_ lhs: Float, _ rhs: Float) -> Float {
14-
return min(lhs, rhs)
14+
min(lhs, rhs)
1515
}
1616
let vwgLessThan = valueWithGradient(at: 2.0, 3.0, of: minContainer)
1717
#expect(vwgLessThan.value == 2.0)
@@ -20,14 +20,14 @@ struct DerivativesOfNativeFunctionsTests {
2020
#expect(vwgGreaterThan.value == -2.0)
2121
#expect(vwgGreaterThan.gradient == (0.0, 1.0))
2222
}
23-
23+
2424
@Test
2525
func testMax() {
2626
// I'm using this container because the compiler can't quite determine
2727
// the type of a top-level min() function passed into gradient(at:of:).
2828
@differentiable(reverse)
2929
func maxContainer(_ lhs: Float, _ rhs: Float) -> Float {
30-
return max(lhs, rhs)
30+
max(lhs, rhs)
3131
}
3232
let vwgLessThan = valueWithGradient(at: 2.0, 3.0, of: maxContainer)
3333
#expect(vwgLessThan.value == 3.0)
@@ -36,14 +36,14 @@ struct DerivativesOfNativeFunctionsTests {
3636
#expect(vwgGreaterThan.value == 20.0)
3737
#expect(vwgGreaterThan.gradient == (1.0, 0.0))
3838
}
39-
39+
4040
@Test
4141
func testAbs() {
4242
// I'm using this container because the compiler can't quite determine
4343
// the type of a top-level abs() function passed into gradient(at:of:).
4444
@differentiable(reverse)
4545
func absContainer(_ value: Float) -> Float {
46-
return abs(value)
46+
abs(value)
4747
}
4848

4949
let vwgPositive = valueWithGradient(at: 4.0, of: absContainer)

0 commit comments

Comments
 (0)