Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scatter leads to wrong results on gpu with cast to short()-tensor #9886

Open
maichmueller opened this issue Dec 23, 2024 · 1 comment
Open
Assignees
Labels

Comments

@maichmueller
Copy link

🐛 Describe the bug

I just encountered that the scatter operation with mul reduction leads to wrong values when I cast my input-tensor to dtype short. I am unsure if this is a bug in torch, torch-geometric, torch-scatter, or my brain, so posting it here, since I am importing pyg's scatter method. (Apologies for the oddly-large demo data, I pulled it straight out of my workflow)

import torch
from torch_geometric.utils import scatter

done = torch.tensor([False, False, False, False, False, False, False, False, False, False,
    False, False, False, False, False, False, False, False, False, False,
    False, False, False, False, False,  True, False, False, False, False,
    False, False, False, False, False, False, False, False, False, False,
    False, False, False, False, False, False, False, False, False, False,
    False, False, False, False, False, False, False, False, False, False,
    False, False, False, False, False, False, False, False, False,  True,
    False, False, False, False, False, False, False, False, False, False,
    False, False, False, False, False, False, False, False, False, False,
    False, False, False, False, False, False, False, False, False, False,
    False, False, False, False, False, False, False, False, False, False,
    False, False, False, False,  True, False, False, False, False, False,
    False, False, False,  True, False, False, False, False, False, False,
    False, False, False, False, False, False, False,  True, False, False,
    False, False, False, False, False, False, False, False, False, False,
    False, False, False, False, False, False, False, False, False, False,
    False, False, False, False, False, False, False, False, False, False,
    False, False, False, False, False, False, False, False, False, False,
    False, False, False, False, False, False,  True, False, False, False,
    False, False, False, False, False, False, False, False, False, False,
    False, False, False, False, False, False, False, False, False, False,
    False, False, False, False, False, False, False, False, False, False,
    False, False, False, False, False, False, False, False, False, False,
    False, False, False, False, False, False, False, False, False, False,
    False, False, False, False, False, False, False, False, False, False,
    False, False, False, False, False,  True, False, False, False, False,
    False, False, False, False, False,  True, False, False, False, False,
    False, False, False, False, False, False, False, False, False, False,
    False, False, False, False, False, False, False, False, False, False,
    False, False, False, False, False, False, False, False, False, False,
    False, False, False, False, False, False, False,  True, False, False,
    False, False, False, False, False, False, False, False, False, False,
    False, False, False, False, False, False, False, False], dtype=torch.bool)

index = torch.tensor(
    [  0,   0,   0,   1,   1,   1,   1,   2,   2,   2,   3,   3,   3,   3,
      4,   4,   4,   5,   5,   5,   6,   6,   7,   7,   7,   8,   9,   9,
      9,   9,  10,  10,  11,  11,  12,  12,  13,  13,  13,  14,  14,  15,
     15,  15,  16,  16,  16,  16,  17,  17,  18,  18,  18,  19,  19,  19,
     19,  20,  20,  20,  21,  21,  21,  22,  22,  23,  23,  23,  23,  24,
     25,  25,  25,  25,  26,  27,  27,  28,  28,  28,  29,  30,  30,  30,
     30,  31,  31,  31,  32,  32,  32,  33,  33,  34,  34,  34,  34,  35,
     35,  35,  35,  36,  36,  37,  37,  38,  38,  38,  39,  39,  39,  40,
     40,  40,  41,  42,  42,  43,  43,  43,  44,  44,  44,  45,  46,  46,
     46,  47,  47,  47,  48,  48,  48,  48,  49,  49,  49,  50,  51,  51,
     51,  52,  52,  52,  52,  53,  53,  53,  54,  54,  55,  55,  56,  56,
     56,  57,  57,  57,  58,  58,  59,  59,  60,  60,  61,  61,  61,  62,
     63,  63,  63,  64,  64,  64,  65,  65,  66,  66,  67,  67,  67,  68,
     68,  68,  69,  69,  70,  71,  71,  72,  72,  73,  73,  73,  74,  74,
     75,  75,  75,  76,  76,  76,  77,  78,  79,  80,  80,  81,  81,  81,
     82,  82,  82,  82,  83,  83,  84,  85,  85,  85,  86,  86,  87,  87,
     87,  88,  88,  88,  88,  89,  89,  90,  90,  90,  91,  92,  92,  92,
     92,  92,  93,  94,  94,  94,  95,  95,  95,  95,  96,  96,  96,  97,
     97,  98,  98,  99, 100, 100, 101, 102, 102, 102, 103, 103, 103, 104,
    105, 105, 105, 106, 106, 106, 106, 107, 107, 107, 107, 108, 108, 108,
    109, 109, 109, 110, 110, 110, 110, 111, 111, 111, 112, 112, 113, 114,
    114, 115, 115, 115, 115, 116, 116, 116, 116, 117, 118, 118, 118, 119,
    120, 121, 121, 121, 122, 122, 123, 123, 123, 124, 124, 124, 125, 125,
    125, 126, 126, 127, 127, 127], dtype=torch.long)

assert len(done) == len(index)

def doit():
    print("Short - Long")
    print(torch.where(
        scatter(done.short(), index=index.long(), dim=0, reduce="mul") == 1))

    print("Long - Long")
    print(torch.where(
        scatter(done.long(), index=index.long(), dim=0, reduce="mul") == 1))

print("-----------------")
print("On CPU")
print("-----------------")
done = done.cpu()
index = index.cpu()
doit()

print("-----------------")
print("On GPU")
print("-----------------")
done = done.cuda()
index = index.cuda()
doit()

provides this output:

-----------------
On CPU
-----------------
Short - Long
(tensor([  8,  24,  41,  45,  50,  70,  99, 104, 119]),)
Long - Long
(tensor([  8,  24,  41,  45,  50,  70,  99, 104, 119]),)
-----------------
On GPU
-----------------
Short - Long
(tensor([  1,   3,   5,   7,   8,   9,  11,  13,  15,  17,  19,  21,  23,  24,
         25,  27,  29,  31,  33,  35,  37,  39,  41,  43,  45,  47,  49,  50,
         51,  53,  55,  57,  59,  61,  63,  65,  67,  69,  70,  71,  73,  75,
         77,  79,  81,  83,  85,  87,  89,  91,  93,  95,  97,  99, 101, 103,
        104, 105, 107, 109, 111, 113, 115, 117, 119, 121, 123, 125, 127],
       device='cuda:0'),)
Long - Long
(tensor([  8,  24,  41,  45,  50,  70,  99, 104, 119], device='cuda:0'),)

The CPU version values are the correct values as I can verify via other facts for this data (and it is the majority solution).

Versions

Python: 3.11.9
OS: Ubuntu 22.04

pip list

Package Version


...
torch 2.4.0
torch-geometric 2.6.1
torch_scatter 2.1.2
torchaudio 2.4.0
torchmetrics 1.4.0.post0
torchrl 0.5.0
torchvision 0.19.0
...

@akihironitta akihironitta self-assigned this Dec 23, 2024
@akihironitta
Copy link
Member

Thanks for reporting this issue! A quick workaround would be to uninstall torch_scatter from your env or to disable the flag in your script by writing:

import torch_geometric.typing

torch_geometric.typing.WITH_TORCH_SCATTER = False

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants