diff --git a/src/destructure.jl b/src/destructure.jl index baea015..bc83ff1 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -35,9 +35,22 @@ end """ destructure!(model) -> vector, reconstructor + destructure!(vector, model) -> vector, reconstructor -This is a variant of [`destructure`](@ref), whose reconstruction function mutates the model. -Requires that all trainable parameters in the model be mutable arrays! +These are variants of [`destructure`](@ref), returning a reconstruction function +which mutates the original model, instead of making a new one. +The second method also mutates an existing flat vector. + +They require that all trainable parameters in the model be mutable arrays, +else `re!` will give an error. + +!!! warning "Gradients" + Despite using mutation, they should be safe to use within Zygote, + with the important caveat that you must use the model returned, `m2 = re!(v)`, not the original. + Even though `m2 === m`, for Zygote to trace what results are used where, it has to see + the returned object being used. + If you discard `m2` and call for example `Flux.mse(m(x), y)` with the original model `m`, + Zygote will give silently wrong results. # Example ```jldoctest @@ -51,6 +64,9 @@ true julia> m (x = [3.0, 5.0], y = (sin, Float32[7.0 9.0], cos)) + +julia> v2, re2! = destructure!(rand(4), m) # works the same way +([3.0, 5.0, 7.0, 9.0], Restructure!(NamedTuple, ..., 4)) ``` """ function destructure!(x) @@ -58,10 +74,10 @@ function destructure!(x) flat, Restructure!(x, off, len) end -# function destructure!(flat::AbstractVector, x) -# flat, off, len = _flatten!(flat, x) -# flat, Restructure!(x, off, len) -# end +function destructure!(flat::AbstractVector, x) + flat, off, len = _flatten!(flat, x) + flat, Restructure!(x, off, len) +end """ Restructure(Model, ..., length) @@ -115,17 +131,17 @@ function _flatten(x) isempty(arrays) && return Bool[], off, 0 return reduce(vcat, arrays), off, len[] end -# function _flatten!(flat, x) -# isnumeric(x) && return copyto!(flat, _vec(x)) # trivial case -# len = Ref(0) -# off = fmap(x; exclude = isnumeric, walk = _TrainableStructWalk()) do y -# o = len[] -# copyto!(flat, o, _vec(y)) -# len[] = o + length(y) -# o -# end -# flat, off, len[] -# end +function _flatten!(flat, x) + isnumeric(x) && return copyto!(flat, _vec(x)) # trivial case + len = Ref(0) + off = fmap(x; exclude = isnumeric, walk = TrainableStructWalk()) do y + o = len[] + copyto!(flat, o+1, _vec(y)) + len[] = o + length(y) + o + end + flat, off, len[] +end struct TrainableStructWalk <: AbstractWalk end