diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 4578b75b5f..f852e781e8 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -29,7 +29,7 @@ Conv2D(k::NTuple{2,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; Flux.treelike(Conv2D) function (c::Conv2D)(x) - σ, b = c.σ, reshape(c.bias, 1, 1, :) + σ, b = c.σ, reshape(c.bias, 1, 1, :, 1) σ.(conv2d(x, c.weight, stride = c.stride, padding = c.pad) .+ b) end