You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I would greatly appreciate if you could please let me know how I can fall back to the native torch behavior after having imported torchdrug somewhere above in my code.
The text was updated successfully, but these errors were encountered:
Hi! The patch is intended to make nn.Module and nn.DistributedDataParallel to accept data.Graph (and its derived classes) in the same way as torch.Tensor. For example, we may want to register a graph as a buffer of nn.Module, or pass graphs as arguments to some nn.Module.forward in distributed data parallel.
The code was originally developed based on PyTorch 1.4.0 -- and that's why persistent is missing there. I am aware that such a monkey patch is really ugly and not good for maintainence. Any idea we can improve that?
I noticed that persistent is always an argument in register_buffer for PyTorch>=1.6.0. So I updated the corresponding patch function in b411837. Feel free to reopen this issue if you have any better solution for such patches.
I'm interested to understand why it is necessary to overwrite the default
nn.Module
of torch inpatch.py
:torchdrug/torchdrug/patch.py
Line 125 in eeee191
This seems to be a quite invasive thing since it alters the behavior of any
torch.nn
module aftertorchdrug
has been imported.For example, your implementation of
register_buffer
inpatch.py
lacks the keyword argumentpersistent
which is present in native torch: https://github.com/pytorch/pytorch/blob/989b24855efe0a8287954040c89d679625dcabe1/torch/nn/modules/module.py#L277I would greatly appreciate if you could please let me know how I can fall back to the native torch behavior after having imported torchdrug somewhere above in my code.
The text was updated successfully, but these errors were encountered: