-
-
Notifications
You must be signed in to change notification settings - Fork 66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC: DenseNet rewrite for correctness #241
Conversation
function (m::DenseBlock)(x) | ||
input = [x] | ||
for layer in m.layers | ||
x = layer(input) | ||
input = vcat(input, [x]) | ||
end | ||
return cat_channels(input...) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is where I think the biggest change is, and where I think the code could use the most input. Is this the best way of doing this sequence of operations?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe
function (m::DenseBlock)(x)
input = (x,)
for layer in m.layers
x = layer(input)
input = (input..., x)
end
return cat_channels(input)
end
to bypass the Zygote vcat
problem?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even better, we can remove cat_channels
from the layers and do
function (m::DenseBlock)(x)
input = x
for layer in m.layers
x = layer(input)
input = cat_channels(input, x)
end
return input
end
layers = [dense_bottleneck(inplanes + (i - 1) * growth_rate, growth_rate, bn_size, | ||
dropout_prob) for i in 1:num_layers] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't you take this vector and build the nested Parallel
of a DenseBlock
recursively? That's how I would try doing it to avoid defining an extra type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm afraid I can't quite immediately see that implementation 😅 Could you give me an example?
@@ -28,126 +40,54 @@ Create a DenseNet transition sequence | |||
- `inplanes`: number of input feature maps | |||
- `outplanes`: number of output feature maps | |||
""" | |||
function transition(inplanes::Int, outplanes::Int) | |||
function transition(inplanes::Integer, outplanes::Integer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just a general note, since in practice allowing an abstract integer doesn't add any value compared to just restricting toInt
, I prefer the latter. In this case Integer
is better though for consistency with the other methods/
dense_bottleneck(inplanes, outplanes; expansion=4) | ||
function dense_bottleneck(inplanes::Integer, growth_rate::Integer, bn_size::Integer, | ||
dropout_prob) | ||
return Chain(cat_channels, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the input to this chain is a vector of arrays, according tohttps://github.com/FluxML/Metalhead.jl/blob/010d4bc72989e4392d2ce0b89e72b8d640927dd8/src/utilities.jl#L37
shouldn't we have
return Chain(cat_channels, | |
return Chain(x -> cat_channels(x...), |
?
function (m::DenseBlock)(x) | ||
input = [x] | ||
for layer in m.layers | ||
x = layer(input) | ||
input = vcat(input, [x]) | ||
end | ||
return cat_channels(input...) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe
function (m::DenseBlock)(x)
input = (x,)
for layer in m.layers
x = layer(input)
input = (input..., x)
end
return cat_channels(input)
end
to bypass the Zygote vcat
problem?
function (m::DenseBlock)(x) | ||
input = [x] | ||
for layer in m.layers | ||
x = layer(input) | ||
input = vcat(input, [x]) | ||
end | ||
return cat_channels(input...) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even better, we can remove cat_channels
from the layers and do
function (m::DenseBlock)(x)
input = x
for layer in m.layers
x = layer(input)
input = cat_channels(input, x)
end
return input
end
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The actual dense net paper has this weird way of drawing the operation that suggest that we want some recursive structure. But concatenation along an axis is inherently cumulative, so that accumulation is implicit? |
right, the old implementation with skip connection seems the same |
And this also seems the same as torchvision. So I guess @theabhirath can you roll back the changes introducing |
but still there should be something different from torchvision since the two forwards don't match |
There is a difference, and it's very subtle. It's definitely there though. Notice that in this image, the way the outputs of the layers (and the inputs) are treated is that they are produced, remembered and concatenated at the end. Contrast that to our approach, where we were passing the output of the previous |
The paper says x1 = H1([x0])
x2 = H1([x0, x1])
x3 = H1([x0, x1, x2])
... which is equivalent to input = [x0]
x = H1(input)
input = [input..., x]
x = H2(input)
input = [input..., x] Torchvision does the same. So it seems to me that 1) the implementation in this PR is equivalent to the old one with skip connection; 2) it is also equivalent to the 2 variants I proposed above; 3) they are all equivalent to the one in torchvision, and the difference in the forwards is not caused concat mechanism but it's somewhere else. |
Yeah agree with Carlo here.
That's because this is a poorly drawn figure. They are re-showing the concatenation of every group of channels into each bottleneck. This is because the output of the concatenation (where arrows merge) goes into the We have some other issue with the DenseNets that's causing the accuracy failures. |
Okay, I spent a lot of time poring over the paper and figured out that it is indeed what we were doing 😅Apologies for the noise. This is what I get for trying to map Pythonic programming styles onto Julia 😅. However, maybe I can keep this open until I figure out where it is that we differ from the torchvision implementation. Given that the we are doing the correct thing, though, this should not block v0.8 (we can release without DenseNet weights, I think). I will land a final docs PR over this weekend and then we can release that. |
Our previous version of
DenseNet
does not combine features at the block level, which is one of the defining innovations of the paper. This one manages to do that, but at the cost of looking a little clunky (it has an explicitly written outDenseBlock
and usesvcat
for vector concatenation). It also does not work currently on Zygote 0.6.45+ due to FluxML/Zygote.jl#1417. Any comments on how to make the design work with less "manual" and more built-in layers are appreciated! This version may also not be very AD-friendly, so any help there is also very much appreciated 😅.cc @ToucheSir @darsnack