Skip to content

Commit 4259585

Browse files
committed
add some tests and fix incorrect derivative
1 parent 448b9ce commit 4259585

File tree

4 files changed

+262
-5
lines changed

4 files changed

+262
-5
lines changed

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
# swift-numerics-differentiable
22

3+
This package attempts to add more Differentiable capabilities to the existing [swift-numerics](https://github.com/apple/swift-numerics) package. Every target in swift-numerics has a Differentiable counterpart that `@_exported import`s the original module such that when you import `NumericsDifferentiable` you will also get all the contents of the `Numerics` module from swift-numerics.
4+
5+
## RealModule Differentiable
6+
- Registers derivatives to the `Float` and `Double` conformances to `ElementaryFunctions` and `RealFunctions` from swift-numerics.
7+
- Conforms all `SIMD{n}` types to `ElementaryFunctions` and adds most of the protocol requirements from `RealFunctions` as well (`signGamma` is not implementable)
8+
- Registers derivatives for all the provided `ElementaryFunctions` and `RealFunctions` implementations on SIMD{n}
9+
- Tries to leverage Apple's `simd` framework to accelerate these operations where possible on Apple platforms.
10+
311
## Contributing
412
### Code Formatting
513
This package makes use of [SwiftFormat](https://github.com/nicklockwood/SwiftFormat?tab=readme-ov-file#command-line-tool), which you can install

Sources/CodeGeneratorExecutable/CodeGenerator.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ import Foundation
33
@main
44
struct CodeGenerator {
55
static func main() throws {
6-
// Use swift-argument-parser or just CommandLine, here we just imply that 2 paths are passed in: input and output
76
guard CommandLine.arguments.count == 2 else {
87
throw CodeGeneratorError.invalidArguments
98
}

Sources/CodeGeneratorExecutable/RealFunctionsDerivativesGenerator.swift

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,32 @@ enum RealFunctionsDerivativesGenerator {
88
// MARK: ElementaryFunctions derivatives
99
extension \(type) {
1010
@derivative(of: exp)
11+
@_transparent
1112
public static func _vjpExp(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
1213
let value = exp(x)
1314
return (value: value, pullback: { v in v * value })
1415
}
1516
1617
@derivative(of: expMinusOne)
18+
@_transparent
1719
public static func _vjpExpMinusOne(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
1820
return (value: expMinusOne(x), pullback: { v in v * exp(x) })
1921
}
2022
2123
@derivative(of: cosh)
24+
@_transparent
2225
public static func _vjpCosh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
2326
(value: cosh(x), pullback: { v in sinh(x) })
2427
}
2528
2629
@derivative(of: sinh)
30+
@_transparent
2731
public static func _vjpSinh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
2832
(value: sinh(x), pullback: { v in cosh(x) })
2933
}
3034
3135
@derivative(of: tanh)
36+
@_transparent
3237
public static func _vjpTanh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
3338
(
3439
value: tanh(x),
@@ -40,16 +45,19 @@ enum RealFunctionsDerivativesGenerator {
4045
}
4146
4247
@derivative(of: cos)
48+
@_transparent
4349
public static func _vjpCos(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
4450
(value: cos(x), pullback: { v in -v * sin(x) })
4551
}
4652
4753
@derivative(of: sin)
54+
@_transparent
4855
public static func _vjpSin(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
4956
(value: sin(x), pullback: { v in v * cos(x) })
5057
}
5158
5259
@derivative(of: tan)
60+
@_transparent
5361
public static func _vjpTan(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
5462
(
5563
value: tan(x),
@@ -61,65 +69,77 @@ enum RealFunctionsDerivativesGenerator {
6169
}
6270
6371
@derivative(of: log(_:))
72+
@_transparent
6473
public static func _vjpLog(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
6574
(value: log(x), pullback: { v in v / x })
6675
}
6776
6877
@derivative(of: acosh)
78+
@_transparent
6979
public static func _vjpAcosh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
7080
// only valid for x > 1
7181
return (value: acosh(x), pullback: { v in v / sqrt(x * x - 1) })
7282
}
7383
7484
@derivative(of: asinh)
85+
@_transparent
7586
public static func _vjpAsinh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
7687
(value: asinh(x), pullback: { v in v / sqrt(x * x + 1) })
7788
}
7889
7990
@derivative(of: atanh)
91+
@_transparent
8092
public static func _vjpAtanh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
8193
(value: atanh(x), pullback: { v in v / (1 - x * x) })
8294
}
8395
8496
@derivative(of: acos)
97+
@_transparent
8598
public static func _vjpAcos(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
86-
(value: acos(x), pullback: { v in -v / (1 - x * x) })
99+
(value: acos(x), pullback: { v in -v / .sqrt(1 - x * x) })
87100
}
88101
89102
@derivative(of: asin)
103+
@_transparent
90104
public static func _vjpAsin(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
91-
(value: asin(x), pullback: { v in v / (1 - x * x) })
105+
(value: asin(x), pullback: { v in v / .sqrt(1 - x * x) })
92106
}
93107
94108
@derivative(of: atan)
109+
@_transparent
95110
public static func _vjpAtan(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
96111
(value: atan(x), pullback: { v in v / (x * x + 1) })
97112
}
98113
99114
@derivative(of: log(onePlus:))
115+
@_transparent
100116
public static func _vjpLog(onePlus x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
101117
(value: log(onePlus: x), pullback: { v in v / (1 + x) })
102118
}
103119
104120
@derivative(of: pow)
121+
@_transparent
105122
public static func _vjpPow(_ x: \(type), _ y: \(type)) -> (value: \(type), pullback: (\(type)) -> (\(type), \(type))) {
106123
let value = pow(x, y)
107124
// pullback wrt y is not defined for (x < 0) and (x = 0, y = 0)
108125
return (value: value, pullback: { v in (v * y * pow(x, y - 1), v * value * log(x)) })
109126
}
110127
111128
@derivative(of: pow)
129+
@_transparent
112130
public static func _vjpPow(_ x: \(type), _ n: Int) -> (value: \(type), pullback: (\(type)) -> \(type)) {
113131
(value: pow(x, n), pullback: { v in v * \(floatingPointType)(n) * pow(x, n - 1) })
114132
}
115133
116134
@derivative(of: sqrt)
135+
@_transparent
117136
public static func _vjpSqrt(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
118137
let value = sqrt(x)
119138
return (value: value, pullback: { v in v / (2 * value) })
120139
}
121140
122141
@derivative(of: root)
142+
@_transparent
123143
public static func _vjpRoot(_ x: \(type), _ n: Int) -> (value: \(type), pullback: (\(type)) -> \(type)) {
124144
let value = root(x, n)
125145
return (value: value, pullback: { v in v * value / (x * \(floatingPointType)(n)) })
@@ -129,48 +149,57 @@ enum RealFunctionsDerivativesGenerator {
129149
// MARK: RealFunctions derivatives
130150
extension \(type) {
131151
@derivative(of: erf)
152+
@_transparent
132153
public static func _vjpErf(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
133154
(value: erf(x), pullback: { v in 2 * exp(-x * x) / .sqrt(\(floatingPointType).pi) })
134155
}
135156
136157
@derivative(of: erfc)
158+
@_transparent
137159
public static func _vjpErfc(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
138160
(value: erfc(x), pullback: { v in -2 * exp(-x * x) / .sqrt(\(floatingPointType).pi) })
139161
}
140162
141163
@derivative(of: exp2)
164+
@_transparent
142165
public static func _vjpExp2(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
143166
let value = exp2(x)
144167
return (value, { v in v * value * .log(2) })
145168
}
146169
147170
@derivative(of: exp10)
171+
@_transparent
148172
public static func _vjpExp10(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
149173
let value = exp10(x)
150174
return (value, { v in v * value * .log(10) })
151175
}
152176
153177
@derivative(of: gamma)
178+
@_transparent
154179
public static func _vjpGamma(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
155180
fatalError("unimplemented")
156181
}
157182
158183
@derivative(of: log2)
184+
@_transparent
159185
public static func _vjpLog2(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
160186
(value: log2(x), pullback: { v in v / (.log(2) * x) })
161187
}
162188
163189
@derivative(of: log10)
190+
@_transparent
164191
public static func _vjpLog10(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
165192
(value: log10(x), pullback: { v in v / (.log(10) * x) })
166193
}
167194
168195
@derivative(of: logGamma)
196+
@_transparent
169197
public static func _vjpLogGamma(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
170198
fatalError("unimplemented")
171199
}
172200
173201
@derivative(of: atan2)
202+
@_transparent
174203
public static func _vjpAtan2(y: \(type), x: \(type)) -> (value: \(type), pullback: (\(type)) -> (\(type), \(type))) {
175204
(
176205
value: atan2(y: y, x: x),
@@ -182,6 +211,7 @@ enum RealFunctionsDerivativesGenerator {
182211
}
183212
184213
@derivative(of: hypot)
214+
@_transparent
185215
public static func _vjpHypot(_ x: \(type), _ y: \(type)) -> (value: \(type), pullback: (\(type)) -> (\(type), \(type))) {
186216
(
187217
value: hypot(x, y),
@@ -196,6 +226,7 @@ enum RealFunctionsDerivativesGenerator {
196226
// MARK: FloatingPoint functions derivatives
197227
extension \(type) {
198228
@derivative(of: abs)
229+
@_transparent
199230
public static func _vjpAbs(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
200231
\({
201232
if type == floatingPointType {

0 commit comments

Comments
 (0)