Skip to content

Commit

Permalink
some printing upgrades
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Oct 13, 2023
1 parent 674ed1a commit 9d4acba
Showing 1 changed file with 55 additions and 4 deletions.
59 changes: 55 additions & 4 deletions src/layers/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ for T in [
end

function _big_show(io::IO, obj, indent::Int=0, name=nothing)
pre, post = obj isa Chain{<:AbstractVector} ? ("([", "])") : ("(", ")")
pre, post = _show_pre_post(obj)
children = _show_children(obj)
if all(_show_leaflike, children)
_layer_show(io, obj, indent, name)
else
println(io, " "^indent, isnothing(name) ? "" : "$name = ", nameof(typeof(obj)), pre)
if obj isa Chain{<:NamedTuple} && children == getfield(obj, :layers)
println(io, " "^indent, isnothing(name) ? "" : "$name = ", pre)
if obj isa Chain{<:NamedTuple} || obj isa NamedTuple
# then we insert names -- can this be done more generically?
for k in Base.keys(obj)
_big_show(io, obj[k], indent+2, k)
Expand All @@ -44,6 +44,11 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing)
end
end

_show_pre_post(obj) = string(nameof(typeof(obj)), "("), ")"
_show_pre_post(::Chain{<:AbstractVector}) = "Chain([", "])"
_show_pre_post(::AbstractVector) = "[", "]"
_show_pre_post(::NamedTuple) = "(;", ")"

_show_leaflike(x) = isleaf(x) # mostly follow Functors, except for:

# note the covariance of tuple, using <:T causes warning or error
Expand Down Expand Up @@ -73,7 +78,7 @@ end

function _layer_show(io::IO, layer, indent::Int=0, name=nothing)
_str = isnothing(name) ? "" : "$name = "
str = _str * sprint(show, layer, context=io)
str = _str * _layer_string(io, layer)
print(io, " "^indent, str, indent==0 ? "" : ",")
if !isempty(params(layer))
print(io, " "^max(2, (indent==0 ? 20 : 39) - indent - length(str)))
Expand All @@ -88,6 +93,12 @@ color=:light_black)
indent==0 || println(io)
end

_layer_string(io::IO, layer) = sprint(show, layer, context=io)
# _layer_string(::IO, a::AbstractArray) = summary(layer) # sometimes too long e.g. CuArray
_layer_string(::IO, a::AbstractArray) = Base.dims2string(size(a)) * " " * String(typeof(a).name.name)
# _layer_string(::IO, a::Array{T}) where T = Base.dims2string(size(a)) * " Array{$T}"
# _layer_string(::IO, a::AbstractArray{T}) where T = Base.dims2string(size(a)) * " AbstractArray{$T}"

function _big_finale(io::IO, m)
ps = params(m)
if length(ps) > 2
Expand Down Expand Up @@ -133,3 +144,43 @@ _any(f, x::Number) = f(x)
# _any(f, x) = false

_all(f, xs) = !_any(!f, xs)

#=
julia> struct Tmp2; x; y; end; Flux.@functor Tmp2
# Before, notice Array(), NamedTuple(), and values
julia> Chain(Tmp2([Dense(2,3), randn(3,4)'], (x=1:3, y=Dense(3,4), z=rand(3))))
Chain(
Tmp2(
Array(
Dense(2 => 3), # 9 parameters
[0.351978391016603 0.6408681372462821 -1.326533184688648; 0.09481930831795712 1.430103476272605 0.7250467613675332; 2.03372151428719 -0.015879812799495713 1.9499692162118236; -1.6346846180722918 -0.8364610153059454 -1.2907265737483433], # 12 parameters
),
NamedTuple(
1:3, # 3 parameters
Dense(3 => 4), # 16 parameters
[0.9666158193429335, 0.01613900990539574, 0.0205920186127464], # 3 parameters
),
),
) # Total: 7 arrays, 43 parameters, 644 bytes.
# After, (; x=, y=, z=) and "3-element Array"
julia> Chain(Tmp2([Dense(2,3), randn(3,4)'], (x=1:3, y=Dense(3,4), z=rand(3))))
Chain(
Tmp2(
[
Dense(2 => 3), # 9 parameters
4×3 Adjoint, # 12 parameters
],
(;
x = 3-element UnitRange, # 3 parameters
y = Dense(3 => 4), # 16 parameters
z = 3-element Array, # 3 parameters
),
),
) # Total: 7 arrays, 43 parameters, 644 bytes.
=#

0 comments on commit 9d4acba

Please sign in to comment.