Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: DenseNet rewrite for correctness #241

Closed
wants to merge 2 commits into from
Closed

Conversation

theabhirath
Copy link
Member

@theabhirath theabhirath commented May 25, 2023

Our previous version of DenseNet does not combine features at the block level, which is one of the defining innovations of the paper. This one manages to do that, but at the cost of looking a little clunky (it has an explicitly written out DenseBlock and uses vcat for vector concatenation). It also does not work currently on Zygote 0.6.45+ due to FluxML/Zygote.jl#1417. Any comments on how to make the design work with less "manual" and more built-in layers are appreciated! This version may also not be very AD-friendly, so any help there is also very much appreciated 😅.

cc @ToucheSir @darsnack

Comment on lines +23 to 30
function (m::DenseBlock)(x)
input = [x]
for layer in m.layers
x = layer(input)
input = vcat(input, [x])
end
return cat_channels(input...)
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is where I think the biggest change is, and where I think the code could use the most input. Is this the best way of doing this sequence of operations?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe

function (m::DenseBlock)(x)
    input = (x,)
    for layer in m.layers
        x = layer(input)
        input = (input..., x)
    end
    return cat_channels(input)
end

to bypass the Zygote vcat problem?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even better, we can remove cat_channels from the layers and do

function (m::DenseBlock)(x)
    input = x
    for layer in m.layers
        x = layer(input)
        input = cat_channels(input, x)
    end
    return input
end

Comment on lines +13 to +14
layers = [dense_bottleneck(inplanes + (i - 1) * growth_rate, growth_rate, bn_size,
dropout_prob) for i in 1:num_layers]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't you take this vector and build the nested Parallel of a DenseBlock recursively? That's how I would try doing it to avoid defining an extra type.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm afraid I can't quite immediately see that implementation 😅 Could you give me an example?

@@ -28,126 +40,54 @@ Create a DenseNet transition sequence
- `inplanes`: number of input feature maps
- `outplanes`: number of output feature maps
"""
function transition(inplanes::Int, outplanes::Int)
function transition(inplanes::Integer, outplanes::Integer)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a general note, since in practice allowing an abstract integer doesn't add any value compared to just restricting toInt, I prefer the latter. In this case Integer is better though for consistency with the other methods/

dense_bottleneck(inplanes, outplanes; expansion=4)
function dense_bottleneck(inplanes::Integer, growth_rate::Integer, bn_size::Integer,
dropout_prob)
return Chain(cat_channels,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the input to this chain is a vector of arrays, according tohttps://github.com/FluxML/Metalhead.jl/blob/010d4bc72989e4392d2ce0b89e72b8d640927dd8/src/utilities.jl#L37
shouldn't we have

Suggested change
return Chain(cat_channels,
return Chain(x -> cat_channels(x...),

?

Comment on lines +23 to 30
function (m::DenseBlock)(x)
input = [x]
for layer in m.layers
x = layer(input)
input = vcat(input, [x])
end
return cat_channels(input...)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe

function (m::DenseBlock)(x)
    input = (x,)
    for layer in m.layers
        x = layer(input)
        input = (input..., x)
    end
    return cat_channels(input)
end

to bypass the Zygote vcat problem?

Comment on lines +23 to 30
function (m::DenseBlock)(x)
input = [x]
for layer in m.layers
x = layer(input)
input = vcat(input, [x])
end
return cat_channels(input...)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even better, we can remove cat_channels from the layers and do

function (m::DenseBlock)(x)
    input = x
    for layer in m.layers
        x = layer(input)
        input = cat_channels(input, x)
    end
    return input
end

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless I am missing something, I don't see how the old implementation is functionally different (modulo dropout). I drew a picture. Let's not merge this until we make sure.

image

@darsnack
Copy link
Member

darsnack commented Jun 2, 2023

The actual dense net paper has this weird way of drawing the operation that suggest that we want some recursive structure. But concatenation along an axis is inherently cumulative, so that accumulation is implicit?

@CarloLucibello
Copy link
Member

I don't see how the old implementation is functionally different

right, the old implementation with skip connection seems the same

@darsnack
Copy link
Member

darsnack commented Jun 2, 2023

And this also seems the same as torchvision. So I guess @theabhirath can you roll back the changes introducing DenseBlock and just keep the smaller pieces like adding dropout?

@CarloLucibello
Copy link
Member

but still there should be something different from torchvision since the two forwards don't match

@theabhirath
Copy link
Member Author

The actual dense net paper has this weird way of drawing the operation that suggest that we want some recursive structure. But concatenation along an axis is inherently cumulative, so that accumulation is implicit?

There is a difference, and it's very subtle. It's definitely there though.

CleanShot 2023-06-02 at 21 22 40@2x

Notice that in this image, the way the outputs of the layers (and the inputs) are treated is that they are produced, remembered and concatenated at the end. Contrast that to our approach, where we were passing the output of the previous SkipConnection concatenated with the output of the current one into the next one. This means that we are sending some outputs through layers that they were not intended to be sent into, which is why we need the explicit DenseBlock structure.

@CarloLucibello
Copy link
Member

The paper says

x1 = H1([x0])
x2 = H1([x0, x1])
x3 = H1([x0, x1, x2])
...

which is equivalent to

input = [x0]
x = H1(input)
input = [input..., x]
x = H2(input)
input = [input..., x]

Torchvision does the same. So it seems to me that 1) the implementation in this PR is equivalent to the old one with skip connection; 2) it is also equivalent to the 2 variants I proposed above; 3) they are all equivalent to the one in torchvision, and the difference in the forwards is not caused concat mechanism but it's somewhere else.

@darsnack
Copy link
Member

darsnack commented Jun 2, 2023

Yeah agree with Carlo here.

Notice that in this image, the way the outputs of the layers (and the inputs) are treated is that they are produced, remembered and concatenated at the end.

That's because this is a poorly drawn figure. They are re-showing the concatenation of every group of channels into each bottleneck. This is because the output of the concatenation (where arrows merge) goes into the i-th bottleneck block but nowhere else. So they have to redraw that concatenation + the output of the i-block going into the i + 1-th block. They could just as easily have drawn an arrow from where the arrows merge into the i-th block to the i + 1-th block (which would have been the same diagram with fewer lines, and a closer match to how you would programmatically implement this sequence of blocks).

We have some other issue with the DenseNets that's causing the accuracy failures.

@theabhirath
Copy link
Member Author

Okay, I spent a lot of time poring over the paper and figured out that it is indeed what we were doing 😅Apologies for the noise. This is what I get for trying to map Pythonic programming styles onto Julia 😅. However, maybe I can keep this open until I figure out where it is that we differ from the torchvision implementation. Given that the we are doing the correct thing, though, this should not block v0.8 (we can release without DenseNet weights, I think). I will land a final docs PR over this weekend and then we can release that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants