diff --git a/Project.toml b/Project.toml index 4a7c8e167..4480ed2e3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,24 +1,31 @@ name = "NNlib" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.7.9" +version = "0.8.0" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Requires = "ae029012-a4dd-5104-9daa-d747884805df" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] +ChainRulesCore = "0.9" Compat = "3.14" Requires = "0.5, 1.0" +ZygoteRules = "0.2" julia = "1.3" [extras] +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Test", "Statistics", "Zygote"] +test = ["FiniteDifferences", "Random", "StableRNGs", "Statistics", "Test", "Zygote"] diff --git a/src/NNlib.jl b/src/NNlib.jl index d5e2d8ed4..7bf22c0ad 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -6,7 +6,7 @@ using Requires include("dim_helpers.jl") is_nnpack_available() = false - + @init @require NNPACK_jll="a6bfbf70-4841-5cb9-aa18-3a8ad3c413ee" begin if isdefined(NNPACK_jll, :libnnpack) include("nnpack/NNPACK.jl") @@ -39,4 +39,8 @@ include("impl/depthwiseconv_im2col.jl") # Direct implementations of pooling include("impl/pooling_direct.jl") +# differentiation rules +include("chainrulescore.jl") +include("zygoterules.jl") + end # module NNlib diff --git a/src/chainrulescore.jl b/src/chainrulescore.jl new file mode 100644 index 000000000..3c7a60766 --- /dev/null +++ b/src/chainrulescore.jl @@ -0,0 +1,90 @@ +using ChainRulesCore + +const Numeric = Union{AbstractArray{<:T}, T} where {T<:Number} + +@scalar_rule(selu(x), dselu(x)) +@scalar_rule(elu(x, α), (delu(x, α), DoesNotExist())) +@scalar_rule(σ(x::Real), Ω * (1 - Ω)) + +function dselu(x) + λ = oftype(x/1, 1.0507009873554804934193349852946) + α = oftype(x/1, 1.6732632423543772848170429916717) + return λ * ifelse(x > 0, one(x), α * exp(x)) +end +delu(x, α) = ifelse(x ≥ 0, one(x), α * exp(x)) + +for softmax in [:softmax, :logsoftmax] + local ∇softmax = Symbol(:∇, softmax) + pullback = Symbol(softmax, :_pullback) + + @eval function ChainRulesCore.rrule(::typeof($softmax), xs; dims=1) + $pullback(Δ) = (NO_FIELDS, @thunk($∇softmax(Δ, xs, dims=dims))) + return $softmax(xs; dims=dims), $pullback + end +end + +for Dims in [:DenseConvDims, :DepthwiseConvDims, :PoolDims] + pullback = Symbol(Dims, :_pullback) + @eval function ChainRulesCore.rrule(::Type{$Dims}, args...; kwargs...) + $pullback(Δ) = (NO_FIELDS, ntuple(_ -> DoesNotExist(), length(args))...) + return $Dims(args...; kwargs...), $pullback + end +end + +colmajor(x) = (is_strided(x) && Base.stride(x, 1) == 1) ? x : collect(x) + +for conv in [:conv, :depthwiseconv] + local ∇conv_data, ∇conv_filter = Symbol.(:∇, conv, [:_data, :_filter]) + conv_pullback, ∇conv_data_pullback = Symbol.([conv, ∇conv_data], :_pullback) + + @eval function ChainRulesCore.rrule(::typeof($conv), x, w, cdims; kw...) + function $conv_pullback(Δ) + Δ = colmajor(Δ) + return ( + NO_FIELDS, + @thunk($∇conv_data(Δ, w, cdims, kw...)), + @thunk($∇conv_filter(x, Δ, cdims, kw...)), + DoesNotExist(), + ) + end + return $conv(x, w, cdims; kw...), $conv_pullback + end + + @eval function ChainRulesCore.rrule(::typeof($∇conv_data), x, w, cdims; kw...) + function $∇conv_data_pullback(Δ) + Δ = colmajor(Δ) + return ( + NO_FIELDS, + @thunk($conv(Δ, w, cdims, kw...)), + @thunk($∇conv_filter(Δ, x, cdims, kw...)), + DoesNotExist(), + ) + end + return $∇conv_data(x, w, cdims; kw...), $∇conv_data_pullback + end +end + +for pool in [:maxpool, :meanpool] + ∇pool = Symbol(:∇, pool) + pullback = Symbol(pool, :_pullback) + @eval function ChainRulesCore.rrule(::typeof($pool), x, pdims::PoolDims; kw...) + Ω = $pool(x, pdims; kw...) + $pullback(Δ) = (NO_FIELDS, @thunk($∇pool(Δ, Ω, x, pdims; kw...)), DoesNotExist()) + return Ω, $pullback + end +end + +function ChainRulesCore.rrule( + ::typeof(batched_mul), + A::AbstractArray{S,3}, + B::AbstractArray{T,3}, +) where {S,T} + function batched_mul_pullback(Δ) + return ( + NO_FIELDS, + @thunk(batched_mul(Δ, batched_adjoint(B))), + @thunk(batched_mul(batched_adjoint(A), Δ)), + ) + end + batched_mul(A, B), batched_mul_pullback +end diff --git a/src/zygoterules.jl b/src/zygoterules.jl new file mode 100644 index 000000000..4d04ec3ec --- /dev/null +++ b/src/zygoterules.jl @@ -0,0 +1,17 @@ +using ZygoteRules + +# This is a performance hack specifically for Zygote, because it doesn't handle fused +# broadcasts well +for (f, df) in [ + (:relu, :(x .> 0)), + (:selu, :(dselu.(x))), + (:elu, :(delu.(x))), + (:σ, :(conj.(Ω .* (1 .- Ω)))), +] + pullback = Symbol(:broadcasted_, f, :_pullback) + @eval @adjoint function Base.Broadcast.broadcasted(::typeof($f), x::Numeric) + Ω = $f.(x) + $pullback(Δ) = (nothing, Δ .* $df) + return Ω, $pullback + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 613781e70..68715a1eb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -19,3 +19,6 @@ end @testset "Softmax" begin include("softmax.jl") end +@testset "Zygote" begin + include("zygote.jl") +end diff --git a/test/zygote.jl b/test/zygote.jl new file mode 100644 index 000000000..bbf41261d --- /dev/null +++ b/test/zygote.jl @@ -0,0 +1,113 @@ +using Zygote, NNlib +using Random +using LinearAlgebra +using NNlib: conv, ∇conv_data, depthwiseconv, batched_mul +using FiniteDifferences: grad, central_fdm +using StableRNGs + +const rng = StableRNG(123) + +function gradcheck(f, xs...; rtol = 1e-5, atol = 1e-5, broken=false) + grad_zygote = gradient(f, xs...) + grad_finite_difference = grad(central_fdm(5, 1), f, xs...) + for (grad_zygote, grad_finite_difference) in zip(grad_zygote, grad_finite_difference) + if broken + @test_broken isapprox(grad_zygote, grad_finite_difference; rtol = rtol, atol = atol) + else + @test isapprox(grad_zygote, grad_finite_difference; rtol = rtol, atol = atol) + if !isapprox(grad_zygote, grad_finite_difference; rtol = rtol, atol = atol) + display(grad_zygote - grad_finite_difference) + @show maximum(abs, grad_zygote - grad_finite_difference) + @show norm(grad_zygote) norm(grad_finite_difference) + println() + end + end + end +end + +gradtest(f, xs::AbstractArray...; kw...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...; kw...) +gradtest(f, dims...; kw...) = gradtest(f, rand.(Float64, dims)...; kw...) + +gradtest((x, W, b) -> relu.(W*x .+ b), 5, (2,5), 2) +gradtest((x, W, b) -> relu.(W*x .+ b), (5,3), (2,5), 2) +gradtest((x, W, b) -> selu.(W*x .+ b), 5, (2,5), 2) +gradtest((x, W, b) -> selu.(W*x .+ b), (5,3), (2,5), 2) +gradtest((x, W, b) -> elu.(W*x .+ b, 2), 5, (2,5), 2) +gradtest((x, W, b) -> elu.(W*x .+ b, 2), (5,3), (2,5), 2) + +# tests for https://github.com/FluxML/Zygote.jl/issues/758 +@test gradient(xs -> sum(selu.(xs)), [1_000, 10_000])[1] ≈ [1.0507009873554805, 1.0507009873554805] rtol=1e-8 +@test gradient(x -> selu(x), 1_000) == (1.0507009873554805,) +@test gradient(xs -> sum(elu.(xs, 2)), [1_000, 10_000]) == ([1., 1.],) +@test gradient(x -> elu(x, 2), 1_000) == (1.,) +@test gradient(x -> elu(x, 2), -1) == (2*exp(-1),) +gradcheck(x->sum(selu.(x)),[100., 1_000.]) +gradcheck(x->sum(elu.(x, 3.5)),[100., 1_000.]) +gradcheck(x->sum(elu.(x, 3.5)),[1_000., 10_000.]) # for elu the tests are passing but for selu not, interesting +# numerical instability even for the linear part of such function, see: +# julia> ngradient(x->sum(selu.(x)),[1_000., 10_000.]) +# ([1.0506591796875, 1.0506591796875],) +# julia> gradient(x->sum(selu.(x)),[1_000., 10_000.]) +# ([1.0507009873554805, 1.0507009873554805],) +@test_broken gradcheck(x->sum(selu.(x)),[1_000., 10_000.]) + +gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2) +gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2) +gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2) +gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2) + +gradtest(x -> softmax(x).*(1:3), 3) +gradtest(x -> softmax(x).*(1:3), (3,5)) +gradtest(x -> softmax(x, dims=2).*(1:3), (3,5)) +gradtest(x -> logsoftmax(x).*(1:3), 3) +gradtest(x -> logsoftmax(x).*(1:3), (3,5)) +gradtest(x -> logsoftmax(x, dims=2).*(1:3), (3,5)) + +@testset "conv: spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) + x = rand(rng, repeat([5], spatial_rank)..., 3, 2) + w = rand(rng, repeat([3], spatial_rank)..., 3, 3) + cdims = DenseConvDims(x, w) + gradtest((x, w) -> conv(x, w, cdims), x, w) + gradtest((x, w) -> sum(conv(x, w, cdims)), x, w) # https://github.com/FluxML/Flux.jl/issues/1055 + + y = conv(x, w, cdims) + gradtest((y, w) -> ∇conv_data(y, w, cdims), y, w) + if spatial_rank == 3 + @test_broken gradtest((y, w) -> sum(∇conv_data(y, w, cdims)), y, w) + else + gradtest((y, w) -> sum(∇conv_data(y, w, cdims)), y, w) + end + + dcdims = DepthwiseConvDims(x, w) + gradtest((x, w) -> depthwiseconv(x, w, dcdims), x, w) + + y = depthwiseconv(x, w, dcdims) + gradtest((y, w) -> ∇depthwiseconv_data(y, w, dcdims), y, w) + if spatial_rank == 3 + @test_broken gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w) + else + gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w) + end +end + +@testset "pooling: spatial_rank=$spatial_rank" for spatial_rank in (1, 2) + x = rand(rng, repeat([10], spatial_rank)..., 3, 2) + pdims = PoolDims(x, 2) + gradtest(x -> maxpool(x, pdims), x; broken=spatial_rank <= 2) + gradtest(x -> meanpool(x, pdims), x) + gradtest(x -> sum(maxpool(x, pdims)), x) + gradtest(x -> sum(meanpool(x, pdims)), x) + + #https://github.com/FluxML/NNlib.jl/issues/188 + k = ntuple(_ -> 2, spatial_rank) # Kernel size of pool in ntuple format + gradtest(x -> maxpool(x, k), x; broken=spatial_rank <= 2) + gradtest(x -> meanpool(x, k), x) + gradtest(x -> sum(maxpool(x, k)), x) + gradtest(x -> sum(meanpool(x, k)), x) +end + +@testset "batched matrix multiplication" begin + M, P, Q = 13, 7, 11 + B = 3 + gradtest(batched_mul, randn(rng, M, P, B), randn(rng, P, Q, B)) +end