Skip to content

Commit

Permalink
Merge pull request #2365 from FluxML/bc/convtrans-dims-typestable
Browse files Browse the repository at this point in the history
Restore type stability of `conv_transpose_dims`
  • Loading branch information
ToucheSir authored Dec 31, 2023
2 parents df468ba + 2fe393e commit 1af3f4d
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,9 @@ end

function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
# Calculate size of "input", from ∇conv_data()'s perspective...
combined_pad = (c.pad[1:2:end] .+ c.pad[2:2:end])
I = (size(x)[1:end-2] .- 1).*c.stride .+ 1 .+ (size(c.weight)[1:end-2] .- 1).*c.dilation .- combined_pad
calc_dim(xsz, wsz, stride, dilation, pad) = (xsz - 1) * stride + 1 + (wsz - 1) * dilation - pad
combined_pad = ntuple(i -> c.pad[2i-1] + c.pad[2i], length(c.pad) ÷ 2)
I = map(calc_dim, size(x)[1:end-2], size(c.weight)[1:end-2], c.stride, c.dilation, combined_pad)
C_in = size(c.weight)[end-1] * c.groups
batch_size = size(x)[end]
# Create DenseConvDims() that looks like the corresponding conv()
Expand Down

0 comments on commit 1af3f4d

Please sign in to comment.