Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT][DONT MERGE] U-net proposal #6611

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

TeodorPoncu
Copy link
Contributor

This is a draft PR supporting the RFC for issue #6610. This is a rough sketch of how an architecture frame-work might look if added into torchvision. Suggestion are most welcome as I believe specifying the individual configurations of certain layers in such frameworks requires some trade-offs in terms of where the code-complexity is offloaded.

@facebook-github-bot
Copy link

Hi @TeodorPoncu!

Thank you for your pull request.

We require contributors to sign our Contributor License Agreement, and yours needs attention.

You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@fb.com. Thanks!

@datumbox
Copy link
Contributor

@TeodorPoncu Thanks for your work in TorchVision the last few months. It's unclear to me what you would like to continue slowly the work on this PR now. It depends on your bandwidth. We can support on our side with training etc. Let me know your thoughts.

@TeodorPoncu
Copy link
Contributor Author

@datumbox I'm actively working on Diffusion so this PR is of high interest to me. I would very much like a ton of feedback on it from the community in order for us to reach a version which can be as re-usable and flexible as nn.Transformer.

@datumbox
Copy link
Contributor

@TeodorPoncu if you are in, we are in. Let's do it. I'll let you progress a bit on it and then we can even schedule a call to discuss the details. Feel free to ping us on Slack if you need anything.

@TeodorPoncu
Copy link
Contributor Author

@datumbox Well, I guess the first thing I would need insight with, from TorchVisions's side at least is how would we handle a "framework" in which some certain module might or might not take in an additional input during forward.

For instance one might use some sort of conditional diffusion where they would need to add the conditional information into either the encoder or the bottleneck, or both.

The only thing that comes to mind right now, without adding custom input-types like Hugging Face does, would be for the network modules to expect dicts as inputs. But that feels very not pytorch-y / torchvision-y at the moment.

@datumbox
Copy link
Contributor

It's a good question, this has come up before in the context of SSL. The issue is that much of these models are very new and the techniques have not converged yet on the input that the specific task takes. This makes it harder to create stable APIs. Note that the detection models do receive a dictionary as input (perhaps not the best example in terms of API design but still a valid existing case in TorchVision), so that's not out of question. What makes it harder for the diffusion case is the fact that we are not yet certain what this input would look like 6 months down the line when 10 more papers are out. So maybe the question here is, can we come up with a U-net implementation that implements faithfully the original paper but has code structure that is reusable or a good starting point for more exotic tasks?

The issue as you are very well aware of is that we don't have the headcount at the moment to open such massive discussions. Perhaps that's something we could take for H1? But I'll be happy to work with you on that and see how we can progress.

@TeodorPoncu
Copy link
Contributor Author

TeodorPoncu commented Sep 24, 2022

In previous cases where dealing with multi-modal / multi-branch data I defaulted to dicts. I feel that's a very intuitive and fast way to set-up things / experiment with things. The only downside I have encountered so far in this paradigm is when you start needing to keep track of a lot of keys (for instance, 6-7 inputs).

As per faithfully implementing the original, there's one thing I have not followed through here - namely that the original does not produce the same spatial size in the output as in the input. That is a decision that is most-likely 100% domain driven since the original paper was destined for medical imaging.

Having worked in that field myself - the largest degree of network uncertainty is located on the input edges, which makes learning in those regions tricky. If you have a direct loss evaluation in those regions (i.e. a detection box, or a segmentation mask) you might sending bad signal throughout the network (without getting into medical world details, things can have very different meanings based on what's located in their vicinities - thus crops / random-crops can produce inconclusive training samples). Therefore, the easiest thing to do, is to take a much larger area around the section you want to run the loss / network predictions on and treat it as context. I can definitely reproduce this with the building blocks I have submitted in this PR.

But do we want to add this behaviour to torchvision's implementation? Throughout my experience so far, I feel that people would expect out of a U-Net to generate the same-sized output as the input.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[RFC] U-Net framework
3 participants