Description
🚀 Feature
An efficient implementation for counting nonzero elements
Pitch
However in some situations (MaskRCNN) you don't need the exact positions of the nonzero elements, but the sum of them and the method is called quite frequently. So far any workaround is faster than retrieving the indices for the elements and taking it's length.
Some may want the differentiable count of these values, which effectively requires to not use the current nonzero method.
Related links:
It was previously mentioned on Discuss [1 2] and on #14848 #15190
Alternatives
import torch
x = torch.randint(2, [20, 1080, 1920]) # e.g. 20 binary mask maps in int64
[len(torch.nonzero(x[i])) for i in range(len(x))] # 311 ms ± 11.5 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)
torch.sum(x != 0, dim=[1,2]) # 136 ms ± 13.6 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)
torch.sum(x.clamp(0, 1), dim=[1,2]) # 49.9 ms ± 5.34 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)
On a 1080Ti, these times are respectively 10.9 ms, 3.54 ms, 2.68 ms
(used torch.cuda.synchronize
before and after operation)
Additional context
A few other non-trivial things that popped up when I dived in finding out what is the fastest way:
torch.clamp_max()
andtorch.clamp_min()
is 5x times slower thantorch.clamp()
. Time on x 261 ms ± 78 ms, 202 ms ± 17.8 ms, 47.1 ms ± 437 µs)- There's no significant difference if I use
uint8
orint64
dtype
Thanks @gchanan for asking to report this issue, hope this will help others.