Skip to content

Commit

Permalink
Try #728:
Browse files Browse the repository at this point in the history
  • Loading branch information
bors[bot] authored Jul 10, 2020
2 parents 469e34b + cbdb28d commit eda2d9b
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 26 deletions.
4 changes: 2 additions & 2 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ image: nvidia/cuda:10.1-cudnn7-devel-ubuntu18.04

julia:1.3:
extends:
- .julia:1.3
- .julia:1.4
- .test
tags:
- nvidia
Expand All @@ -20,7 +20,7 @@ julia:nightly:

documentation:
extends:
- .julia:1.3
- .julia:1.4
- .documentation
tags:
- nvidia
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Future = "9fa8497b-333b-5362-9e8d-4d0656e87820"
IRTools = "7869d1d1-7146-5819-86e3-90919afe41df"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
Expand All @@ -30,13 +32,14 @@ DiffRules = "1.0"
FillArrays = "0.8"
ForwardDiff = "0"
IRTools = "0.4"
LoopVectorization = "0.8.15"
MacroTools = "0.5"
NNlib = "0.7"
NaNMath = "0.3"
Requires = "0.5, 1.0"
SpecialFunctions = "0.10"
ZygoteRules = "0.2"
julia = "1.3"
julia = "1.4"

[extras]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand Down
2 changes: 2 additions & 0 deletions src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ using IRTools
using MacroTools, Requires
using MacroTools: @forward

using LoopVectorization # for vmap

export Params, gradient, pullback, pushforward, @code_grad

include("tools/idset.jl")
Expand Down
31 changes: 17 additions & 14 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Random, FillArrays, AbstractFFTs
using FillArrays: AbstractFill, getindex_value
using Base.Broadcast: broadcasted, broadcast_shape
using Distributed: pmap

@adjoint (::Type{T})(::UndefInitializer, args...) where T<:Array = T(undef, args...), Δ -> nothing

Expand Down Expand Up @@ -170,23 +171,25 @@ function unzip(tuples)
_unzip(tuples, Val(N))
end

function ∇map(cx, f, args...)
ys_and_backs = map((args...) -> _pullback(cx, f, args...), args...)
if isempty(ys_and_backs)
ys_and_backs, _ -> nothing
else
ys, backs = unzip(ys_and_backs)
ys, function (Δ)
Δf_and_args_zipped = map((f, δ) -> f(δ), backs, Δ)
Δf_and_args = unzip(Δf_and_args_zipped)
Δf = reduce(accum, Δf_and_args[1])
(Δf, Δf_and_args[2:end]...)
for (mapfunc,∇mapfunc) in [(:map,:∇map),(:pmap,:∇pmap),(:vmap,:∇vmap)]
@eval function $∇mapfunc(cx, f, args...)
ys_and_backs = $mapfunc((args...) -> _pullback(cx, f, args...), args...)
if isempty(ys_and_backs)
ys_and_backs, _ -> nothing
else
ys, backs = unzip(ys_and_backs)
ys, function (Δ)
Δf_and_args_zipped = $mapfunc((f, δ) -> f(δ), backs, Δ)
Δf_and_args = unzip(Δf_and_args_zipped)
Δf = reduce(accum, Δf_and_args[1])
(Δf, Δf_and_args[2:end]...)
end
end
end
end

@adjoint function map(f, args::Union{AbstractArray,Tuple}...)
∇map(__context__, f, args...)
@eval @adjoint function $mapfunc(f, args::Union{AbstractArray,Tuple}...)
$∇mapfunc(__context__, f, args...)
end
end

function _pullback(cx::AContext, ::typeof(collect), g::Base.Generator)
Expand Down
21 changes: 12 additions & 9 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using Zygote, NNlib, Test, Random, LinearAlgebra, Statistics, FillArrays,
using Zygote: gradient
using NNlib: conv, ∇conv_data, depthwiseconv, batched_mul
using Base.Broadcast: broadcast_shape
using LoopVectorization, Distributed

function ngradient(f, xs::AbstractArray...)
grads = zero.(xs)
Expand Down Expand Up @@ -103,7 +104,7 @@ end
@test gradtest((w, x) -> parent(w)*x, randn(5,5)', randn(5,5))
@test gradtest((w, x) -> parent(w)*x, transpose(randn(5,5)), randn(5,5))

@testset "sum, prod, cumsum" begin
@testset "sum, prod, cumsum" begin
@test gradtest(x -> sum(x, dims = (2, 3)), (3,4,5))
@test gradtest(x -> sum(abs2, x), randn(4, 3, 2))
@test gradtest(x -> sum(abs2, x; dims=1), randn(4, 3, 2))
Expand Down Expand Up @@ -301,15 +302,17 @@ end
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))

@testset "map" begin
@test gradtest(xs -> sum(map(x -> x^2, xs)), rand(2,3))
@test gradtest((xss...) -> sum(map((xs...) -> sqrt(sum(xs.^2)), xss...)), [rand(5) for _ in 1:6]...)
function foo(y)
bar = (x) -> x*y
sum(map(bar, 1:5))
for mapfunc in [map,pmap,vmap]
@testset "$mapfunc" begin
@test gradtest(xs -> sum(mapfunc(x -> x^2, xs)), rand(2,3))
@test gradtest((xss...) -> sum(mapfunc((xs...) -> sqrt(sum(xs.^2)), xss...)), [rand(5) for _ in 1:6]...)
function foo(y)
bar = (x) -> x*y
sum(mapfunc(bar, 1:5))
end
@test gradtest(foo, 3)
@test gradient(v -> sum([x for x in v]), [1.1,2.2,3.3]) == ([1, 1, 1],)
end
@test gradtest(foo, 3)
@test gradient(v -> sum([x for x in v]), [1.1,2.2,3.3]) == ([1, 1, 1],)
end

@testset "sort" begin
Expand Down

0 comments on commit eda2d9b

Please sign in to comment.