Skip to content

Commit

Permalink
Try #728:
Browse files Browse the repository at this point in the history
  • Loading branch information
bors[bot] authored Jul 9, 2020
2 parents 469e34b + 58ee77a commit 7a80201
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 26 deletions.
4 changes: 2 additions & 2 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ image: nvidia/cuda:10.1-cudnn7-devel-ubuntu18.04

julia:1.3:
extends:
- .julia:1.3
- .julia:1.4
- .test
tags:
- nvidia
Expand All @@ -20,7 +20,7 @@ julia:nightly:

documentation:
extends:
- .julia:1.3
- .julia:1.4
- .documentation
tags:
- nvidia
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Future = "9fa8497b-333b-5362-9e8d-4d0656e87820"
IRTools = "7869d1d1-7146-5819-86e3-90919afe41df"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
Expand All @@ -36,7 +37,7 @@ NaNMath = "0.3"
Requires = "0.5, 1.0"
SpecialFunctions = "0.10"
ZygoteRules = "0.2"
julia = "1.3"
julia = "1.4"

[extras]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand Down
2 changes: 2 additions & 0 deletions src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ using IRTools
using MacroTools, Requires
using MacroTools: @forward

using LoopVectorization # for vmap

export Params, gradient, pullback, pushforward, @code_grad

include("tools/idset.jl")
Expand Down
30 changes: 16 additions & 14 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,23 +170,25 @@ function unzip(tuples)
_unzip(tuples, Val(N))
end

function ∇map(cx, f, args...)
ys_and_backs = map((args...) -> _pullback(cx, f, args...), args...)
if isempty(ys_and_backs)
ys_and_backs, _ -> nothing
else
ys, backs = unzip(ys_and_backs)
ys, function (Δ)
Δf_and_args_zipped = map((f, δ) -> f(δ), backs, Δ)
Δf_and_args = unzip(Δf_and_args_zipped)
Δf = reduce(accum, Δf_and_args[1])
(Δf, Δf_and_args[2:end]...)
for (mapfunc,∇mapfunc) in [(:map,:∇map),(:pmap,:∇pmap),(:vmap,:∇vmap)]
@eval function $∇mapfunc(cx, f, args...)
ys_and_backs = $mapfunc((args...) -> _pullback(cx, f, args...), args...)
if isempty(ys_and_backs)
ys_and_backs, _ -> nothing
else
ys, backs = unzip(ys_and_backs)
ys, function (Δ)
Δf_and_args_zipped = $mapfunc((f, δ) -> f(δ), backs, Δ)
Δf_and_args = unzip(Δf_and_args_zipped)
Δf = reduce(accum, Δf_and_args[1])
(Δf, Δf_and_args[2:end]...)
end
end
end
end

@adjoint function map(f, args::Union{AbstractArray,Tuple}...)
∇map(__context__, f, args...)
@eval @adjoint function $mapfunc(f, args::Union{AbstractArray,Tuple}...)
$∇mapfunc(__context__, f, args...)
end
end

function _pullback(cx::AContext, ::typeof(collect), g::Base.Generator)
Expand Down
21 changes: 12 additions & 9 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using Zygote, NNlib, Test, Random, LinearAlgebra, Statistics, FillArrays,
using Zygote: gradient
using NNlib: conv, ∇conv_data, depthwiseconv, batched_mul
using Base.Broadcast: broadcast_shape
using LoopVectorization

function ngradient(f, xs::AbstractArray...)
grads = zero.(xs)
Expand Down Expand Up @@ -103,7 +104,7 @@ end
@test gradtest((w, x) -> parent(w)*x, randn(5,5)', randn(5,5))
@test gradtest((w, x) -> parent(w)*x, transpose(randn(5,5)), randn(5,5))

@testset "sum, prod, cumsum" begin
@testset "sum, prod, cumsum" begin
@test gradtest(x -> sum(x, dims = (2, 3)), (3,4,5))
@test gradtest(x -> sum(abs2, x), randn(4, 3, 2))
@test gradtest(x -> sum(abs2, x; dims=1), randn(4, 3, 2))
Expand Down Expand Up @@ -301,15 +302,17 @@ end
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))

@testset "map" begin
@test gradtest(xs -> sum(map(x -> x^2, xs)), rand(2,3))
@test gradtest((xss...) -> sum(map((xs...) -> sqrt(sum(xs.^2)), xss...)), [rand(5) for _ in 1:6]...)
function foo(y)
bar = (x) -> x*y
sum(map(bar, 1:5))
for mapfunc in [map,pmap,vmap]
@testset "$mapfunc" begin
@test gradtest(xs -> sum(mapfunc(x -> x^2, xs)), rand(2,3))
@test gradtest((xss...) -> sum(mapfunc((xs...) -> sqrt(sum(xs.^2)), xss...)), [rand(5) for _ in 1:6]...)
function foo(y)
bar = (x) -> x*y
sum(mapfunc(bar, 1:5))
end
@test gradtest(foo, 3)
@test gradient(v -> sum([x for x in v]), [1.1,2.2,3.3]) == ([1, 1, 1],)
end
@test gradtest(foo, 3)
@test gradient(v -> sum([x for x in v]), [1.1,2.2,3.3]) == ([1, 1, 1],)
end

@testset "sort" begin
Expand Down

0 comments on commit 7a80201

Please sign in to comment.