Skip to content

Commit 161fa86

Browse files
committed
conflicts
1 parent 38b307b commit 161fa86

File tree

3 files changed

+13
-524
lines changed

3 files changed

+13
-524
lines changed

src/layers/conv.jl

Lines changed: 12 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
1-
<<<<<<< HEAD
2-
using NNlib: conv, ∇conv_data, depthwiseconv, crossconv
3-
=======
4-
using NNlib: conv, depthwiseconv, crosscor
5-
>>>>>>> some final changes
1+
using NNlib: conv, ∇conv_data, depthwiseconv, crosscor
62

73
@generated sub2(::Val{N}) where N = :(Val($(N-2)))
84

@@ -73,8 +69,6 @@ end
7369
"""
7470
ConvTranspose(size, in=>out)
7571
ConvTranspose(size, in=>out, relu)
76-
CrossCor(size, in=>out)
77-
CrossCor(size, in=>out, relu)
7872
7973
Standard convolutional transpose layer. `size` should be a tuple like `(2, 2)`.
8074
`in` and `out` specify the number of input and output channels respectively.
@@ -83,7 +77,6 @@ be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
8377
Takes the keyword arguments `pad`, `stride` and `dilation`.
8478
"""
8579
struct ConvTranspose{N,F,A,V}
86-
struct CrossCor{N,F,A,V}
8780
σ::F
8881
weight::A
8982
bias::V
@@ -173,8 +166,8 @@ function Base.show(io::IO, l::DepthwiseConv)
173166
end
174167

175168
"""
176-
CrossConv(size, in=>out)
177-
CrossConv(size, in=>out, relu)
169+
CrossCor(size, in=>out)
170+
CrossCor(size, in=>out, relu)
178171
179172
Standard cross convolutional layer. `size` should be a tuple like `(2, 2)`.
180173
`in` and `out` specify the number of input and output channels respectively.
@@ -197,8 +190,8 @@ CrossCor(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
197190
stride = 1, pad = 0, dilation = 1) where {T,N} =
198191
CrossCor(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...)
199192

200-
CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
201-
stride = 1, pad = 0, dilation = 1) where N =
193+
CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
194+
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
202195
CrossCor(param(init(k..., ch...)), param(zeros(ch[2])), σ,
203196
stride = stride, pad = pad, dilation = dilation)
204197

@@ -218,6 +211,12 @@ function Base.show(io::IO, l::CrossCor)
218211
print(io, ")")
219212
end
220213

214+
(a::CrossCor{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
215+
invoke(a, Tuple{AbstractArray}, x)
216+
217+
(a::CrossCor{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
218+
a(T.(x))
219+
221220
"""
222221
MaxPool(k)
223222
@@ -260,48 +259,4 @@ MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N =
260259

261260
function Base.show(io::IO, m::MeanPool)
262261
print(io, "MeanPool(", m.k, ", pad = ", m.pad, ", stride = ", m.stride, ")")
263-
end
264-
<<<<<<< HEAD
265-
=======
266-
267-
"""
268-
CrossCor(size, in=>out)
269-
CrossCor(size, in=>out, relu)
270-
Standard cross convolutional layer. `size` should be a tuple like `(2, 2)`.
271-
`in` and `out` specify the number of input and output channels respectively.
272-
Data should be stored in WHCN order. In other words, a 100×100 RGB image would
273-
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
274-
Takes the keyword arguments `pad`, `stride` and `dilation`.
275-
"""
276-
struct CrossCor{N,F,A,V}
277-
σ::F
278-
weight::A
279-
bias::V
280-
stride::NTuple{N,Int}
281-
pad::NTuple{N,Int}
282-
dilation::NTuple{N,Int}
283-
end
284-
CrossCor(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
285-
stride = 1, pad = 0, dilation = 1) where {T,N} =
286-
CrossCor(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...)
287-
288-
CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
289-
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
290-
CrossCor(param(init(k..., ch...)), param(zeros(ch[2])), σ,
291-
stride = stride, pad = pad, dilation = dilation)
292-
293-
@treelike CrossCor
294-
295-
function (c::CrossCor)(x)
296-
# TODO: breaks gpu broadcast :(
297-
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
298-
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
299-
σ.(crosscor(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b)
300-
end
301-
function Base.show(io::IO, l::CrossCor)
302-
print(io, "CrossCor(", size(l.weight)[1:ndims(l.weight)-2])
303-
print(io, ", ", size(l.weight, ndims(l.weight)-1), "=>", size(l.weight, ndims(l.weight)))
304-
l.σ == identity || print(io, ", ", l.σ)
305-
print(io, ")")
306-
end
307-
>>>>>>> some final changes
262+
end

0 commit comments

Comments
 (0)