Skip to content

Commit

Permalink
1.0 fix for conv transpose
Browse files Browse the repository at this point in the history
  • Loading branch information
tejank10 committed Sep 8, 2018
1 parent d5d9441 commit e86365e
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ module Flux
using MacroTools, Juno, Requires, Reexport, Statistics, Random
using MacroTools: @forward

export Chain, Dense, RNN, LSTM, GRU, Conv, MaxPool, MeanPool,
export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
Dropout, LayerNorm, BatchNorm,
params, mapleaves, cpu, gpu

Expand Down
44 changes: 43 additions & 1 deletion src/layers/conv.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using NNlib: conv
using NNlib: conv, ∇conv_data

@generated sub2(::Val{N}) where N = :(Val($(N-2)))

Expand Down Expand Up @@ -51,6 +51,48 @@ function Base.show(io::IO, l::Conv)
print(io, ")")
end

"""
ConvTranspose(size, in=>out)
ConvTranspose(size, in=>out, relu)
Standard convolutional transpose layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively.
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad`, `stride` and `dilation`.
"""
struct ConvTranspose{N,F,A,V}
σ::F
weight::A
bias::V
stride::NTuple{N,Int}
pad::NTuple{N,Int}
dilation::NTuple{N,Int}
end

ConvTranspose(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N} =
ConvTranspose(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...)

ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
stride = 1, pad = 0, dilation = 1) where N =
ConvTranspose(param(init(k..., reverse(ch)...)), param(zeros(ch[2])), σ,
stride = stride, pad = pad, dilation = dilation)

@treelike ConvTranspose

function (c::ConvTranspose)(x)
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
σ.(∇conv_data(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b)
end

function Base.show(io::IO, l::ConvTranspose)
print(io, "ConvTranspose(", size(l.weight)[1:ndims(l.weight)-2])
print(io, ", ", size(l.weight, ndims(l.weight)), "=>", size(l.weight, ndims(l.weight)-1))
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end

"""
MaxPool(k)
Expand Down
12 changes: 11 additions & 1 deletion src/tracker/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ x::TrackedVector * y::TrackedVector = track(*, x, y)
# NNlib

using NNlib
import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, maxpool, meanpool
import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, ∇conv_data, maxpool, meanpool

softmax(xs::TrackedArray) = track(softmax, xs)

Expand All @@ -309,6 +309,16 @@ conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...)
(NNlib.∇conv_data(data.((Δ, x, w))...; kw...),
NNlib.∇conv_filter(data.((Δ, x, w))...; kw...)))

∇conv_data(x::TrackedArray, w::TrackedArray; kw...) = track(∇conv_data, x, w; kw...)
∇conv_data(x::AbstractArray, w::TrackedArray; kw...) = track(∇conv_data, x, w; kw...)
∇conv_data(x::TrackedArray, w::AbstractArray; kw...) = track(∇conv_data, x, w; kw...)

@grad ∇conv_data(x, w; kw...) =
∇conv_data(data(x), data(w); kw...),
Δ -> nobacksies(:conv,
(NNlib.conv(data.((x, Δ, w))...; kw...),
NNlib.∇conv_filter(data.((x, Δ, w))...; kw...)))

maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...)

@grad function maxpool(x, k; kw...)
Expand Down
6 changes: 5 additions & 1 deletion test/tracker.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Flux
using Flux.Tracker, Test, NNlib
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
using NNlib: conv
using NNlib: conv, ∇conv_data
using Printf: @sprintf
using LinearAlgebra: Diagonal, dot, LowerTriangular, norm
using Statistics: mean, std
Expand Down Expand Up @@ -176,6 +176,10 @@ end
@test gradtest(conv, rand(10, 10, 3, 2), randn(Float64,2, 2, 3, 2))
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 3, 2))

@test gradtest(∇conv_data, rand(10, 3, 2), randn(Float64, 2, 2, 3))
@test gradtest(∇conv_data, rand(10, 10, 3, 2), randn(Float64,2, 2, 2, 3))
@test gradtest(∇conv_data, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 2, 3))

@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2))
@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2))

Expand Down

0 comments on commit e86365e

Please sign in to comment.