diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 62b65c6d8..04a7c5760 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -5,7 +5,7 @@ image: nvidia/cuda:10.1-cudnn7-devel-ubuntu18.04 julia:1.3: extends: - - .julia:1.3 + - .julia:1.4 - .test tags: - nvidia @@ -20,7 +20,7 @@ julia:nightly: documentation: extends: - - .julia:1.3 + - .julia:1.4 - .documentation tags: - nvidia diff --git a/Project.toml b/Project.toml index 7d4c2914d..5f636c5cb 100644 --- a/Project.toml +++ b/Project.toml @@ -7,12 +7,14 @@ AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" +Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Future = "9fa8497b-333b-5362-9e8d-4d0656e87820" IRTools = "7869d1d1-7146-5819-86e3-90919afe41df" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" @@ -30,13 +32,14 @@ DiffRules = "1.0" FillArrays = "0.8" ForwardDiff = "0" IRTools = "0.4" +LoopVectorization = "0.8.15" MacroTools = "0.5" NNlib = "0.7" NaNMath = "0.3" Requires = "0.5, 1.0" SpecialFunctions = "0.10" ZygoteRules = "0.2" -julia = "1.3" +julia = "1.4" [extras] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" diff --git a/src/Zygote.jl b/src/Zygote.jl index 6722ce880..0752fd803 100644 --- a/src/Zygote.jl +++ b/src/Zygote.jl @@ -11,6 +11,8 @@ using IRTools using MacroTools, Requires using MacroTools: @forward +using LoopVectorization # for vmap + export Params, gradient, pullback, pushforward, @code_grad include("tools/idset.jl") diff --git a/src/lib/array.jl b/src/lib/array.jl index e23b815cb..3d4475f09 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -1,6 +1,7 @@ using Random, FillArrays, AbstractFFTs using FillArrays: AbstractFill, getindex_value using Base.Broadcast: broadcasted, broadcast_shape +using Distributed: pmap @adjoint (::Type{T})(::UndefInitializer, args...) where T<:Array = T(undef, args...), Δ -> nothing @@ -170,23 +171,25 @@ function unzip(tuples) _unzip(tuples, Val(N)) end -function ∇map(cx, f, args...) - ys_and_backs = map((args...) -> _pullback(cx, f, args...), args...) - if isempty(ys_and_backs) - ys_and_backs, _ -> nothing - else - ys, backs = unzip(ys_and_backs) - ys, function (Δ) - Δf_and_args_zipped = map((f, δ) -> f(δ), backs, Δ) - Δf_and_args = unzip(Δf_and_args_zipped) - Δf = reduce(accum, Δf_and_args[1]) - (Δf, Δf_and_args[2:end]...) +for (mapfunc,∇mapfunc) in [(:map,:∇map),(:pmap,:∇pmap),(:vmap,:∇vmap)] + @eval function $∇mapfunc(cx, f, args...) + ys_and_backs = $mapfunc((args...) -> _pullback(cx, f, args...), args...) + if isempty(ys_and_backs) + ys_and_backs, _ -> nothing + else + ys, backs = unzip(ys_and_backs) + ys, function (Δ) + Δf_and_args_zipped = $mapfunc((f, δ) -> f(δ), backs, Δ) + Δf_and_args = unzip(Δf_and_args_zipped) + Δf = reduce(accum, Δf_and_args[1]) + (Δf, Δf_and_args[2:end]...) + end end end -end -@adjoint function map(f, args::Union{AbstractArray,Tuple}...) - ∇map(__context__, f, args...) + @eval @adjoint function $mapfunc(f, args::Union{AbstractArray,Tuple}...) + $∇mapfunc(__context__, f, args...) + end end function _pullback(cx::AContext, ::typeof(collect), g::Base.Generator) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 43eab346d..e77d73483 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -3,6 +3,7 @@ using Zygote, NNlib, Test, Random, LinearAlgebra, Statistics, FillArrays, using Zygote: gradient using NNlib: conv, ∇conv_data, depthwiseconv, batched_mul using Base.Broadcast: broadcast_shape +using LoopVectorization, Distributed function ngradient(f, xs::AbstractArray...) grads = zero.(xs) @@ -103,7 +104,7 @@ end @test gradtest((w, x) -> parent(w)*x, randn(5,5)', randn(5,5)) @test gradtest((w, x) -> parent(w)*x, transpose(randn(5,5)), randn(5,5)) -@testset "sum, prod, cumsum" begin +@testset "sum, prod, cumsum" begin @test gradtest(x -> sum(x, dims = (2, 3)), (3,4,5)) @test gradtest(x -> sum(abs2, x), randn(4, 3, 2)) @test gradtest(x -> sum(abs2, x; dims=1), randn(4, 3, 2)) @@ -301,15 +302,17 @@ end @test gradtest(kron, rand(5,1), rand(3,1), rand(8,1)) @test gradtest(kron, rand(5,2), rand(3,2), rand(8,2)) -@testset "map" begin - @test gradtest(xs -> sum(map(x -> x^2, xs)), rand(2,3)) - @test gradtest((xss...) -> sum(map((xs...) -> sqrt(sum(xs.^2)), xss...)), [rand(5) for _ in 1:6]...) - function foo(y) - bar = (x) -> x*y - sum(map(bar, 1:5)) +for mapfunc in [map,pmap,vmap] + @testset "$mapfunc" begin + @test gradtest(xs -> sum(mapfunc(x -> x^2, xs)), rand(2,3)) + @test gradtest((xss...) -> sum(mapfunc((xs...) -> sqrt(sum(xs.^2)), xss...)), [rand(5) for _ in 1:6]...) + function foo(y) + bar = (x) -> x*y + sum(mapfunc(bar, 1:5)) + end + @test gradtest(foo, 3) + @test gradient(v -> sum([x for x in v]), [1.1,2.2,3.3]) == ([1, 1, 1],) end - @test gradtest(foo, 3) - @test gradient(v -> sum([x for x in v]), [1.1,2.2,3.3]) == ([1, 1, 1],) end @testset "sort" begin