Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conflict with torch due to overwritten modules #77

Closed
jannisborn opened this issue Mar 11, 2022 · 4 comments
Closed

Conflict with torch due to overwritten modules #77

jannisborn opened this issue Mar 11, 2022 · 4 comments
Labels
help wanted Extra attention is needed

Comments

@jannisborn
Copy link
Contributor

I'm interested to understand why it is necessary to overwrite the default nn.Module of torch in patch.py:

nn.Module = PatchedModule

This seems to be a quite invasive thing since it alters the behavior of any torch.nn module after torchdrug has been imported.

For example, your implementation of register_buffer in patch.py lacks the keyword argument persistent which is present in native torch: https://github.com/pytorch/pytorch/blob/989b24855efe0a8287954040c89d679625dcabe1/torch/nn/modules/module.py#L277

I would greatly appreciate if you could please let me know how I can fall back to the native torch behavior after having imported torchdrug somewhere above in my code.

@KiddoZhu
Copy link
Member

KiddoZhu commented Mar 11, 2022

Hi! The patch is intended to make nn.Module and nn.DistributedDataParallel to accept data.Graph (and its derived classes) in the same way as torch.Tensor. For example, we may want to register a graph as a buffer of nn.Module, or pass graphs as arguments to some nn.Module.forward in distributed data parallel.

The code was originally developed based on PyTorch 1.4.0 -- and that's why persistent is missing there. I am aware that such a monkey patch is really ugly and not good for maintainence. Any idea we can improve that?

@KiddoZhu
Copy link
Member

The original nn.Module is reassigned to nn._Module in the patch. To fall back to native PyTorch, you may use

torch.nn.Module = torch.nn._Module

But I guess this will make many TorchDrug classes complain.

@KiddoZhu KiddoZhu added the help wanted Extra attention is needed label Mar 11, 2022
@KiddoZhu
Copy link
Member

KiddoZhu commented Apr 3, 2022

I noticed that persistent is always an argument in register_buffer for PyTorch>=1.6.0. So I updated the corresponding patch function in b411837. Feel free to reopen this issue if you have any better solution for such patches.

@KiddoZhu KiddoZhu closed this as completed Apr 3, 2022
@jannisborn
Copy link
Contributor Author

Cool, thanks a lot for looking into it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

2 participants