Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

show(::Chain) #1467

Merged
merged 25 commits into from
Jul 10, 2021
Merged

show(::Chain) #1467

merged 25 commits into from
Jul 10, 2021

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Jan 15, 2021

Screenshot 2021-01-15 at 21 52 03

Not quite done, RFC?

@ToucheSir
Copy link
Member

👍 for information density.

Is there perhaps a more general implementation lurking in here though? For example, I would've expected the first Chain in the second SkipConnection to be expanded and all parameters therein counted. Being able to operate on all non-leaf functors (e.g. those with sublayers) instead of special-casing Union{Chain, Parallel, SkipConnection} would also help justify the additional implementation complexity.

@mcabbott
Copy link
Member Author

mcabbott commented Jan 15, 2021

My resnet here is from FluxML/model-zoo#221, and it was missing an @functor Combinator to make params(m[7].connection) not be empty. After that (and fixing bugs...) those are printed with # 74_112 parameters -- I have updated the picture.

They are still printed on one line, because it's not smart enough to recurse into unknown structs. But I agree it ought to be possible to expand everything, by the same functor magic?

@mcabbott
Copy link
Member Author

mcabbott commented Jan 15, 2021

A more recursive version:

Screenshot 2021-01-16 at 22 33 37

One catch is that, if this wants to be copy-paste-able, it assumes that the order of children of a layer is the order of arguments of its constructor. Maybe that's OK, maybe Functors even assumes the same thing? (If there are any anonymous functions in the chain, then it won't work when copy-pasted anyway.)

Variant without wasting 3 lines for ) ) ):
Screenshot 2021-01-15 at 23 41 17

@CarloLucibello
Copy link
Member

This looks very nice. Probably the show methods for Grads and Params should go to Zygote to avoid piracy.
I wonder if we should provide a general describe (or something similar, summarize?) method that works also on custom layers, and have show fallback to that for the layers we define in Flux.

@mcabbott
Copy link
Member Author

My first picture above expanded only things from inside Flux, but I now think probably it can guess well enough to just work elsewhere, although I'd welcome weird examples to test it against. You could certainly break it if you wished to but then your custom container-like layer will print badly, which isn't the end of the world, and this will only happen when printed inside Chain etc. from Flux. And could be corrected by overloading _big_show for your type.

It could all be tied to some other function instead of show, but then it will rarely get used. The only downside is that you have to scroll a bit after julia> m = resnet(50) if you forget the semicolon. (On the other hand, you currently have to scroll a lot after params(m) if you forget it.)

Maybe the other caveat is that scanning arrays for Inf/NaN might be too slow. Anyone have a huge model to try that out against?

Could certainly move Params / Grads methods over to Zygote, once settled down.

@CarloLucibello
Copy link
Member

It could all be tied to some other function instead of show, but then it will rarely get used.

If you are suggesting to overload show for generic types, that clearly cannot be done. If we want to support printing a custom Autoencoder struct, I see no way around defining describe function.

@DhairyaLGandhi
Copy link
Member

Not sure we want to do a very complicated show method for Params, not much different from a summary from the elements of the vector

@mcabbott
Copy link
Member Author

I hope you guys don't think I'm proposing to overload show(::Any)! Instead, it is 3-arg show of Chain which is the entry point to the recursive printing, which uses Functors:

  • If all children of x are leaf types (or things like tuples of arrays) then it concludes x is a layer and prints it as usual.
  • Otherwise, it concludes it's a container (like Chain, SkipConnection -- these work without special casing) and it prints the name before proceeding to each of trainable(x).

So if you define some custom type, you will never see this unless you put it inside Chain etc. (Unless you explicitly opt in by defining show(..., ::Autoencoder) = Flux._big_show(...).) And then, only if your type has non-leaf children will it be printed some way other than what you (or Base) defines. Only if, in addition, your type constructor does not accept trainable(x)... as arguments, will the printing not be something you could copy-paste and run. (Which will not affect how it functions, of course, only the printing. And since you are targeting Flux, you will probably notice, and can easily correct this. And printing isn't guaranteed to be runnable, try x->x^2.)

My initial idea (again) was to expand only the Flux-owned containers (like Chain, SkipConnection). But the generic version seems simpler, and the worst-case scenario isn't all that scary. I've been trying to find weird models to test this against, and so far (for instance the resnet linked above) it seems to work just fine. Anyone have good links to things to try?

Am not so sure what we should do for Params. The current printing is 100 pages of stuff you can't learn much from., surely there are many directions in which you could improve it. Ditto Grads(...) which prints next to nothing.

@DhairyaLGandhi
Copy link
Member

I feel like the show function is way too large and complex to be considered "simple". I would also recommend removing changes made to other layers. Seems out of scope of this PR

@DhairyaLGandhi
Copy link
Member

I also don't think custom layers or whatever should have to add more code to opt out of these assumptions.

@mcabbott
Copy link
Member Author

changes made to other layers. Seems out of scope

BatchNorm and GroupNorm were for some reason printing as if they took a keyword argument, which they don't. I'm not sure how that got in, seems like a minor bug-fix. I can move that to a different PR if you feel strongly about it.

@darsnack
Copy link
Member

darsnack commented Feb 7, 2021

I'm curious what is remaining for this to go from "draft" to "PR"? This seems like a super convenient utility to have.

@mcabbott
Copy link
Member Author

mcabbott commented Feb 7, 2021

One concern is whether doing any(isnan, x) etc. recursively on big models is going to be too slow? If it is, it could be removed.

There are also a lot of doctests broken, can go through and fix them but would prefer to do it once (unless someone knows an automated way).

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Feb 7, 2021

This seems like it doesn't really add much over printing layers in a new line. I don't think the big_show/ compact subsystem (doesn't Julia have one of those as it is? Pretty sure it does) is the right kind of abstraction, since individual composite layers (Parallel, SkipConnection) might want to define how they are represented, and most standard layers are pretty compact as it is. Could we simplify the function substantially by removing that?

underscorise(n::Integer) =
join(reverse(join.(reverse.(Iterators.partition(digits(n), 3)))), '_')

function _nan_show(io::IO, x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed? Layers typically don't show their arrays, and custom layers should define their own, I don't want to take control of how they show their params.

@darsnack
Copy link
Member

More love for this PR becoming a reality:

julia> m = resnet50()
Chain(Conv((7, 7), 3=>64), BatchNorm(64, λ = relu), MaxPool((3, 3), pad=1, stride=2), Parallel(+, Chain(Conv((1, 1), 64=>64), BatchNorm(64, λ = relu), Conv((3, 3), 64=>64), BatchNorm(64, λ = relu), Conv((1, 1), 64=>256), BatchNorm(256, λ = relu)), Chain(Conv((1, 1), 64=>256), BatchNorm(256, λ = relu))), Parallel(+, Chain(Conv((1, 1), 256=>64), BatchNorm(64, λ = relu), Conv((3, 3), 64=>64), BatchNorm(64, λ = relu), Conv((1, 1), 64=>256), BatchNorm(256, λ = relu)), identity), Parallel(+, Chain(Conv((1, 1), 256=>64), BatchNorm(64, λ = relu), Conv((3, 3), 64=>64), BatchNorm(64, λ = relu), Conv((1, 1), 64=>256), BatchNorm(256, λ = relu)), identity), Parallel(+, Chain(Conv((1, 1), 256=>128), BatchNorm(128, λ = relu), Conv((3, 3), 128=>128), BatchNorm(128, λ = relu), Conv((1, 1), 128=>512), BatchNorm(512, λ = relu)), Chain(Conv((1, 1), 256=>512), BatchNorm(512, λ = relu))), Parallel(+, Chain(Conv((1, 1), 512=>128), BatchNorm(128, λ = relu), Conv((3, 3), 128=>128), BatchNorm(128, λ = relu), Conv((1, 1), 128=>512), BatchNorm(512, λ = relu)), identity), Parallel(+, Chain(Conv((1, 1), 512=>128), BatchNorm(128, λ = relu), Conv((3, 3), 128=>128), BatchNorm(128, λ = relu), Conv((1, 1), 128=>512), BatchNorm(512, λ = relu)), identity), Parallel(+, Chain(Conv((1, 1), 512=>128), BatchNorm(128, λ = relu), Conv((3, 3), 128=>128), BatchNorm(128, λ = relu), Conv((1, 1), 128=>512), BatchNorm(512, λ = relu)), identity), Parallel(+, Chain(Conv((1, 1), 512=>256), BatchNorm(256, λ = relu), Conv((3, 3), 256=>256), BatchNorm(256, λ = relu), Conv((1, 1), 256=>1024), BatchNorm(1024, λ = relu)), Chain(Conv((1, 1), 512=>1024), BatchNorm(1024, λ = relu))), Parallel(+, Chain(Conv((1, 1), 1024=>256), BatchNorm(256, λ = relu), Conv((3, 3), 256=>256), BatchNorm(256, λ = relu), Conv((1, 1), 256=>1024), BatchNorm(1024, λ = relu)), identity), Parallel(+, Chain(Conv((1, 1), 1024=>256), BatchNorm(256, λ = relu), Conv((3, 3), 256=>256), BatchNorm(256, λ = relu), Conv((1, 1), 256=>1024), BatchNorm(1024, λ = relu)), identity), Parallel(+, Chain(Conv((1, 1), 1024=>256), BatchNorm(256, λ = relu), Conv((3, 3), 256=>256), BatchNorm(256, λ = relu), Conv((1, 1), 256=>1024), BatchNorm(1024, λ = relu)), identity), Parallel(+, Chain(Conv((1, 1), 1024=>256), BatchNorm(256, λ = relu), Conv((3, 3), 256=>256), BatchNorm(256, λ = relu), Conv((1, 1), 256=>1024), BatchNorm(1024, λ = relu)), identity), Parallel(+, Chain(Conv((1, 1), 1024=>256), BatchNorm(256, λ = relu), Conv((3, 3), 256=>256), BatchNorm(256, λ = relu), Conv((1, 1), 256=>1024), BatchNorm(1024, λ = relu)), identity), Parallel(+, Chain(Conv((1, 1), 1024=>512), BatchNorm(512, λ = relu), Conv((3, 3), 512=>512), BatchNorm(512, λ = relu), Conv((1, 1), 512=>2048), BatchNorm(2048, λ = relu)), Chain(Conv((1, 1), 1024=>2048), BatchNorm(2048, λ = relu))), Parallel(+, Chain(Conv((1, 1), 2048=>512), BatchNorm(512, λ = relu), Conv((3, 3), 512=>512), BatchNorm(512, λ = relu), Conv((1, 1), 512=>2048), BatchNorm(2048, λ = relu)), identity), Parallel(+, Chain(Conv((1, 1), 2048=>512), BatchNorm(512, λ = relu), Conv((3, 3), 512=>512), BatchNorm(512, λ = relu), Conv((1, 1), 512=>2048), BatchNorm(2048, λ = relu)), identity), AdaptiveMeanPool((1, 1)), flatten, Dense(2048, 1000))

Screen Shot 2021-02-22 at 11 37 58 AM

@CarloLucibello
Copy link
Member

this is very useful indeed. I'd like to have a display function handling generic models as well, but this definitely a great start and I'd love to see it merged anytime soon

@ToucheSir
Copy link
Member

Would it be possible to split out a PR that just does pretty-printing of generic models? That appears to be the most welcomed and least controversial part of this.

@mcabbott mcabbott force-pushed the show2 branch 2 times, most recently from 2a1b189 to 9ad56f4 Compare May 20, 2021 02:06
@mcabbott
Copy link
Member Author

I still think we should do this, FWIW. Once you decide to make Chain et al. "unwrap" and print layers on different lines, then you may as well (1) use the space to the right to print potentially useful information, and (2) make it generic, since everything uses Functors & this is so easy. Further, I think (3) it really should be the default show; why print something awful most of the time unless you know some magic command? Also, it's much more likely to rot if it's hidden away somewhere.

@logankilpatrick's question at https://stackoverflow.com/questions/68143133/model-summary-in-flux-jl links to https://github.com/sksq96/pytorch-summary which is a more elaborate display, including sizes. We could build something like that using outputsize. However, I think it's going to require you to supply the input size, so it can't be the default printing. Ideally they would share code & aesthetics.

@mcabbott mcabbott marked this pull request as ready for review June 30, 2021 02:28
@logankilpatrick
Copy link
Member

I am not sure if one of Flux's design goals is to have parity with PyTorch, but if that is a goal, we should have a summary function.

@ToucheSir
Copy link
Member

ToucheSir commented Jul 2, 2021

why print something awful most of the time unless you know some magic command?

I agree, but as-is we don't have a good mechanism for actually getting the parameters that should be counted for a given layer. params is incorrect because trainable only considers params that can get gradients (thus undercounting for BatchNorm and co). Using functor or children might work, but it may also overcount. e.g. looking again at BatchNorm, do we want ϵ and momentum to be counted?

Hence my recommendation to get the tree printing out the door first. People can start benefiting from it immediately while we work out the edge cases of the parameter display (e.g. can we remove "parameters" from each line and put it as a top-level header comment? Lots of bikeshedding potential).

@mcabbott
Copy link
Member Author

mcabbott commented Jul 2, 2021

BatchNorm is weird. But perhaps we should just pick some number for it & be done with it? If someone has python handy we could see what Pytorch's summary prints.

I'd completely forgotten but the current state uses params; some options which would be easy to automatically make recursive are:

julia> b = BatchNorm(10);  # affine=true

julia> sum(length, params(b))
20

julia> sum(length, trainable(b))
20

julia> sum(x -> x isa Numeric ? length(x) : 0, functor(b)[1])
45

julia> sum(x -> x isa AbstractArray ? length(x) : 0, functor(b)[1])
40

If necessary we could add some overloadable function for this purpose, and special-case BatchNorm. (And friends?) But it would be nice to avoid that at least on the first go.

@ToucheSir
Copy link
Member

ToucheSir commented Jul 2, 2021

PyTorch punts on the question because it doesn't have any summary functionality built-in, but torchinfo calls Module.named_params, which functions the same as trainable for BatchNorm. I still think that behaviour is wrong (the buffers are still hanging around and using up precious device memory, after all), but I guess it's an argument for not holding up the parameter counting part of this PR.

@mcabbott
Copy link
Member Author

mcabbott commented Jul 2, 2021

Maybe that's an argument for wanting a column for size in memory, in addition to counting params? Although perhaps that can be part of the showsize(m, size) function which you get by crossing this with outputsize, which will surely need a slightly more elaborate display to squeeze the sizes in.

julia> rsizeof(x) = isempty(fieldnames(typeof(x))) ? sizeof(x) : 
         sizeof(x) + sum(rsizeof(getfield(x,n)) for n in fieldnames(typeof(x)));

julia> Base.format_bytes(rsizeof(b))
"234 bytes"

julia> Base.format_bytes(rsizeof(Tuple(params(b))))
"96 bytes"

(Is there a built-in rsizeof BTW? This PR's last-line summary should probably print this rather than what it now has, from sum(sizeof, params(m)).)

@ToucheSir
Copy link
Member

ToucheSir commented Jul 2, 2021

How (tf.)Keras and PyTorch Lightning handle this is by reporting separate totals for trainable and non-trainable parameters. Lightning reports the same as torchinfo because it uses .parameters **, but TF manages to capture both:

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(1,)),
  tf.keras.layers.BatchNormalization()
])
model.summary()
Model: "sequential_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
flatten_4 (Flatten)          (None, 1)                 0         
_________________________________________________________________
batch_normalization_3 (Batch (None, 1)                 4         
=================================================================
Total params: 4
Trainable params: 2
Non-trainable params: 2
_________________________________________________________________

I think our equivalent would be using functor vs using trainable then, perhaps with exclusions for anything that isn't an AbstractArray.

** PyTorch does have a mechanism for registering non-trainable parameters (register_buffer), but it's not picked up by parameters (which works just like Flux.params AFAICT). For a more accurate count, PyTorch Lightning should probably be using state_dict.

@mcabbott
Copy link
Member Author

mcabbott commented Jul 3, 2021

Thanks for digging!

If we copy that, then summarising at the end separately params and functor - params should be easy. Perhaps we need to bikeshed a nice way to put that on a few lines. If we stick with the "comments, in words" look, then perhaps just something like this:

Screenshot 2021-07-02 at 21 35 08

@ToucheSir
Copy link
Member

How about including non-trainable params in the per-layer counts too? The word could be replaced with a header comment.

julia> m
Chain(                                  # trainable / non-trainable parameters
  Dense(784, 64),                       # 50_240 / 0
  BatchNorm(64, relu),                  # 128 / 128
  Dense(64, 10),                        # 650 / 0
  BatchNorm(10),                        # 20 / 20
  NNlib.softmax,
)                   # Total: 8 trainable arrays, 51_038 parameters,
                    # and 4 non-trainable, 148 parameters, total 200.648 KiB

@mcabbott
Copy link
Member Author

mcabbott commented Jul 3, 2021

Could do. Or just tack them on when necessary like ", plus 128 non-trainable" for the BatchNorm layers. The headings will often be out of view in e.g. my resnet example. Are there things more exotic than BatchNorm where it might highlight surprises?

Screenshot 2021-07-03 at 00 12 41

@ToucheSir
Copy link
Member

BatchNorm, InstanceNorm, and GroupNorm are the only examples in Flux, but one example I found from PyTorch is weighted losses.

CarloLucibello
CarloLucibello previously approved these changes Jul 3, 2021
@CarloLucibello
Copy link
Member

bors r+

@bors
Copy link
Contributor

bors bot commented Jul 10, 2021

Build succeeded:

@bors bors bot merged commit 87a0065 into FluxML:master Jul 10, 2021
@mcabbott mcabbott deleted the show2 branch July 10, 2021 18:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants