diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 35261abec2..5bffa7a152 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -261,21 +261,21 @@ function back(::typeof(_conv), Δ, x, w, stride, pad) @back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad)) end -_maxpool(x, k, pad) = maxpool(x, k; pad = pad) +_maxpool(x, k, pad, stride) = maxpool(x, k; pad = pad, stride = stride) -maxpool(x::TrackedArray, k; pad = map(_->0,k)) = - track(_maxpool, x, k, pad) +maxpool(x::TrackedArray, k; pad = map(_->0,k), stride = k) = + track(_maxpool, x, k, pad, stride) -back_(::typeof(_maxpool), y, Δ, x, k, pad) = - back(x, NNlib.∇maxpool(Δ, y, data(x), k, pad=pad)) +back_(::typeof(_maxpool), y, Δ, x, k, pad, stride) = + back(x, NNlib.∇maxpool(Δ, y, data(x), k, pad=pad, stride=stride)) -_meanpool(x, k, pad) = meanpool(x, k; pad = pad) +_meanpool(x, k, pad, stride) = meanpool(x, k; pad = pad, stride = stride) -meanpool(x::TrackedArray, k; pad = map(_->0,k)) = - track(_meanpool, x, k, pad) +meanpool(x::TrackedArray, k; pad = map(_->0,k), stride = k) = + track(_meanpool, x, k, pad, stride) -back_(::typeof(_meanpool), y, Δ, x, k, pad) = - back(x, NNlib.∇meanpool(Δ, y, data(x), k, pad=pad)) +back_(::typeof(_meanpool), y, Δ, x, k, pad, stride) = + back(x, NNlib.∇meanpool(Δ, y, data(x), k, pad=pad, stride=stride)) # Broadcasting