@@ -8,27 +8,32 @@ enum RealFunctionsDerivativesGenerator {
88 // MARK: ElementaryFunctions derivatives
99 extension \( type) {
1010 @derivative(of: exp)
11+ @_transparent
1112 public static func _vjpExp(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
1213 let value = exp(x)
1314 return (value: value, pullback: { v in v * value })
1415 }
1516
1617 @derivative(of: expMinusOne)
18+ @_transparent
1719 public static func _vjpExpMinusOne(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
1820 return (value: expMinusOne(x), pullback: { v in v * exp(x) })
1921 }
2022
2123 @derivative(of: cosh)
24+ @_transparent
2225 public static func _vjpCosh(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
2326 (value: cosh(x), pullback: { v in sinh(x) })
2427 }
2528
2629 @derivative(of: sinh)
30+ @_transparent
2731 public static func _vjpSinh(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
2832 (value: sinh(x), pullback: { v in cosh(x) })
2933 }
3034
3135 @derivative(of: tanh)
36+ @_transparent
3237 public static func _vjpTanh(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
3338 (
3439 value: tanh(x),
@@ -40,16 +45,19 @@ enum RealFunctionsDerivativesGenerator {
4045 }
4146
4247 @derivative(of: cos)
48+ @_transparent
4349 public static func _vjpCos(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
4450 (value: cos(x), pullback: { v in -v * sin(x) })
4551 }
4652
4753 @derivative(of: sin)
54+ @_transparent
4855 public static func _vjpSin(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
4956 (value: sin(x), pullback: { v in v * cos(x) })
5057 }
5158
5259 @derivative(of: tan)
60+ @_transparent
5361 public static func _vjpTan(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
5462 (
5563 value: tan(x),
@@ -61,65 +69,77 @@ enum RealFunctionsDerivativesGenerator {
6169 }
6270
6371 @derivative(of: log(_:))
72+ @_transparent
6473 public static func _vjpLog(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
6574 (value: log(x), pullback: { v in v / x })
6675 }
6776
6877 @derivative(of: acosh)
78+ @_transparent
6979 public static func _vjpAcosh(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
7080 // only valid for x > 1
7181 return (value: acosh(x), pullback: { v in v / sqrt(x * x - 1) })
7282 }
7383
7484 @derivative(of: asinh)
85+ @_transparent
7586 public static func _vjpAsinh(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
7687 (value: asinh(x), pullback: { v in v / sqrt(x * x + 1) })
7788 }
7889
7990 @derivative(of: atanh)
91+ @_transparent
8092 public static func _vjpAtanh(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
8193 (value: atanh(x), pullback: { v in v / (1 - x * x) })
8294 }
8395
8496 @derivative(of: acos)
97+ @_transparent
8598 public static func _vjpAcos(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
86- (value: acos(x), pullback: { v in -v / (1 - x * x) })
99+ (value: acos(x), pullback: { v in -v / .sqrt (1 - x * x) })
87100 }
88101
89102 @derivative(of: asin)
103+ @_transparent
90104 public static func _vjpAsin(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
91- (value: asin(x), pullback: { v in v / (1 - x * x) })
105+ (value: asin(x), pullback: { v in v / .sqrt (1 - x * x) })
92106 }
93107
94108 @derivative(of: atan)
109+ @_transparent
95110 public static func _vjpAtan(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
96111 (value: atan(x), pullback: { v in v / (x * x + 1) })
97112 }
98113
99114 @derivative(of: log(onePlus:))
115+ @_transparent
100116 public static func _vjpLog(onePlus x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
101117 (value: log(onePlus: x), pullback: { v in v / (1 + x) })
102118 }
103119
104120 @derivative(of: pow)
121+ @_transparent
105122 public static func _vjpPow(_ x: \( type) , _ y: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> ( \( type) , \( type) )) {
106123 let value = pow(x, y)
107124 // pullback wrt y is not defined for (x < 0) and (x = 0, y = 0)
108125 return (value: value, pullback: { v in (v * y * pow(x, y - 1), v * value * log(x)) })
109126 }
110127
111128 @derivative(of: pow)
129+ @_transparent
112130 public static func _vjpPow(_ x: \( type) , _ n: Int) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
113131 (value: pow(x, n), pullback: { v in v * \( floatingPointType) (n) * pow(x, n - 1) })
114132 }
115133
116134 @derivative(of: sqrt)
135+ @_transparent
117136 public static func _vjpSqrt(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
118137 let value = sqrt(x)
119138 return (value: value, pullback: { v in v / (2 * value) })
120139 }
121140
122141 @derivative(of: root)
142+ @_transparent
123143 public static func _vjpRoot(_ x: \( type) , _ n: Int) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
124144 let value = root(x, n)
125145 return (value: value, pullback: { v in v * value / (x * \( floatingPointType) (n)) })
@@ -129,48 +149,57 @@ enum RealFunctionsDerivativesGenerator {
129149 // MARK: RealFunctions derivatives
130150 extension \( type) {
131151 @derivative(of: erf)
152+ @_transparent
132153 public static func _vjpErf(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
133154 (value: erf(x), pullback: { v in 2 * exp(-x * x) / .sqrt( \( floatingPointType) .pi) })
134155 }
135156
136157 @derivative(of: erfc)
158+ @_transparent
137159 public static func _vjpErfc(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
138160 (value: erfc(x), pullback: { v in -2 * exp(-x * x) / .sqrt( \( floatingPointType) .pi) })
139161 }
140162
141163 @derivative(of: exp2)
164+ @_transparent
142165 public static func _vjpExp2(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
143166 let value = exp2(x)
144167 return (value, { v in v * value * .log(2) })
145168 }
146169
147170 @derivative(of: exp10)
171+ @_transparent
148172 public static func _vjpExp10(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
149173 let value = exp10(x)
150174 return (value, { v in v * value * .log(10) })
151175 }
152176
153177 @derivative(of: gamma)
178+ @_transparent
154179 public static func _vjpGamma(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
155180 fatalError( " unimplemented " )
156181 }
157182
158183 @derivative(of: log2)
184+ @_transparent
159185 public static func _vjpLog2(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
160186 (value: log2(x), pullback: { v in v / (.log(2) * x) })
161187 }
162188
163189 @derivative(of: log10)
190+ @_transparent
164191 public static func _vjpLog10(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
165192 (value: log10(x), pullback: { v in v / (.log(10) * x) })
166193 }
167194
168195 @derivative(of: logGamma)
196+ @_transparent
169197 public static func _vjpLogGamma(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
170198 fatalError( " unimplemented " )
171199 }
172200
173201 @derivative(of: atan2)
202+ @_transparent
174203 public static func _vjpAtan2(y: \( type) , x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> ( \( type) , \( type) )) {
175204 (
176205 value: atan2(y: y, x: x),
@@ -182,6 +211,7 @@ enum RealFunctionsDerivativesGenerator {
182211 }
183212
184213 @derivative(of: hypot)
214+ @_transparent
185215 public static func _vjpHypot(_ x: \( type) , _ y: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> ( \( type) , \( type) )) {
186216 (
187217 value: hypot(x, y),
@@ -196,6 +226,7 @@ enum RealFunctionsDerivativesGenerator {
196226 // MARK: FloatingPoint functions derivatives
197227 extension \( type) {
198228 @derivative(of: abs)
229+ @_transparent
199230 public static func _vjpAbs(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
200231 \( {
201232 if type == floatingPointType {
0 commit comments