Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

show(::Chain) #1467

Merged
merged 25 commits into from
Jul 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/src/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ To change the default on an applicable layer, pass the desired function with the

```jldoctest; setup = :(using Flux)
julia> conv = Conv((3, 3), 1 => 8, relu; init=Flux.glorot_normal)
Conv((3, 3), 1=>8, relu)
Conv((3, 3), 1 => 8, relu) # 80 parameters
```

```@docs
Expand Down
1 change: 1 addition & 0 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ include("layers/conv.jl")
include("layers/recurrent.jl")
include("layers/normalise.jl")
include("layers/upsample.jl")
include("layers/show.jl")

include("outputsize.jl")

Expand Down
3 changes: 1 addition & 2 deletions src/functor.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import Adapt: adapt, adapt_storage
using LinearAlgebra: Cholesky
using Zygote: IdSet
import Functors: @functor, functor, fmap
import Functors
import Functors: Functors, @functor, functor, fmap, isleaf

trainable(m) = functor(m)[1]

Expand Down
14 changes: 10 additions & 4 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ The weight matrix and/or the bias vector (of length `out`) may also be provided
# Examples
```jldoctest
julia> d = Dense(5, 2)
Dense(5, 2)
Dense(5, 2) # 12 parameters

julia> d(rand(Float32, 5, 64)) |> size
(2, 64)
Expand All @@ -98,7 +98,7 @@ julia> d(rand(Float32, 5, 1, 1, 64)) |> size # treated as three batch dimension
(2, 1, 1, 64)

julia> d1 = Dense(ones(2, 5), false, tanh) # using provided weight matrix
Dense(5, 2, tanh; bias=false)
Dense(5, 2, tanh; bias=false) # 10 parameters

julia> d1(ones(5))
2-element Vector{Float64}:
Expand Down Expand Up @@ -395,7 +395,11 @@ julia> size(model(rand(3)))
(17,)

julia> model = Parallel(+, Dense(10, 2), Dense(5, 2))
Parallel(+, Dense(10, 2), Dense(5, 2))
Parallel(
+,
Dense(10, 2), # 22 parameters
Dense(5, 2), # 12 parameters
) # Total: 4 arrays, 34 parameters, 392 bytes.

julia> size(model(rand(10), rand(5)))
(2,)
Expand All @@ -417,8 +421,10 @@ Parallel(connection, layers...) = Parallel(connection, layers)
Base.getindex(m::Parallel, i::Integer) = m.layers[i]
Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i]...)

trainable(m::Parallel) = (m.connection, m.layers...)

function Base.show(io::IO, m::Parallel)
print(io, "Parallel(", m.connection, ", ")
join(io, m.layers, ", ")
print(io, ")")
end
end
71 changes: 41 additions & 30 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ See also [`ConvTranspose`](@ref), [`DepthwiseConv`](@ref), [`CrossCor`](@ref).
julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of images

julia> lay = Conv((5,5), 3 => 7, relu; bias=false)
Conv((5, 5), 3=>7, relu)
Conv((5, 5), 3 => 7, relu, bias=false) # 525 parameters

julia> lay(xs) |> size
(96, 96, 7, 50)
Expand Down Expand Up @@ -98,7 +98,7 @@ end
Conv(weight::AbstractArray, [bias, activation; stride, pad, dilation])

Constructs a convolutional layer with the given weight and bias.
Accepts the same keywords (and has the same defaults) as the `Conv((4,4), 3=>7, relu)`
Accepts the same keywords (and has the same defaults) as the `Conv((4,4), 3 => 7, relu)`
method.

# Examples
Expand All @@ -108,7 +108,7 @@ julia> weight = rand(3, 4, 5);
julia> bias = zeros(5);

julia> c1 = Conv(weight, bias, sigmoid) # expects 1 spatial dimension
Conv((3,), 4=>5, σ)
Conv((3,), 4 => 5, σ) # 65 parameters

julia> c1(randn(100, 4, 64)) |> size
(98, 5, 64)
Expand All @@ -134,7 +134,7 @@ function Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity
end

"""
convfilter(filter::Tuple, in=>out)
convfilter(filter::Tuple, in => out)

Constructs a standard convolutional weight matrix with given `filter` and
channels from `in` to `out`.
Expand All @@ -159,11 +159,18 @@ end

function Base.show(io::IO, l::Conv)
print(io, "Conv(", size(l.weight)[1:ndims(l.weight)-2])
print(io, ", ", size(l.weight, ndims(l.weight)-1), "=>", size(l.weight, ndims(l.weight)))
l.σ == identity || print(io, ", ", l.σ)
print(io, ", ", size(l.weight, ndims(l.weight)-1), " => ", size(l.weight, ndims(l.weight)))
_print_conv_opt(io, l)
print(io, ")")
end

function _print_conv_opt(io::IO, l)
l.σ == identity || print(io, ", ", l.σ)
all(==(0), l.pad) || print(io, ", pad=", _maybetuple_string(l.pad))
all(==(1), l.stride) || print(io, ", stride=", _maybetuple_string(l.stride))
all(==(1), l.dilation) || print(io, ", dilation=", _maybetuple_string(l.dilation))
l.bias == Zeros() && print(io, ", bias=false")
end

"""
ConvTranspose(filter, in => out, σ=identity; stride=1, pad=0, dilation=1, [bias, init])
Expand All @@ -184,15 +191,15 @@ See also [`Conv`](@ref) for more detailed description of keywords.
julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of 50 RGB images

julia> lay = ConvTranspose((5,5), 3 => 7, relu)
ConvTranspose((5, 5), 3=>7, relu)
ConvTranspose((5, 5), 3 => 7, relu) # 532 parameters

julia> lay(xs) |> size
(104, 104, 7, 50)

julia> ConvTranspose((5,5), 3=>7, stride=2)(xs) |> size
julia> ConvTranspose((5,5), 3 => 7, stride=2)(xs) |> size
(203, 203, 7, 50)

julia> ConvTranspose((5,5), 3=>7, stride=3, pad=SamePad())(xs) |> size
julia> ConvTranspose((5,5), 3 => 7, stride=3, pad=SamePad())(xs) |> size
(300, 300, 7, 50)
```
"""
Expand All @@ -209,7 +216,7 @@ end
ConvTranspose(weight::AbstractArray, [bias, activation; stride, pad, dilation])

Constructs a layer with the given weight and bias arrays.
Accepts the same keywords as the `ConvTranspose((4,4), 3=>7, relu)` method.
Accepts the same keywords as the `ConvTranspose((4,4), 3 => 7, relu)` method.
"""
function ConvTranspose(w::AbstractArray{T,N}, bias = true, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
Expand Down Expand Up @@ -255,8 +262,8 @@ end

function Base.show(io::IO, l::ConvTranspose)
print(io, "ConvTranspose(", size(l.weight)[1:ndims(l.weight)-2])
print(io, ", ", size(l.weight, ndims(l.weight)), "=>", size(l.weight, ndims(l.weight)-1))
l.σ == identity || print(io, ", ", l.σ)
print(io, ", ", size(l.weight, ndims(l.weight)), " => ", size(l.weight, ndims(l.weight)-1))
_print_conv_opt(io, l)
print(io, ")")
end

Expand All @@ -266,7 +273,7 @@ function calc_padding(::Type{ConvTranspose}, pad::SamePad, k::NTuple{N,T}, dilat
end

"""
DepthwiseConv(filter, in=>out, σ=identity; stride=1, pad=0, dilation=1, [bias, init])
DepthwiseConv(filter, in => out, σ=identity; stride=1, pad=0, dilation=1, [bias, init])

Depthwise convolutional layer. `filter` is a tuple of integers
specifying the size of the convolutional kernel, while
Expand All @@ -284,7 +291,7 @@ See also [`Conv`](@ref) for more detailed description of keywords.
julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of 50 RGB images

julia> lay = DepthwiseConv((5,5), 3 => 6, relu; bias=false)
DepthwiseConv((5, 5), 3=>6, relu)
DepthwiseConv((5, 5), 3 => 6, relu, bias=false) # 150 parameters

julia> lay(xs) |> size
(96, 96, 6, 50)
Expand All @@ -306,7 +313,7 @@ end
DepthwiseConv(weight::AbstractArray, bias, [activation; stride, pad, dilation])

Constructs a layer with the given weight and bias arrays.
Accepts the same keywords as the `DepthwiseConv((4,4), 3=>6, relu)` method.
Accepts the same keywords as the `DepthwiseConv((4,4), 3 => 6, relu)` method.
"""
function DepthwiseConv(w::AbstractArray{T,N}, bias = true, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
Expand All @@ -327,7 +334,7 @@ end
@functor DepthwiseConv

"""
depthwiseconvfilter(filter::Tuple, in=>out)
depthwiseconvfilter(filter::Tuple, in => out)

Constructs a depthwise convolutional weight array defined by `filter` and channels
from `in` to `out`.
Expand All @@ -348,8 +355,8 @@ end

function Base.show(io::IO, l::DepthwiseConv)
print(io, "DepthwiseConv(", size(l.weight)[1:end-2])
print(io, ", ", size(l.weight)[end], "=>", prod(size(l.weight)[end-1:end]))
l.σ == identity || print(io, ", ", l.σ)
print(io, ", ", size(l.weight)[end], " => ", prod(size(l.weight)[end-1:end]))
_print_conv_opt(io, l)
print(io, ")")
end

Expand All @@ -372,12 +379,12 @@ See also [`Conv`](@ref) for more detailed description of keywords.
julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of 50 RGB images

julia> lay = CrossCor((5,5), 3 => 6, relu; bias=false)
CrossCor((5, 5), 3=>6, relu)
CrossCor((5, 5), 3 => 6, relu, bias=false) # 450 parameters

julia> lay(xs) |> size
(96, 96, 6, 50)

julia> CrossCor((5,5), 3=>7, stride=3, pad=(2,0))(xs) |> size
julia> CrossCor((5,5), 3 => 7, stride=3, pad=(2,0))(xs) |> size
(34, 32, 7, 50)
```
"""
Expand All @@ -394,7 +401,7 @@ end
CrossCor(weight::AbstractArray, [bias, activation; stride, pad, dilation])

Constructs a layer with the given weight and bias arrays.
Accepts the same keywords as the `CrossCor((4,4), 3=>7, relu)` method.
Accepts the same keywords as the `CrossCor((4,4), 3 => 7, relu)` method.
"""
function CrossCor(w::AbstractArray{T,N}, bias = true, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
Expand Down Expand Up @@ -429,8 +436,8 @@ end

function Base.show(io::IO, l::CrossCor)
print(io, "CrossCor(", size(l.weight)[1:ndims(l.weight)-2])
print(io, ", ", size(l.weight, ndims(l.weight)-1), "=>", size(l.weight, ndims(l.weight)))
l.σ == identity || print(io, ", ", l.σ)
print(io, ", ", size(l.weight, ndims(l.weight)-1), " => ", size(l.weight, ndims(l.weight)))
_print_conv_opt(io, l)
print(io, ")")
end

Expand Down Expand Up @@ -529,8 +536,7 @@ See also [`MaxPool`](@ref), [`GlobalMeanPool`](@ref).
```jldoctest
julia> xs = rand(Float32, 100, 100, 3, 50);

julia> m = Chain(Conv((3,3), 3=>7), GlobalMaxPool())
Chain(Conv((3, 3), 3=>7), GlobalMaxPool())
julia> m = Chain(Conv((3,3), 3 => 7), GlobalMaxPool());

julia> m(xs) |> size
(1, 1, 7, 50)
Expand Down Expand Up @@ -567,8 +573,7 @@ by performing mean pooling on the complete (w,h)-shaped feature maps.
```jldoctest
julia> xs = rand(Float32, 100, 100, 3, 50);

julia> m = Chain(Conv((3,3), 3=>7), GlobalMeanPool())
Chain(Conv((3, 3), 3=>7), GlobalMeanPool())
julia> m = Chain(Conv((3,3), 3 => 7), GlobalMeanPool());

julia> m(xs) |> size
(1, 1, 7, 50)
Expand Down Expand Up @@ -611,8 +616,11 @@ See also [`Conv`](@ref), [`MeanPool`](@ref), [`AdaptiveMaxPool`](@ref), [`Global
```jldoctest
julia> xs = rand(Float32, 100, 100, 3, 50); # batch of 50 RGB images

julia> m = Chain(Conv((5, 5), 3=>7, pad=SamePad()), MaxPool((5, 5), pad=SamePad()))
Chain(Conv((5, 5), 3=>7), MaxPool((5, 5), pad=2))
julia> m = Chain(Conv((5, 5), 3 => 7, pad=SamePad()), MaxPool((5, 5), pad=SamePad()))
Chain(
Conv((5, 5), 3 => 7, pad=2), # 532 parameters
MaxPool((5, 5), pad=2),
)

julia> m[1](xs) |> size
(100, 100, 7, 50)
Expand Down Expand Up @@ -674,7 +682,10 @@ See also [`Conv`](@ref), [`MaxPool`](@ref), [`AdaptiveMeanPool`](@ref).
julia> xs = rand(Float32, 100, 100, 3, 50);

julia> m = Chain(Conv((5,5), 3 => 7), MeanPool((5,5), pad=SamePad()))
Chain(Conv((5, 5), 3=>7), MeanPool((5, 5), pad=2))
Chain(
Conv((5, 5), 3 => 7), # 532 parameters
MeanPool((5, 5), pad=2),
)

julia> m[1](xs) |> size
(96, 96, 7, 50)
Expand Down
5 changes: 3 additions & 2 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ testmode!(m::BatchNorm, mode=true) =

function Base.show(io::IO, l::BatchNorm)
print(io, "BatchNorm($(l.chs)")
l.λ == identity || print(io, ", $(l.λ)")
(l.λ == identity) || print(io, ", $(l.λ)")
hasaffine(l) || print(io, ", affine=false")
print(io, ")")
end
Expand Down Expand Up @@ -443,8 +443,9 @@ testmode!(m::GroupNorm, mode = true) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)

function Base.show(io::IO, l::GroupNorm)
# print(io, "GroupNorm($(join(size(l.β), ", "))", ", ", l.G)
print(io, "GroupNorm($(l.chs), $(l.G)")
l.λ == identity || print(io, ", $(l.λ)")
l.λ == identity || print(io, ", ", l.λ)
hasaffine(l) || print(io, ", affine=false")
print(io, ")")
end
Expand Down
Loading