Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/pull_request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@ name: Pull Request
on: [pull_request]

jobs:
Lint:
runs-on: macos-latest
steps:
- uses: actions/checkout@v3
- name: SwiftFormat
run: swiftformat --lint --strict . --reporter github-actions-log
test-ubuntu-latest:
name: Test Swift ${{ matrix.swift }} Ubuntu Latest
strategy:
Expand Down
94 changes: 94 additions & 0 deletions .swiftformat
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
--allman false
--binarygrouping none
--decimalgrouping none
--disable redundantSelf,trailingClosures
--elseposition next-line
--exponentcase uppercase
--guardelse same-line
# TODO: add header config
--header ""
--hexliteralcase lowercase
--indent 4
--indentcase true
--ifdef no-indent
--maxwidth 140
--self insert
--semicolons never
--swiftversion 6
--trimwhitespace always
--storedvarattrs prev-line
--computedvarattrs prev-line
--typeattributes prev-line
--wrapparameters before-first
--wrapcollections before-first
--wraparguments before-first

# Explicitly enable rules to prevent automatic opt-in
# Generated from `swiftformat --rules` on v0.47.1
--rules andOperator
--rules anyObjectProtocol
--rules blankLinesAroundMark
--rules blankLinesAtEndOfScope
--rules blankLinesAtStartOfScope
--rules blankLinesBetweenScopes
--rules braces
--rules consecutiveBlankLines
--rules consecutiveSpaces
--rules duplicateImports
--rules elseOnSameLine
--rules emptyBraces
--rules enumNamespaces
--rules fileHeader
--rules hoistPatternLet
--rules indent
--rules initCoderUnavailable
--rules leadingDelimiters
--rules linebreakAtEndOfFile
--rules linebreaks
--rules modifierOrder
--rules numberFormatting
--rules preferKeyPath
--rules redundantBackticks
--rules redundantBreak
--rules redundantExtensionACL
--rules redundantFileprivate
--rules redundantGet
--rules redundantInit
--rules redundantLet
--rules redundantLetError
--rules redundantNilInit
--rules redundantObjc
--rules redundantParens
--rules redundantPattern
--rules redundantRawValues
--rules redundantReturn
--rules redundantSelf
--rules redundantType
--rules semicolons
--rules sortImports
--rules spaceAroundBraces
--rules spaceAroundBrackets
--rules spaceAroundComments
--rules spaceAroundGenerics
--rules spaceAroundOperators
--rules spaceAroundParens
--rules spaceInsideBraces
--rules spaceInsideBrackets
--rules spaceInsideComments
--rules spaceInsideGenerics
--rules spaceInsideParens
--rules strongOutlets
--rules strongifiedSelf
--rules todos
--rules trailingClosures
--rules trailingCommas
--rules trailingSpace
--rules unusedArguments
--rules void
--rules wrap
--rules wrapArguments
--rules wrapAttributes
--rules wrapMultilineStatementBraces
--rules wrapSingleLineComments
--rules yodaConditions

12 changes: 11 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,15 @@
## Getting Started

## Contributing
### Code Formatting
This package makes use of [SwiftFormat](https://github.com/nicklockwood/SwiftFormat?tab=readme-ov-file#command-line-tool), which you can install
from [homebrew](https://brew.sh/).

## Additional Resources
To apply formatting rules to all files, which you should do before submitting a PR, run from the root of the repository:

```sh
swiftformat .
```
Formatting is validated with the `--strict` flag on every PR

## Additional Resources
1 change: 0 additions & 1 deletion Sources/Differentiation/Array+Update.swift
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,3 @@ extension Array where Element: Differentiable {
}

#endif

16 changes: 9 additions & 7 deletions Sources/Differentiation/ArrayDifferentiableView+Collection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,32 @@ extension Array.DifferentiableView:
public typealias Element = Array.Element
public typealias Index = Array.Index
public typealias SubSequence = Array.SubSequence

@inlinable
public subscript(position: Index) -> Element {
_read { yield base[position] }
set(newValue) { base[position] = newValue }
}

@inlinable
public subscript(bounds: Range<Index>) -> SubSequence {
get { base[bounds] }
set(newValue) { base[bounds] = newValue }
}

@inlinable
public var startIndex: Index { base.startIndex }

@inlinable
public var endIndex: Index { base.endIndex }

@inlinable
public init() { self.init(Array<Element>()) }

@inlinable
public mutating func replaceSubrange<C>(_ subrange: Range<Self.Index>, with newElements: C) where C : Collection, Self.Element == C.Element {
public mutating func replaceSubrange<C>(_ subrange: Range<Self.Index>, with newElements: C)
where C: Collection, Self.Element == C.Element
{
base.replaceSubrange(subrange, with: newElements)
}
}
Expand Down
2 changes: 1 addition & 1 deletion Sources/Differentiation/Dictionary+Differentiation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ extension Dictionary where Value: Differentiable {
{
// When adding two dictionaries, nil values are equivalent to zeroes, so there is no need to manually zero-out
// every key's value. Instead, it is faster to create a dictionary with the single non-zero entry.
return (self[key], { tangentVector in
(self[key], { tangentVector in
if let value = tangentVector.value {
return [key: value]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ struct DerivativesOfNativeFunctionsTests {
// the type of a top-level min() function passed into gradient(at:of:).
@differentiable(reverse)
func minContainer(_ lhs: Float, _ rhs: Float) -> Float {
return min(lhs, rhs)
min(lhs, rhs)
}
let vwgLessThan = valueWithGradient(at: 2.0, 3.0, of: minContainer)
#expect(vwgLessThan.value == 2.0)
Expand All @@ -20,14 +20,14 @@ struct DerivativesOfNativeFunctionsTests {
#expect(vwgGreaterThan.value == -2.0)
#expect(vwgGreaterThan.gradient == (0.0, 1.0))
}

@Test
func testMax() {
// I'm using this container because the compiler can't quite determine
// the type of a top-level min() function passed into gradient(at:of:).
@differentiable(reverse)
func maxContainer(_ lhs: Float, _ rhs: Float) -> Float {
return max(lhs, rhs)
max(lhs, rhs)
}
let vwgLessThan = valueWithGradient(at: 2.0, 3.0, of: maxContainer)
#expect(vwgLessThan.value == 3.0)
Expand All @@ -36,14 +36,14 @@ struct DerivativesOfNativeFunctionsTests {
#expect(vwgGreaterThan.value == 20.0)
#expect(vwgGreaterThan.gradient == (1.0, 0.0))
}

@Test
func testAbs() {
// I'm using this container because the compiler can't quite determine
// the type of a top-level abs() function passed into gradient(at:of:).
@differentiable(reverse)
func absContainer(_ value: Float) -> Float {
return abs(value)
abs(value)
}

let vwgPositive = valueWithGradient(at: 4.0, of: absContainer)
Expand Down
36 changes: 18 additions & 18 deletions Tests/DifferentiationTests/Dictionary+DifferentiationTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,25 @@ struct DictionaryDifferentiationTests {
#expect(vwg.value == 3 * aMultiplier + 7 * bMultiplier)
#expect(vwg.gradient == ["a": aMultiplier, "b": bMultiplier])
}

@Test
func testDictionaryReadAndCombineValues() {
@differentiable(reverse)
func testFunction(newValues: [String: Double]) -> Double {
return 1.0 * newValues["s1"]! +
2.0 * newValues["s2"]! +
3.0 * newValues["s3"]!
1.0 * newValues["s1"]! +
2.0 * newValues["s2"]! +
3.0 * newValues["s3"]!
}

let vwg = valueWithGradient(
at: ["s1": 10.0, "s2": 20.0, "s3": 30.0],
of: testFunction
)

#expect(vwg.value == 140.0)
#expect(vwg.gradient == ["s1": 1.0, "s2": 2.0, "s3": 3.0])
}

@Test
func testDictionaryInoutWriteMethod() {
@differentiable(reverse)
Expand All @@ -51,28 +51,28 @@ struct DictionaryDifferentiationTests {
mainDict.update(at: key, with: otherValue)
}
}

@differentiable(reverse)
func inoutWrapper(dictionary: [String: Double], otherDictionary: [String: Double]) -> [String: Double] {
// we wrap the `combineByReplacingDictionaryValues`
var mainCopy = dictionary
combineByReplacingDictionaryValues(of: &mainCopy, with: otherDictionary)
return mainCopy
}

let vwpb = valueWithPullback(
at: ["s1": 10.0, "s2": 20.0, "s3": 30.0],
["s1": 2.0],//, "s2": nil, "s3": nil],
["s1": 2.0], // , "s2": nil, "s3": nil],
of: inoutWrapper
)

#expect(vwpb.value == ["s1": 2.0, "s2": 20.0, "s3": 30.0])
// we need to provide a full tangentvector to the pullback hence the keys with zero entries.
#expect(vwpb.pullback(["s1": 1.0, "s2": 0.0, "s3": 0.0]) == (["s1": 0.0, "s2": 0.0, "s3": 0.0], ["s1": 1.0]))
#expect(vwpb.pullback(["s1": 0.0, "s2": 1.0, "s3": 0.0]) == (["s1": 0.0, "s2": 1.0, "s3": 0.0], ["s1": 0.0]))
#expect(vwpb.pullback(["s1": 0.0, "s2": 0.0, "s3": 1.0]) == (["s1": 0.0, "s2": 0.0, "s3": 1.0], ["s1": 0.0]))
}

@Test
func testInoutWriteAndSumValues() {
@differentiable(reverse)
Expand All @@ -82,29 +82,29 @@ struct DictionaryDifferentiationTests {
mainDict.update(at: key, with: otherValue)
}
}

@differentiable(reverse)
func sumValues(of dictionary: [String: Double]) -> Double {
var sum: Double = 0.0
var sum = 0.0
for key in withoutDerivative(at: dictionary.keys) {
sum += dictionary[key]!
}
return sum
}

@differentiable(reverse, wrt: dictionary)
func inoutWrapperAndSum(dictionary: [String: Double], otherDictionary: [String: Double]) -> Double {
var mainCopy = dictionary
combineByReplacingDictionaryValues(of: &mainCopy, with: otherDictionary)
return sumValues(of: mainCopy)
}

let vwg = valueWithGradient(
at: ["s1": 10.0, "s2": 20.0, "s3": 30.0],
["s1": 2.0],//, "s2": nil, "s3": nil],
["s1": 2.0], // , "s2": nil, "s3": nil],
of: inoutWrapperAndSum
)

#expect(vwg.value == 52.0)
#expect(vwg.gradient == (["s1": 0.0, "s2": 1.0, "s3": 1.0], ["s1": 1.0]))
}
Expand Down
2 changes: 1 addition & 1 deletion Tests/DifferentiationTests/Dictionary+UpdateTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ struct DictionaryUpdateTests {
let newB: Double = 7

let valAndGrad = valueWithGradient(at: dictionary, newA, newB, of: writeAndReadFromDictionary)

#expect(valAndGrad.value == newA * aMultiplier + newB * bMultiplier)
#expect(valAndGrad.gradient == (["a": 0, "b": 0], aMultiplier, bMultiplier))
}
Expand Down
2 changes: 1 addition & 1 deletion Tests/DifferentiationTests/Foundation+VJPsTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ struct FoundationVJPsTests {
// the type of a top-level min() function passed into gradient(at:of:).
@differentiable(reverse)
func atan2Container(_ p1: Double, _ p2: Double) -> Double {
return atan2(p1, p2)
atan2(p1, p2)
}
let vwg = valueWithGradient(at: 1.0, 1.0, of: atan2Container)
#expect(vwg.value == .pi / 4)
Expand Down
1 change: 0 additions & 1 deletion Tests/DifferentiationTests/Sequence+MinMaxVJPsTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import Testing

@Suite("Sequence+MinMaxVJPs")
struct SequenceMinMaxVJPs {

let inputArray = [2.0, 1.0, 3.0]

@Test
Expand Down
Loading