Skip to content

Support for pytorch 2.0 #94

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 18, 2023
Merged

Support for pytorch 2.0 #94

merged 14 commits into from
Apr 18, 2023

Conversation

RaulPPelaez
Copy link
Contributor

@RaulPPelaez RaulPPelaez commented Mar 17, 2023

This PR is to start working on making NNPOps compatible with pytorch 2.0 and torch.compile

  • Ensure the workflows and tests work for pytorch 2.0 (wich requires CUDA 11.8)
  • Make operations compatible with torch.compile()
    Hopefully this will not require a lot of work, but reducing the number of graph breaks torch.compile will introduce will probably be more challenging. I reckon avoiding graph breaks is really similar to making stuff compatible with CUDA graphs.
  • Write tests for compiled versions
    In principle compiled functions/models should be completely equivalent to the uncompiled versions, but I have seen this not being the case (granted, torch2 was still a beta)

@RaulPPelaez
Copy link
Contributor Author

RaulPPelaez commented Apr 13, 2023

Torchani cannot be installed with pytorch2, which forces to skip some tests.
EDIT: Torchani made a new torch2 compatible release

@RaulPPelaez
Copy link
Contributor Author

All tests pass and the ci works for an installation with pytorch2. I believe this should be merged now and work on compile() compatibility be done in another PR.
A new release could be done now so that users can install NNPOps along pytorch2.

@RaulPPelaez
Copy link
Contributor Author

In case you have some experience with torch2 compile:
This test miserably fails in CUDA mode:

@pytest.mark.parametrize('device', ['cpu', 'cuda'])
@pytest.mark.parametrize('dtype', [pt.float32, pt.float64])
def test_torch_compile_compatible(device, dtype):

    class ForceModule(pt.nn.Module):

        def forward(self, positions):

            neighbors, deltas, distances = getNeighborPairs(positions, cutoff=1.0)
            mask = pt.isnan(distances)
            distances = distances[~mask]
            return pt.sum(distances**2)

    original_model = ForceModule()
    num_atoms=10
    positions = (20 * pt.randn((num_atoms, 3), device=device, dtype=dtype)) - 10
    original_model(positions)
    model = pt.compile(original_model)
    model(positions)

It yields a really verbose error about something called FakeTensor that makes the most obscure gcc recursive template error look clear and informative:

TestNeighbors.py::test_torch_compile_compatible[dtype1-cuda] FAILED                                                                                                [600/1860]
                                                                                                                                                                             
================================================================================= FAILURES ==================================================================================
________________________________________________________________ test_torch_compile_compatible[dtype0-cuda] _________________________________________________________________
                                                                                                                                                                             
output_graph = <torch._dynamo.output_graph.OutputGraph object at 0x7fc3e7d81fc0>, node = get_neighbor_pairs                                                                  
args = (FakeTensor(FakeTensor(..., device='meta', size=(10, 3)), cuda:0), 1.0, -1, FakeTensor(FakeTensor(..., device='meta', size=(0, 0)), cuda:0)), kwargs = {}             
nnmodule = None                                                                                                                                                              
                                                                                                                                                                             
    def run_node(output_graph, node, args, kwargs, nnmodule):                                                                                                                
        """                                                                                                                                                                  
        Runs a given node, with the given args and kwargs.
     
        Behavior is dicatated by a node's op.
     
        run_node is useful for extracting real values out of nodes.
        See get_real_value for more info on common usage.
     
        Note: The output_graph arg is only used for 'get_attr' ops
        Note: The nnmodule arg is only used for 'call_module' ops
     
        Nodes that are not call_function, call_method, call_module, or get_attr will
        raise an AssertionError.
        """
        op = node.op
        try:
            if op == "call_function":
>               return node.target(*args, **kwargs)

../../../mambaforge/envs/nnpops-torch2-nvidia/lib/python3.10/site-packages/torch/_dynamo/utils.py:1194: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <OpOverloadPacket(op='neighbors.getNeighborPairs')>
args = (FakeTensor(FakeTensor(..., device='meta', size=(10, 3)), cuda:0), 1.0, -1, FakeTensor(FakeTensor(..., device='meta', size=(0, 0)), cuda:0)), kwargs = {}

    def __call__(self, *args, **kwargs):
        # overloading __call__ to ensure torch.ops.foo.bar()
        # is still callable from JIT
        # We save the function ptr as the `op` attribute on
        # OpOverloadPacket to access it here.
>       return self._op(*args, **kwargs or {})
E       RuntimeError: The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, so you will need to call mutable_data()
 or raw_mutable_data() to actually allocate memory.

../../../mambaforge/envs/nnpops-torch2-nvidia/lib/python3.10/site-packages/torch/_ops.py:502: RuntimeError

and following for a gazillion lines.

I have not been able to solve this, from what I have gathered this should not happen and it is a bug in torch (there are a lot of issues describing stuff like this: pytorch/pytorch#96742 pytorch/pytorch#95791

@raimis
Copy link
Contributor

raimis commented Apr 14, 2023

Yes, we can skip the compile feature for now.

@raimis raimis requested review from raimis and sef43 and removed request for sef43 April 14, 2023 12:24
@RaulPPelaez
Copy link
Contributor Author

Ok I think this is done now.

@raimis
Copy link
Contributor

raimis commented Apr 17, 2023

@RaulPPelaez can I merge?

@RaulPPelaez
Copy link
Contributor Author

Yes, thanks. @raimis

@raimis raimis merged commit b63fc70 into openmm:master Apr 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants