Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

port rule definitions to ChainRulesCore #242

Merged
merged 16 commits into from
Dec 27, 2020
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,24 +1,30 @@
name = "NNlib"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.7.9"
version = "0.8.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
ChainRulesCore = "0.9"
Compat = "3.14"
Requires = "0.5, 1.0"
ZygoteRules = "0.2"
julia = "1.3"

[extras]
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "Statistics", "Zygote"]
test = ["FiniteDifferences", "Random", "Statistics", "Test", "Zygote"]
6 changes: 5 additions & 1 deletion src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using Requires
include("dim_helpers.jl")

is_nnpack_available() = false

@init @require NNPACK_jll="a6bfbf70-4841-5cb9-aa18-3a8ad3c413ee" begin
if isdefined(NNPACK_jll, :libnnpack)
include("nnpack/NNPACK.jl")
Expand Down Expand Up @@ -39,4 +39,8 @@ include("impl/depthwiseconv_im2col.jl")
# Direct implementations of pooling
include("impl/pooling_direct.jl")

# differentiation rules
include("chainrulescore.jl")
include("zygoterules.jl")

end # module NNlib
90 changes: 90 additions & 0 deletions src/chainrulescore.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
using ChainRulesCore

const Numeric = Union{AbstractArray{<:T}, T} where {T<:Number}

@scalar_rule(selu(x), dselu(x))
@scalar_rule(elu(x, α), (delu(x, α), DoesNotExist()))
@scalar_rule(σ(x::Real), Ω * (1 - Ω))

function dselu(x)
λ = oftype(x/1, 1.0507009873554804934193349852946)
α = oftype(x/1, 1.6732632423543772848170429916717)
return λ * ifelse(x > 0, one(x), α * exp(x))
end
delu(x, α) = ifelse(x ≥ 0, one(x), α * exp(x))

for softmax in [:softmax, :logsoftmax]
local ∇softmax = Symbol(:∇, softmax)
pullback = Symbol(softmax, :_pullback)

@eval function ChainRulesCore.rrule(::typeof($softmax), xs; dims=1)
$pullback(Δ) = (NO_FIELDS, @thunk($∇softmax(Δ, xs, dims=dims)))
return $softmax(xs; dims=dims), $pullback
end
end

for Dims in [:DenseConvDims, :DepthwiseConvDims, :PoolDims]
pullback = Symbol(Dims, :_pullback)
@eval function ChainRulesCore.rrule(::Type{$Dims}, args...; kwargs...)
$pullback(Δ) = (NO_FIELDS, ntuple(_ -> DoesNotExist(), length(args))...)
return $Dims(args...; kwargs...), $pullback
end
end

colmajor(x) = (is_strided(x) && Base.stride(x, 1) == 1) ? x : collect(x)

for conv in [:conv, :depthwiseconv]
local ∇conv_data, ∇conv_filter = Symbol.(:∇, conv, [:_data, :_filter])
conv_pullback, ∇conv_data_pullback = Symbol.([conv, ∇conv_data], :_pullback)

@eval function ChainRulesCore.rrule(::typeof($conv), x, w, cdims; kw...)
function $conv_pullback(Δ)
Δ = colmajor(Δ)
return (
NO_FIELDS,
@thunk($∇conv_data(Δ, w, cdims, kw...)),
@thunk($∇conv_filter(x, Δ, cdims, kw...)),
DoesNotExist(),
)
end
return $conv(x, w, cdims; kw...), $conv_pullback
end

@eval function ChainRulesCore.rrule(::typeof($∇conv_data), x, w, cdims; kw...)
function $∇conv_data_pullback(Δ)
Δ = colmajor(Δ)
return (
NO_FIELDS,
@thunk($conv(Δ, w, cdims, kw...)),
@thunk($∇conv_filter(Δ, x, cdims, kw...)),
DoesNotExist(),
)
end
return $∇conv_data(x, w, cdims; kw...), $∇conv_data_pullback
end
end

for pool in [:maxpool, :meanpool]
∇pool = Symbol(:∇, pool)
pullback = Symbol(pool, :_pullback)
@eval function ChainRulesCore.rrule(::typeof($pool), x, pdims::PoolDims; kw...)
Ω = $pool(x, pdims; kw...)
$pullback(Δ) = (NO_FIELDS, @thunk($∇pool(Δ, Ω, x, pdims; kw...)), DoesNotExist())
return Ω, $pullback
end
end

function ChainRulesCore.rrule(
::typeof(batched_mul),
A::AbstractArray{S,3},
B::AbstractArray{T,3},
) where {S,T}
function batched_mul_pullback(Δ)
return (
NO_FIELDS,
@thunk(batched_mul(Δ, batched_adjoint(B))),
@thunk(batched_mul(batched_adjoint(A), Δ)),
)
end
batched_mul(A, B), batched_mul_pullback
end
17 changes: 17 additions & 0 deletions src/zygoterules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using ZygoteRules

# This is a performance hack specifically for Zygote, because it doesn't handle fused
# broadcasts well
for (f, df) in [
(:relu, :(x .> 0)),
(:selu, :(dselu.(x))),
(:elu, :(delu.(x))),
(:σ, :(conj.(Ω .* (1 .- Ω)))),
]
pullback = Symbol(:broadcasted_, f, :_pullback)
@eval @adjoint function Base.Broadcast.broadcasted(::typeof($f), x::Numeric)
Ω = $f.(x)
$pullback(Δ) = (nothing, Δ .* $df)
return Ω, $pullback
end
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,6 @@ end
@testset "Softmax" begin
include("softmax.jl")
end
@testset "Zygote" begin
include("zygote.jl")
end
107 changes: 107 additions & 0 deletions test/zygote.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
using Zygote, NNlib
using Random
using LinearAlgebra
using NNlib: conv, ∇conv_data, depthwiseconv, batched_mul
using FiniteDifferences: grad, central_fdm

function gradcheck(f, xs...; rtol = 1e-5, atol = 1e-5)
grad_zygote = gradient(f, xs...)
grad_finite_difference = grad(central_fdm(5, 1), f, xs...)
#return all(isapprox.(grad_zygote, grad_finite_difference; rtol = rtol, atol = atol))
for (grad_zygote, grad_finite_difference) in zip(grad_zygote, grad_finite_difference)
@test isapprox(grad_zygote, grad_finite_difference; rtol = rtol, atol = atol)
if !isapprox(grad_zygote, grad_finite_difference; rtol = rtol, atol = atol)
display(grad_zygote - grad_finite_difference)
@show maximum(grad_zygote - grad_finite_difference)
simeonschaub marked this conversation as resolved.
Show resolved Hide resolved
@show norm(grad_zygote) norm(grad_finite_difference)
println()
end
end
end

gradtest(f, xs::AbstractArray...; kw...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...; kw...)
gradtest(f, dims...; kw...) = gradtest(f, rand.(Float64, dims)...; kw...)

gradtest((x, W, b) -> relu.(W*x .+ b), 5, (2,5), 2)
gradtest((x, W, b) -> relu.(W*x .+ b), (5,3), (2,5), 2)
gradtest((x, W, b) -> selu.(W*x .+ b), 5, (2,5), 2)
gradtest((x, W, b) -> selu.(W*x .+ b), (5,3), (2,5), 2)
gradtest((x, W, b) -> elu.(W*x .+ b, 2), 5, (2,5), 2)
gradtest((x, W, b) -> elu.(W*x .+ b, 2), (5,3), (2,5), 2)

# tests for https://github.com/FluxML/Zygote.jl/issues/758
@test gradient(xs -> sum(selu.(xs)), [1_000, 10_000])[1] ≈ [1.0507009873554805, 1.0507009873554805] rtol=1e-8
@test gradient(x -> selu(x), 1_000) == (1.0507009873554805,)
@test gradient(xs -> sum(elu.(xs, 2)), [1_000, 10_000]) == ([1., 1.],)
@test gradient(x -> elu(x, 2), 1_000) == (1.,)
@test gradient(x -> elu(x, 2), -1) == (2*exp(-1),)
gradcheck(x->sum(selu.(x)),[100., 1_000.])
gradcheck(x->sum(elu.(x, 3.5)),[100., 1_000.])
gradcheck(x->sum(elu.(x, 3.5)),[1_000., 10_000.]) # for elu the tests are passing but for selu not, interesting
# numerical instability even for the linear part of such function, see:
# julia> ngradient(x->sum(selu.(x)),[1_000., 10_000.])
# ([1.0506591796875, 1.0506591796875],)
# julia> gradient(x->sum(selu.(x)),[1_000., 10_000.])
# ([1.0507009873554805, 1.0507009873554805],)
@test_broken gradcheck(x->sum(selu.(x)),[1_000., 10_000.])

gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2)
gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2)

gradtest(x -> softmax(x).*(1:3), 3)
gradtest(x -> softmax(x).*(1:3), (3,5))
gradtest(x -> softmax(x, dims=2).*(1:3), (3,5))
gradtest(x -> logsoftmax(x).*(1:3), 3)
gradtest(x -> logsoftmax(x).*(1:3), (3,5))
gradtest(x -> logsoftmax(x, dims=2).*(1:3), (3,5))

@testset "conv: spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3)
x = rand(repeat([5], spatial_rank)..., 3, 2)
w = rand(repeat([3], spatial_rank)..., 3, 3)
cdims = DenseConvDims(x, w)
gradtest((x, w) -> conv(x, w, cdims), x, w)
gradtest((x, w) -> sum(conv(x, w, cdims)), x, w) # https://github.com/FluxML/Flux.jl/issues/1055

y = conv(x, w, cdims)
gradtest((y, w) -> ∇conv_data(y, w, cdims), y, w)
if spatial_rank == 3
@test_broken gradtest((y, w) -> sum(∇conv_data(y, w, cdims)), y, w)
else
gradtest((y, w) -> sum(∇conv_data(y, w, cdims)), y, w)
end

dcdims = DepthwiseConvDims(x, w)
gradtest((x, w) -> depthwiseconv(x, w, dcdims), x, w)

y = depthwiseconv(x, w, dcdims)
gradtest((y, w) -> ∇depthwiseconv_data(y, w, dcdims), y, w)
if spatial_rank == 3
@test_broken gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w)
else
gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w)
end
end

@testset "pooling: spatial_rank=$spatial_rank" for spatial_rank in (1, 2)
x = rand(repeat([10], spatial_rank)..., 3, 2)
pdims = PoolDims(x, 2)
gradtest(x -> maxpool(x, pdims), x)
gradtest(x -> meanpool(x, pdims), x)
gradtest(x -> sum(maxpool(x, pdims)), x)
gradtest(x -> sum(meanpool(x, pdims)), x)

#https://github.com/FluxML/NNlib.jl/issues/188
k = ntuple(_ -> 2, spatial_rank) # Kernel size of pool in ntuple format
gradtest(x -> maxpool(x, k), x)
gradtest(x -> meanpool(x, k), x)
gradtest(x -> sum(maxpool(x, k)), x)
gradtest(x -> sum(meanpool(x, k)), x)
end

@testset "batched matrix multiplication" begin
rng, M, P, Q = MersenneTwister(123456), 13, 7, 11
B = 3
gradtest(batched_mul, randn(rng, M, P, B), randn(rng, P, Q, B))
end