Skip to content

scatter_logsumexp: NaNs on untouched indices #368

Closed
@aabbas90

Description

@aabbas90

Hi,

I am trying to perform scatter_logsumexp on a strict subset of indices of the out tensor. I am getting NaNs at the indices where out is supposed to be untouched. Example:

import torch
from torch_scatter import scatter_logsumexp

src = torch.Tensor([0.0, 1.0, 4.0])
index = torch.tensor([1, 1, 4])
out = torch.zeros((6, ), dtype = torch.float32)

scatter_logsumexp(src, index, out = out)
print(out)
tensor([   nan, 1.5514,    nan,    nan, 4.0181,    nan]) # Only indices 1, 4 should be changed
print(torch_scatter.__version__)
'2.1.1+pt20cu118'

Another issue even if the NaN issue is resolved is about efficiency. We would ideally like to only operate those locations of out which are referred to in index. Otherwise for a very large sized out we are doing redundant calculations.

Thanks,
Ahmed

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions