Closed
Description
The following code incorrectly set an unused output element to 0 instead of letting it to its original value (here -10):
import torch
from torch_scatter import scatter_logsumexp
src = torch.tensor([-1., -50])
index = torch.tensor([0, 0])
out = torch.full((2,), -10.)
scatter_logsumexp(src=src, index=index, out=out)
# tensor([-0.9999, 0.0000]) instead of tensor([-0.9999, -10])
which mean scatter_logsumexp
only works in the corner case where all outputs are affected by the scatter operation.
Metadata
Assignees
Labels
No labels