Skip to content

Commit 7540284

Browse files
authored
Merge pull request #119 from SymbolicML/compathelper/new_version/2025-01-05-01-23-24-852-00256837307
CompatHelper: bump compat for Zygote in [weakdeps] to 0.7, (keep existing compat)
2 parents 605e111 + 4b22d2c commit 7540284

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ Optim = "0.19, 1"
3838
PrecompileTools = "1"
3939
Reexport = "1"
4040
SymbolicUtils = "0.19, ^1.0.5, 2, 3"
41-
Zygote = "0.6"
41+
Zygote = "0.6, 0.7"
4242
julia = "1.10"
4343

4444
[extras]

src/ChainRules.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using ChainRulesCore:
77
ZeroTangent,
88
Tangent,
99
@thunk,
10+
unthunk,
1011
canonicalize
1112
using ..OperatorEnumModule: OperatorEnum
1213
using ..NodeModule: AbstractExpressionNode, with_type_parameters, tree_mapreduce
@@ -52,7 +53,8 @@ struct EvalPullback{N,A,O} <: Function
5253
end
5354

5455
# TODO: Preferable to use the primal in the pullback somehow
55-
function (e::EvalPullback)((dY, _))
56+
function (e::EvalPullback)((thunked_dY, _))
57+
dY = unthunk(thunked_dY)
5658
_, dX_constants_dY, complete = eval_grad_tree_array(
5759
e.tree, e.X, e.operators; variable=Val(:both)
5860
)

0 commit comments

Comments
 (0)