@@ -90,7 +90,7 @@ macro scalar_rule(call, maybe_setup, partials...)
90
90
91
91
# Generate variables to store derivatives named
92
92
derivatives = map (keys (partials)) do i
93
- syms = map (j -> gensym (" df$(i) /dx$(j) " ), keys (inputs))
93
+ syms = map (j -> esc ( gensym (Symbol ( " df" , i, " /dx" , j)) ), keys (inputs))
94
94
return Expr (:tuple , syms... )
95
95
end
96
96
@@ -143,10 +143,10 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
143
143
# For consistency in code that follows we make all partials tuple expressions
144
144
partials = map (partials) do partial
145
145
if Meta. isexpr (partial, :tuple )
146
- partial
146
+ Expr ( :tuple , map (esc, partial. args) ... )
147
147
else
148
148
length (inputs) == 1 || error (" Invalid use of `@scalar_rule`" )
149
- Expr (:tuple , partial)
149
+ Expr (:tuple , esc ( partial) )
150
150
end
151
151
end
152
152
@@ -169,7 +169,7 @@ function scalar_derivative_expr(__source__, f, setup_stmts, inputs, partials)
169
169
function ChainRulesCore. derivatives_given_output ($ (esc (:Ω )), :: Core.Typeof ($ f), $ (inputs... ))
170
170
$ (__source__)
171
171
$ (setup_stmts... )
172
- return $ (esc ( Expr (:tuple , partials... ) ))
172
+ return $ (Expr (:tuple , partials... ))
173
173
end
174
174
end
175
175
end
@@ -201,7 +201,7 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials)
201
201
$ (__source__)
202
202
$ (esc (:Ω )) = $ call
203
203
$ (setup_stmts... )
204
- $ (esc ( Expr (:tuple , partials... ) )) = ChainRulesCore. derivatives_given_output ($ (esc (:Ω )), $ f, $ (inputs... ))
204
+ $ (Expr (:tuple , partials... )) = ChainRulesCore. derivatives_given_output ($ (esc (:Ω )), $ f, $ (inputs... ))
205
205
return $ (esc (:Ω )), $ pushforward_returns
206
206
end
207
207
end
@@ -239,7 +239,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials)
239
239
$ (__source__)
240
240
$ (esc (:Ω )) = $ call
241
241
$ (setup_stmts... )
242
- $ (esc ( Expr (:tuple , partials... ) )) = ChainRulesCore. derivatives_given_output ($ (esc (:Ω )), $ f, $ (inputs... ))
242
+ $ (Expr (:tuple , partials... )) = ChainRulesCore. derivatives_given_output ($ (esc (:Ω )), $ f, $ (inputs... ))
243
243
return $ (esc (:Ω )), $ pullback
244
244
end
245
245
end
@@ -270,9 +270,9 @@ function propagation_expr(Δs, ∂s, _conj=false, proj=identity)
270
270
# This is basically Δs ⋅ ∂s
271
271
_∂s = map (∂s) do ∂s_i
272
272
if _conj
273
- :(conj ($ ( esc ( ∂s_i)) ))
273
+ :(conj ($ ∂s_i))
274
274
else
275
- esc ( ∂s_i)
275
+ ∂s_i
276
276
end
277
277
end
278
278
0 commit comments