Skip to content

Commit 35246fa

Browse files
committed
some final changes
1 parent b64bd5a commit 35246fa

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

src/layers/conv.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using NNlib: conv, depthwiseconv
1+
using NNlib: conv, depthwiseconv, crosscor
22

33
@generated sub2(::Val{N}) where N = :(Val($(N-2)))
44

@@ -156,11 +156,14 @@ end
156156
CrossCor(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
157157
stride = 1, pad = 0, dilation = 1) where {T,N} =
158158
CrossCor(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...)
159-
CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
160-
stride = 1, pad = 0, dilation = 1) where N =
159+
160+
CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
161+
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
161162
CrossCor(param(init(k..., ch...)), param(zeros(ch[2])), σ,
162-
stride = stride, pad = pad, dilation = dilation)
163+
stride = stride, pad = pad, dilation = dilation)
164+
163165
@treelike CrossCor
166+
164167
function (c::CrossCor)(x)
165168
# TODO: breaks gpu broadcast :(
166169
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))

test/tracker.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,13 +186,10 @@ end
186186
@test gradtest(conv, rand(10, 10, 3, 2), randn(Float64,2, 2, 3, 2))
187187
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 3, 2))
188188

189-
<<<<<<< HEAD
190189
@test gradtest(crosscor, rand(10, 3, 2), randn(Float64,2, 3, 2))
191190
@test gradtest(crosscor, rand(10, 10, 3, 2), randn(Float64,2, 2, 3, 2))
192191
@test gradtest(crosscor, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 3, 2))
193-
=======
194192
@test gradtest(depthwiseconv, rand(10,10,3,2), randn(2, 2, 2, 3))
195-
>>>>>>> 30486f9c0394304649bbc6121bd391ef066966c3
196193

197194
@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2))
198195
@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2))

0 commit comments

Comments
 (0)