From e86365ed3f4e53fce37f4a0f70258956c657ab9d Mon Sep 17 00:00:00 2001 From: Tejan Karmali Date: Sat, 8 Sep 2018 15:44:06 -0400 Subject: [PATCH] 1.0 fix for conv transpose --- src/Flux.jl | 2 +- src/layers/conv.jl | 44 +++++++++++++++++++++++++++++++++++++++++++- src/tracker/array.jl | 12 +++++++++++- test/tracker.jl | 6 +++++- 4 files changed, 60 insertions(+), 4 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index 8c959fecc1..ba6ac6e7ad 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -5,7 +5,7 @@ module Flux using MacroTools, Juno, Requires, Reexport, Statistics, Random using MacroTools: @forward -export Chain, Dense, RNN, LSTM, GRU, Conv, MaxPool, MeanPool, +export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool, Dropout, LayerNorm, BatchNorm, params, mapleaves, cpu, gpu diff --git a/src/layers/conv.jl b/src/layers/conv.jl index dbf8ccf93e..9b92a5e8dc 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1,4 +1,4 @@ -using NNlib: conv +using NNlib: conv, ∇conv_data @generated sub2(::Val{N}) where N = :(Val($(N-2))) @@ -51,6 +51,48 @@ function Base.show(io::IO, l::Conv) print(io, ")") end +""" + ConvTranspose(size, in=>out) + ConvTranspose(size, in=>out, relu) + +Standard convolutional transpose layer. `size` should be a tuple like `(2, 2)`. +`in` and `out` specify the number of input and output channels respectively. +Data should be stored in WHCN order. In other words, a 100×100 RGB image would +be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array. +Takes the keyword arguments `pad`, `stride` and `dilation`. +""" +struct ConvTranspose{N,F,A,V} + σ::F + weight::A + bias::V + stride::NTuple{N,Int} + pad::NTuple{N,Int} + dilation::NTuple{N,Int} +end + +ConvTranspose(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity; + stride = 1, pad = 0, dilation = 1) where {T,N} = + ConvTranspose(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...) + +ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn, + stride = 1, pad = 0, dilation = 1) where N = +ConvTranspose(param(init(k..., reverse(ch)...)), param(zeros(ch[2])), σ, + stride = stride, pad = pad, dilation = dilation) + +@treelike ConvTranspose + +function (c::ConvTranspose)(x) + # ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1))) + σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) + σ.(∇conv_data(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b) +end + +function Base.show(io::IO, l::ConvTranspose) + print(io, "ConvTranspose(", size(l.weight)[1:ndims(l.weight)-2]) + print(io, ", ", size(l.weight, ndims(l.weight)), "=>", size(l.weight, ndims(l.weight)-1)) + l.σ == identity || print(io, ", ", l.σ) + print(io, ")") +end """ MaxPool(k) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index ffa3a89eeb..da948a4e52 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -289,7 +289,7 @@ x::TrackedVector * y::TrackedVector = track(*, x, y) # NNlib using NNlib -import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, maxpool, meanpool +import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, ∇conv_data, maxpool, meanpool softmax(xs::TrackedArray) = track(softmax, xs) @@ -309,6 +309,16 @@ conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...) (NNlib.∇conv_data(data.((Δ, x, w))...; kw...), NNlib.∇conv_filter(data.((Δ, x, w))...; kw...))) +∇conv_data(x::TrackedArray, w::TrackedArray; kw...) = track(∇conv_data, x, w; kw...) +∇conv_data(x::AbstractArray, w::TrackedArray; kw...) = track(∇conv_data, x, w; kw...) +∇conv_data(x::TrackedArray, w::AbstractArray; kw...) = track(∇conv_data, x, w; kw...) + +@grad ∇conv_data(x, w; kw...) = + ∇conv_data(data(x), data(w); kw...), + Δ -> nobacksies(:conv, + (NNlib.conv(data.((x, Δ, w))...; kw...), + NNlib.∇conv_filter(data.((x, Δ, w))...; kw...))) + maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...) @grad function maxpool(x, k; kw...) diff --git a/test/tracker.jl b/test/tracker.jl index 03d14c35fa..f697e1d764 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -1,7 +1,7 @@ using Flux using Flux.Tracker, Test, NNlib using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint -using NNlib: conv +using NNlib: conv, ∇conv_data using Printf: @sprintf using LinearAlgebra: Diagonal, dot, LowerTriangular, norm using Statistics: mean, std @@ -176,6 +176,10 @@ end @test gradtest(conv, rand(10, 10, 3, 2), randn(Float64,2, 2, 3, 2)) @test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 3, 2)) +@test gradtest(∇conv_data, rand(10, 3, 2), randn(Float64, 2, 2, 3)) +@test gradtest(∇conv_data, rand(10, 10, 3, 2), randn(Float64,2, 2, 2, 3)) +@test gradtest(∇conv_data, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 2, 3)) + @test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2)) @test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2))