Closed
Description
Hi,
I am trying to perform scatter_logsumexp on a strict subset of indices of the out
tensor. I am getting NaNs at the indices where out
is supposed to be untouched. Example:
import torch
from torch_scatter import scatter_logsumexp
src = torch.Tensor([0.0, 1.0, 4.0])
index = torch.tensor([1, 1, 4])
out = torch.zeros((6, ), dtype = torch.float32)
scatter_logsumexp(src, index, out = out)
print(out)
tensor([ nan, 1.5514, nan, nan, 4.0181, nan]) # Only indices 1, 4 should be changed
print(torch_scatter.__version__)
'2.1.1+pt20cu118'
Another issue even if the NaN issue is resolved is about efficiency. We would ideally like to only operate those locations of out
which are referred to in index
. Otherwise for a very large sized out
we are doing redundant calculations.
Thanks,
Ahmed
Metadata
Metadata
Assignees
Labels
No labels