Skip to content

Commit

Permalink
Fix single number (v/h)cat (#152)
Browse files Browse the repository at this point in the history
* fix single number (v/h)cat

* fix some indents

* fix tests
  • Loading branch information
mohamed82008 authored Sep 21, 2020
1 parent c14e0e2 commit 71c5ac0
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 37 deletions.
60 changes: 30 additions & 30 deletions src/derivatives/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,17 @@ end

@grad function vcat(xs::Union{Number, AbstractVecOrMat}...)
xs_value = value.(xs)
out_value = reduce(vcat,xs_value)
out_value = vcat(xs_value...,)
function back(Δ)
start = 0
Δs = map(xs) do xsi
if xsi isa Number
d = Δ[start+1]
else
d = Δ[start+1:start+size(xsi,1), :]
end
start += size(xsi, 1)
d
if xsi isa Number
d = Δ[start+1]
else
d = Δ[start+1:start+size(xsi,1), :]
end
start += size(xsi, 1)
d
end
return (Δs...,)
end
Expand All @@ -77,20 +77,20 @@ end

@grad function hcat(xs::Union{Number, AbstractVecOrMat}...)
xs_value = value.(xs)
out_value = reduce(hcat, xs_value)
out_value = hcat(xs_value...,)
function back(Δ)
start = 0
Δs = map(xs) do xsi
d = if ndims(xsi) == 0
Δ[start+1]
elseif ndims(xsi) == 1
Δ[:, start+1]
else
i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail
Δ[:, start+1:start+size(xsi,2), i...]
end
start += size(xsi, 2)
d
d = if ndims(xsi) == 0
Δ[start+1]
elseif ndims(xsi) == 1
Δ[:, start+1]
else
i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail
Δ[:, start+1:start+size(xsi,2), i...]
end
start += size(xsi, 2)
d
end
return (Δs...,)
end
Expand All @@ -106,17 +106,17 @@ end
return cat(Xs_value...; dims = dims), Δ -> begin
start = ntuple(i -> 0, Val(ndims(Δ)))
Δs = map(Xs) do xs
if xs isa Number
d = Δ[start+1]
start = start .+ 1
else
dim_xs = 1:ndims(xs)
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val(ndims(Δ)))
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val(ndims(Δ)))
d = reshape(Δ[xs_in_Δ...],size(xs))
start = start .+ till_xs
end
d
if xs isa Number
d = Δ[start+1]
start = start .+ 1
else
dim_xs = 1:ndims(xs)
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val(ndims(Δ)))
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val(ndims(Δ)))
d = reshape(Δ[xs_in_Δ...],size(xs))
start = start .+ till_xs
end
d
end
return (Δs...,)
end
Expand Down
24 changes: 17 additions & 7 deletions test/derivatives/ArrayFunctionTests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,25 @@ end
@test any(iszero, track([ones(2); 0.0]))
end

function testcat(f, args::Tuple{Any, Any}, type, kwargs=NamedTuple())
function testcat(f, args::Tuple, type, kwargs=NamedTuple())
x = f(track.(args)...; kwargs...)
@test x isa type
@test value(x) == f(args...; kwargs...)

x = f(track(args[1]), args[2]; kwargs...)
@test x isa type
@test value(x) == f(args...; kwargs...)
if length(args) == 1
x = f(track(args[1]); kwargs...)
@test x isa type
@test value(x) == f(args...; kwargs...)
else
@assert length(args) == 2
x = f(track(args[1]), args[2]; kwargs...)
@test x isa type
@test value(x) == f(args...; kwargs...)

x = f(args[1], track(args[2]); kwargs...)
@test x isa type
@test value(x) == f(args...; kwargs...)
x = f(args[1], track(args[2]); kwargs...)
@test x isa type
@test value(x) == f(args...; kwargs...)
end

args = (args..., args...)
x = f(track.(args)...; kwargs...)
Expand Down Expand Up @@ -64,6 +71,7 @@ end
a = rand(3,3,3)
n = rand()

testcat(cat, (n,), TrackedVector, (dims=1,))
testcat(cat, (n, n), TrackedVector, (dims=1,))
testcat(cat, (n, n), TrackedMatrix, (dims=2,))
testcat(cat, (v, n), TrackedVector, (dims=1,))
Expand All @@ -79,11 +87,13 @@ end
testcat(cat, (a, a), TrackedArray, (dims=3,))
testcat(cat, (a, m), TrackedArray, (dims=3,))

testcat(vcat, (n,), TrackedVector)
testcat(vcat, (n, n), TrackedVector)
testcat(vcat, (v, n), TrackedVector)
testcat(vcat, (n, v), TrackedVector)
testcat(vcat, (v, v), TrackedVector)

testcat(hcat, (n,), TrackedMatrix)
testcat(hcat, (n, n), TrackedMatrix)
testcat(hcat, (v, v), TrackedMatrix)
testcat(hcat, (v, m), TrackedMatrix)
Expand Down

0 comments on commit 71c5ac0

Please sign in to comment.