Skip to content

Commit

Permalink
restore 2-arg version, and add scary warning
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Apr 11, 2024
1 parent 058a25b commit de381dd
Showing 1 changed file with 33 additions and 17 deletions.
50 changes: 33 additions & 17 deletions src/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,22 @@ end

"""
destructure!(model) -> vector, reconstructor
destructure!(vector, model) -> vector, reconstructor
This is a variant of [`destructure`](@ref), whose reconstruction function mutates the model.
Requires that all trainable parameters in the model be mutable arrays!
These are variants of [`destructure`](@ref), returning a reconstruction function
which mutates the original model, instead of making a new one.
The second method also mutates an existing flat vector.
They require that all trainable parameters in the model be mutable arrays,
else `re!` will give an error.
!!! warning "Gradients"
Despite using mutation, they should be safe to use within Zygote,
with the important caveat that you must use the model returned, `m2 = re!(v)`, not the original.
Even though `m2 === m`, for Zygote to trace what results are used where, it has to see
the returned object being used.
If you discard `m2` and call for example `Flux.mse(m(x), y)` with the original model `m`,
Zygote will give silently wrong results.
# Example
```jldoctest
Expand All @@ -51,17 +64,20 @@ true
julia> m
(x = [3.0, 5.0], y = (sin, Float32[7.0 9.0], cos))
julia> v2, re2! = destructure!(rand(4), m) # works the same way
([3.0, 5.0, 7.0, 9.0], Restructure!(NamedTuple, ..., 4))
```
"""
function destructure!(x)
flat, off, len = _flatten(x)
flat, Restructure!(x, off, len)
end

# function destructure!(flat::AbstractVector, x)
# flat, off, len = _flatten!(flat, x)
# flat, Restructure!(x, off, len)
# end
function destructure!(flat::AbstractVector, x)
flat, off, len = _flatten!(flat, x)
flat, Restructure!(x, off, len)
end

"""
Restructure(Model, ..., length)
Expand Down Expand Up @@ -115,17 +131,17 @@ function _flatten(x)
isempty(arrays) && return Bool[], off, 0
return reduce(vcat, arrays), off, len[]
end
# function _flatten!(flat, x)
# isnumeric(x) && return copyto!(flat, _vec(x)) # trivial case
# len = Ref(0)
# off = fmap(x; exclude = isnumeric, walk = _TrainableStructWalk()) do y
# o = len[]
# copyto!(flat, o, _vec(y))
# len[] = o + length(y)
# o
# end
# flat, off, len[]
# end
function _flatten!(flat, x)
isnumeric(x) && return copyto!(flat, _vec(x)) # trivial case
len = Ref(0)
off = fmap(x; exclude = isnumeric, walk = TrainableStructWalk()) do y
o = len[]
copyto!(flat, o+1, _vec(y))
len[] = o + length(y)
o
end
flat, off, len[]
end

struct TrainableStructWalk <: AbstractWalk end

Expand Down

0 comments on commit de381dd

Please sign in to comment.