From 93e1de7bfb9aa45b5a08d98ea82497d95d3be2fd Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 25 Oct 2024 22:22:36 -0400 Subject: [PATCH] Some small printing upgrades (#2344) * some printing upgrades * print eltype too * move one line to solve order-of-loading issue * better fix * tests, and Fix1 --- Project.toml | 2 +- src/layers/show.jl | 73 +++++++++++++++++++++++++++++++++++++++++---- test/layers/show.jl | 8 ++++- 3 files changed, 76 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 18dafa1234..c347b9def2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Flux" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.14.22" +version = "0.14.23" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/layers/show.jl b/src/layers/show.jl index a03ddf3754..29f1aaec19 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -20,16 +20,16 @@ function _macro_big_show(ex) end function _big_show(io::IO, obj, indent::Int=0, name=nothing) - pre, post = obj isa Chain{<:AbstractVector} ? ("([", "])") : ("(", ")") + pre, post = _show_pre_post(obj) children = _show_children(obj) if all(_show_leaflike, children) # This check may not be useful anymore: it tries to infer when to stop the recursion by looking for grandkids, # but once all layers use @layer, they stop the recursion by defining a method for _big_show. _layer_show(io, obj, indent, name) else - println(io, " "^indent, isnothing(name) ? "" : "$name = ", nameof(typeof(obj)), pre) - if obj isa Chain{<:NamedTuple} && children == getfield(obj, :layers) - # then we insert names -- can this be done more generically? + println(io, " "^indent, isnothing(name) ? "" : "$name = ", pre) + if obj isa Chain{<:NamedTuple} || obj isa NamedTuple + # then we insert names -- can this be done more generically? for k in Base.keys(obj) _big_show(io, obj[k], indent+2, k) end @@ -52,6 +52,20 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing) end end +for Fix in (:Fix1, :Fix2) + pre = string(Fix, "(") + @eval function _big_show(io::IO, obj::Base.$Fix, indent::Int=0, name=nothing) + println(io, " "^indent, isnothing(name) ? "" : "$name = ", $pre) + _big_show(io, obj.f, indent+2) + _big_show(io, obj.x, indent+2) + println(io, " "^indent, ")", ",") + end +end + +_show_pre_post(obj) = string(nameof(typeof(obj)), "("), ")" +_show_pre_post(::AbstractVector) = "[", "]" +_show_pre_post(::NamedTuple) = "(;", ")" + _show_leaflike(x) = isleaf(x) # mostly follow Functors, except for: # note the covariance of tuple, using <:T causes warning or error @@ -88,7 +102,7 @@ end function _layer_show(io::IO, layer, indent::Int=0, name=nothing) _str = isnothing(name) ? "" : "$name = " - str = _str * sprint(show, layer, context=io) + str = _str * _layer_string(io, layer) print(io, " "^indent, str, indent==0 ? "" : ",") if !isempty(params(layer)) print(io, " "^max(2, (indent==0 ? 20 : 39) - indent - length(str))) @@ -103,6 +117,15 @@ color=:light_black) indent==0 || println(io) end +_layer_string(io::IO, layer) = sprint(show, layer, context=io) +# _layer_string(::IO, a::AbstractArray) = summary(layer) # sometimes too long e.g. CuArray +function _layer_string(::IO, a::AbstractArray) + full = string(typeof(a)) + comma = findfirst(',', full) + short = isnothing(comma) ? full : full[1:comma] * "...}" + Base.dims2string(size(a)) * " " * short +end + function _big_finale(io::IO, m) ps = params(m) if length(ps) > 2 @@ -150,3 +173,43 @@ _any(f, x::Number) = f(x) # _any(f, x) = false _all(f, xs) = !_any(!f, xs) + +#= + +julia> struct Tmp2; x; y; end; Flux.@functor Tmp2 + +# Before, notice Array(), NamedTuple(), and values + +julia> Chain(Tmp2([Dense(2,3), randn(3,4)'], (x=1:3, y=Dense(3,4), z=rand(3)))) +Chain( + Tmp2( + Array( + Dense(2 => 3), # 9 parameters + [0.351978391016603 0.6408681372462821 -1.326533184688648; 0.09481930831795712 1.430103476272605 0.7250467613675332; 2.03372151428719 -0.015879812799495713 1.9499692162118236; -1.6346846180722918 -0.8364610153059454 -1.2907265737483433], # 12 parameters + ), + NamedTuple( + 1:3, # 3 parameters + Dense(3 => 4), # 16 parameters + [0.9666158193429335, 0.01613900990539574, 0.0205920186127464], # 3 parameters + ), + ), +) # Total: 7 arrays, 43 parameters, 644 bytes. + +# After, (; x=, y=, z=) and "3-element Array" + +julia> Chain(Tmp2([Dense(2,3), randn(3,4)'], (x=1:3, y=Dense(3,4), z=rand(3)))) +Chain( + Tmp2( + [ + Dense(2 => 3), # 9 parameters + 4×3 Adjoint, # 12 parameters + ], + (; + x = 3-element UnitRange, # 3 parameters + y = Dense(3 => 4), # 16 parameters + z = 3-element Array, # 3 parameters + ), + ), +) # Total: 7 arrays, 43 parameters, 644 bytes. + +=# diff --git a/test/layers/show.jl b/test/layers/show.jl index 95ddca0571..375303c7ec 100644 --- a/test/layers/show.jl +++ b/test/layers/show.jl @@ -71,7 +71,13 @@ end # Functors@0.3 marks transposed matrices non-leaf, shouldn't affect printing: adjoint_chain = repr("text/plain", Chain([Dense([1 2; 3 4]')])) @test occursin("Dense(2 => 2)", adjoint_chain) - @test occursin("Chain([", adjoint_chain) + @test occursin("Chain(", adjoint_chain) + @test occursin("[", adjoint_chain) + + # New printing of arrays, and Fix1 + fix_chain = repr("text/plain", Chain(Base.Fix1(*, rand32(22,33)), softmax)) + @test occursin("Fix1(", fix_chain) + @test occursin("22×33 Matrix{Float32}", fix_chain) end # Bug when no children, https://github.com/FluxML/Flux.jl/issues/2208