-
Notifications
You must be signed in to change notification settings - Fork 74
Closed
Description
Thanks for fixing #2429! That fix uncovered a few more errors which mean that Enzyme doesn't work with DynamicPPL@0.37. It will take me a bit of time to figure them all out. I'll post as and when I can minimise stuff, but here's the first one.
Full error: https://gist.github.com/penelopeysm/a8b06038bd1aea542f9dfca619974115
(This only errors on Julia 1.11; 1.10 is fine.)
import Enzyme: Enzyme, Forward, Reverse, set_runtime_activity, Const
struct LP{T}
lp::T
end
struct LL{T}
ll::T
end
nm(::LP) = :SymLP
nm(::LL) = :SymLL
cmbn(acc::LP, acc2::LP) = LP(acc.lp + acc2.lp)
cmbn(acc::LL, acc2::LL) = LL(acc.ll + acc2.ll)
struct AT2{N,T<:NamedTuple}
nt::T
end
function AT2(t::T) where {N,T<:NTuple{N,Any}}
names = map(nm, t)
nt = NamedTuple{names}(t)
return AT2{N,typeof(nt)}(nt)
end
AT2(nt::NamedTuple) = AT2(tuple(nt...))
struct V0{T}
accs::T
end
# the NamedTuple needs at least two entries for it to fail
vi = V0(AT2((; SymLP = LP(0.0), SymLL = LL(0.0))))
struct TSVI4{V,L<:AT2}
varinfo::V
accs_vec::Vector{L}
end
function TSVI4(vi)
return TSVI4(vi, [AT2(vi.accs.nt)])
end
function gacc(vi::TSVI4, ::Val{accname_inner}) where {accname_inner}
main_acc = vi.varinfo.accs.nt[accname_inner]
other_accs = map(accs -> accs.nt[accname_inner], vi.accs_vec)
return foldl(cmbn, other_accs; init=main_acc)
end
# must be noinline
@noinline id(x) = x
acas(acc::LP, x) = LP(acc.lp + x)
acas(acc::LL, x) = acc
function mdlf(x, tsvi)
new_nt = map(acc -> acas(acc, x), tsvi.accs_vec[1].nt)
tsvi.accs_vec[1] = AT2(new_nt)
return tsvi
end
function gacs(tsvi)
accname_vals = map(Val, keys(tsvi.varinfo.accs.nt))
return AT2(map(anv -> gacc(tsvi, anv), accname_vals)).nt
end
function evlt(x, vi)
return if Threads.nthreads() > 1
tsvi = TSVI4(vi)
tsvi = mdlf(x, id(tsvi))
gacs(tsvi)
else
nt = vi.accs.nt
new_val = LP(nt[:SymLP].lp + x)
# also needs at least two entries for it to fail
(; SymLP = new_val, SymLL = nt[:SymLL])
end
end
function f(x, vi)
nt = evlt(x[1], vi)
return nt.SymLP.lp
end
params = [0.5]
f(params, vi)
Enzyme.gradient(set_runtime_activity(Reverse), f, params, Const(vi))
setup
julia> versioninfo()
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: macOS (arm64-apple-darwin24.0.0)
CPU: 10 × Apple M1 Pro
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, apple-m1)
Threads: 1 default, 0 interactive, 1 GC (on 8 virtual cores)
(ppl) pkg> st
Status `~/ppl/Project.toml`
[7da242da] Enzyme v0.13.68
Metadata
Metadata
Assignees
Labels
No labels