diff --git a/src/layers/show.jl b/src/layers/show.jl index 0ae14dd9ee..36980527e5 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -14,13 +14,13 @@ for T in [ end function _big_show(io::IO, obj, indent::Int=0, name=nothing) - pre, post = obj isa Chain{<:AbstractVector} ? ("([", "])") : ("(", ")") + pre, post = _show_pre_post(obj) children = _show_children(obj) if all(_show_leaflike, children) _layer_show(io, obj, indent, name) else - println(io, " "^indent, isnothing(name) ? "" : "$name = ", nameof(typeof(obj)), pre) - if obj isa Chain{<:NamedTuple} && children == getfield(obj, :layers) + println(io, " "^indent, isnothing(name) ? "" : "$name = ", pre) + if obj isa Chain{<:NamedTuple} || obj isa NamedTuple # then we insert names -- can this be done more generically? for k in Base.keys(obj) _big_show(io, obj[k], indent+2, k) @@ -44,6 +44,11 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing) end end +_show_pre_post(obj) = string(nameof(typeof(obj)), "("), ")" +_show_pre_post(::Chain{<:AbstractVector}) = "Chain([", "])" +_show_pre_post(::AbstractVector) = "[", "]" +_show_pre_post(::NamedTuple) = "(;", ")" + _show_leaflike(x) = isleaf(x) # mostly follow Functors, except for: # note the covariance of tuple, using <:T causes warning or error @@ -73,7 +78,7 @@ end function _layer_show(io::IO, layer, indent::Int=0, name=nothing) _str = isnothing(name) ? "" : "$name = " - str = _str * sprint(show, layer, context=io) + str = _str * _layer_string(io, layer) print(io, " "^indent, str, indent==0 ? "" : ",") if !isempty(params(layer)) print(io, " "^max(2, (indent==0 ? 20 : 39) - indent - length(str))) @@ -88,6 +93,12 @@ color=:light_black) indent==0 || println(io) end +_layer_string(io::IO, layer) = sprint(show, layer, context=io) +# _layer_string(::IO, a::AbstractArray) = summary(layer) # sometimes too long e.g. CuArray +# _layer_string(::IO, a::AbstractArray) = Base.dims2string(size(a)) * " " * String(typeof(a).name.name) +# _layer_string(::IO, a::Array{T}) where T = Base.dims2string(size(a)) * " Array{$T}" +# _layer_string(::IO, a::AbstractArray{T}) where T = Base.dims2string(size(a)) * " AbstractArray{$T}" + function _big_finale(io::IO, m) ps = params(m) if length(ps) > 2 @@ -133,3 +144,43 @@ _any(f, x::Number) = f(x) # _any(f, x) = false _all(f, xs) = !_any(!f, xs) + +#= + +julia> struct Tmp2; x; y; end; Flux.@functor Tmp2 + +# Before, notice Array(), NamedTuple(), and values + +julia> Chain(Tmp2([Dense(2,3), randn(3,4)'], (x=1:3, y=Dense(3,4), z=rand(3)))) +Chain( + Tmp2( + Array( + Dense(2 => 3), # 9 parameters + [0.351978391016603 0.6408681372462821 -1.326533184688648; 0.09481930831795712 1.430103476272605 0.7250467613675332; 2.03372151428719 -0.015879812799495713 1.9499692162118236; -1.6346846180722918 -0.8364610153059454 -1.2907265737483433], # 12 parameters + ), + NamedTuple( + 1:3, # 3 parameters + Dense(3 => 4), # 16 parameters + [0.9666158193429335, 0.01613900990539574, 0.0205920186127464], # 3 parameters + ), + ), +) # Total: 7 arrays, 43 parameters, 644 bytes. + +# After, (; x=, y=, z=) and "3-element Array" + +julia> Chain(Tmp2([Dense(2,3), randn(3,4)'], (x=1:3, y=Dense(3,4), z=rand(3)))) +Chain( + Tmp2( + [ + Dense(2 => 3), # 9 parameters + 4×3 Adjoint, # 12 parameters + ], + (; + x = 3-element UnitRange, # 3 parameters + y = Dense(3 => 4), # 16 parameters + z = 3-element Array, # 3 parameters + ), + ), +) # Total: 7 arrays, 43 parameters, 644 bytes. + +=#