Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify @compact printing #20

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft

Simplify @compact printing #20

wants to merge 3 commits into from

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Aug 29, 2023

This simplifies @compact so that it no longer stores a string for each keyword argument. Like real keywords, they run once, and the resulting NamedTuple is all that remains.

Instead, it proposes to print something like summary(w) for any array. Maybe this needs to be trimmed a bit, as for CuArray etc it can get quite long. Could trim it all the way to 32×26 AbstractMatrix{Float32}. This change should probably be made in Flux anyway (along with other fixes to how show handles arrays) so is just prototyped piratically here for now. [Edit: now https://github.com/FluxML/Flux.jl/pull/2344]

julia> using Flux; using Fluxperimental: @compact

julia> let n = 26, d = 32  # previously the constructor string was printed, although it cannot run again
        Chain(
          @compact(w=randn32(d, n)./=sqrt(n), pow=1+1) do (x, y)
            @views (w[:,x] .+ w[:,y]).^pow
          end,
          Dense(d => 3),
         )
       end
Chain(
  @compact(
    w = 32×26 Matrix{Float32},          # 832 parameters
    pow = 2,
  ) do (x, y) 
      #= REPL[54]:4 =# @views (w[:, x] .+ w[:, y]) .^ pow
  end,
  Dense(32 => 3),                       # 99 parameters
)                   # Total: 3 arrays, 931 parameters, 4.011 KiB.

julia> ans((1,2))
3-element Vector{Float32}:
  0.30514276
 -0.24690717
 -0.26537845

julia> let n = 3, m = 4  # Scale previously triggered the "non-layer" printing path. Note captured m still printed.
         @compact(s = Flux.Scale(max(2n, m))) do x
           s(x ./ m)
         end
       end
@compact(
  s = Scale(6),                         # 12 parameters
) do x 
    s(x ./ m)
end

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant