Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for half #137

Merged
merged 11 commits into from
Jul 19, 2021
Merged

Add support for half #137

merged 11 commits into from
Jul 19, 2021

Conversation

cmpute
Copy link
Contributor

@cmpute cmpute commented Apr 24, 2020

Should fix #126.
I didn't test it yet since my model has other bugs using half precision...

@rusty1s
Copy link
Owner

rusty1s commented Jun 9, 2020

Hi, are there any updates on this PR? The last time I was working with fp16, standard operators like + and * needed adoption, too.

@cmpute
Copy link
Contributor Author

cmpute commented Jun 10, 2020

In this PR all functions that need to be adapted in pytorch-scatter has been fixed. Remaining functions are all pytorch built-in functions. I'll try to make pull request to pytorch and as long as those functions are adapted for half precision, this PR should work.

@rusty1s
Copy link
Owner

rusty1s commented Jun 10, 2020

Cool! Please keep me updated:)

@cmpute
Copy link
Contributor Author

cmpute commented Jun 10, 2020

I created a PR in pytorch. Actually I found that the tests only fail using CPU. When using CUDA, those functions missing in CPU are already implemented.

@murnanedaniel
Copy link

murnanedaniel commented Jul 9, 2020

@cmpute, is this PR basically ready to merge? I'm very excited for this functionality to be incorporated! Let me know if there are any remaining jobs I can help with.

@cmpute
Copy link
Contributor Author

cmpute commented Jul 9, 2020

@cmpute, is this PR basically ready to merge? I'm very excited for this functionality to be incorporated! Let me know if there are any remaining jobs I can help with.

I created a PR in pytorch, but it still need some time to be merged. I just need to write some test for that PR. After it's merged I think we're done~

@rusty1s
Copy link
Owner

rusty1s commented Jul 9, 2020

Can you link to the PR?

@cmpute
Copy link
Contributor Author

cmpute commented Jul 9, 2020

link

I guess Github already links it here under your last comment? pytorch/pytorch#39788

@speedcell4
Copy link

speedcell4 commented Apr 30, 2021

I tried this with Pytorch 1.8.1 and it works fine. Does this mean fp16 is fully supported by torch-scatter now?

from torch_scatter import scatter_add

import torch

a = torch.randn((5, 7), dtype=torch.half).cuda()
index = torch.randint(0, 10, (5,)).cuda()
print(scatter_add(a, index=index, dim=0))

@rusty1s
Copy link
Owner

rusty1s commented Apr 30, 2021

Should be only supported for scatter_add and scatter_mean, not for scatter_max and scatter_mul yet.

@rusty1s rusty1s mentioned this pull request May 5, 2021
@rusty1s rusty1s merged commit 1f49a3a into rusty1s:master Jul 19, 2021
@rusty1s
Copy link
Owner

rusty1s commented Jul 19, 2021

Thank you very much. I fixed the remaining PyTorch compatibility issues and everything should work now.

@HelloWorldLTY
Copy link

Hi, I wonder if there are any example codes we can refer for amp training. Thanks a lot.

@rusty1s
Copy link
Owner

rusty1s commented Jan 25, 2023

@jykr
Copy link

jykr commented Sep 21, 2024

Hi, thanks for this update- I'm having trouble using half precision with HeteroLinear with following error (in using HGTConv). Am I doing something wrong, or does torch.ops.pyg not support half precision? Thank you.

  File "/data/pinello/PROJECTS/2022_12_GCPA/dynot/dynot/encoders.py", line 205, in forward
    _x_dict = conv(_x_dict, edge_dict)
              ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/pinello/SHARED_SOFTWARE/anaconda_latest/envs/jy_gcpa_torch4/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/pinello/SHARED_SOFTWARE/anaconda_latest/envs/jy_gcpa_torch4/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/pinello/PROJECTS/2022_12_GCPA/software/pytorch_geometric/torch_geometric/nn/conv/hgt_conv.py", line 194, in forward
    k, v, src_offset = self._construct_src_node_feat(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/pinello/PROJECTS/2022_12_GCPA/software/pytorch_geometric/torch_geometric/nn/conv/hgt_conv.py", line 153, in _construct_src_node_feat
    k = self.k_rel(ks, type_vec).view(H, -1, D).transpose(0, 1)
        ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/pinello/SHARED_SOFTWARE/anaconda_latest/envs/jy_gcpa_torch4/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/pinello/SHARED_SOFTWARE/anaconda_latest/envs/jy_gcpa_torch4/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/pinello/PROJECTS/2022_12_GCPA/software/pytorch_geometric/torch_geometric/nn/dense/linear.py", line 342, in forward
    self._update_timing_cache(x, type_ptr, key)
  File "/data/pinello/SHARED_SOFTWARE/anaconda_latest/envs/jy_gcpa_torch4/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data/pinello/PROJECTS/2022_12_GCPA/software/pytorch_geometric/torch_geometric/nn/dense/linear.py", line 301, in _update_timing_cache
    _ = self.forward_segmm(x, type_ptr)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/pinello/PROJECTS/2022_12_GCPA/software/pytorch_geometric/torch_geometric/nn/dense/linear.py", line 285, in forward_segmm
    return pyg_lib.ops.segment_matmul(x, type_ptr, self.weight)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/pinello/SHARED_SOFTWARE/anaconda_latest/envs/jy_gcpa_torch4/lib/python3.12/site-packages/pyg_lib/ops/__init__.py", line 174, in segment_matmul
    out = torch.ops.pyg.segment_matmul(inputs, ptr, other)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/pinello/SHARED_SOFTWARE/anaconda_latest/envs/jy_gcpa_torch4/lib/python3.12/site-packages/torch/_ops.py", line 1061, in __call__
    return self_._op(*args, **(kwargs or {}))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: expected scalar type Float but found Half

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add support for half
6 participants