Skip to content

Commit 8f655de

Browse files
committed
Use similar() followed by fill!() for GPU compatibility
1 parent 95d7782 commit 8f655de

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/pooling.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ for backend in (Symbol(), :_direct, :_im2col)
110110
@timeit_debug to function $(Symbol("$(name)$(backend)"))(
111111
x::AbstractArray{xT,N},
112112
pdims::PoolDims; kwargs...) where {xT, N}
113-
y = zeros(xT, output_size(pdims)..., channels_out(pdims), size(x, N))
113+
y = similar(x, output_size(pdims)..., channels_out(pdims), size(x, N))
114+
fill!(y, xT(0))
114115
return $(Symbol("$(name)$(backend)!"))(y, x, pdims; kwargs...)
115116
end
116117

@@ -119,7 +120,8 @@ for backend in (Symbol(), :_direct, :_im2col)
119120
dy::AbstractArray{T,N}, y::AbstractArray{T,N},
120121
x::AbstractArray{T,N}, pdims::PoolDims;
121122
kwargs...) where {T, N}
122-
dx = zeros(T, input_size(pdims)..., channels_in(pdims), size(dy, N))
123+
dx = similar(x, input_size(pdims)..., channels_in(pdims), size(dy, N))
124+
fill!(dx, T(0))
123125
return $(Symbol("$(name)$(backend)!"))(dx, dy, y, x, pdims; kwargs...)
124126
end
125127
end

0 commit comments

Comments
 (0)