diff --git a/Project.toml b/Project.toml index 7997ae6d5..ec7c6ebb0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,10 +1,9 @@ name = "Zygote" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.5.13" +version = "0.6.0" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" -ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" @@ -15,7 +14,6 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" @@ -25,7 +23,6 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractFFTs = "0.5" -ArrayLayouts = "0.1, 0.2, 0.3, 0.4" ChainRules = "0.7.16" DiffRules = "1.0" FillArrays = "0.8, 0.9, 0.10" @@ -33,7 +30,6 @@ ForwardDiff = "0.10" IRTools = "0.4" LoopVectorization = "0.8.15" MacroTools = "0.5" -NNlib = "0.7" NaNMath = "0.3" Requires = "0.5, 1.0" SpecialFunctions = "0.10, 1.0" diff --git a/src/Zygote.jl b/src/Zygote.jl index f06f00b51..3061eab05 100644 --- a/src/Zygote.jl +++ b/src/Zygote.jl @@ -2,7 +2,6 @@ module Zygote using LinearAlgebra, Statistics using LinearAlgebra: copytri!, AbstractTriangular -using ArrayLayouts: MemoryLayout, AbstractColumnMajor import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback, literal_getproperty @@ -35,7 +34,6 @@ include("lib/base.jl") include("lib/array.jl") include("lib/buffer.jl") include("lib/broadcast.jl") -include("lib/nnlib.jl") include("lib/forward.jl") include("lib/utils.jl") include("lib/range.jl") diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 682ae5997..15f66b8ea 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -14,7 +14,6 @@ using Base.Broadcast using Base.Broadcast: Broadcasted, AbstractArrayStyle, broadcasted, materialize -using NNlib # There's a saying that debugging code is about twice as hard as writing it in # the first place. So if you're as clever as you can be when writing code, how @@ -89,11 +88,6 @@ end @adjoint broadcasted(::typeof(identity), x::Numeric) = x, Δ -> (nothing, Δ) -@adjoint function broadcasted(::typeof(σ), x::Numeric) - y = σ.(x) - y, ȳ -> (nothing, ȳ .* conj.(y .* (1 .- y))) -end - @adjoint function broadcasted(::typeof(tanh), x::Numeric) y = tanh.(x) y, ȳ -> (nothing, ȳ .* conj.(1 .- y.^2)) diff --git a/src/lib/nnlib.jl b/src/lib/nnlib.jl deleted file mode 100644 index c8039ed94..000000000 --- a/src/lib/nnlib.jl +++ /dev/null @@ -1,103 +0,0 @@ -using NNlib -import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, depthwiseconv, ∇conv_data, ∇depthwiseconv_data, maxpool, meanpool, σ, relu, batched_mul, batched_adjoint - -drelu(x, Δ) = ifelse(x > 0, Δ, zero(x)) - -@adjoint function Base.Broadcast.broadcasted(::typeof(relu), x::Numeric) - relu.(x), Δ -> (nothing, drelu.(x, Δ)) -end - -function dselu(x) - λ = oftype(x/1, 1.0507009873554804934193349852946) - α = oftype(x/1, 1.6732632423543772848170429916717) - λ * ifelse(x > 0, one(x), α * exp(x)) -end - -@adjoint selu(x::Numeric) = selu(x), Δ -> (dselu(x) * Δ,) -@adjoint function Base.Broadcast.broadcasted(::typeof(selu), x::Numeric) - selu.(x), Δ -> (nothing, dselu.(x) .* Δ) -end - -delu(x, α) = ifelse(x ≥ 0, one(x), α * exp(x)) - -@adjoint elu(x::Numeric, α::Numeric) = elu(x, α), Δ -> (delu.(x, α) .* Δ, nothing) -@adjoint function Base.Broadcast.broadcasted(::typeof(elu), x::Numeric, α::Numeric) - elu.(x, α), Δ -> (nothing, delu.(x, α) .* Δ, nothing) -end - -@adjoint function σ(x::Real) - y = σ(x) - return y, Δ -> (Δ * y * (1 - y),) -end - -@adjoint softmax(xs; dims=1) = softmax(xs, dims=dims), Δ -> (∇softmax(Δ, xs, dims=dims),) - -@adjoint logsoftmax(xs; dims=1) = logsoftmax(xs, dims=dims), Δ -> (∇logsoftmax(Δ, xs, dims=dims),) - -@adjoint NNlib.DenseConvDims(args...; kwargs...) = NNlib.DenseConvDims(args...; kwargs...), _ -> nothing -@adjoint NNlib.DepthwiseConvDims(args...; kwargs...) = NNlib.DepthwiseConvDims(args...; kwargs...), _ -> nothing -@adjoint NNlib.PoolDims(args...; kwargs...) = NNlib.PoolDims(args...; kwargs...), _ -> nothing - -colmajor(x) = colmajor(MemoryLayout(typeof(x)), x) -colmajor(_, x) = convert(Array, x) -colmajor(::AbstractColumnMajor, x) = x - - -@adjoint conv(x, w, cdims; kw...) = - conv(x, w, cdims; kw...), - Δ -> begin - Δ = colmajor(Δ) - return ( - NNlib.∇conv_data(Δ, w, cdims; kw...), - NNlib.∇conv_filter(x, Δ, cdims; kw...), - nothing, - ) - end - -@adjoint ∇conv_data(x, w, cdims; kw...) = - ∇conv_data(x, w, cdims; kw...), - Δ -> begin - Δ = colmajor(Δ) - return ( - NNlib.conv(Δ, w, cdims; kw...), - NNlib.∇conv_filter(Δ, x, cdims; kw...), - nothing, - ) - end - -@adjoint depthwiseconv(x, w, cdims; kw...) = - depthwiseconv(x, w, cdims; kw...), - Δ -> begin - Δ = colmajor(Δ) - return ( - NNlib.∇depthwiseconv_data(Δ, w, cdims; kw...), - NNlib.∇depthwiseconv_filter(x, Δ, cdims; kw...), - nothing, - ) - end - -@adjoint ∇depthwiseconv_data(x, w, cdims; kw...) = - ∇depthwiseconv_data(x, w, cdims; kw...), - Δ -> begin - Δ = colmajor(Δ) - return ( - NNlib.depthwiseconv(Δ, w, cdims; kw...), - NNlib.∇depthwiseconv_filter(Δ, x, cdims; kw...), - nothing, - ) - end - -@adjoint function maxpool(x, pdims::NNlib.PoolDims; kw...) - y = maxpool(x, pdims; kw...) - y, Δ -> (NNlib.∇maxpool(Δ, y, x, pdims; kw...), nothing) -end - -@adjoint function meanpool(x, pdims::NNlib.PoolDims; kw...) - y = meanpool(x, pdims; kw...) - y, Δ -> (NNlib.∇meanpool(Δ, y, x, pdims; kw...), nothing) -end - -@adjoint function batched_mul(A, B) - C = batched_mul(A, B) - C, Δ -> (batched_mul(Δ, batched_adjoint(B)), batched_mul(batched_adjoint(A), Δ)) -end diff --git a/test/gradcheck.jl b/test/gradcheck.jl index ff32b31e0..4023a5656 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1,7 +1,6 @@ -using Zygote, NNlib, Test, Random, LinearAlgebra, Statistics, FillArrays, +using Zygote, Test, Random, LinearAlgebra, Statistics, FillArrays, AbstractFFTs, FFTW, Distances using Zygote: gradient -using NNlib: conv, ∇conv_data, depthwiseconv, batched_mul using Base.Broadcast: broadcast_shape using LoopVectorization: vmap using Distributed: pmap @@ -92,37 +91,10 @@ end @test gradtest((x, W, b) -> identity.(W*x .+ b), 5, (2,5), 2) @test gradtest((x, W, b) -> identity.(W*x .+ b), (5,3), (2,5), 2) -@test gradtest((x, W, b) -> relu.(W*x .+ b), 5, (2,5), 2) -@test gradtest((x, W, b) -> relu.(W*x .+ b), (5,3), (2,5), 2) -@test gradtest((x, W, b) -> selu.(W*x .+ b), 5, (2,5), 2) -@test gradtest((x, W, b) -> selu.(W*x .+ b), (5,3), (2,5), 2) -@test gradtest((x, W, b) -> elu.(W*x .+ b, 2), 5, (2,5), 2) -@test gradtest((x, W, b) -> elu.(W*x .+ b, 2), (5,3), (2,5), 2) - -# tests for https://github.com/FluxML/Zygote.jl/issues/758 -@test gradient(xs -> sum(selu.(xs)), [1_000, 10_000]) == ([1.0507009873554805, 1.0507009873554805],) -@test gradient(x -> selu(x), 1_000) == (1.0507009873554805,) -@test gradient(xs -> sum(elu.(xs, 2)), [1_000, 10_000]) == ([1., 1.],) -@test gradient(x -> elu(x, 2), 1_000) == (1.,) -@test gradient(x -> elu(x, 2), -1) == (2*exp(-1),) -@test gradcheck(x->sum(selu.(x)),[100., 1_000.]) -@test gradcheck(x->sum(elu.(x, 3.5)),[100., 1_000.]) -@test gradcheck(x->sum(elu.(x, 3.5)),[1_000., 10_000.]) # for elu the tests are passing but for selu not, interesting -# numerical instability even for the linear part of such function, see: -# julia> ngradient(x->sum(selu.(x)),[1_000., 10_000.]) -# ([1.0506591796875, 1.0506591796875],) -# julia> gradient(x->sum(selu.(x)),[1_000., 10_000.]) -# ([1.0507009873554805, 1.0507009873554805],) -@test_broken gradcheck(x->sum(selu.(x)),[1_000., 10_000.]) @test gradtest((x, W, b) -> tanh.(W*x .+ b), 5, (2,5), 2) @test gradtest((x, W, b) -> tanh.(W*x .+ b), (5,3), (2,5), 2) -@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2) -@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2) -@test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2) -@test gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2) - @test gradtest((w, x) -> w'*x, randn(10, 2), randn(10)) @test gradtest((w, x) -> Adjoint(w)*x, randn(10, 2), randn(10)) @test gradtest((w, x) -> transpose(w)*x, randn(5,5), randn(5,5)) @@ -163,13 +135,6 @@ end @test gradtest(x -> cumsum(x, dims=3), (3,4)) # trivial end -@test gradtest(x -> softmax(x).*(1:3), 3) -@test gradtest(x -> softmax(x).*(1:3), (3,5)) -@test gradtest(x -> softmax(x, dims=2).*(1:3), (3,5)) -@test gradtest(x -> logsoftmax(x).*(1:3), 3) -@test gradtest(x -> logsoftmax(x).*(1:3), (3,5)) -@test gradtest(x -> logsoftmax(x, dims=2).*(1:3), (3,5)) - @test gradtest(x -> x', rand(5)) @test gradtest(det, (4, 4)) @@ -235,49 +200,6 @@ end @test gradient(xs -> sum(inv, [x^2 for x in xs]), ones(2)) == ([-2, -2],) end -@testset "conv: spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) - x = rand(repeat([5], spatial_rank)..., 3, 2) - w = rand(repeat([3], spatial_rank)..., 3, 3) - cdims = DenseConvDims(x, w) - @test gradtest((x, w) -> conv(x, w, cdims), x, w) - @test gradtest((x, w) -> sum(conv(x, w, cdims)), x, w) # https://github.com/FluxML/Flux.jl/issues/1055 - - y = conv(x, w, cdims) - @test gradtest((y, w) -> ∇conv_data(y, w, cdims), y, w) - if spatial_rank == 3 - @test_broken gradtest((y, w) -> sum(∇conv_data(y, w, cdims)), y, w) - else - @test gradtest((y, w) -> sum(∇conv_data(y, w, cdims)), y, w) - end - - dcdims = DepthwiseConvDims(x, w) - @test gradtest((x, w) -> depthwiseconv(x, w, dcdims), x, w) - - y = depthwiseconv(x, w, dcdims) - @test gradtest((y, w) -> ∇depthwiseconv_data(y, w, dcdims), y, w) - if spatial_rank == 3 - @test_broken gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w) - else - @test gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w) - end -end - -@testset "pooling: spatial_rank=$spatial_rank" for spatial_rank in (1, 2) - x = rand(repeat([10], spatial_rank)..., 3, 2) - pdims = PoolDims(x, 2) - @test gradtest(x -> maxpool(x, pdims), x) - @test gradtest(x -> meanpool(x, pdims), x) - @test gradtest(x -> sum(maxpool(x, pdims)), x) - @test gradtest(x -> sum(meanpool(x, pdims)), x) - - #https://github.com/FluxML/NNlib.jl/issues/188 - k = ntuple(_ -> 2, spatial_rank) # Kernel size of pool in ntuple format - @test gradtest(x -> maxpool(x, k), x) - @test gradtest(x -> meanpool(x, k), x) - @test gradtest(x -> sum(maxpool(x, k)), x) - @test gradtest(x -> sum(meanpool(x, k)), x) -end - @test gradtest(x -> reverse(x), rand(17)) @test gradtest(x -> reverse(x, 8), rand(17)) @test gradtest(x -> reverse(x, 8, 13), rand(17)) @@ -523,11 +445,6 @@ end @test first(back(randn(rng, M, P))) isa Vector end end - - @testset "batched matrix multiplication" begin - B = 3 - @test gradtest(batched_mul, randn(rng, M, P, B), randn(rng, P, Q, B)) - end end @testset "backsolve" begin