-
-
Notifications
You must be signed in to change notification settings - Fork 122
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
port rule definitions to ChainRulesCore (#242)
* port rule definitions from Zygote
- Loading branch information
1 parent
e117313
commit 9fc717a
Showing
6 changed files
with
237 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,31 @@ | ||
name = "NNlib" | ||
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" | ||
version = "0.7.9" | ||
version = "0.8.0" | ||
|
||
[deps] | ||
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" | ||
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" | ||
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" | ||
Requires = "ae029012-a4dd-5104-9daa-d747884805df" | ||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | ||
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" | ||
|
||
[compat] | ||
ChainRulesCore = "0.9" | ||
Compat = "3.14" | ||
Requires = "0.5, 1.0" | ||
ZygoteRules = "0.2" | ||
julia = "1.3" | ||
|
||
[extras] | ||
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" | ||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" | ||
|
||
[targets] | ||
test = ["Test", "Statistics", "Zygote"] | ||
test = ["FiniteDifferences", "Random", "StableRNGs", "Statistics", "Test", "Zygote"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
using ChainRulesCore | ||
|
||
const Numeric = Union{AbstractArray{<:T}, T} where {T<:Number} | ||
|
||
@scalar_rule(selu(x), dselu(x)) | ||
@scalar_rule(elu(x, α), (delu(x, α), DoesNotExist())) | ||
@scalar_rule(σ(x::Real), Ω * (1 - Ω)) | ||
|
||
function dselu(x) | ||
λ = oftype(x/1, 1.0507009873554804934193349852946) | ||
α = oftype(x/1, 1.6732632423543772848170429916717) | ||
return λ * ifelse(x > 0, one(x), α * exp(x)) | ||
end | ||
delu(x, α) = ifelse(x ≥ 0, one(x), α * exp(x)) | ||
|
||
for softmax in [:softmax, :logsoftmax] | ||
local ∇softmax = Symbol(:∇, softmax) | ||
pullback = Symbol(softmax, :_pullback) | ||
|
||
@eval function ChainRulesCore.rrule(::typeof($softmax), xs; dims=1) | ||
$pullback(Δ) = (NO_FIELDS, @thunk($∇softmax(Δ, xs, dims=dims))) | ||
return $softmax(xs; dims=dims), $pullback | ||
end | ||
end | ||
|
||
for Dims in [:DenseConvDims, :DepthwiseConvDims, :PoolDims] | ||
pullback = Symbol(Dims, :_pullback) | ||
@eval function ChainRulesCore.rrule(::Type{$Dims}, args...; kwargs...) | ||
$pullback(Δ) = (NO_FIELDS, ntuple(_ -> DoesNotExist(), length(args))...) | ||
return $Dims(args...; kwargs...), $pullback | ||
end | ||
end | ||
|
||
colmajor(x) = (is_strided(x) && Base.stride(x, 1) == 1) ? x : collect(x) | ||
|
||
for conv in [:conv, :depthwiseconv] | ||
local ∇conv_data, ∇conv_filter = Symbol.(:∇, conv, [:_data, :_filter]) | ||
conv_pullback, ∇conv_data_pullback = Symbol.([conv, ∇conv_data], :_pullback) | ||
|
||
@eval function ChainRulesCore.rrule(::typeof($conv), x, w, cdims; kw...) | ||
function $conv_pullback(Δ) | ||
Δ = colmajor(Δ) | ||
return ( | ||
NO_FIELDS, | ||
@thunk($∇conv_data(Δ, w, cdims, kw...)), | ||
@thunk($∇conv_filter(x, Δ, cdims, kw...)), | ||
DoesNotExist(), | ||
) | ||
end | ||
return $conv(x, w, cdims; kw...), $conv_pullback | ||
end | ||
|
||
@eval function ChainRulesCore.rrule(::typeof($∇conv_data), x, w, cdims; kw...) | ||
function $∇conv_data_pullback(Δ) | ||
Δ = colmajor(Δ) | ||
return ( | ||
NO_FIELDS, | ||
@thunk($conv(Δ, w, cdims, kw...)), | ||
@thunk($∇conv_filter(Δ, x, cdims, kw...)), | ||
DoesNotExist(), | ||
) | ||
end | ||
return $∇conv_data(x, w, cdims; kw...), $∇conv_data_pullback | ||
end | ||
end | ||
|
||
for pool in [:maxpool, :meanpool] | ||
∇pool = Symbol(:∇, pool) | ||
pullback = Symbol(pool, :_pullback) | ||
@eval function ChainRulesCore.rrule(::typeof($pool), x, pdims::PoolDims; kw...) | ||
Ω = $pool(x, pdims; kw...) | ||
$pullback(Δ) = (NO_FIELDS, @thunk($∇pool(Δ, Ω, x, pdims; kw...)), DoesNotExist()) | ||
return Ω, $pullback | ||
end | ||
end | ||
|
||
function ChainRulesCore.rrule( | ||
::typeof(batched_mul), | ||
A::AbstractArray{S,3}, | ||
B::AbstractArray{T,3}, | ||
) where {S,T} | ||
function batched_mul_pullback(Δ) | ||
return ( | ||
NO_FIELDS, | ||
@thunk(batched_mul(Δ, batched_adjoint(B))), | ||
@thunk(batched_mul(batched_adjoint(A), Δ)), | ||
) | ||
end | ||
batched_mul(A, B), batched_mul_pullback | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
using ZygoteRules | ||
|
||
# This is a performance hack specifically for Zygote, because it doesn't handle fused | ||
# broadcasts well | ||
for (f, df) in [ | ||
(:relu, :(x .> 0)), | ||
(:selu, :(dselu.(x))), | ||
(:elu, :(delu.(x))), | ||
(:σ, :(conj.(Ω .* (1 .- Ω)))), | ||
] | ||
pullback = Symbol(:broadcasted_, f, :_pullback) | ||
@eval @adjoint function Base.Broadcast.broadcasted(::typeof($f), x::Numeric) | ||
Ω = $f.(x) | ||
$pullback(Δ) = (nothing, Δ .* $df) | ||
return Ω, $pullback | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,3 +19,6 @@ end | |
@testset "Softmax" begin | ||
include("softmax.jl") | ||
end | ||
@testset "Zygote" begin | ||
include("zygote.jl") | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
using Zygote, NNlib | ||
using Random | ||
using LinearAlgebra | ||
using NNlib: conv, ∇conv_data, depthwiseconv, batched_mul | ||
using FiniteDifferences: grad, central_fdm | ||
using StableRNGs | ||
|
||
const rng = StableRNG(123) | ||
|
||
function gradcheck(f, xs...; rtol = 1e-5, atol = 1e-5, broken=false) | ||
grad_zygote = gradient(f, xs...) | ||
grad_finite_difference = grad(central_fdm(5, 1), f, xs...) | ||
for (grad_zygote, grad_finite_difference) in zip(grad_zygote, grad_finite_difference) | ||
if broken | ||
@test_broken isapprox(grad_zygote, grad_finite_difference; rtol = rtol, atol = atol) | ||
else | ||
@test isapprox(grad_zygote, grad_finite_difference; rtol = rtol, atol = atol) | ||
if !isapprox(grad_zygote, grad_finite_difference; rtol = rtol, atol = atol) | ||
display(grad_zygote - grad_finite_difference) | ||
@show maximum(abs, grad_zygote - grad_finite_difference) | ||
@show norm(grad_zygote) norm(grad_finite_difference) | ||
println() | ||
end | ||
end | ||
end | ||
end | ||
|
||
gradtest(f, xs::AbstractArray...; kw...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...; kw...) | ||
gradtest(f, dims...; kw...) = gradtest(f, rand.(Float64, dims)...; kw...) | ||
|
||
gradtest((x, W, b) -> relu.(W*x .+ b), 5, (2,5), 2) | ||
gradtest((x, W, b) -> relu.(W*x .+ b), (5,3), (2,5), 2) | ||
gradtest((x, W, b) -> selu.(W*x .+ b), 5, (2,5), 2) | ||
gradtest((x, W, b) -> selu.(W*x .+ b), (5,3), (2,5), 2) | ||
gradtest((x, W, b) -> elu.(W*x .+ b, 2), 5, (2,5), 2) | ||
gradtest((x, W, b) -> elu.(W*x .+ b, 2), (5,3), (2,5), 2) | ||
|
||
# tests for https://github.com/FluxML/Zygote.jl/issues/758 | ||
@test gradient(xs -> sum(selu.(xs)), [1_000, 10_000])[1] ≈ [1.0507009873554805, 1.0507009873554805] rtol=1e-8 | ||
@test gradient(x -> selu(x), 1_000) == (1.0507009873554805,) | ||
@test gradient(xs -> sum(elu.(xs, 2)), [1_000, 10_000]) == ([1., 1.],) | ||
@test gradient(x -> elu(x, 2), 1_000) == (1.,) | ||
@test gradient(x -> elu(x, 2), -1) == (2*exp(-1),) | ||
gradcheck(x->sum(selu.(x)),[100., 1_000.]) | ||
gradcheck(x->sum(elu.(x, 3.5)),[100., 1_000.]) | ||
gradcheck(x->sum(elu.(x, 3.5)),[1_000., 10_000.]) # for elu the tests are passing but for selu not, interesting | ||
# numerical instability even for the linear part of such function, see: | ||
# julia> ngradient(x->sum(selu.(x)),[1_000., 10_000.]) | ||
# ([1.0506591796875, 1.0506591796875],) | ||
# julia> gradient(x->sum(selu.(x)),[1_000., 10_000.]) | ||
# ([1.0507009873554805, 1.0507009873554805],) | ||
@test_broken gradcheck(x->sum(selu.(x)),[1_000., 10_000.]) | ||
|
||
gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2) | ||
gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2) | ||
gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2) | ||
gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2) | ||
|
||
gradtest(x -> softmax(x).*(1:3), 3) | ||
gradtest(x -> softmax(x).*(1:3), (3,5)) | ||
gradtest(x -> softmax(x, dims=2).*(1:3), (3,5)) | ||
gradtest(x -> logsoftmax(x).*(1:3), 3) | ||
gradtest(x -> logsoftmax(x).*(1:3), (3,5)) | ||
gradtest(x -> logsoftmax(x, dims=2).*(1:3), (3,5)) | ||
|
||
@testset "conv: spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) | ||
x = rand(rng, repeat([5], spatial_rank)..., 3, 2) | ||
w = rand(rng, repeat([3], spatial_rank)..., 3, 3) | ||
cdims = DenseConvDims(x, w) | ||
gradtest((x, w) -> conv(x, w, cdims), x, w) | ||
gradtest((x, w) -> sum(conv(x, w, cdims)), x, w) # https://github.com/FluxML/Flux.jl/issues/1055 | ||
|
||
y = conv(x, w, cdims) | ||
gradtest((y, w) -> ∇conv_data(y, w, cdims), y, w) | ||
if spatial_rank == 3 | ||
@test_broken gradtest((y, w) -> sum(∇conv_data(y, w, cdims)), y, w) | ||
else | ||
gradtest((y, w) -> sum(∇conv_data(y, w, cdims)), y, w) | ||
end | ||
|
||
dcdims = DepthwiseConvDims(x, w) | ||
gradtest((x, w) -> depthwiseconv(x, w, dcdims), x, w) | ||
|
||
y = depthwiseconv(x, w, dcdims) | ||
gradtest((y, w) -> ∇depthwiseconv_data(y, w, dcdims), y, w) | ||
if spatial_rank == 3 | ||
@test_broken gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w) | ||
else | ||
gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w) | ||
end | ||
end | ||
|
||
@testset "pooling: spatial_rank=$spatial_rank" for spatial_rank in (1, 2) | ||
x = rand(rng, repeat([10], spatial_rank)..., 3, 2) | ||
pdims = PoolDims(x, 2) | ||
gradtest(x -> maxpool(x, pdims), x; broken=spatial_rank <= 2) | ||
gradtest(x -> meanpool(x, pdims), x) | ||
gradtest(x -> sum(maxpool(x, pdims)), x) | ||
gradtest(x -> sum(meanpool(x, pdims)), x) | ||
|
||
#https://github.com/FluxML/NNlib.jl/issues/188 | ||
k = ntuple(_ -> 2, spatial_rank) # Kernel size of pool in ntuple format | ||
gradtest(x -> maxpool(x, k), x; broken=spatial_rank <= 2) | ||
gradtest(x -> meanpool(x, k), x) | ||
gradtest(x -> sum(maxpool(x, k)), x) | ||
gradtest(x -> sum(meanpool(x, k)), x) | ||
end | ||
|
||
@testset "batched matrix multiplication" begin | ||
M, P, Q = 13, 7, 11 | ||
B = 3 | ||
gradtest(batched_mul, randn(rng, M, P, B), randn(rng, P, Q, B)) | ||
end |