Skip to content

Commit 8ee6af1

Browse files
bors[bot]ashridh
andcommitted
Merge #762
762: CrossCor layer r=avik-pal a=ayush-1506 Same as #423 (which could be edited since I lost access to that github account). Co-authored-by: ayush-1506 <ayush.shridhar1506@gmail.com>
2 parents 308b199 + 98a027a commit 8ee6af1

File tree

6 files changed

+99
-1
lines changed

6 files changed

+99
-1
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
* [Data.Iris](https://github.com/FluxML/Flux.jl/pull/652) makes Fisher's Iris dataset available with `Iris.labels` and `Iris.features`.
1818
* New [InstanceNorm](https://github.com/FluxML/Flux.jl/pull/634), as popularized by [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022).
1919
* New [GroupNorm](https://github.com/FluxML/Flux.jl/pull/696), as described in [Group Normalization](https://arxiv.org/abs/1803.08494).
20+
* New [CrossCor](https://github.com/FluxML/Flux.jl/pull/762).
2021

2122
AD Changes:
2223

docs/src/models/layers.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ MaxPool
1717
MeanPool
1818
DepthwiseConv
1919
ConvTranspose
20+
CrossCor
2021
```
2122

2223
## Recurrent Layers

src/Flux.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using Base: tail
66
using MacroTools, Juno, Requires, Reexport, Statistics, Random
77
using MacroTools: @forward
88

9-
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
9+
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
1010
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
1111
params, mapleaves, cpu, gpu, f32, f64
1212

src/layers/conv.jl

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,76 @@ end
198198

199199
(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
200200
a(T.(x))
201+
"""
202+
CrossCor(size, in=>out)
203+
CrossCor(size, in=>out, relu)
204+
205+
Standard cross convolutional layer. `size` should be a tuple like `(2, 2)`.
206+
`in` and `out` specify the number of input and output channels respectively.
207+
208+
Example: Applying CrossCor layer to a 1-channel input using a 2x2 window size,
209+
giving us a 16-channel output. Output is activated with ReLU.
210+
211+
size = (2,2)
212+
in = 1
213+
out = 16
214+
CrossCor((2, 2), 1=>16, relu)
215+
216+
Data should be stored in WHCN order (width, height, # channels, # batches).
217+
In other words, a 100×100 RGB image would be a `100×100×3×1` array,
218+
and a batch of 50 would be a `100×100×3×50` array.
219+
220+
Takes the keyword arguments `pad`, `stride` and `dilation`.
221+
"""
222+
struct CrossCor{N,M,F,A,V}
223+
σ::F
224+
weight::A
225+
bias::V
226+
stride::NTuple{N,Int}
227+
pad::NTuple{M,Int}
228+
dilation::NTuple{N,Int}
229+
end
230+
231+
function CrossCor(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
232+
stride = 1, pad = 0, dilation = 1) where {T,N}
233+
stride = expand(Val(N-2), stride)
234+
pad = expand(Val(2*(N-2)), pad)
235+
dilation = expand(Val(N-2), dilation)
236+
return CrossCor(σ, w, b, stride, pad, dilation)
237+
end
238+
239+
CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
240+
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
241+
CrossCor(param(init(k..., ch...)), param(zeros(ch[2])), σ,
242+
stride = stride, pad = pad, dilation = dilation)
243+
244+
@treelike CrossCor
245+
246+
function crosscor(x, w, ddims::DenseConvDims)
247+
ddims = DenseConvDims(ddims, F=true)
248+
return conv(x, w, ddims)
249+
end
250+
251+
function (c::CrossCor)(x::AbstractArray)
252+
# TODO: breaks gpu broadcast :(
253+
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
254+
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
255+
cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
256+
σ.(crosscor(x, c.weight, cdims) .+ b)
257+
end
258+
259+
function Base.show(io::IO, l::CrossCor)
260+
print(io, "CrossCor(", size(l.weight)[1:ndims(l.weight)-2])
261+
print(io, ", ", size(l.weight, ndims(l.weight)-1), "=>", size(l.weight, ndims(l.weight)))
262+
l.σ == identity || print(io, ", ", l.σ)
263+
print(io, ")")
264+
end
265+
266+
(a::CrossCor{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
267+
invoke(a, Tuple{AbstractArray}, x)
268+
269+
(a::CrossCor{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
270+
a(T.(x))
201271

202272
"""
203273
MaxPool(k)

test/cuda/cuda.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ c = gpu(Conv((2,2),3=>4))
3636
l = c(gpu(rand(10,10,3,2)))
3737
Flux.back!(sum(l))
3838

39+
c = gpu(CrossCor((2,2),3=>4))
40+
l = c(gpu(rand(10,10,3,2)))
41+
Flux.back!(sum(l))
42+
3943
end
4044

4145
@testset "onecold gpu" begin

test/layers/conv.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,27 @@ end
5656
@test size(x_hat) == size(x)
5757
end
5858

59+
@testset "CrossCor" begin
60+
x = rand(Float32, 28, 28, 1, 1)
61+
w = rand(2,2,1,1)
62+
y = CrossCor(w, [0.0])
63+
64+
@test sum(w .* x[1:2, 1:2, :, :]) == y(x)[1, 1, 1, 1]
65+
66+
r = zeros(Float32, 28, 28, 1, 5)
67+
m = Chain(
68+
CrossCor((2, 2), 1=>16, relu),
69+
MaxPool((2,2)),
70+
CrossCor((2, 2), 16=>8, relu),
71+
MaxPool((2,2)),
72+
x -> reshape(x, :, size(x, 4)),
73+
Dense(288, 10), softmax)
74+
75+
@test size(m(r)) == (10, 5)
76+
@test y(x) != Conv(w, [0.0])(x)
77+
@test CrossCor(w[end:-1:1, end:-1:1, :, :], [0.0])(x) == Conv(w, [0.0])(x)
78+
end
79+
5980
@testset "Conv with non quadratic window #700" begin
6081
data = zeros(Float32, 7,7,1,1)
6182
data[4,4,1,1] = 1
@@ -81,3 +102,4 @@ end
81102
true
82103
end
83104
end
105+

0 commit comments

Comments
 (0)