Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add map/broadcast/algebra/iteration/dict interface for Grads #902

Merged
merged 17 commits into from
Feb 20, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions docs/src/utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,26 @@ Zygote.checkpointed
```

`Params` and `Grads` can be copied to and from arrays using the `copy!` function.

### Operations with Grads

Map and broadcast are supported for the dictionary-like `Grads` object.
```julia
using Zygote, Test

w = rand(2)
x1 = rand(2)
x2 = rand(2)
b = rand(2)
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved

gs1 = gradient(() -> sum(w .* x1 .+ b), Params([w]))
gs2 = gradient(() -> sum(w .* x2 .+ b), Params([w]))
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved

# accumulate gradients
gs = gs .+ gs
@test gs[w] ≈ gs1[w] + gs2[w]
@test gs[b] ≈ gs1[b] + gs2[b]

# clip gradients in-place
map!(x -> clamp!(x, -0.1, 0.1), gs, gs)
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
```
45 changes: 44 additions & 1 deletion src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using InteractiveUtils
using InteractiveUtils: typesof
using Core: Typeof
import Base: copy!
import Base.Broadcast: broadcasted

mutable struct Context <: AContext
cache::Union{IdDict{Any,Any},Nothing}
Expand Down Expand Up @@ -139,7 +140,8 @@ end

Base.show(io::IO, ps::Grads) = print(io, "Grads(...)")

@forward Grads.grads Base.getindex, Base.haskey
@forward Grads.grads Base.haskey, Base.setindex!
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
@forward Grads.params Base.length

function Base.getindex(gs::Grads, x)
isbits(x) && error("Only reference types can be differentiated with `Params`.")
Expand Down Expand Up @@ -171,6 +173,47 @@ function copy!(x::AbstractVector, gs::Grads)
x
end

broadcasted(f, gss::Grads...) = map(f, gss...)

for op in (:*, :/)
@eval broadcasted(::typeof($op), a::Number, gs::Grads) = _mapscalar($op, a, gs)
@eval broadcasted(::typeof($op), gs::Grads, a::Number) = _mapscalar($op, gs, a)
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
end

function Base.map(f, gs1::Grads, gss::Grads...)
gsout = Grads(IdDict{Any,Any}(), Params(gs1.params))
return map!(f, gsout, gs1, gss...)
end

function Base.map!(f, gsout::Grads, gss::Grads...)
@assert all(issetequal(gsout.params, gs.params) for gs in gss)
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
for p in gsout.params
gsout[p] = f((_getformap(gs, p) for gs in gss)...)
end
return gsout
end

function _mapscalar(f, gs::Grads, xs...)
gsout = Grads(IdDict{Any,Any}(), Params(gs.params))
for p in gsout.params
gsout[p] = f(_getformap(gs, p), xs...)
end
return gsout
end

function _mapscalar(f, x, gs::Grads, xs...)
gsout = Grads(IdDict{Any,Any}(), Params(gs.params))
for p in gsout.params
gsout[p] = f(x, _getformap(gs, p), xs...)
end
return gsout
end

function _getformap(gs::Grads, p)
g = gs[p]
g === nothing ? fill!(similar(p), 0) : g
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
end

function pullback(f, ps::Params)
cx = Context()
y, back = _pullback(cx, f)
Expand Down
56 changes: 54 additions & 2 deletions test/interface.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
@testset "Parmas" begin
using Zygote: Grads

@testset "Params" begin
@testset "delete!" begin
w = rand(2,3)
b = rand(2)
Expand Down Expand Up @@ -55,4 +57,54 @@
@test ps1 == ps2
@test ps1 != ps3 # comparison is order dependent
end
end
end

@testset "Grads" begin
@testset "algebra" begin
w = rand(2)
x1 = rand(2)
x2 = rand(2)
b = rand(2)

gs1 = gradient(() -> sum(w .* x1), Params([w]))
gs2 = gradient(() -> sum(w .* x2), Params([w]))

@test .- gs1 isa Grads
@test gs1 .- gs2 isa Grads
@test .+ gs1 isa Grads
@test gs1 .+ gs2 isa Grads
@test 2 .* gs1 isa Grads
@test gs1 .* 2 isa Grads
@test gs1 ./ 2 isa Grads
@test (gs1 .+ gs2)[w] ≈ gs1[w] .+ gs2[w]

gs3 = gradient(() -> sum(w .* x1), Params([w, b])) # grad nothing with respect to b
gs4 = gradient(() -> sum(w .* x2 .+ b), Params([w, b]))

@test .- gs3 isa Grads
@test gs3 .- gs4 isa Grads
@test .+ gs3 isa Grads
@test gs3 .+ gs4 isa Grads
@test 2 .* gs3 isa Grads
@test gs3 .* 2 isa Grads
@test gs3 ./ 2 isa Grads
@test (gs3 .+ gs4)[w] ≈ gs3[w] .+ gs4[w]
@test (gs3 .+ gs4)[b] ≈ gs4[b]
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
end
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved

@testset "map and broadcast" begin
w = rand(2)
x1 = rand(2)
x2 = rand(2)

gs1 = gradient(() -> sum(w .* x1), Params([w]))
gs2 = gradient(() -> sum(w .* x2), Params([w]))

@test map(x -> zeros(2), gs1) isa Grads

@test map!(x -> clamp!(x, -1e-5, 1e-5), gs1, gs1) isa Grads
@test all(abs.(gs1[w]) .<= 1e-5)

@test (x -> zeros(2)).(gs1) isa Grads
end
end