Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding summarise_array to show.jl #2593

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

murrellb
Copy link

@murrellb murrellb commented Mar 1, 2025

I'm not sure if you'll want this, and maybe there is a better way to do what this does, but:

This PR adds two lines that, if I haven't missed anything, should make no difference to the current behavior of Base.show. But it allows the user some customization of the model/layer display.

The particular problem this solves for me is "being able to easily inspect summary stats of the model weights at the individual tensor level". With the change in this PR, I can define eg. this:

Flux.summarise_array(a) = "; L=$(length(a)):σ=$(round(std(a), digits = 3))"

and then when I display my model in the REPL, I can see the details I wanted for each tensor in each layer, in the context of the model structure:
image

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable

@mcabbott
Copy link
Member

mcabbott commented Mar 3, 2025

Something like this might be useful, thanks for digging into the somewhat messy show code...

At present _nan_show(io, trainables(layer)) (a few lines down) aims to warn you about Inf/NaN, and all-zero. But maybe just always printing something about the values would be better? σ = NaN would convey almost the same information. Is std the best one number, or would say norm(W) be better?

I'm a little reluctant to add more functions we encourage people to overload.

I wonder if it should only do the first (say) 5 parameter arrays, so that you will never get 10 lines of noise.

@murrellb
Copy link
Author

murrellb commented Mar 3, 2025

If I were to pick one number it would be std, but std is NaN when taken over a single value, and if you aren't aware of this you might go looking for bugs that aren't there. So maybe the "population" std (which divides by N instead of N-1)? norm is also a defensible choice.

And yes, I think standardly reporting something like this is good. But if you do it, please do it in a way that is easy to overload, (even if we don't encourage people to do so)? Sometimes I care about the std, but sometimes also the mean, the extrema, the number of params < 0, etc.

One option would be to support this with a more explicit call where the user passes in the array function directly, like summarise(model, array_printing_function = f) and have show call that with the default? I'm not sure I'd be able to figure that out though...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants