-
-
Notifications
You must be signed in to change notification settings - Fork 66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix UNet implementation with arbitrary channel sizes (#243) #276
base: master
Are you sure you want to change the base?
Conversation
Fix UNet implementation to support input with channel sizes other than 3
Hi Vinayakjeet, thanks for the PR! Unfortunately, I don't think this does what we want yet. The problem is that julia> using Metalhead
julia> model = UNet((128,128),1,3,Metalhead.backbone(DenseNet(121)))
ERROR: DimensionMismatch: layer Conv((7, 7), 3 => 64, pad=3, stride=2, bias=false) expects size(input, 3) == 3, but got 128×128×1×1 Array{Flux.NilNumber.Nil, 4}
Stacktrace:
[1] _size_check(layer::Flux.Conv{2, 2, typeof(identity), Array{…}, Bool}, x::Array{Flux.NilNumber.Nil, 4}, ::Pair{Int64, Int64})
@ Flux ~/.julia/packages/Flux/jgpVj/src/layers/basic.jl:195
[2] (::Flux.Conv{2, 2, typeof(identity), Array{Float32, 4}, Bool})(x::Array{Flux.NilNumber.Nil, 4})
@ Flux ~/.julia/packages/Flux/jgpVj/src/layers/conv.jl:198
[3] #outputsize#340
@ ~/.julia/packages/Flux/jgpVj/src/outputsize.jl:93 [inlined]
[4] outputsize(m::Flux.Conv{2, 2, typeof(identity), Array{Float32, 4}, Bool}, inputsizes::NTuple{4, Int64})
@ Flux ~/.julia/packages/Flux/jgpVj/src/outputsize.jl:91
[5] unetlayers(layers::Vector{…}, sz::NTuple{…}; outplanes::Nothing, skip_upscale::Int64, m_middle::typeof(Metalhead.unet_middle_block))
@ Metalhead ~/Code/Metalhead.jl/src/convnets/unet.jl:34
[6] unet(encoder_backbone::Flux.Chain{…}, imgdims::Tuple{…}, inchannels::Int64, outplanes::Int64, final::typeof(Metalhead.unet_final_block), fdownscale::Int64)
@ Metalhead ~/Code/Metalhead.jl/src/convnets/unet.jl:81
[7] unet
@ ~/Code/Metalhead.jl/src/convnets/unet.jl:76 [inlined]
[8] #UNet#175
@ ~/Code/Metalhead.jl/src/convnets/unet.jl:120 [inlined]
[9] UNet(imsize::Tuple{Int64, Int64}, inchannels::Int64, outplanes::Int64, encoder_backbone::Flux.Chain{Tuple{…}})
@ Metalhead ~/Code/Metalhead.jl/src/convnets/unet.jl:118
[10] top-level scope
@ REPL[3]:1
Some type information was truncated. Use `show(err)` to see complete types. I would suggest that you try and rewrite the function in such a way that |
src/convnets/unet.jl
Outdated
encoder_backbone = Metalhead.backbone(DenseNet(121)); pretrain::Bool = false) | ||
layers = unet(encoder_backbone, (imsize..., inchannels, 1), outplanes) | ||
layers = unet(encoder_backbone, imsize, inchannels, outplanes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inchannels
should somehow be passed in to the encoder backbone here. Of course, we will have to decide how to deal with this in case the user passes in a model with this initialised and also separately inchannels
Modified the first convolutional layer of the encoder backbone to ensure compatibility with the input's channel size and dimension mismatch error is thus prevented #1
skip_upscale = fdownscale) | ||
function unet(encoder_backbone, imgdims, inchannels::Integer, outplanes::Integer, | ||
final::Any = unet_final_block, fdownscale::Integer = 0) | ||
backbonelayers = collect(flatten_chains(encoder_backbone)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please pay attention to the formatting, you lost the indentation here
Indentation issue resolved
2nd try
3rd try
4th try
5th try
6th try
A beginner contributor to the codebase, can you review the logic I have implemented, additionally I have encountered an error MethodError indicating a mismatch in method signatures for the unet function. It appears that there might be an issue with how the encoder_backbone is instantiated or utilized within the unet function. Could you please review the instantiation and usage of the encoder_backbone |
#243
Bug Description:
The current UNet implementation in the Metalhead package has a limitation where it only works with input tensors of channel size 3. This restriction causes compatibility issues when users try to use UNet with input tensors of different channel sizes.
Patch Description:
To address this limitation, I've modified the UNet implementation to support input tensors with arbitrary channel sizes. The UNet model can now handle input with varying dimensions
Test Case:
using Metalhead
UNet((128,128),1,3,Metalhead.backbone(DenseNet(121)))
This UNet model can process without any errors