-
Notifications
You must be signed in to change notification settings - Fork 20
Support for pytorch 2.0 #94
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…der does not include 11.8
Torchani cannot be installed with pytorch2, which forces to skip some tests. |
All tests pass and the ci works for an installation with pytorch2. I believe this should be merged now and work on compile() compatibility be done in another PR. |
In case you have some experience with torch2 compile: @pytest.mark.parametrize('device', ['cpu', 'cuda'])
@pytest.mark.parametrize('dtype', [pt.float32, pt.float64])
def test_torch_compile_compatible(device, dtype):
class ForceModule(pt.nn.Module):
def forward(self, positions):
neighbors, deltas, distances = getNeighborPairs(positions, cutoff=1.0)
mask = pt.isnan(distances)
distances = distances[~mask]
return pt.sum(distances**2)
original_model = ForceModule()
num_atoms=10
positions = (20 * pt.randn((num_atoms, 3), device=device, dtype=dtype)) - 10
original_model(positions)
model = pt.compile(original_model)
model(positions) It yields a really verbose error about something called FakeTensor that makes the most obscure gcc recursive template error look clear and informative: TestNeighbors.py::test_torch_compile_compatible[dtype1-cuda] FAILED [600/1860]
================================================================================= FAILURES ==================================================================================
________________________________________________________________ test_torch_compile_compatible[dtype0-cuda] _________________________________________________________________
output_graph = <torch._dynamo.output_graph.OutputGraph object at 0x7fc3e7d81fc0>, node = get_neighbor_pairs
args = (FakeTensor(FakeTensor(..., device='meta', size=(10, 3)), cuda:0), 1.0, -1, FakeTensor(FakeTensor(..., device='meta', size=(0, 0)), cuda:0)), kwargs = {}
nnmodule = None
def run_node(output_graph, node, args, kwargs, nnmodule):
"""
Runs a given node, with the given args and kwargs.
Behavior is dicatated by a node's op.
run_node is useful for extracting real values out of nodes.
See get_real_value for more info on common usage.
Note: The output_graph arg is only used for 'get_attr' ops
Note: The nnmodule arg is only used for 'call_module' ops
Nodes that are not call_function, call_method, call_module, or get_attr will
raise an AssertionError.
"""
op = node.op
try:
if op == "call_function":
> return node.target(*args, **kwargs)
../../../mambaforge/envs/nnpops-torch2-nvidia/lib/python3.10/site-packages/torch/_dynamo/utils.py:1194:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <OpOverloadPacket(op='neighbors.getNeighborPairs')>
args = (FakeTensor(FakeTensor(..., device='meta', size=(10, 3)), cuda:0), 1.0, -1, FakeTensor(FakeTensor(..., device='meta', size=(0, 0)), cuda:0)), kwargs = {}
def __call__(self, *args, **kwargs):
# overloading __call__ to ensure torch.ops.foo.bar()
# is still callable from JIT
# We save the function ptr as the `op` attribute on
# OpOverloadPacket to access it here.
> return self._op(*args, **kwargs or {})
E RuntimeError: The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, so you will need to call mutable_data()
or raw_mutable_data() to actually allocate memory.
../../../mambaforge/envs/nnpops-torch2-nvidia/lib/python3.10/site-packages/torch/_ops.py:502: RuntimeError and following for a gazillion lines. I have not been able to solve this, from what I have gathered this should not happen and it is a bug in torch (there are a lot of issues describing stuff like this: pytorch/pytorch#96742 pytorch/pytorch#95791 |
Yes, we can skip the compile feature for now. |
Ok I think this is done now. |
@RaulPPelaez can I merge? |
Yes, thanks. @raimis |
This PR is to start working on making NNPOps compatible with pytorch 2.0 and torch.compile
Hopefully this will not require a lot of work, but reducing the number of graph breaks torch.compile will introduce will probably be more challenging. I reckon avoiding graph breaks is really similar to making stuff compatible with CUDA graphs.
In principle compiled functions/models should be completely equivalent to the uncompiled versions, but I have seen this not being the case (granted, torch2 was still a beta)