Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

port rule definitions to ChainRulesCore #242

Merged
merged 16 commits into from
Dec 27, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.7.7"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -11,13 +13,15 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "0.9"
Compat = "3.14"
Requires = "0.5, 1.0"
julia = "1.3"

[extras]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "Zygote"]
test = ["Random", "Test", "Zygote"]
3 changes: 3 additions & 0 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,7 @@ include("impl/depthwiseconv_im2col.jl")
# Direct implementations of pooling
include("impl/pooling_direct.jl")

# differentiation rules
include("chainrulescore.jl")

end # module NNlib
129 changes: 129 additions & 0 deletions src/chainrulescore.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
using ChainRulesCore
using ArrayLayouts: MemoryLayout, AbstractColumnMajor
simeonschaub marked this conversation as resolved.
Show resolved Hide resolved

const Numeric = Union{AbstractArray{<:T}, T} where {T<:Number}

drelu(x, Δ) = ifelse(x > 0, Δ, zero(x))

function ChainRulesCore.rrule(
::typeof(Base.Broadcast.broadcasted),
::typeof(relu),
x::Numeric,
)
broadcasted_relu_pullback(Δ) = (NO_FIELDS, NO_FIELDS, @thunk(drelu.(x, Δ)))
return relu.(x), broadcasted_relu_pullback
end

function dselu(x)
λ = oftype(x/1, 1.0507009873554804934193349852946)
α = oftype(x/1, 1.6732632423543772848170429916717)
return λ * ifelse(x > 0, one(x), α * exp(x))
end

@scalar_rule(selu(x), dselu(x))
function ChainRulesCore.rrule(
::typeof(Base.Broadcast.broadcasted),
::typeof(selu),
x::Numeric,
)
broadcasted_selu_pullback(Δ) = (NO_FIELDS, NO_FIELDS, @thunk(dselu.(x) .* Δ))
return selu.(x), broadcasted_selu_pullback
end

delu(x, α) = ifelse(x ≥ 0, one(x), α * exp(x))

@scalar_rule(elu(x, α), (delu(x, α), DoesNotExist()))
function ChainRulesCore.rrule(
::typeof(Base.Broadcast.broadcasted),
::typeof(elu),
x::Numeric,
α::Numeric,
)
broadcasted_elu_pullback(Δ) = (NO_FIELDS, NO_FIELDS, @thunk(delu.(x) .* Δ), DoesNotExist())
return elu.(x), broadcasted_elu_pullback
end

@scalar_rule(σ(x::Real), Ω * (1 - Ω))
function ChainRulesCore.rrule(
::typeof(Base.Broadcast.broadcasted),
::typeof(σ),
x::Numeric,
)
Ω = σ.(x)
broadcasted_σ_pullback(Δ) = (NO_FIELDS, NO_FIELDS, @thunk(Δ .* conj.(Ω .* (1 .- Ω))), DoesNotExist())
return Ω, broadcasted_σ_pullback
end
simeonschaub marked this conversation as resolved.
Show resolved Hide resolved

for softmax in [:softmax, :logsoftmax]
local ∇softmax = Symbol(:∇, softmax)
pullback = Symbol(softmax, :_pullback)

@eval function ChainRulesCore.rrule(::typeof($softmax), xs; dims=1)
$pullback(Δ) = (NO_FIELDS, @thunk($∇softmax(Δ, xs, dims=dims)))
return $softmax(xs; dims=dims), $pullback
end
end

for Dims in [:DenseConvDims, :DepthwiseConvDims, :PoolDims]
pullback = Symbol(Dims, :_pullback)
@eval function ChainRulesCore.rrule(::Type{$Dims}, args...; kwargs...)
$pullback(Δ) = (NO_FIELDS, ntuple(_ -> DoesNotExist(), length(args))...)
return $Dims(args...; kwargs...), $pullback
end
end

colmajor(x) = colmajor(MemoryLayout(typeof(x)), x)
colmajor(_, x) = convert(Array, x)
colmajor(::AbstractColumnMajor, x) = x

for conv in [:conv, :depthwiseconv]
local ∇conv_data, ∇conv_filter = Symbol.(:∇, conv, [:_data, :_filter])
conv_pullback, ∇conv_data_pullback = Symbol.([conv, ∇conv_data], :_pullback)

@eval function ChainRulesCore.rrule(::typeof($conv), x, w, cdims; kw...)
function $conv_pullback(Δ)
Δ = colmajor(Δ)
return (
NO_FIELDS,
@thunk($∇conv_data(Δ, w, cdims, kw...)),
@thunk($∇conv_filter(x, Δ, cdims, kw...)),
DoesNotExist(),
)
end
return $conv(x, w, cdims; kw...), $conv_pullback
end

@eval function ChainRulesCore.rrule(::typeof($∇conv_data), x, w, cdims; kw...)
function $∇conv_data_pullback(Δ)
Δ = colmajor(Δ)
return (
NO_FIELDS,
@thunk($conv(Δ, w, cdims, kw...)),
@thunk($∇conv_filter(Δ, x, cdims, kw...)),
DoesNotExist(),
)
end
return $∇conv_data(x, w, cdims; kw...), $∇conv_data_pullback
end
end

for pool in [:maxpool, :meanpool]
∇pool = Symbol(:∇, pool)
pullback = Symbol(pool, :_pullback)
@eval function ChainRulesCore.rrule(::typeof($pool), x, pdims::PoolDims; kw...)
Ω = maxpool(x, pdims; kw...)
$pullback(Δ) = (NO_FIELDS, @thunk($∇pool(Δ, Ω, x, pdims; kw...)), DoesNotExist())
return Ω, $pullback
end
end

function ChainRulesCore.rrule(::typeof(batched_mul), A, B)
simeonschaub marked this conversation as resolved.
Show resolved Hide resolved
function batched_mul_pullback(Δ)
return (
NO_FIELDS,
@thunk(batched_mul(Δ, batched_adjoint(B))),
@thunk(batched_mul(batched_adjoint(A), Δ)),
)
end
batched_mul(A, B), batched_mul_pullback
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using NNlib, Test

include("zygote.jl")
include("activation.jl")
include("conv.jl")
include("batchedmul.jl")
Expand Down
111 changes: 111 additions & 0 deletions test/zygote.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
using Zygote, NNlib
using Random
using NNlib: conv, ∇conv_data, depthwiseconv, batched_mul

function ngradient(f, xs::AbstractArray...)
grads = zero.(xs)
for (x, Δ) in zip(xs, grads), i in 1:length(x)
δ = sqrt(eps())
tmp = x[i]
x[i] = tmp - δ/2
y1 = f(xs...)
x[i] = tmp + δ/2
y2 = f(xs...)
x[i] = tmp
Δ[i] = (y2-y1)/δ
end
return grads
end

function gradcheck(f, xs...; rtol = 1e-5, atol = 1e-5)
grad_zygote = gradient(f, xs...)
grad_finite_difference = ngradient(f, xs...)
return all(isapprox.(grad_zygote, grad_finite_difference; rtol = rtol, atol = atol))
end

gradtest(f, xs::AbstractArray...; kw...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...; kw...)
gradtest(f, dims...; kw...) = gradtest(f, rand.(Float64, dims)...; kw...)

@test gradtest((x, W, b) -> relu.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> relu.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest((x, W, b) -> selu.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> selu.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest((x, W, b) -> elu.(W*x .+ b, 2), 5, (2,5), 2)
@test gradtest((x, W, b) -> elu.(W*x .+ b, 2), (5,3), (2,5), 2)

# tests for https://github.com/FluxML/Zygote.jl/issues/758
@test gradient(xs -> sum(selu.(xs)), [1_000, 10_000]) == ([1.0507009873554805, 1.0507009873554805],)
@test gradient(x -> selu(x), 1_000) == (1.0507009873554805,)
@test gradient(xs -> sum(elu.(xs, 2)), [1_000, 10_000]) == ([1., 1.],)
@test gradient(x -> elu(x, 2), 1_000) == (1.,)
@test gradient(x -> elu(x, 2), -1) == (2*exp(-1),)
@test gradcheck(x->sum(selu.(x)),[100., 1_000.])
@test gradcheck(x->sum(elu.(x, 3.5)),[100., 1_000.])
@test gradcheck(x->sum(elu.(x, 3.5)),[1_000., 10_000.]) # for elu the tests are passing but for selu not, interesting
# numerical instability even for the linear part of such function, see:
# julia> ngradient(x->sum(selu.(x)),[1_000., 10_000.])
# ([1.0506591796875, 1.0506591796875],)
# julia> gradient(x->sum(selu.(x)),[1_000., 10_000.])
# ([1.0507009873554805, 1.0507009873554805],)
@test_broken gradcheck(x->sum(selu.(x)),[1_000., 10_000.])

@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2)

@test gradtest(x -> softmax(x).*(1:3), 3)
@test gradtest(x -> softmax(x).*(1:3), (3,5))
@test gradtest(x -> softmax(x, dims=2).*(1:3), (3,5))
@test gradtest(x -> logsoftmax(x).*(1:3), 3)
@test gradtest(x -> logsoftmax(x).*(1:3), (3,5))
@test gradtest(x -> logsoftmax(x, dims=2).*(1:3), (3,5))

@testset "conv: spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3)
x = rand(repeat([5], spatial_rank)..., 3, 2)
w = rand(repeat([3], spatial_rank)..., 3, 3)
cdims = DenseConvDims(x, w)
@test gradtest((x, w) -> conv(x, w, cdims), x, w)
@test gradtest((x, w) -> sum(conv(x, w, cdims)), x, w) # https://github.com/FluxML/Flux.jl/issues/1055

y = conv(x, w, cdims)
@test gradtest((y, w) -> ∇conv_data(y, w, cdims), y, w)
if spatial_rank == 3
@test_broken gradtest((y, w) -> sum(∇conv_data(y, w, cdims)), y, w)
else
@test gradtest((y, w) -> sum(∇conv_data(y, w, cdims)), y, w)
end

dcdims = DepthwiseConvDims(x, w)
@test gradtest((x, w) -> depthwiseconv(x, w, dcdims), x, w)

y = depthwiseconv(x, w, dcdims)
@test gradtest((y, w) -> ∇depthwiseconv_data(y, w, dcdims), y, w)
if spatial_rank == 3
@test_broken gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w)
else
@test gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w)
end
end

@testset "pooling: spatial_rank=$spatial_rank" for spatial_rank in (1, 2)
x = rand(repeat([10], spatial_rank)..., 3, 2)
pdims = PoolDims(x, 2)
@test gradtest(x -> maxpool(x, pdims), x)
@test gradtest(x -> meanpool(x, pdims), x)
@test gradtest(x -> sum(maxpool(x, pdims)), x)
@test gradtest(x -> sum(meanpool(x, pdims)), x)

#https://github.com/FluxML/NNlib.jl/issues/188
k = ntuple(_ -> 2, spatial_rank) # Kernel size of pool in ntuple format
@test gradtest(x -> maxpool(x, k), x)
@test gradtest(x -> meanpool(x, k), x)
@test gradtest(x -> sum(maxpool(x, k)), x)
@test gradtest(x -> sum(meanpool(x, k)), x)
end

@testset "batched matrix multiplication" begin
rng, M, P, Q = MersenneTwister(123456), 13, 7, 11
B = 3
@test gradtest(batched_mul, randn(rng, M, P, B), randn(rng, P, Q, B))
end