Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add map/broadcast/algebra/iteration/dict interface for Grads #902

Merged
merged 17 commits into from
Feb 20, 2021
Merged
37 changes: 36 additions & 1 deletion src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using InteractiveUtils
using InteractiveUtils: typesof
using Core: Typeof
import Base: copy!
import Base.Broadcast: broadcasted

mutable struct Context <: AContext
cache::Union{IdDict{Any,Any},Nothing}
Expand Down Expand Up @@ -139,7 +140,8 @@ end

Base.show(io::IO, ps::Grads) = print(io, "Grads(...)")

@forward Grads.grads Base.getindex, Base.haskey
@forward Grads.grads Base.haskey
@forward Grads.params Base.length

function Base.getindex(gs::Grads, x)
isbits(x) && error("Only reference types can be differentiated with `Params`.")
Expand Down Expand Up @@ -171,6 +173,39 @@ function copy!(x::AbstractVector, gs::Grads)
x
end

# # Gradient Algebra: unary
# for op in (:+, :-)
# @eval broadcasted(::typeof($op), gs::Grads) = map($op, gs)
# end

# # Gradient Algebra: binary
# for op in (:+, :-)
# @eval broadcasted(::typeof($op), gs1::Grads, gs2::Grads) = map($op, gs1, gs2)
# end

broadcasted(f, gss::Grads...) = map(f, gss...)
broadcasted(f, gs1::Grads, xs...) = map(f, gs1, xs...)
broadcasted(::typeof(*), a::Number, gs::Grads) = map(*, gs, a)
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved

function Base.map(f, gs1::Grads, gss::Grads...)
@assert all(issetequal(gs1.params, gs.params) for gs in gss)
grads = IdDict{Any,Any}()
ps = Params(gs1.params)
for p in ps
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
grads[p] = f(gs1[p], (gs[p] for gs in gss)...)
end
return Grads(grads, ps)
end

function Base.map(f, gs1::Grads, xs...)
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
grads = IdDict{Any,Any}()
ps = Params(gs1.params)
for p in ps
grads[p] = f(gs1[p], xs...)
end
return Grads(grads, ps)
end

function pullback(f, ps::Params)
cx = Context()
y, back = _pullback(cx, f)
Expand Down
37 changes: 36 additions & 1 deletion test/interface.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
@testset "Parmas" begin
using Zygote: Grads

@testset "Params" begin
@testset "delete!" begin
w = rand(2,3)
b = rand(2)
Expand Down Expand Up @@ -55,4 +57,37 @@
@test ps1 == ps2
@test ps1 != ps3 # comparison is order dependent
end
end

@testset "Grads" begin
@testset "algebra" begin
w = rand(2)
x1 = rand(2)
x2 = rand(2)

gs1 = gradient(() -> sum(w .* x1), Params([w]))
gs2 = gradient(() -> sum(w .* x2), Params([w]))

@test .- gs1 isa Grads
@test gs1 .- gs2 isa Grads
@test .+ gs1 isa Grads
@test gs1 .+ gs2 isa Grads
@test 2 .* gs1 isa Grads
@test gs1 .* 2 isa Grads
@test gs1 ./ 2 isa Grads
@test gs1 .+ rand(2) isa Grads
end
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved

@testset "map and broadcast" begin
w = rand(2)
x1 = rand(2)
x2 = rand(2)

gs1 = gradient(() -> sum(w .* x1), Params([w]))
gs2 = gradient(() -> sum(w .* x2), Params([w]))

@test map(x -> zeros(2), gs1) isa Grads
@test (x -> zeros(2)).(gs1) isa Grads
end

end