Skip to content

Commit

Permalink
assertion num channels compatible with groups
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Apr 1, 2022
1 parent 450cb2e commit cab5f26
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ julia> Flux.params(c1) |> length
"""
function Conv(w::AbstractArray{T,N}, b = true, σ = identity;
stride = 1, pad = 0, dilation = 1, groups = 1) where {T,N}

@assert size(w, N) % groups == 0 "Output channel dimension must be divisible by groups."
stride = expand(Val(N-2), stride)
dilation = expand(Val(N-2), dilation)
pad = calc_padding(Conv, pad, size(w)[1:N-2], dilation, stride)
Expand Down Expand Up @@ -155,6 +157,8 @@ distribution.
function convfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
init = glorot_uniform, groups = 1) where N
cin, cout = ch
@assert cin % groups == 0 "Input channel dimension must be divisible by groups."
@assert cout % groups == 0 "Output channel dimension must be divisible by groups."
init(filter..., cin÷groups, cout)
end

Expand Down
4 changes: 4 additions & 0 deletions test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ end
c = Conv((3,4,5), 100 => 25, groups = 5)
@test size(c.weight) == (3,4,5, 20, 25)
@test size(c(ip)) == (8,8,8, 25, 2)

# Test that we cannot ask for non-integer multiplication factors
@test_throws AssertionError Conv((2, 2), 3=>10, groups=2)
@test_throws AssertionError Conv((2, 2), 2=>9, groups=2)
end
end

Expand Down

0 comments on commit cab5f26

Please sign in to comment.