From cab5f2621268138fb0e64603a948d6b5050d3d7c Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 1 Apr 2022 08:49:23 +0200 Subject: [PATCH] assertion num channels compatible with groups --- src/layers/conv.jl | 4 ++++ test/layers/conv.jl | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index f98f3a0d50..142a129f11 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -128,6 +128,8 @@ julia> Flux.params(c1) |> length """ function Conv(w::AbstractArray{T,N}, b = true, σ = identity; stride = 1, pad = 0, dilation = 1, groups = 1) where {T,N} + + @assert size(w, N) % groups == 0 "Output channel dimension must be divisible by groups." stride = expand(Val(N-2), stride) dilation = expand(Val(N-2), dilation) pad = calc_padding(Conv, pad, size(w)[1:N-2], dilation, stride) @@ -155,6 +157,8 @@ distribution. function convfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}; init = glorot_uniform, groups = 1) where N cin, cout = ch + @assert cin % groups == 0 "Input channel dimension must be divisible by groups." + @assert cout % groups == 0 "Output channel dimension must be divisible by groups." init(filter..., cin÷groups, cout) end diff --git a/test/layers/conv.jl b/test/layers/conv.jl index eb7d13be1c..019f3fd603 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -81,6 +81,10 @@ end c = Conv((3,4,5), 100 => 25, groups = 5) @test size(c.weight) == (3,4,5, 20, 25) @test size(c(ip)) == (8,8,8, 25, 2) + + # Test that we cannot ask for non-integer multiplication factors + @test_throws AssertionError Conv((2, 2), 3=>10, groups=2) + @test_throws AssertionError Conv((2, 2), 2=>9, groups=2) end end