Skip to content

Commit e65c9ed

Browse files
committed
esc fixes
1 parent 9f683c8 commit e65c9ed

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

src/rule_definition_tools.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ macro scalar_rule(call, maybe_setup, partials...)
9090

9191
# Generate variables to store derivatives named
9292
derivatives = map(keys(partials)) do i
93-
syms = map(j -> gensym("df$(i)/dx$(j)"), keys(inputs))
93+
syms = map(j -> esc(gensym(Symbol("df", i, "/dx", j))), keys(inputs))
9494
return Expr(:tuple, syms...)
9595
end
9696

@@ -143,10 +143,10 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
143143
# For consistency in code that follows we make all partials tuple expressions
144144
partials = map(partials) do partial
145145
if Meta.isexpr(partial, :tuple)
146-
partial
146+
Expr(:tuple, map(esc, partial.args)...)
147147
else
148148
length(inputs) == 1 || error("Invalid use of `@scalar_rule`")
149-
Expr(:tuple, partial)
149+
Expr(:tuple, esc(partial))
150150
end
151151
end
152152

@@ -169,7 +169,7 @@ function scalar_derivative_expr(__source__, f, setup_stmts, inputs, partials)
169169
function ChainRulesCore.derivatives_given_output($(esc()), ::Core.Typeof($f), $(inputs...))
170170
$(__source__)
171171
$(setup_stmts...)
172-
return $(esc(Expr(:tuple, partials...)))
172+
return $(Expr(:tuple, partials...))
173173
end
174174
end
175175
end
@@ -201,7 +201,7 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials)
201201
$(__source__)
202202
$(esc()) = $call
203203
$(setup_stmts...)
204-
$(esc(Expr(:tuple, partials...))) = ChainRulesCore.derivatives_given_output($(esc()), $f, $(inputs...))
204+
$(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output($(esc()), $f, $(inputs...))
205205
return $(esc()), $pushforward_returns
206206
end
207207
end
@@ -239,7 +239,7 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials)
239239
$(__source__)
240240
$(esc()) = $call
241241
$(setup_stmts...)
242-
$(esc(Expr(:tuple, partials...))) = ChainRulesCore.derivatives_given_output($(esc()), $f, $(inputs...))
242+
$(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output($(esc()), $f, $(inputs...))
243243
return $(esc()), $pullback
244244
end
245245
end
@@ -270,9 +270,9 @@ function propagation_expr(Δs, ∂s, _conj=false, proj=identity)
270270
# This is basically Δs ⋅ ∂s
271271
_∂s = map(∂s) do ∂s_i
272272
if _conj
273-
:(conj($(esc(∂s_i))))
273+
:(conj($∂s_i))
274274
else
275-
esc(∂s_i)
275+
∂s_i
276276
end
277277
end
278278

0 commit comments

Comments
 (0)