Skip to content

Commit

Permalink
added stride for pooling in tracker
Browse files Browse the repository at this point in the history
  • Loading branch information
tejank10 authored and MikeInnes committed Apr 15, 2018
1 parent 0ba5ce4 commit 5cc6813
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions src/tracker/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,21 +261,21 @@ function back(::typeof(_conv), Δ, x, w, stride, pad)
@back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad))
end

_maxpool(x, k, pad) = maxpool(x, k; pad = pad)
_maxpool(x, k, pad, stride) = maxpool(x, k; pad = pad, stride = stride)

maxpool(x::TrackedArray, k; pad = map(_->0,k)) =
track(_maxpool, x, k, pad)
maxpool(x::TrackedArray, k; pad = map(_->0,k), stride = k) =
track(_maxpool, x, k, pad, stride)

back_(::typeof(_maxpool), y, Δ, x, k, pad) =
back(x, NNlib.∇maxpool(Δ, y, data(x), k, pad=pad))
back_(::typeof(_maxpool), y, Δ, x, k, pad, stride) =
back(x, NNlib.∇maxpool(Δ, y, data(x), k, pad=pad, stride=stride))

_meanpool(x, k, pad) = meanpool(x, k; pad = pad)
_meanpool(x, k, pad, stride) = meanpool(x, k; pad = pad, stride = stride)

meanpool(x::TrackedArray, k; pad = map(_->0,k)) =
track(_meanpool, x, k, pad)
meanpool(x::TrackedArray, k; pad = map(_->0,k), stride = k) =
track(_meanpool, x, k, pad, stride)

back_(::typeof(_meanpool), y, Δ, x, k, pad) =
back(x, NNlib.∇meanpool(Δ, y, data(x), k, pad=pad))
back_(::typeof(_meanpool), y, Δ, x, k, pad, stride) =
back(x, NNlib.∇meanpool(Δ, y, data(x), k, pad=pad, stride=stride))

# Broadcasting

Expand Down

0 comments on commit 5cc6813

Please sign in to comment.