Skip to content

Commit

Permalink
port rule definitions to ChainRulesCore (#242)
Browse files Browse the repository at this point in the history
* port rule definitions from Zygote
  • Loading branch information
simeonschaub authored Dec 27, 2020
1 parent e117313 commit 9fc717a
Show file tree
Hide file tree
Showing 6 changed files with 237 additions and 3 deletions.
11 changes: 9 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,24 +1,31 @@
name = "NNlib"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.7.9"
version = "0.8.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
ChainRulesCore = "0.9"
Compat = "3.14"
Requires = "0.5, 1.0"
ZygoteRules = "0.2"
julia = "1.3"

[extras]
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "Statistics", "Zygote"]
test = ["FiniteDifferences", "Random", "StableRNGs", "Statistics", "Test", "Zygote"]
6 changes: 5 additions & 1 deletion src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using Requires
include("dim_helpers.jl")

is_nnpack_available() = false

@init @require NNPACK_jll="a6bfbf70-4841-5cb9-aa18-3a8ad3c413ee" begin
if isdefined(NNPACK_jll, :libnnpack)
include("nnpack/NNPACK.jl")
Expand Down Expand Up @@ -39,4 +39,8 @@ include("impl/depthwiseconv_im2col.jl")
# Direct implementations of pooling
include("impl/pooling_direct.jl")

# differentiation rules
include("chainrulescore.jl")
include("zygoterules.jl")

end # module NNlib
90 changes: 90 additions & 0 deletions src/chainrulescore.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
using ChainRulesCore

const Numeric = Union{AbstractArray{<:T}, T} where {T<:Number}

@scalar_rule(selu(x), dselu(x))
@scalar_rule(elu(x, α), (delu(x, α), DoesNotExist()))
@scalar_rule(σ(x::Real), Ω * (1 - Ω))

function dselu(x)
λ = oftype(x/1, 1.0507009873554804934193349852946)
α = oftype(x/1, 1.6732632423543772848170429916717)
return λ * ifelse(x > 0, one(x), α * exp(x))
end
delu(x, α) = ifelse(x 0, one(x), α * exp(x))

for softmax in [:softmax, :logsoftmax]
local ∇softmax = Symbol(:∇, softmax)
pullback = Symbol(softmax, :_pullback)

@eval function ChainRulesCore.rrule(::typeof($softmax), xs; dims=1)
$pullback(Δ) = (NO_FIELDS, @thunk($∇softmax(Δ, xs, dims=dims)))
return $softmax(xs; dims=dims), $pullback
end
end

for Dims in [:DenseConvDims, :DepthwiseConvDims, :PoolDims]
pullback = Symbol(Dims, :_pullback)
@eval function ChainRulesCore.rrule(::Type{$Dims}, args...; kwargs...)
$pullback(Δ) = (NO_FIELDS, ntuple(_ -> DoesNotExist(), length(args))...)
return $Dims(args...; kwargs...), $pullback
end
end

colmajor(x) = (is_strided(x) && Base.stride(x, 1) == 1) ? x : collect(x)

for conv in [:conv, :depthwiseconv]
local ∇conv_data, ∇conv_filter = Symbol.(:∇, conv, [:_data, :_filter])
conv_pullback, ∇conv_data_pullback = Symbol.([conv, ∇conv_data], :_pullback)

@eval function ChainRulesCore.rrule(::typeof($conv), x, w, cdims; kw...)
function $conv_pullback(Δ)
Δ = colmajor(Δ)
return (
NO_FIELDS,
@thunk($∇conv_data(Δ, w, cdims, kw...)),
@thunk($∇conv_filter(x, Δ, cdims, kw...)),
DoesNotExist(),
)
end
return $conv(x, w, cdims; kw...), $conv_pullback
end

@eval function ChainRulesCore.rrule(::typeof($∇conv_data), x, w, cdims; kw...)
function $∇conv_data_pullback(Δ)
Δ = colmajor(Δ)
return (
NO_FIELDS,
@thunk($conv(Δ, w, cdims, kw...)),
@thunk($∇conv_filter(Δ, x, cdims, kw...)),
DoesNotExist(),
)
end
return $∇conv_data(x, w, cdims; kw...), $∇conv_data_pullback
end
end

for pool in [:maxpool, :meanpool]
∇pool = Symbol(:∇, pool)
pullback = Symbol(pool, :_pullback)
@eval function ChainRulesCore.rrule(::typeof($pool), x, pdims::PoolDims; kw...)
Ω = $pool(x, pdims; kw...)
$pullback(Δ) = (NO_FIELDS, @thunk($∇pool(Δ, Ω, x, pdims; kw...)), DoesNotExist())
return Ω, $pullback
end
end

function ChainRulesCore.rrule(
::typeof(batched_mul),
A::AbstractArray{S,3},
B::AbstractArray{T,3},
) where {S,T}
function batched_mul_pullback(Δ)
return (
NO_FIELDS,
@thunk(batched_mul(Δ, batched_adjoint(B))),
@thunk(batched_mul(batched_adjoint(A), Δ)),
)
end
batched_mul(A, B), batched_mul_pullback
end
17 changes: 17 additions & 0 deletions src/zygoterules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using ZygoteRules

# This is a performance hack specifically for Zygote, because it doesn't handle fused
# broadcasts well
for (f, df) in [
(:relu, :(x .> 0)),
(:selu, :(dselu.(x))),
(:elu, :(delu.(x))),
(, :(conj.(Ω .* (1 .- Ω)))),
]
pullback = Symbol(:broadcasted_, f, :_pullback)
@eval @adjoint function Base.Broadcast.broadcasted(::typeof($f), x::Numeric)
Ω = $f.(x)
$pullback(Δ) = (nothing, Δ .* $df)
return Ω, $pullback
end
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,6 @@ end
@testset "Softmax" begin
include("softmax.jl")
end
@testset "Zygote" begin
include("zygote.jl")
end
113 changes: 113 additions & 0 deletions test/zygote.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
using Zygote, NNlib
using Random
using LinearAlgebra
using NNlib: conv, ∇conv_data, depthwiseconv, batched_mul
using FiniteDifferences: grad, central_fdm
using StableRNGs

const rng = StableRNG(123)

function gradcheck(f, xs...; rtol = 1e-5, atol = 1e-5, broken=false)
grad_zygote = gradient(f, xs...)
grad_finite_difference = grad(central_fdm(5, 1), f, xs...)
for (grad_zygote, grad_finite_difference) in zip(grad_zygote, grad_finite_difference)
if broken
@test_broken isapprox(grad_zygote, grad_finite_difference; rtol = rtol, atol = atol)
else
@test isapprox(grad_zygote, grad_finite_difference; rtol = rtol, atol = atol)
if !isapprox(grad_zygote, grad_finite_difference; rtol = rtol, atol = atol)
display(grad_zygote - grad_finite_difference)
@show maximum(abs, grad_zygote - grad_finite_difference)
@show norm(grad_zygote) norm(grad_finite_difference)
println()
end
end
end
end

gradtest(f, xs::AbstractArray...; kw...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...; kw...)
gradtest(f, dims...; kw...) = gradtest(f, rand.(Float64, dims)...; kw...)

gradtest((x, W, b) -> relu.(W*x .+ b), 5, (2,5), 2)
gradtest((x, W, b) -> relu.(W*x .+ b), (5,3), (2,5), 2)
gradtest((x, W, b) -> selu.(W*x .+ b), 5, (2,5), 2)
gradtest((x, W, b) -> selu.(W*x .+ b), (5,3), (2,5), 2)
gradtest((x, W, b) -> elu.(W*x .+ b, 2), 5, (2,5), 2)
gradtest((x, W, b) -> elu.(W*x .+ b, 2), (5,3), (2,5), 2)

# tests for https://github.com/FluxML/Zygote.jl/issues/758
@test gradient(xs -> sum(selu.(xs)), [1_000, 10_000])[1] [1.0507009873554805, 1.0507009873554805] rtol=1e-8
@test gradient(x -> selu(x), 1_000) == (1.0507009873554805,)
@test gradient(xs -> sum(elu.(xs, 2)), [1_000, 10_000]) == ([1., 1.],)
@test gradient(x -> elu(x, 2), 1_000) == (1.,)
@test gradient(x -> elu(x, 2), -1) == (2*exp(-1),)
gradcheck(x->sum(selu.(x)),[100., 1_000.])
gradcheck(x->sum(elu.(x, 3.5)),[100., 1_000.])
gradcheck(x->sum(elu.(x, 3.5)),[1_000., 10_000.]) # for elu the tests are passing but for selu not, interesting
# numerical instability even for the linear part of such function, see:
# julia> ngradient(x->sum(selu.(x)),[1_000., 10_000.])
# ([1.0506591796875, 1.0506591796875],)
# julia> gradient(x->sum(selu.(x)),[1_000., 10_000.])
# ([1.0507009873554805, 1.0507009873554805],)
@test_broken gradcheck(x->sum(selu.(x)),[1_000., 10_000.])

gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2)
gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2)

gradtest(x -> softmax(x).*(1:3), 3)
gradtest(x -> softmax(x).*(1:3), (3,5))
gradtest(x -> softmax(x, dims=2).*(1:3), (3,5))
gradtest(x -> logsoftmax(x).*(1:3), 3)
gradtest(x -> logsoftmax(x).*(1:3), (3,5))
gradtest(x -> logsoftmax(x, dims=2).*(1:3), (3,5))

@testset "conv: spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3)
x = rand(rng, repeat([5], spatial_rank)..., 3, 2)
w = rand(rng, repeat([3], spatial_rank)..., 3, 3)
cdims = DenseConvDims(x, w)
gradtest((x, w) -> conv(x, w, cdims), x, w)
gradtest((x, w) -> sum(conv(x, w, cdims)), x, w) # https://github.com/FluxML/Flux.jl/issues/1055

y = conv(x, w, cdims)
gradtest((y, w) -> ∇conv_data(y, w, cdims), y, w)
if spatial_rank == 3
@test_broken gradtest((y, w) -> sum(∇conv_data(y, w, cdims)), y, w)
else
gradtest((y, w) -> sum(∇conv_data(y, w, cdims)), y, w)
end

dcdims = DepthwiseConvDims(x, w)
gradtest((x, w) -> depthwiseconv(x, w, dcdims), x, w)

y = depthwiseconv(x, w, dcdims)
gradtest((y, w) -> ∇depthwiseconv_data(y, w, dcdims), y, w)
if spatial_rank == 3
@test_broken gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w)
else
gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w)
end
end

@testset "pooling: spatial_rank=$spatial_rank" for spatial_rank in (1, 2)
x = rand(rng, repeat([10], spatial_rank)..., 3, 2)
pdims = PoolDims(x, 2)
gradtest(x -> maxpool(x, pdims), x; broken=spatial_rank <= 2)
gradtest(x -> meanpool(x, pdims), x)
gradtest(x -> sum(maxpool(x, pdims)), x)
gradtest(x -> sum(meanpool(x, pdims)), x)

#https://github.com/FluxML/NNlib.jl/issues/188
k = ntuple(_ -> 2, spatial_rank) # Kernel size of pool in ntuple format
gradtest(x -> maxpool(x, k), x; broken=spatial_rank <= 2)
gradtest(x -> meanpool(x, k), x)
gradtest(x -> sum(maxpool(x, k)), x)
gradtest(x -> sum(meanpool(x, k)), x)
end

@testset "batched matrix multiplication" begin
M, P, Q = 13, 7, 11
B = 3
gradtest(batched_mul, randn(rng, M, P, B), randn(rng, P, Q, B))
end

0 comments on commit 9fc717a

Please sign in to comment.