-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
show(::Chain) #1467
show(::Chain) #1467
Conversation
👍 for information density. Is there perhaps a more general implementation lurking in here though? For example, I would've expected the first |
My resnet here is from FluxML/model-zoo#221, and it was missing an They are still printed on one line, because it's not smart enough to recurse into unknown structs. But I agree it ought to be possible to expand everything, by the same functor magic? |
A more recursive version: One catch is that, if this wants to be copy-paste-able, it assumes that the order of children of a layer is the order of arguments of its constructor. Maybe that's OK, maybe Functors even assumes the same thing? (If there are any anonymous functions in the chain, then it won't work when copy-pasted anyway.) |
This looks very nice. Probably the show methods for Grads and Params should go to Zygote to avoid piracy. |
My first picture above expanded only things from inside Flux, but I now think probably it can guess well enough to just work elsewhere, although I'd welcome weird examples to test it against. You could certainly break it if you wished to but then your custom container-like layer will print badly, which isn't the end of the world, and this will only happen when printed inside It could all be tied to some other function instead of Maybe the other caveat is that scanning arrays for Could certainly move |
If you are suggesting to overload |
Not sure we want to do a very complicated show method for Params, not much different from a summary from the elements of the vector |
I hope you guys don't think I'm proposing to overload
So if you define some custom type, you will never see this unless you put it inside My initial idea (again) was to expand only the Flux-owned containers (like Am not so sure what we should do for Params. The current printing is 100 pages of stuff you can't learn much from., surely there are many directions in which you could improve it. Ditto |
I feel like the show function is way too large and complex to be considered "simple". I would also recommend removing changes made to other layers. Seems out of scope of this PR |
I also don't think custom layers or whatever should have to add more code to opt out of these assumptions. |
BatchNorm and GroupNorm were for some reason printing as if they took a keyword argument, which they don't. I'm not sure how that got in, seems like a minor bug-fix. I can move that to a different PR if you feel strongly about it. |
I'm curious what is remaining for this to go from "draft" to "PR"? This seems like a super convenient utility to have. |
One concern is whether doing There are also a lot of doctests broken, can go through and fix them but would prefer to do it once (unless someone knows an automated way). |
This seems like it doesn't really add much over printing layers in a new line. I don't think the big_show/ compact subsystem (doesn't Julia have one of those as it is? Pretty sure it does) is the right kind of abstraction, since individual composite layers (Parallel, SkipConnection) might want to define how they are represented, and most standard layers are pretty compact as it is. Could we simplify the function substantially by removing that? |
underscorise(n::Integer) = | ||
join(reverse(join.(reverse.(Iterators.partition(digits(n), 3)))), '_') | ||
|
||
function _nan_show(io::IO, x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this needed? Layers typically don't show their arrays, and custom layers should define their own, I don't want to take control of how they show their params.
More love for this PR becoming a reality: julia> m = resnet50()
Chain(Conv((7, 7), 3=>64), BatchNorm(64, λ = relu), MaxPool((3, 3), pad=1, stride=2), Parallel(+, Chain(Conv((1, 1), 64=>64), BatchNorm(64, λ = relu), Conv((3, 3), 64=>64), BatchNorm(64, λ = relu), Conv((1, 1), 64=>256), BatchNorm(256, λ = relu)), Chain(Conv((1, 1), 64=>256), BatchNorm(256, λ = relu))), Parallel(+, Chain(Conv((1, 1), 256=>64), BatchNorm(64, λ = relu), Conv((3, 3), 64=>64), BatchNorm(64, λ = relu), Conv((1, 1), 64=>256), BatchNorm(256, λ = relu)), identity), Parallel(+, Chain(Conv((1, 1), 256=>64), BatchNorm(64, λ = relu), Conv((3, 3), 64=>64), BatchNorm(64, λ = relu), Conv((1, 1), 64=>256), BatchNorm(256, λ = relu)), identity), Parallel(+, Chain(Conv((1, 1), 256=>128), BatchNorm(128, λ = relu), Conv((3, 3), 128=>128), BatchNorm(128, λ = relu), Conv((1, 1), 128=>512), BatchNorm(512, λ = relu)), Chain(Conv((1, 1), 256=>512), BatchNorm(512, λ = relu))), Parallel(+, Chain(Conv((1, 1), 512=>128), BatchNorm(128, λ = relu), Conv((3, 3), 128=>128), BatchNorm(128, λ = relu), Conv((1, 1), 128=>512), BatchNorm(512, λ = relu)), identity), Parallel(+, Chain(Conv((1, 1), 512=>128), BatchNorm(128, λ = relu), Conv((3, 3), 128=>128), BatchNorm(128, λ = relu), Conv((1, 1), 128=>512), BatchNorm(512, λ = relu)), identity), Parallel(+, Chain(Conv((1, 1), 512=>128), BatchNorm(128, λ = relu), Conv((3, 3), 128=>128), BatchNorm(128, λ = relu), Conv((1, 1), 128=>512), BatchNorm(512, λ = relu)), identity), Parallel(+, Chain(Conv((1, 1), 512=>256), BatchNorm(256, λ = relu), Conv((3, 3), 256=>256), BatchNorm(256, λ = relu), Conv((1, 1), 256=>1024), BatchNorm(1024, λ = relu)), Chain(Conv((1, 1), 512=>1024), BatchNorm(1024, λ = relu))), Parallel(+, Chain(Conv((1, 1), 1024=>256), BatchNorm(256, λ = relu), Conv((3, 3), 256=>256), BatchNorm(256, λ = relu), Conv((1, 1), 256=>1024), BatchNorm(1024, λ = relu)), identity), Parallel(+, Chain(Conv((1, 1), 1024=>256), BatchNorm(256, λ = relu), Conv((3, 3), 256=>256), BatchNorm(256, λ = relu), Conv((1, 1), 256=>1024), BatchNorm(1024, λ = relu)), identity), Parallel(+, Chain(Conv((1, 1), 1024=>256), BatchNorm(256, λ = relu), Conv((3, 3), 256=>256), BatchNorm(256, λ = relu), Conv((1, 1), 256=>1024), BatchNorm(1024, λ = relu)), identity), Parallel(+, Chain(Conv((1, 1), 1024=>256), BatchNorm(256, λ = relu), Conv((3, 3), 256=>256), BatchNorm(256, λ = relu), Conv((1, 1), 256=>1024), BatchNorm(1024, λ = relu)), identity), Parallel(+, Chain(Conv((1, 1), 1024=>256), BatchNorm(256, λ = relu), Conv((3, 3), 256=>256), BatchNorm(256, λ = relu), Conv((1, 1), 256=>1024), BatchNorm(1024, λ = relu)), identity), Parallel(+, Chain(Conv((1, 1), 1024=>512), BatchNorm(512, λ = relu), Conv((3, 3), 512=>512), BatchNorm(512, λ = relu), Conv((1, 1), 512=>2048), BatchNorm(2048, λ = relu)), Chain(Conv((1, 1), 1024=>2048), BatchNorm(2048, λ = relu))), Parallel(+, Chain(Conv((1, 1), 2048=>512), BatchNorm(512, λ = relu), Conv((3, 3), 512=>512), BatchNorm(512, λ = relu), Conv((1, 1), 512=>2048), BatchNorm(2048, λ = relu)), identity), Parallel(+, Chain(Conv((1, 1), 2048=>512), BatchNorm(512, λ = relu), Conv((3, 3), 512=>512), BatchNorm(512, λ = relu), Conv((1, 1), 512=>2048), BatchNorm(2048, λ = relu)), identity), AdaptiveMeanPool((1, 1)), flatten, Dense(2048, 1000)) |
this is very useful indeed. I'd like to have a display function handling generic models as well, but this definitely a great start and I'd love to see it merged anytime soon |
Would it be possible to split out a PR that just does pretty-printing of generic models? That appears to be the most welcomed and least controversial part of this. |
2a1b189
to
9ad56f4
Compare
I still think we should do this, FWIW. Once you decide to make @logankilpatrick's question at https://stackoverflow.com/questions/68143133/model-summary-in-flux-jl links to https://github.com/sksq96/pytorch-summary which is a more elaborate display, including sizes. We could build something like that using |
I am not sure if one of Flux's design goals is to have parity with PyTorch, but if that is a goal, we should have a summary function. |
I agree, but as-is we don't have a good mechanism for actually getting the parameters that should be counted for a given layer. Hence my recommendation to get the tree printing out the door first. People can start benefiting from it immediately while we work out the edge cases of the parameter display (e.g. can we remove "parameters" from each line and put it as a top-level header comment? Lots of bikeshedding potential). |
BatchNorm is weird. But perhaps we should just pick some number for it & be done with it? If someone has python handy we could see what Pytorch's summary prints. I'd completely forgotten but the current state uses
If necessary we could add some overloadable function for this purpose, and special-case BatchNorm. (And friends?) But it would be nice to avoid that at least on the first go. |
PyTorch punts on the question because it doesn't have any summary functionality built-in, but torchinfo calls |
Maybe that's an argument for wanting a column for size in memory, in addition to counting
|
How (tf.)Keras and PyTorch Lightning handle this is by reporting separate totals for trainable and non-trainable parameters. Lightning reports the same as torchinfo because it uses model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(1,)),
tf.keras.layers.BatchNormalization()
])
model.summary()
I think our equivalent would be using ** PyTorch does have a mechanism for registering non-trainable parameters ( |
How about including non-trainable params in the per-layer counts too? The word could be replaced with a header comment.
|
BatchNorm, InstanceNorm, and GroupNorm are the only examples in Flux, but one example I found from PyTorch is weighted losses. |
bors r+ |
Build succeeded: |
Not quite done, RFC?