Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

∇getindex mutates, causing issues with higher order AD over getindex. #820

Closed
jlmaccal opened this issue Nov 7, 2020 · 19 comments · Fixed by #1328
Closed

∇getindex mutates, causing issues with higher order AD over getindex. #820

jlmaccal opened this issue Nov 7, 2020 · 19 comments · Fixed by #1328
Labels
second order zygote over zygote, or otherwise

Comments

@jlmaccal
Copy link

jlmaccal commented Nov 7, 2020

I'm new to flux/zygote/julia. I'm trying to develop a model that looks something like below.

I have a network that produces two outputs, A and B. The gradient of A with respect to the inputs is part of my loss function, along with other terms that depend on B. I've just summed things here for simplicity, but my actual model produces the same error.

using Flux
using Zygote

net = Chain(
    Dense(2, 100, relu),
    Dense(100, 100, relu),
    Dense(100, 100, relu),
    Dense(100, 4),
    x -> (A=x[1, :], B=x[2:end, :])
)
θ, builder = Flux.destructure(net)

x = randn(Float32, 2, 16)

function predict(θ, x, builder)
    net = builder(θ)
    results, pullback = Zygote.pullback(net, x)
    A = results.A
    B = results.BA = pullback((A=ones(eltype(A), size(A)), B=nothing))[1]
    a = sum(∇A; dims=1)
    b = sum(B; dims=1)
    return a + b
end

Zygote.gradient(θ -> sum(abs2, predict(θ, x, builder)), θ)

The error is ERROR: LoadError: Mutating arrays is not supported, which comes from the x -> (A=x[1, :], B=x[2:end, :]) line in the network, but I don't understand where the mutation is coming from.

I gather from a number of issues here and threads on Discourse that higher-order gradients are not well supported, but there isn't much documentation around this. As a new user, it would be extremely helpful if there was some kind of documentation / guidance about how to work around this.

On a related Discourse thread @ChrisRackauckas suggested using another AD, like ReverseDiff, but I'm can't figure out how to get the gradient that I want. Any guidance would be appreciated.

ERROR: LoadError: Mutating arrays is not supported
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] (::Zygote.var"#368#369")(::Nothing) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/lib/array.jl:61
 [3] (::Zygote.var"#2255#back#370"{Zygote.var"#368#369"})(::Nothing) at /Users/jlmaccal/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [4] materialize! at ./broadcast.jl:848 [inlined]
 [5] materialize! at ./broadcast.jl:845 [inlined]
 [6] materialize! at ./broadcast.jl:841 [inlined]
 [7] (::typeof(∂(materialize!)))(::Nothing) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [8] #356 at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/lib/array.jl:42 [inlined]
 [9] (::typeof(∂(λ)))(::Tuple{Array{Float32,2},Nothing,Nothing}) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [10] #2209#back at /Users/jlmaccal/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
 [11] (::typeof(∂(λ)))(::Tuple{Nothing,Array{Float32,2},Nothing,Nothing}) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [12] #11 at /Users/jlmaccal/Source/Julia/NonEqOpt/scratch.jl:9 [inlined]
 [13] (::typeof(∂(λ)))(::Tuple{Nothing,Array{Float32,2}}) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [14] applychain at /Users/jlmaccal/.julia/packages/Flux/q3zeA/src/layers/basic.jl:36 [inlined]
 [15] (::typeof(∂(λ)))(::Tuple{Nothing,Nothing,Array{Float32,2}}) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [16] applychain at /Users/jlmaccal/.julia/packages/Flux/q3zeA/src/layers/basic.jl:36 [inlined]
 [17] (::typeof(∂(λ)))(::Tuple{Nothing,Nothing,Array{Float32,2}}) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [18] applychain at /Users/jlmaccal/.julia/packages/Flux/q3zeA/src/layers/basic.jl:36 [inlined]
 [19] (::typeof(∂(λ)))(::Tuple{Nothing,Nothing,Array{Float32,2}}) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [20] applychain at /Users/jlmaccal/.julia/packages/Flux/q3zeA/src/layers/basic.jl:36 [inlined]
 [21] (::typeof(∂(λ)))(::Tuple{Nothing,Nothing,Array{Float32,2}}) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [22] applychain at /Users/jlmaccal/.julia/packages/Flux/q3zeA/src/layers/basic.jl:36 [inlined]
 [23] (::typeof(∂(λ)))(::Tuple{Nothing,Nothing,Array{Float32,2}}) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [24] Chain at /Users/jlmaccal/.julia/packages/Flux/q3zeA/src/layers/basic.jl:38 [inlined]
 [25] (::typeof(∂(λ)))(::Tuple{Nothing,Array{Float32,2}}) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [26] #41 at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface.jl:45 [inlined]
 [27] (::typeof(∂(λ)))(::Tuple{Array{Float32,2}}) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [28] predict at /Users/jlmaccal/Source/Julia/NonEqOpt/scratch.jl:20 [inlined]
 [29] (::typeof(∂(predict)))(::Array{Float32,2}) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [30] #13 at /Users/jlmaccal/Source/Julia/NonEqOpt/scratch.jl:26 [inlined]
 [31] (::typeof(∂(#13)))(::Float32) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [32] (::Zygote.var"#41#42"{typeof(∂(#13))})(::Float32) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface.jl:45
 [33] gradient(::Function, ::Array{Float32,1}) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface.jl:54
 [34] top-level scope at /Users/jlmaccal/Source/Julia/NonEqOpt/scratch.jl:26
 [35] include_string(::Function, ::Module, ::String, ::String) at ./loading.jl:1091
 [36] invokelatest(::Any, ::Any, ::Vararg{Any,N} where N; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at ./essentials.jl:710
 [37] invokelatest(::Any, ::Any, ::Vararg{Any,N} where N) at ./essentials.jl:709
 [38] inlineeval(::Module, ::String, ::Int64, ::Int64, ::String; softscope::Bool) at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/VSCodeServer/src/eval.jl:83
 [39] (::VSCodeServer.var"#43#45"{VSCodeServer.ReplRunCodeRequestParams,String,Int64,Int64,String,Module,Bool})() at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/VSCodeServer/src/eval.jl:45
 [40] withpath(::VSCodeServer.var"#43#45"{VSCodeServer.ReplRunCodeRequestParams,String,Int64,Int64,String,Module,Bool}, ::String) at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/VSCodeServer/src/repl.jl:118
 [41] (::VSCodeServer.var"#42#44"{VSCodeServer.ReplRunCodeRequestParams,String,Int64,Int64,String,Module,Bool,Bool})() at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/VSCodeServer/src/eval.jl:43
 [42] hideprompt(::VSCodeServer.var"#42#44"{VSCodeServer.ReplRunCodeRequestParams,String,Int64,Int64,String,Module,Bool,Bool}) at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/VSCodeServer/src/repl.jl:36
 [43] repl_runcode_request(::VSCodeServer.JSONRPC.JSONRPCEndpoint{Base.PipeEndpoint,Base.PipeEndpoint}, ::VSCodeServer.ReplRunCodeRequestParams) at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/VSCodeServer/src/eval.jl:23
 [44] dispatch_msg(::VSCodeServer.JSONRPC.JSONRPCEndpoint{Base.PipeEndpoint,Base.PipeEndpoint}, ::VSCodeServer.JSONRPC.MsgDispatcher, ::Dict{String,Any}) at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/JSONRPC/src/typed.jl:66
 [45] macro expansion at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/VSCodeServer/src/VSCodeServer.jl:95 [inlined]
 [46] (::VSCodeServer.var"#61#63"{Bool,String})() at ./task.jl:356
in expression starting at /Users/jlmaccal/Source/Julia/NonEqOpt/scratch.jl:26
@ChrisRackauckas
Copy link
Member

Your example can be reduced even more:

using Flux
using Zygote

net = Chain(
    Dense(2, 100, relu),
    Dense(100, 100, relu),
    Dense(100, 100, relu),
    Dense(100, 4),
    x -> x[1, :]
)
θ, builder = Flux.destructure(net)

x = randn(Float32, 2, 16)

function predict(θ, x, builder)
    net = builder(θ)
    A, pullback = Zygote.pullback(net, x)
    ∇A = pullback(ones(eltype(A), size(A)))[1]
    a = sum(∇A)
end

Zygote.gradient-> predict(θ, x, builder), θ)

The issue is just that you can't nest Zygote. If you make the outer differentiation use ReverseDiff.jl and the inner one use Zygote.jl you're fine though.

@jlmaccal
Copy link
Author

jlmaccal commented Nov 8, 2020

@ChrisRackauckas Thanks for the reply. I can see now that many of the pullbacks defined for common operations, like getindex mutate in-place, which prevents Zygote from being nested.

Do you know if there is any plan to address this? It seems to me that loss functions that use the gradients of a scalar field could be quite common when modelling physical systems with conserved quantities, as with Hamiltonian NNs.

I have tried using combining various ADs as you suggest, but keep running into problems. I give some examples of things I've tried below. My apologies for their length, but I feel it might help to show what I'm actually trying to accomplish.

For context, I am trying to model the power dissipated by a non-equilibrium thermodynamic system undergoing some control protocol. This is modelled with two components. The first is a conservative term, that depends on the directional derivative of a scalar field (the free energy). The second is a dissipative term that depends on a positive definite friction tensor. In the spirit of SciML, I'm trying to model both the free energy and the friction tensor using NNs.

@jlmaccal
Copy link
Author

jlmaccal commented Nov 8, 2020

Here is one approach I tried using ForwardDiff to calculate the inner directional derivative and Zygote for the outer gradient.

If I don't include DiffEqFlux, I get an error about no method matching *(::NamedTuple..., which is addressed by one of the adjoint definitions in DiffEqFlux.

However, this does not give me the correct gradient. Instead, all of the gradients for the ξnet parameters are zero. If I change return Pcons + Pdiss to return Pdiss, then I get the correct (or at least non-zero) gradient for the ξnet parameters, but the Fnet gradients are zero.

using Flux
using Zygote
using DiffEqFlux
using ForwardDiff
using NNlib
using LinearAlgebra
using Statistics

export predictpower, create_Fnetwork, create_ξnetwork, combine_networks

struct Builder{R1,R2}
    re1::R1
    re2::R2
    n::Int64
end

(builder::Builder)(p) = begin
    p1 = p[1:builder.n]
    p2 = p[(builder.n + 1):end]
    return (builder.re1(p1), builder.re2(p2))
end
    
function combine_models(m1, m2)
    p1, re1 = Flux.destructure(m1)
    p2, re2 = Flux.destructure(m2)
    n = size(p1)[1]
    p = [p1; p2]
    builder = Builder(re1, re2, n)
    return (p, builder)
end

struct DirectionalDerivative{F, V}
    f::F
    direction::V
end
const DD = DirectionalDerivative

function (dd::DD)(pt)
    let dd=dd
        ForwardDiff.derivative(0) do h
            dd.f(pt + h * dd.direction)
        end
    end
end

function predictpower(x, θ, builder)
    n, nbatch = size(x)
    @assert n % 2 == 0
    ncontrol = n ÷ 2

    # Unpack the inputs
    Fnet, ξnet = builder(θ)
    λ = x[1:ncontrol, :]
     = x[ncontrol + 1:end, :]

    # Compute the directional derivative dλ⋅∇F
    Pcons = DD(Fnet, )(λ)

    # Reshape to column / row vectors
     = reshape(, ncontrol, 1, nbatch)
    dλT = permutedims(, [2, 1, 3])

    # Calculate the dissaptive part of the power
    # Pdiss = dλ^T ⋅ ξ ⋅ dλ
    ξ = ξnet(λ)
    Pdiss = batched_mul(batched_mul(dλT, ξ), )

    Pcons = reshape(Pcons, :)
    Pdiss = reshape(Pdiss, :)
    return Pcons + Pdiss
end

function create_Fnetwork(controldim, hiddendim, hiddenlayers)
    initial = Dense(controldim, hiddendim, relu)
    layers = [Dense(hiddendim, hiddendim, relu) for i = 1:hiddenlayers]
    final = Dense(hiddendim, 1)
    return Chain(initial, layers..., final)
end

function create_ξnetwork(controldim, hiddendim, hiddenlayers)
    componentdim = controldim * (controldim - 1) ÷ 2 + controldim
    initial = Dense(controldim, hiddendim, relu)
    layers = [Dense(hiddendim, hiddendim, relu) for i = 1:hiddenlayers]
    final = Dense(hiddendim, componentdim)
    posdef = VecToPosDef(componentdim, controldim)
    return Chain(initial, layers..., final, posdef)
end

function create_network(controldim, hiddendim, hiddenlayers)
    componentdim = controldim * (controldim - 1) ÷ 2 + controldim
    initial = Dense(controldim, hiddendim, relu)
    layers = [Dense(hiddendim, hiddendim, relu) for i = 1:hiddenlayers]
    penultimate = Dense(hiddendim, componentdim + 1)
    posdef = VecToPosDef(componentdim, controldim)
    function output(x)
        F = reshape(x[1, :], 1, :)
        ξ = posdef(x[2:end, :])
        return (F, ξ)
    end
    return Chain(initial, layers..., penultimate, output)
end

"""
VecToPosDef(indim, n)

Convert a vector to a positive definite matrix.

Take `indims` dimensional batch of vectors and convert to
a batch of `(n, n)`` positive definite matrices. The dimensions
must much sch that `indim == n*(n-1)/2 + n`. The entries
of the input are treated as elements of lower triangular
matrix. The diagonal elements are exponentated to unsure
positivity.
"""
struct VecToPosDef
    indim::Int64
    n::Int64

    function VecToPosDef(indim, n)
        @assert indim == n * (n - 1) ÷ 2 + n
        return new(indim, n)
    end
end

function (lpd::VecToPosDef)(x::AbstractArray)
    indim, n_batch = size(x)
    @assert indim == lpd.indim

    # Zygote does not support mutation of arrays,
    # so we need to use a Buffer object, which does.
    out = Zygote.Buffer(x, lpd.n, lpd.n, n_batch)

    # Set the upper triangle to zero.
    for i = 1:lpd.n
        for j = i + 1:lpd.n
            for k = 1:n_batch
                out[i, j, k] = 0.0
            end
        end
    end

    i = 1
    # Compute the diagonal.
    # Exponentiate to ensure > 0.
    for j = 1:lpd.n
        out[j, j, :] = exp.(x[i, :])
        i += 1
    end

    # Compute the lower triangle.
    for j = 1:lpd.n
        for k = 1:(j - 1)
            out[j, k, :] = x[i, :]
            i += 1
        end
    end
    # Turn the buffer back into an array
    out = copy(out)
    return batched_mul(out, permutedims(out, [2, 1, 3]))
end



# Test code

Fnet = create_Fnetwork(2, 128, 2)
ξnet = create_ξnetwork(2, 128, 2)
θ, builder = combine_models(Fnet, ξnet)

x = randn(Float32, 4, 128)

function loss(x, θ, builder)
    power = predictpower(x, θ, builder)
    return mean(power.^2)
end

grad = Zygote.gradient(p -> loss(x, p, builder), θ)[1]
grad = getindex.(ForwardDiff.partials.(grad),1)

@jlmaccal
Copy link
Author

jlmaccal commented Nov 8, 2020

This version tries to use a single network with Zygote for the inner gradient and ReverseDiff for the outer.

It fails with (full traceback below): ERROR: LoadError: ArgumentError: indexed assignment with a single value to many locations is not supported; perhaps use broadcasting .= instead?.

using Flux
using Zygote
using ReverseDiff
using NNlib
using LinearAlgebra

export predictpower, create_Fnetwork, create_ξnetwork, combine_networks

function predictpower(x, θ, builder)
    n, nbatch = size(x)
    @assert n % 2 == 0
    ncontrol = n ÷ 2

    # Unpack the inputs
    net = builder(θ)
    λ = x[1:ncontrol, :]
    dλ = x[ncontrol + 1:end, :]

    # Forward pass
    results, pullback = Zygote.pullback(net, λ)
    F = results.F
    ξ = results.ξ
    ∇F = pullback((F = ones(eltype(F), size(F)), ξ = nothing))
    ∇F = reshape(∇F, 1, :)

    # Reshape to column / row vectors= reshape(dλ, ncontrol, 1, nbatch)
    dλT = permutedims(dλ, [2, 1, 3])

    # Compute the conservative part of the power
    Pcons = batched_mul(dλ, ∇F)

    # Calculate the dissaptive part of the power
    # Pdiss = dλ^T ⋅ ξ ⋅ dλ
    ξ = ξnet(λ)
    Pdiss = batched_mul(batched_mul(dλT, ξ), dλ)

    Pcons = reshape(Pcons, :)
    Pdiss = reshape(Pdiss, :)
    return Pcons + Pdiss
end

function create_network(controldim, hiddendim, hiddenlayers)
    componentdim = controldim * (controldim - 1) ÷ 2 + controldim
    initial = Dense(controldim, hiddendim, relu)
    layers = [Dense(hiddendim, hiddendim, relu) for i = 1:hiddenlayers]
    penultimate = Dense(hiddendim, componentdim + 1)
    posdef = VecToPosDef(componentdim, controldim)
    function output(x)
        F = reshape(x[1, :], 1, :)
        ξ = posdef(x[2:end, :])
        return (F, ξ)
    end
    return Chain(initial, layers..., penultimate, output)
end

"""
VecToPosDef(indim, n)

Convert a vector to a positive definite matrix.

Take `indims` dimensional batch of vectors and convert to
a batch of `(n, n)`` positive definite matrices. The dimensions
must much sch that `indim == n*(n-1)/2 + n`. The entries
of the input are treated as elements of lower triangular
matrix. The diagonal elements are exponentated to unsure
positivity.
"""
struct VecToPosDef
    indim::Int64
    n::Int64

    function VecToPosDef(indim, n)
        @assert indim == n * (n - 1) ÷ 2 + n
        return new(indim, n)
    end
end

function (lpd::VecToPosDef)(x::AbstractArray)
    indim, n_batch = size(x)
    @assert indim == lpd.indim

    # Zygote does not support mutation of arrays,
    # so we need to use a Buffer object, which does.
    out = Zygote.Buffer(x, lpd.n, lpd.n, n_batch)

    # Set the upper triangle to zero.
    for i = 1:lpd.n
        for j = i + 1:lpd.n
            for k = 1:n_batch
                out[i, j, k] = 0.0
            end
        end
    end

    i = 1
    # Compute the diagonal.
    # Exponentiate to ensure > 0.
    for j = 1:lpd.n
        out[j, j, :] = exp.(x[i, :])
        i += 1
    end

    # Compute the lower triangle.
    for j = 1:lpd.n
        for k = 1:(j - 1)
            out[j, k, :] = x[i, :]
            i += 1
        end
    end
    # Turn the buffer back into an array
    out = copy(out)
    return batched_mul(out, permutedims(out, [2, 1, 3]))
end


# Test it
net = create_network(2, 128, 2)
θ, builder = Flux.destructure(net)

x = randn(Float32, 4, 128)

function loss(x, θ, builder)
    power = predictpower(x, θ, builder)
    return mean(power.^2)
end

grad = ReverseDiff.gradient-> loss(x, θ, builder), θ)

Here is the traceback:

ERROR: LoadError: ArgumentError: indexed assignment with a single value to many locations is not supported; perhaps use broadcasting `.=` instead?
Stacktrace:
 [1] setindex_shape_check(::Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof(exp),Tuple{ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}}}, ::Int64, ::Int64, ::Int64) at ./indices.jl:258
 [2] macro expansion at ./multidimensional.jl:795 [inlined]
 [3] _unsafe_setindex!(::IndexLinear, ::Array{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,2,Array{Float32,2},Array{Float32,2}}},3}, ::Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof(exp),Tuple{ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}}}, ::Int64, ::Int64, ::Base.Slice{Base.OneTo{Int64}}) at ./multidimensional.jl:789
 [4] _setindex! at ./multidimensional.jl:785 [inlined]
 [5] setindex! at ./abstractarray.jl:1153 [inlined]
 [6] setindex!(::Zygote.Buffer{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,2,Array{Float32,2},Array{Float32,2}}},Array{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,2,Array{Float32,2},Array{Float32,2}}},3}}, ::Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle,Nothing,typeof(exp),Tuple{ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}}}, ::Int64, ::Int64, ::Function) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/tools/buffer.jl:51
 [7] adjoint at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/lib/buffer.jl:15 [inlined]
 [8] _pullback at /Users/jlmaccal/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:47 [inlined]
 [9] VecToPosDef at /Users/jlmaccal/Source/Julia/NonEqOpt/scratch2.jl:100 [inlined]
 [10] _pullback(::Zygote.Context, ::VecToPosDef, ::ReverseDiff.TrackedArray{Float32,Float32,2,Array{Float32,2},Array{Float32,2}}) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [11] output at /Users/jlmaccal/Source/Julia/NonEqOpt/scratch2.jl:51 [inlined]
 [12] applychain at /Users/jlmaccal/.julia/packages/Flux/q3zeA/src/layers/basic.jl:36 [inlined] (repeats 5 times)
 [13] Chain at /Users/jlmaccal/.julia/packages/Flux/q3zeA/src/layers/basic.jl:38 [inlined]
 [14] _pullback at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface.jl:38 [inlined]
 [15] pullback(::Chain{Tuple{Dense{typeof(relu),ReverseDiff.TrackedArray{Float32,Float32,2,Array{Float32,2},Array{Float32,2}},ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}},Dense{typeof(relu),ReverseDiff.TrackedArray{Float32,Float32,2,Array{Float32,2},Array{Float32,2}},ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}},Dense{typeof(relu),ReverseDiff.TrackedArray{Float32,Float32,2,Array{Float32,2},Array{Float32,2}},ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}},Dense{typeof(identity),ReverseDiff.TrackedArray{Float32,Float32,2,Array{Float32,2},Array{Float32,2}},ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}},var"#output#13"{VecToPosDef}}}, ::Array{Float32,2}) at /Users/jlmaccal/.julia/packages/Zygote/c0awc/src/compiler/interface.jl:44
 [16] predictpower(::Array{Float32,2}, ::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}, ::Flux.var"#34#36"{Chain{Tuple{Dense{typeof(relu),Array{Float32,2},Array{Float32,1}},Dense{typeof(relu),Array{Float32,2},Array{Float32,1}},Dense{typeof(relu),Array{Float32,2},Array{Float32,1}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},var"#output#13"{VecToPosDef}}}}) at /Users/jlmaccal/Source/Julia/NonEqOpt/scratch2.jl:20
 [17] loss(::Array{Float32,2}, ::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}, ::Function) at /Users/jlmaccal/Source/Julia/NonEqOpt/scratch2.jl:124
 [18] (::var"#14#15")(::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}) at /Users/jlmaccal/Source/Julia/NonEqOpt/scratch2.jl:128
 [19] ReverseDiff.GradientTape(::var"#14#15", ::Array{Float32,1}, ::ReverseDiff.GradientConfig{ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}}) at /Users/jlmaccal/.julia/packages/ReverseDiff/jFRo1/src/api/tape.jl:199
 [20] gradient(::Function, ::Array{Float32,1}, ::ReverseDiff.GradientConfig{ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}}) at /Users/jlmaccal/.julia/packages/ReverseDiff/jFRo1/src/api/gradients.jl:22 (repeats 2 times)
 [21] top-level scope at /Users/jlmaccal/Source/Julia/NonEqOpt/scratch2.jl:128
 [22] include_string(::Function, ::Module, ::String, ::String) at ./loading.jl:1091
 [23] invokelatest(::Any, ::Any, ::Vararg{Any,N} where N; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at ./essentials.jl:710
 [24] invokelatest(::Any, ::Any, ::Vararg{Any,N} where N) at ./essentials.jl:709
 [25] inlineeval(::Module, ::String, ::Int64, ::Int64, ::String; softscope::Bool) at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/VSCodeServer/src/eval.jl:83
 [26] (::VSCodeServer.var"#43#45"{VSCodeServer.ReplRunCodeRequestParams,String,Int64,Int64,String,Module,Bool})() at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/VSCodeServer/src/eval.jl:45
 [27] withpath(::VSCodeServer.var"#43#45"{VSCodeServer.ReplRunCodeRequestParams,String,Int64,Int64,String,Module,Bool}, ::String) at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/VSCodeServer/src/repl.jl:118
 [28] (::VSCodeServer.var"#42#44"{VSCodeServer.ReplRunCodeRequestParams,String,Int64,Int64,String,Module,Bool,Bool})() at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/VSCodeServer/src/eval.jl:43
 [29] hideprompt(::VSCodeServer.var"#42#44"{VSCodeServer.ReplRunCodeRequestParams,String,Int64,Int64,String,Module,Bool,Bool}) at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/VSCodeServer/src/repl.jl:36
 [30] repl_runcode_request(::VSCodeServer.JSONRPC.JSONRPCEndpoint{Base.PipeEndpoint,Base.PipeEndpoint}, ::VSCodeServer.ReplRunCodeRequestParams) at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/VSCodeServer/src/eval.jl:23
 [31] dispatch_msg(::VSCodeServer.JSONRPC.JSONRPCEndpoint{Base.PipeEndpoint,Base.PipeEndpoint}, ::VSCodeServer.JSONRPC.MsgDispatcher, ::Dict{String,Any}) at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/JSONRPC/src/typed.jl:66
 [32] macro expansion at /Users/jlmaccal/.vscode/extensions/julialang.language-julia-1.0.8/scripts/packages/VSCodeServer/src/VSCodeServer.jl:95 [inlined]
 [33] (::VSCodeServer.var"#61#63"{Bool,String})() at ./task.jl:356
in expression starting at /Users/jlmaccal/Source/Julia/NonEqOpt/scratch2.jl:128

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Nov 11, 2020

@ChrisRackauckas the MWE can actually be reduced to only include the x -> x[1,:], since that is what causes the mutation error

julia> gradient(rand(3,3)) do p
         gradient(p) do p
           sum(p[1, :])
         end[1] |> sum
       end

@axsk
Copy link

axsk commented Mar 25, 2021

I'm stuck with the same problem, how did you solve yours?

https://discourse.julialang.org/t/flux-higher-order-derivatives-and-forward-mode/38805/3
seems to have a similar problem but the workaround there (extracting the NN paramteres and optimzing externally) doesn't work in my case since I want to stick in the Flux training framework.

I tried integrating that approach to write an Zygote.@adjoint but could not work out how to mangle the closures.

@axsk
Copy link

axsk commented Mar 25, 2021

So what is the actual problem preventing Zygote from computing higher order derivatives?

@DhairyaLGandhi
Copy link
Member

So the "issue" is that zygote uses mutation on the adjoint of the getindex. Hmm, let me think about if we can handle it better

@axsk
Copy link

axsk commented Mar 25, 2021

I actually thought the problem lied elsewhere, but using sum instead of getindex seems to work here.

x = rand(2)
m = Chain(Dense(2,1))


Flux.gradient(params(m)) do
    gradient(m,x) |> sum |> sum
end

Edit:
loss(f,x) = sum(abs2, Flux.gradient(x->f(x) |> sum, x) |> sum) indeed works for me as desired (although I should probably comment the use of the sums in the source :D )
So relieved, I already thought I had to switch to JAX

@DhairyaLGandhi
Copy link
Member

Yeah, I think that should be fine but it is less generally correct to do, I think

@axsk
Copy link

axsk commented Mar 25, 2021

A cleaner way is to extract the gradient by tuple destructuring (is it called that?)
dx, = gradient(m, x)

@ChrisRackauckas
Copy link
Member

The issue is that:

∇getindex(x::AbstractArray, inds) = dy -> begin
  if inds isa  NTuple{<:Any, Integer}
    allinds = eachindex(x)
    ininds(i) = i  inds
    dx = ifelse.(_zero(x, typeof(dy))
    dx[inds...] = dy
  else
    dx = _zero(x, eltype(dy))
    dxv = view(dx, inds...)
    dxv .= accum.(dxv, _droplike(dy, dxv))
  end
  return (dx, map(_->nothing, inds)...)
end

These mutate. I would suggest splitting that into two separate dispatches and trying to come up with schemes that are just broadcasts or filters. If that's not easy to do, then I think a dispatch on just arrays (to avoid CuArrays) that just loops would be nice and fix the problem for most non-GPU users.

@Janssena
Copy link

Janssena commented Apr 5, 2021

I believe I'm facing a similar issue, where I need to use the jacobian of my prediction function with respect to an array of random variables in the loss function.

Here my code:

# produces θ for pred function
ann = Flux.Chain(
    Flux.Dense(input, 32, Flux.tanh),
    Flux.Dense(32, 32),
    Flux.Dense(32, 3),
);

function dAdt(dA, A, p, t)
    a, b, c = p
    dA[1] = -a * b * A[1] 
    dA[2] = c * b * A[1] - a * A[2]
end

function pred(θ, η, t, callback)
    p = θ .* exp.(η)
    prob = diffeq.ODEProblem(dAdt, [0., 0.], (-.1, maximum(t)), p)
    sol = diffeq.solve(prob, diffeq.Tsit5(), saveat=t, tstops=[0.], callback=callback, sensealg=des.ForwardDiffSensitivity())
    return sol[2 , :] # A[2] corresponds to y measurements
end 

∂pred_∂η(θ, η, time, callback) = Zygote.jacobian(eta -> pred(θ, eta, time, callback), η)

# p == 3x3 correlation matrix
function Obj(x, y, p, times, callbacks)
    if !isposdef(p)
        return Inf
    end

    N = length(times) # equal to the number of observations in dataset
    θ = ann(x')
    η = zeros(size(p, 1)) # test

    loss = 0.

    for i in 1:N
        ŷ = pred(θ[:, i], η, times[i], callbacks[i])
        residuals = y[i] - ŷ
        jac_eta = ∂pred_∂η(θ[:, i], η, times[i], callbacks[i]) # line 1
        loss = mean(residuals) + mean(jac_eta * p * jac_eta') # line 2
    end
    
    return loss
end

grad = Zygote.gradient(() -> Obj(x, y, p, times, callbacks), Flux.params(ann)) # error mutating arrays

removing line 1 and changing line 2 to loss = mean(residuals) runs fine, but calculation of the jacobian results in the mutating arrays error in zygote.
Is there someone working on implementing the above comment by Chris, or is there some way I can help on this? I'm not that experienced with working on Zygote code but trying to solve the above issue.

@ChrisRackauckas
Copy link
Member

#77 is a solution that could be used.

@axsk
Copy link

axsk commented Apr 9, 2021

So out of #77 we would just need the @adjoint ∇getindex part to circumvent the setindex call is that correct?

@ChrisRackauckas
Copy link
Member

Yes

@axsk
Copy link

axsk commented Apr 9, 2021

This is my take at Keno's approach

∇getindex(x::AbstractArray, inds) = dy -> (_zerosetindex(x, inds, dy), map(_->nothing, inds)...)

function _zerosetindex(x, inds::NTuple{<:Any, Integer}, dy)
  dx = _zero(x, typeof(dy))
  dx[inds...] = dy
  dx
end

function _zerosetindex(x, inds, dy)
  dx = _zero(x, eltype(dy))
  dxv = view(dx, inds...)
  dxv .= accum.(dxv, _droplike(dy, dxv))
  dx
end

@adjoint function _zerosetindex(x, inds, dy)
  _zerosetindex(x, inds, dy), ddx -> (nothing, nothing, ddx[inds...])
end

Keno's tests seem to run through as well. Should I put up a PR?

@ChrisRackauckas
Copy link
Member

I think that would be great!

@DhairyaLGandhi
Copy link
Member

We'll want to test this with GPUs, and check for performance

@ChrisRackauckas ChrisRackauckas changed the title Issue with mutation when computing gradients ∇getindex mutates, causing issues with higher order AD over getindex. Jun 6, 2021
@mcabbott mcabbott added the second order zygote over zygote, or otherwise label Jul 22, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
second order zygote over zygote, or otherwise
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants