Skip to content

Commit

Permalink
Use arange and repeat for deterministic bincount (#2184)
Browse files Browse the repository at this point in the history
* Use meshgrid for deterministic bincount

* Update src/torchmetrics/utilities/data.py

Use arange instead of meashgrid

* Update data.py

Update _bincount doc

* chlog

* size

* Update CHANGELOG.md

update changelog

* Update data.py

update comment

* improve text

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
6 people authored Nov 25, 2023
1 parent 1b16341 commit 8ef2a89
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Changed minimum supported Pytorch version from 1.8 to 1.10 ([#2145](https://github.com/Lightning-AI/torchmetrics/pull/2145))


- Use arange and repeat for deterministic bincount ([#2184](https://github.com/Lightning-AI/torchmetrics/pull/2184))


### Deprecated

- Deprecated `metric._update_called` ([#2141](https://github.com/Lightning-AI/torchmetrics/pull/2141))
Expand Down
18 changes: 8 additions & 10 deletions src/torchmetrics/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,10 @@ def _squeeze_if_scalar(data: Any) -> Any:
def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor:
"""Implement custom bincount.
PyTorch currently does not support ``torch.bincount`` for:
- deterministic mode on GPU.
- MPS devices
This implementation fallback to a for-loop counting occurrences in that case.
PyTorch currently does not support ``torch.bincount`` when running in deterministic mode on GPU or when running
MPS devices or when running on XLA device. This implementation therefore falls back to using a combination of
`torch.arange` and `torch.eq` in these scenarios. A small performance hit can expected and higher memory consumption
as `[batch_size, mincount]` tensor needs to be initialized compared to native ``torch.bincount``.
Args:
x: tensor to count
Expand All @@ -191,11 +189,11 @@ def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor:
"""
if minlength is None:
minlength = len(torch.unique(x))

if torch.are_deterministic_algorithms_enabled() or _XLA_AVAILABLE or _TORCH_GREATER_EQUAL_1_12 and x.is_mps:
output = torch.zeros(minlength, device=x.device, dtype=torch.long)
for i in range(minlength):
output[i] = (x == i).sum()
return output
mesh = torch.arange(minlength, device=x.device).repeat(len(x), 1)
return torch.eq(x.reshape(-1, 1), mesh).sum(dim=0)

return torch.bincount(x, minlength=minlength)


Expand Down

0 comments on commit 8ef2a89

Please sign in to comment.