Skip to content

Gradients from keyword arguments dropped #567

Open
@mcabbott

Description

@mcabbott

The rules for accumulate and foldl don't compute anything for the init keyword. This can lead to silently wrong gradients, which is bad. Maybe this is a bigger problem than I realised.

It would be better to return @not_implemented. In fact, it's possible that all keywords everywhere should be that, or something like it, if possible. And even better to return the true answer, which IIRC these functions do know. Can this be done?

Example from here: https://discourse.julialang.org/t/how-to-efficiently-build-ad-compatible-matrices-line-by-line/74632/17

function f(K, xi, d)
    x = xi
    for i = 2:d
        x = hcat(x, K*x[:, i-1])
    end
    return x
end

K = rand(3,3)
xi = rand(3,1)
f(K, xi, 50)

function f2(K, xi, d::Int)
    xs = accumulate(1:d-1; init=xi) do x, i
        K * x
    end
    hcat(xi, reduce(hcat, xs))
end

Gives this, Fill(1.0, 3, 1) is from hcat alone:

julia> using Zygote

julia> gradient(sum∘f, K, xi, 10)
([63.45016309970954 50.40609159573776 101.36588271461751; 23.572874731387856 18.379315265377535 35.224999619160954; 31.033286457367566 24.176359057416636 46.03455941092244], [48.455853839178204; 14.765466919614408; 18.845362109436827;;], nothing)

julia> gradient(sum∘f2, K, xi, 10) # NB the gradient for init=xi is missing!
([63.45016309970953 50.40609159573775 101.3658827146175; 23.572874731387852 18.379315265377535 35.22499961916095; 31.033286457367552 24.17635905741663 46.03455941092242], Fill(1.0, 3, 1), nothing)

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions